Source code for efro.message._receiver

# Released under the MIT License. See LICENSE for details.
#
"""Functionality for sending and responding to messages.
Supports static typing for message types and possible return types.
"""

from __future__ import annotations

import types
import inspect
import logging
from typing import TYPE_CHECKING

from efro.message._message import (
    Message,
    Response,
    EmptySysResponse,
    UnregisteredMessageIDError,
)

if TYPE_CHECKING:
    from typing import Any, Callable, Awaitable

    from efro.message._protocol import MessageProtocol
    from efro.message._message import SysResponse


class MessageReceiver:
    """Facilitates receiving & responding to messages from a remote source.

    This is instantiated at the class level with unbound methods registered
    as handlers for different message types in the protocol.

    Example:

    class MyClass:
        receiver = MyMessageReceiver()

        # MyMessageReceiver fills out handler() overloads to ensure all
        # registered handlers have valid types/return-types.

        @receiver.handler
        def handle_some_message_type(self, message: SomeMsg) -> SomeResponse:
            # Deal with this message type here.

    # This will trigger the registered handler being called.
    obj = MyClass()
    obj.receiver.handle_raw_message(some_raw_data)

    Any unhandled Exception occurring during message handling will result in
    an efro.error.RemoteError being raised on the sending end.
    """

    is_async = False

    def __init__(self, protocol: MessageProtocol) -> None:
        self.protocol = protocol
        self._handlers: dict[type[Message], Callable] = {}
        self._decode_filter_call: (
            Callable[[Any, dict, Message], None] | None
        ) = None
        self._encode_filter_call: (
            Callable[[Any, Message | None, Response | SysResponse, dict], None]
            | None
        ) = None

    # noinspection PyProtectedMember
