# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""
from __future__ import annotations
import types
import inspect
import logging
from typing import TYPE_CHECKING
from efro.message._message import (
Message,
Response,
EmptySysResponse,
UnregisteredMessageIDError,
)
if TYPE_CHECKING:
from typing import Any, Callable, Awaitable
from efro.message._protocol import MessageProtocol
from efro.message._message import SysResponse
class MessageReceiver:
"""Facilitates receiving & responding to messages from a remote source.
This is instantiated at the class level with unbound methods registered
as handlers for different message types in the protocol.
Example:
class MyClass:
receiver = MyMessageReceiver()
# MyMessageReceiver fills out handler() overloads to ensure all
# registered handlers have valid types/return-types.
@receiver.handler
def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
# Deal with this message type here.
# This will trigger the registered handler being called.
obj = MyClass()
obj.receiver.handle_raw_message(some_raw_data)
Any unhandled Exception occurring during message handling will result in
an efro.error.RemoteError being raised on the sending end.
"""
is_async = False
def __init__(self, protocol: MessageProtocol) -> None:
self.protocol = protocol
self._handlers: dict[type[Message], Callable] = {}
self._decode_filter_call: (
Callable[[Any, dict, Message], None] | None
) = None
self._encode_filter_call: (
Callable[[Any, Message | None, Response | SysResponse, dict], None]
| None
) = None
# noinspection PyProtectedMember
[docs]
def register_handler(
self, call: Callable[[Any, Message], Response | None]
) -> None:
"""Register a handler call.
The message type handled by the call is determined by its
type annotation.
"""
# TODO: can use types.GenericAlias in 3.9.
# (hmm though now that we're there, it seems a drop-in
# replace gives us errors. Should re-test in 3.11 as it seems
# that typing_extensions handles it differently in that case)
from typing import _GenericAlias # type: ignore
from typing import get_type_hints, get_args
sig = inspect.getfullargspec(call)
# The provided callable should be a method taking one 'msg' arg.
expectedsig = ['self', 'msg']
if sig.args != expectedsig:
raise ValueError(
f'Expected callable signature of {expectedsig};'
f' got {sig.args}'
)
# Check annotation types to determine what message types we handle.
# Return-type annotation can be a Union, but we probably don't
# have it available at runtime. Explicitly pull it in.
# UPDATE: we've updated our pylint filter to where we should
# have all annotations available.
# anns = get_type_hints(call, localns={'Union': Union})
anns = get_type_hints(call)
msgtype = anns.get('msg')
if not isinstance(msgtype, type):
raise TypeError(
f'expected a type for "msg" annotation; got {type(msgtype)}.'
)
assert issubclass(msgtype, Message)
ret = anns.get('return')
responsetypes: tuple[type[Any] | None, ...]
# Return types can be a single type or a union of types.
if isinstance(ret, (_GenericAlias, types.UnionType)):
targs = get_args(ret)
if not all(isinstance(a, (type, type(None))) for a in targs):
raise TypeError(
f'expected only types for "return" annotation;'
f' got {targs}.'
)
responsetypes = targs
else:
if not isinstance(ret, (type, type(None))):
raise TypeError(
f'expected one or more types for'
f' "return" annotation; got a {type(ret)}.'
)
# This seems like maybe a mypy bug. Appeared after adding
# types.UnionType above.
responsetypes = (ret,)
# This will contain NoneType for empty return cases, but
# we expect it to be None.
# noinspection PyPep8
responsetypes = tuple(
None if r is type(None) else r for r in responsetypes
)
# Make sure our protocol has this message type registered and our
# return types exactly match. (Technically we could return a subset
# of the supported types; can allow this in the future if it makes
# sense).
registered_types = self.protocol.message_ids_by_type.keys()
if msgtype not in registered_types:
raise TypeError(
f'Message type {msgtype} is not registered'
f' in this Protocol.'
)
if msgtype in self._handlers:
raise TypeError(
f'Message type {msgtype} already has a registered handler.'
)
# Make sure the responses exactly matches what the message expects.
if set(responsetypes) != set(msgtype.get_response_types()):
raise TypeError(
f'Provided response types {responsetypes} do not'
f' match the set expected by message type {msgtype}: '
f'({msgtype.get_response_types()})'
)
# Ok; we're good!
self._handlers[msgtype] = call
[docs]
def decode_filter_method(
self, call: Callable[[Any, dict, Message], None]
) -> Callable[[Any, dict, Message], None]:
"""Function decorator for defining a decode filter.
Decode filters can be used to extract extra data from incoming
message dicts. This version will work for both handle_raw_message()
and handle_raw_message_async()
"""
assert self._decode_filter_call is None
self._decode_filter_call = call
return call
[docs]
def encode_filter_method(
self,
call: Callable[
[Any, Message | None, Response | SysResponse, dict], None
],
) -> Callable[[Any, Message | None, Response, dict], None]:
"""Function decorator for defining an encode filter.
Encode filters can be used to add extra data to the message
dict before is is encoded to a string and sent out.
"""
assert self._encode_filter_call is None
self._encode_filter_call = call
return call
[docs]
def validate(self, log_only: bool = False) -> None:
"""Check for handler completeness, valid types, etc."""
for msgtype in self.protocol.message_ids_by_type.keys():
if issubclass(msgtype, Response):
continue
if msgtype not in self._handlers:
msg = (
f'Protocol message type {msgtype} is not handled'
f' by receiver type {type(self)}.'
)
if log_only:
logging.error(msg)
else:
raise TypeError(msg)
def _decode_incoming_message_base(
self, bound_obj: Any, msg: str
) -> tuple[Any, dict, Message]:
# Decode the incoming message.
msg_dict = self.protocol.decode_dict(msg)
msg_decoded = self.protocol.message_from_dict(msg_dict)
assert isinstance(msg_decoded, Message)
if self._decode_filter_call is not None:
self._decode_filter_call(bound_obj, msg_dict, msg_decoded)
return bound_obj, msg_dict, msg_decoded
def _decode_incoming_message(self, bound_obj: Any, msg: str) -> Message:
bound_obj, _msg_dict, msg_decoded = self._decode_incoming_message_base(
bound_obj=bound_obj, msg=msg
)
return msg_decoded
[docs]
def encode_user_response(
self, bound_obj: Any, message: Message, response: Response | None
) -> str:
"""Encode a response provided by the user for sending."""
assert isinstance(response, Response | None)
# (user should never explicitly return error-responses)
assert (
response is None or type(response) in message.get_response_types()
)
# A return value of None equals EmptySysResponse.
out_response: Response | SysResponse
if response is None:
out_response = EmptySysResponse()
else:
out_response = response
response_dict = self.protocol.response_to_dict(out_response)
if self._encode_filter_call is not None:
self._encode_filter_call(
bound_obj, message, out_response, response_dict
)
return self.protocol.encode_dict(response_dict)
[docs]
def encode_error_response(
self, bound_obj: Any, message: Message | None, exc: Exception
) -> tuple[str, bool]:
"""Given an error, return sysresponse str and whether to log."""
response, dolog = self.protocol.error_to_response(exc)
response_dict = self.protocol.response_to_dict(response)
if self._encode_filter_call is not None:
self._encode_filter_call(
bound_obj, message, response, response_dict
)
return self.protocol.encode_dict(response_dict), dolog
[docs]
def handle_raw_message(
self, bound_obj: Any, msg: str, raise_unregistered: bool = False
) -> str:
"""Decode, handle, and return an response for a message.
if 'raise_unregistered' is True, will raise an
efro.message.UnregisteredMessageIDError for messages not handled by
the protocol. In all other cases local errors will translate to
error responses returned to the sender.
"""
assert not self.is_async, "can't call sync handler on async receiver"
msg_decoded: Message | None = None
try:
msg_decoded = self._decode_incoming_message(bound_obj, msg)
msgtype = type(msg_decoded)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
response = handler(bound_obj, msg_decoded)
assert isinstance(response, Response | None)
return self.encode_user_response(bound_obj, msg_decoded, response)
except Exception as exc:
if raise_unregistered and isinstance(
exc, UnregisteredMessageIDError
):
raise
rstr, dolog = self.encode_error_response(
bound_obj, msg_decoded, exc
)
if dolog:
if msg_decoded is not None:
msgtype = type(msg_decoded)
logging.exception(
'Error handling %s.%s message.',
msgtype.__module__,
msgtype.__qualname__,
)
else:
logging.exception(
'Error handling raw efro.message'
' (likely a message format incompatibility): %s.',
msg,
)
return rstr
[docs]
def handle_raw_message_async(
self, bound_obj: Any, msg: str, raise_unregistered: bool = False
) -> Awaitable[str]:
"""Should be called when the receiver gets a message.
The return value is the raw response to the message.
"""
# Note: This call is synchronous so that the first part of it can
# happen synchronously. If the whole call were async we wouldn't be
# able to guarantee that messages handlers would be called in the
# order the messages were received.
assert self.is_async, "Can't call async handler on sync receiver."
msg_decoded: Message | None = None
try:
msg_decoded = self._decode_incoming_message(bound_obj, msg)
msgtype = type(msg_decoded)
handler = self._handlers.get(msgtype)
if handler is None:
raise RuntimeError(f'Got unhandled message type: {msgtype}.')
handler_awaitable = handler(bound_obj, msg_decoded)
except Exception as exc:
if raise_unregistered and isinstance(
exc, UnregisteredMessageIDError
):
raise
return self._handle_raw_message_async_error(
bound_obj, msg, msg_decoded, exc
)
# Return an awaitable to handle the rest asynchronously.
return self._handle_raw_message_async(
bound_obj, msg, msg_decoded, handler_awaitable
)
async def _handle_raw_message_async_error(
self,
bound_obj: Any,
msg_raw: str,
msg_decoded: Message | None,
exc: Exception,
) -> str:
rstr, dolog = self.encode_error_response(bound_obj, msg_decoded, exc)
if dolog:
if msg_decoded is not None:
msgtype = type(msg_decoded)
logging.exception(
'Error handling %s.%s message.',
msgtype.__module__,
msgtype.__qualname__,
# We need to explicitly provide the exception here,
# otherwise it shows up at None. I assume related to
# the fact that we're an async function.
exc_info=exc,
)
else:
logging.exception(
'Error handling raw async efro.message'
' (likely a message format incompatibility): %s.',
msg_raw,
# We need to explicitly provide the exception here,
# otherwise it shows up at None. I assume related to
# the fact that we're an async function.
exc_info=exc,
)
return rstr
async def _handle_raw_message_async(
self,
bound_obj: Any,
msg_raw: str,
msg_decoded: Message,
handler_awaitable: Awaitable[Response | None],
) -> str:
"""Should be called when the receiver gets a message.
The return value is the raw response to the message.
"""
try:
response = await handler_awaitable
assert isinstance(response, Response | None)
return self.encode_user_response(bound_obj, msg_decoded, response)
except Exception as exc:
return await self._handle_raw_message_async_error(
bound_obj, msg_raw, msg_decoded, exc
)
class BoundMessageReceiver:
"""Base bound receiver class."""
def __init__(
self,
obj: Any,
receiver: MessageReceiver,
) -> None:
assert obj is not None
self._obj = obj
self._receiver = receiver
@property
def protocol(self) -> MessageProtocol:
"""Protocol associated with this receiver."""
return self._receiver.protocol
[docs]
def encode_error_response(self, exc: Exception) -> str:
"""Given an error, return a response ready to send.
This should be used for any errors that happen outside of
standard handle_raw_message calls. Any errors within those
calls will be automatically returned as encoded strings.
"""
# Passing None for Message here; we would only have that available
# for things going wrong in the handler (which this is not for).
return self._receiver.encode_error_response(self._obj, None, exc)[0]