# Released under the MIT License. See LICENSE for details.
#
"""Remote procedure call functionality over WebSockets."""
from __future__ import annotations
import time
import asyncio
import logging
from typing import TYPE_CHECKING, Protocol
from efro.error import CommunicationError
if TYPE_CHECKING:
from typing import Awaitable, Callable, Literal
logger = logging.getLogger(__name__)
[docs]
class WebSocketTransport(Protocol):
"""Minimal interface for a WebSocket connection.
This allows the RPC layer to work with any WebSocket library
(websockets, aiohttp, etc.) via a thin adapter.
"""
[docs]
async def send(self, data: bytes) -> None:
"""Send binary data."""
[docs]
async def recv(self) -> bytes:
"""Receive binary data."""
[docs]
async def close(self) -> None:
"""Close the connection."""
class _InFlightMessage:
"""Represents a message that is out on the wire."""
def __init__(self, message_id: int) -> None:
self._response: bytes | None = None
self._got_response = asyncio.Event()
self.wait_task = asyncio.create_task(
self._wait(), name=f'rpcws in-flight-msg {message_id} wait'
)
async def _wait(self) -> bytes:
await self._got_response.wait()
assert self._response is not None
return self._response
def set_response(self, data: bytes) -> None:
"""Set response data."""
assert self._response is None
self._response = data
self._got_response.set()
# Packet type bytes prepended to each WebSocket message.
_TYPE_MESSAGE: int = 0
_TYPE_RESPONSE: int = 1
_BYTE_ORDER: Literal['big'] = 'big'
[docs]
class RPCWSEndpoint:
"""Facilitates asynchronous multiplexed remote procedure calls
over WebSockets.
Similar to RPCEndpoint but leverages WebSocket framing for message
boundaries and WebSocket ping/pong for keepalive, resulting in a
simpler implementation.
"""
# How long we should wait before giving up on a message by default.
# Note this includes processing time on the other end.
DEFAULT_MESSAGE_TIMEOUT = 60.0
def __init__(
self,
handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
transport: WebSocketTransport,
label: str,
*,
debug_print: bool = False,
debug_print_call: Callable[[str], None] | None = None,
) -> None:
self._handle_raw_message_call = handle_raw_message_call
self._transport = transport
self._label = label
self.debug_print = debug_print
if debug_print_call is None:
debug_print_call = print
self.debug_print_call: Callable[[str], None] = debug_print_call
self._closing = False
self._did_wait_closed = False
self._event_loop = asyncio.get_running_loop()
self._run_called = False
self._create_time = time.monotonic()
self._tasks: list[asyncio.Task] = []
# (Start near the end to make sure our looping logic is sound).
self._next_message_id = 65530
self._in_flight_messages: dict[int, _InFlightMessage] = {}
if self.debug_print:
self.debug_print_call(f'{self._label}: connected at {self._tm()}.')
[docs]
async def run(self) -> None:
"""Run the endpoint until the connection is lost or closed."""
if self._run_called:
raise RuntimeError('Run can be called only once per endpoint.')
self._run_called = True
try:
await self._run_read_loop()
except asyncio.CancelledError:
logger.warning(
'RPCWSEndpoint.run cancelled; want to try and avoid this.'
)
raise
except CommunicationError:
if self.debug_print:
self.debug_print_call(f'{self._label}: connection ended.')
except Exception:
logger.exception(
'Unexpected error in rpcws %s read loop (age=%.1f).',
self._label,
time.monotonic() - self._create_time,
)
finally:
try:
self.close()
await self.wait_closed()
except Exception:
logger.exception('Error closing %s.', self._label)
if self.debug_print:
self.debug_print_call(f'{self._label}: finished.')
[docs]
def send_message(
self,
message: bytes,
timeout: float | None = None,
close_on_error: bool = True,
) -> Awaitable[bytes]:
"""Send a message to the peer and return a response.
If timeout is not provided, the default will be used.
Raises a CommunicationError if the round trip is not completed
for any reason.
By default, the entire endpoint will go down in the case of
errors. This allows messages to be treated as 'reliable' with
respect to a given endpoint. Pass close_on_error=False to
override this for a particular message.
"""
if self._closing:
raise CommunicationError('Endpoint is closed.')
# message_id is a 16 bit looping value.
message_id = self._next_message_id
self._next_message_id = (self._next_message_id + 1) % 65536
# Make an entry so we know this message is out there.
assert message_id not in self._in_flight_messages
msgobj = self._in_flight_messages[message_id] = _InFlightMessage(
message_id
)
# Also add its task to our list so we properly cancel it if we
# die.
self._prune_tasks()
self._tasks.append(msgobj.wait_task)
if timeout is None:
timeout = self.DEFAULT_MESSAGE_TIMEOUT
assert timeout is not None
return self._send_message(
message, timeout, close_on_error, msgobj.wait_task, message_id
)
async def _send_message(
self,
message: bytes,
timeout: float,
close_on_error: bool,
bytes_awaitable: asyncio.Task[bytes],
message_id: int,
) -> bytes:
# pylint: disable=too-many-positional-arguments
# Build the wire frame: type(1b) + message_id(2b) + payload.
frame = (
_TYPE_MESSAGE.to_bytes(1, _BYTE_ORDER)
+ message_id.to_bytes(2, _BYTE_ORDER)
+ message
)
try:
await self._transport.send(frame)
except Exception as exc:
bytes_awaitable.cancel()
del self._in_flight_messages[message_id]
if close_on_error:
self.close()
raise CommunicationError() from exc
if self.debug_print:
self.debug_print_call(
f'{self._label}: sent message {message_id}'
f' of size {len(message)} at {self._tm()}.'
)
try:
return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
except asyncio.CancelledError as exc:
current_task = asyncio.current_task()
if current_task is not None and current_task.cancelling() > 0:
raise
if self.debug_print:
self.debug_print_call(
f'{self._label}: message {message_id} was cancelled.'
)
if close_on_error:
self.close()
raise CommunicationError() from exc
except asyncio.TimeoutError as exc:
if self.debug_print:
self.debug_print_call(
f'{self._label}: message {message_id} timed out.'
)
bytes_awaitable.cancel()
del self._in_flight_messages[message_id]
if close_on_error:
self.close()
raise CommunicationError() from exc
[docs]
def close(self) -> None:
"""Begin closing the endpoint."""
if self._closing:
return
if self.debug_print:
self.debug_print_call(f'{self._label}: closing...')
self._closing = True
# Kill all of our in-flight tasks.
for task in self._get_live_tasks():
task.cancel()
# We don't need this anymore and it may create a dependency loop.
del self._handle_raw_message_call
[docs]
def is_closing(self) -> bool:
"""Have we begun the process of closing?"""
return self._closing
[docs]
async def wait_closed(self) -> None:
"""Wait for the endpoint to finish closing.
This is called by run() so generally does not need to be
explicitly called.
"""
if self._did_wait_closed:
return
self._did_wait_closed = True
if not self._closing:
raise RuntimeError('Must be called after close()')
live_tasks = self._get_live_tasks()
self._tasks = []
if live_tasks:
results = await asyncio.gather(*live_tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
logger.warning(
'Got unexpected error cleaning up %s task: %s',
self._label,
result,
)
# Close the underlying transport.
try:
await asyncio.wait_for(self._transport.close(), timeout=10.0)
except Exception:
pass
async def _run_read_loop(self) -> None:
"""Read incoming WebSocket messages and dispatch them."""
while not self._closing:
try:
raw = await self._transport.recv()
except Exception as exc:
# If we're closing, the recv error is expected.
if self.is_closing():
return
raise CommunicationError() from exc
if len(raw) < 3:
raise CommunicationError('Invalid rpcws frame.')
ptype = raw[0]
message_id = int.from_bytes(raw[1:3], _BYTE_ORDER)
payload = raw[3:]
if ptype == _TYPE_MESSAGE:
if self.debug_print:
self.debug_print_call(
f'{self._label}: received message {message_id}'
f' of size {len(payload)} at {self._tm()}.'
)
self._prune_tasks()
self._tasks.append(
asyncio.create_task(
self._handle_raw_message(
message_id=message_id, message=payload
),
name='rpcws message handle',
)
)
elif ptype == _TYPE_RESPONSE:
if self.debug_print:
self.debug_print_call(
f'{self._label}: received response {message_id}'
f' of size {len(payload)} at {self._tm()}.'
)
msgobj = self._in_flight_messages.get(message_id)
if msgobj is None:
if self.debug_print:
self.debug_print_call(
f'{self._label}: got response for nonexistent'
f' message id {message_id};'
f' perhaps it timed out?'
)
else:
msgobj.set_response(payload)
else:
raise CommunicationError(f'Invalid rpcws packet type: {ptype}.')
async def _handle_raw_message(
self, message_id: int, message: bytes
) -> None:
try:
response = await self._handle_raw_message_call(message)
except Exception:
logger.exception('Error handling raw rpcws message')
return
# Send back the response.
frame = (
_TYPE_RESPONSE.to_bytes(1, _BYTE_ORDER)
+ message_id.to_bytes(2, _BYTE_ORDER)
+ response
)
try:
await self._transport.send(frame)
except Exception:
if not self._closing:
logger.warning(
'Error sending rpcws response for message %d.',
message_id,
)
def _tm(self) -> str:
"""Simple readable time value for debugging."""
tval = time.monotonic() % 100.0
return f'{tval:.2f}'
def _prune_tasks(self) -> None:
self._tasks = self._get_live_tasks()
def _get_live_tasks(self) -> list[asyncio.Task]:
return [t for t in self._tasks if not t.done()]
# Docs-generation hack; import some stuff that we likely only forward-declared
# in our actual source code so that docs tools can find it.
from typing import (Coroutine, Any, Literal, Callable,
Generator, Awaitable, Sequence, Self)
import asyncio
from concurrent.futures import Future
from pathlib import Path
from enum import Enum