[docs] def register_handler( self, call: Callable[[Any, Message], Response | None] ) -> None: """Register a handler call. The message type handled by the call is determined by its type annotation. """ # TODO: can use types.GenericAlias in 3.9. # (hmm though now that we're there, it seems a drop-in # replace gives us errors. Should re-test in 3.11 as it seems # that typing_extensions handles it differently in that case) from typing import _GenericAlias # type: ignore from typing import get_type_hints, get_args sig = inspect.getfullargspec(call) # The provided callable should be a method taking one 'msg' arg. expectedsig = ['self', 'msg'] if sig.args != expectedsig: raise ValueError( f'Expected callable signature of {expectedsig};' f' got {sig.args}' ) # Check annotation types to determine what message types we handle. # Return-type annotation can be a Union, but we probably don't # have it available at runtime. Explicitly pull it in. # UPDATE: we've updated our pylint filter to where we should # have all annotations available. # anns = get_type_hints(call, localns={'Union': Union}) anns = get_type_hints(call) msgtype = anns.get('msg') if not isinstance(msgtype, type): raise TypeError( f'expected a type for "msg" annotation; got {type(msgtype)}.' ) assert issubclass(msgtype, Message) ret = anns.get('return') responsetypes: tuple[type[Any] | None, ...] # Return types can be a single type or a union of types. if isinstance(ret, (_GenericAlias, types.UnionType)): targs = get_args(ret) if not all(isinstance(a, (type, type(None))) for a in targs): raise TypeError( f'expected only types for "return" annotation;' f' got {targs}.' ) responsetypes = targs else: if not isinstance(ret, (type, type(None))): raise TypeError( f'expected one or more types for' f' "return" annotation; got a {type(ret)}.' ) # This seems like maybe a mypy bug. Appeared after adding # types.UnionType above. responsetypes = (ret,) # This will contain NoneType for empty return cases, but # we expect it to be None. # noinspection PyPep8 responsetypes = tuple( None if r is type(None) else r for r in responsetypes ) # Make sure our protocol has this message type registered and our # return types exactly match. (Technically we could return a subset # of the supported types; can allow this in the future if it makes # sense). registered_types = self.protocol.message_ids_by_type.keys() if msgtype not in registered_types: raise TypeError( f'Message type {msgtype} is not registered' f' in this Protocol.' ) if msgtype in self._handlers: raise TypeError( f'Message type {msgtype} already has a registered handler.' ) # Make sure the responses exactly matches what the message expects. if set(responsetypes) != set(msgtype.get_response_types()): raise TypeError( f'Provided response types {responsetypes} do not' f' match the set expected by message type {msgtype}: ' f'({msgtype.get_response_types()})' ) # Ok; we're good! self._handlers[msgtype] = call
[docs] def decode_filter_method( self, call: Callable[[Any, dict, Message], None] ) -> Callable[[Any, dict, Message], None]: """Function decorator for defining a decode filter. Decode filters can be used to extract extra data from incoming message dicts. This version will work for both handle_raw_message() and handle_raw_message_async() """ assert self._decode_filter_call is None self._decode_filter_call = call return call
[docs] def encode_filter_method( self, call: Callable[ [Any, Message | None, Response | SysResponse, dict], None ], ) -> Callable[[Any, Message | None, Response, dict], None]: """Function decorator for defining an encode filter. Encode filters can be used to add extra data to the message dict before is is encoded to a string and sent out. """ assert self._encode_filter_call is None self._encode_filter_call = call return call
[docs] def validate(self, log_only: bool = False) -> None: """Check for handler completeness, valid types, etc.""" for msgtype in self.protocol.message_ids_by_type.keys(): if issubclass(msgtype, Response): continue if msgtype not in self._handlers: msg = ( f'Protocol message type {msgtype} is not handled' f' by receiver type {type(self)}.' ) if log_only: logging.error(msg) else: raise TypeError(msg)
def _decode_incoming_message_base( self, bound_obj: Any, msg: str ) -> tuple[Any, dict, Message]: # Decode the incoming message. msg_dict = self.protocol.decode_dict(msg) msg_decoded = self.protocol.message_from_dict(msg_dict) assert isinstance(msg_decoded, Message) if self._decode_filter_call is not None: self._decode_filter_call(bound_obj, msg_dict, msg_decoded) return bound_obj, msg_dict, msg_decoded def _decode_incoming_message(self, bound_obj: Any, msg: str) -> Message: bound_obj, _msg_dict, msg_decoded = self._decode_incoming_message_base( bound_obj=bound_obj, msg=msg ) return msg_decoded
[docs] def encode_user_response( self, bound_obj: Any, message: Message, response: Response | None ) -> str: """Encode a response provided by the user for sending.""" assert isinstance(response, Response | None) # (user should never explicitly return error-responses) assert ( response is None or type(response) in message.get_response_types() ) # A return value of None equals EmptySysResponse. out_response: Response | SysResponse if response is None: out_response = EmptySysResponse() else: out_response = response response_dict = self.protocol.response_to_dict(out_response) if self._encode_filter_call is not None: self._encode_filter_call( bound_obj, message, out_response, response_dict ) return self.protocol.encode_dict(response_dict)
[docs] def encode_error_response( self, bound_obj: Any, message: Message | None, exc: Exception ) -> tuple[str, bool]: """Given an error, return sysresponse str and whether to log.""" response, dolog = self.protocol.error_to_response(exc) response_dict = self.protocol.response_to_dict(response) if self._encode_filter_call is not None: self._encode_filter_call( bound_obj, message, response, response_dict ) return self.protocol.encode_dict(response_dict), dolog
[docs] def handle_raw_message( self, bound_obj: Any, msg: str, raise_unregistered: bool = False ) -> str: """Decode, handle, and return an response for a message. if 'raise_unregistered' is True, will raise an efro.message.UnregisteredMessageIDError for messages not handled by the protocol. In all other cases local errors will translate to error responses returned to the sender. """ assert not self.is_async, "can't call sync handler on async receiver" msg_decoded: Message | None = None try: msg_decoded = self._decode_incoming_message(bound_obj, msg) msgtype = type(msg_decoded) handler = self._handlers.get(msgtype) if handler is None: raise RuntimeError(f'Got unhandled message type: {msgtype}.') response = handler(bound_obj, msg_decoded) assert isinstance(response, Response | None) return self.encode_user_response(bound_obj, msg_decoded, response) except Exception as exc: if raise_unregistered and isinstance( exc, UnregisteredMessageIDError ): raise rstr, dolog = self.encode_error_response( bound_obj, msg_decoded, exc ) if dolog: if msg_decoded is not None: msgtype = type(msg_decoded) logging.exception( 'Error handling %s.%s message.', msgtype.__module__, msgtype.__qualname__, ) else: logging.exception( 'Error handling raw efro.message' ' (likely a message format incompatibility): %s.', msg, ) return rstr
[docs] def handle_raw_message_async( self, bound_obj: Any, msg: str, raise_unregistered: bool = False ) -> Awaitable[str]: """Should be called when the receiver gets a message. The return value is the raw response to the message. """ # Note: This call is synchronous so that the first part of it can # happen synchronously. If the whole call were async we wouldn't be # able to guarantee that messages handlers would be called in the # order the messages were received. assert self.is_async, "Can't call async handler on sync receiver." msg_decoded: Message | None = None try: msg_decoded = self._decode_incoming_message(bound_obj, msg) msgtype = type(msg_decoded) handler = self._handlers.get(msgtype) if handler is None: raise RuntimeError(f'Got unhandled message type: {msgtype}.') handler_awaitable = handler(bound_obj, msg_decoded) except Exception as exc: if raise_unregistered and isinstance( exc, UnregisteredMessageIDError ): raise return self._handle_raw_message_async_error( bound_obj, msg, msg_decoded, exc ) # Return an awaitable to handle the rest asynchronously. return self._handle_raw_message_async( bound_obj, msg, msg_decoded, handler_awaitable )
async def _handle_raw_message_async_error( self, bound_obj: Any, msg_raw: str, msg_decoded: Message | None, exc: Exception, ) -> str: rstr, dolog = self.encode_error_response(bound_obj, msg_decoded, exc) if dolog: if msg_decoded is not None: msgtype = type(msg_decoded) logging.exception( 'Error handling %s.%s message.', msgtype.__module__, msgtype.__qualname__, # We need to explicitly provide the exception here, # otherwise it shows up at None. I assume related to # the fact that we're an async function. exc_info=exc, ) else: logging.exception( 'Error handling raw async efro.message' ' (likely a message format incompatibility): %s.', msg_raw, # We need to explicitly provide the exception here, # otherwise it shows up at None. I assume related to # the fact that we're an async function. exc_info=exc, ) return rstr async def _handle_raw_message_async( self, bound_obj: Any, msg_raw: str, msg_decoded: Message, handler_awaitable: Awaitable[Response | None], ) -> str: """Should be called when the receiver gets a message. The return value is the raw response to the message. """ try: response = await handler_awaitable assert isinstance(response, Response | None) return self.encode_user_response(bound_obj, msg_decoded, response) except Exception as exc: return await self._handle_raw_message_async_error( bound_obj, msg_raw, msg_decoded, exc ) class BoundMessageReceiver: """Base bound receiver class.""" def __init__( self, obj: Any, receiver: MessageReceiver, ) -> None: assert obj is not None self._obj = obj self._receiver = receiver @property def protocol(self) -> MessageProtocol: """Protocol associated with this receiver.""" return self._receiver.protocol
[docs] def encode_error_response(self, exc: Exception) -> str: """Given an error, return a response ready to send. This should be used for any errors that happen outside of standard handle_raw_message calls. Any errors within those calls will be automatically returned as encoded strings. """ # Passing None for Message here; we would only have that available # for things going wrong in the handler (which this is not for). return self._receiver.encode_error_response(self._obj, None, exc)[0]