Source code for bacommontools.streamws

# Released under the MIT License. See LICENSE for details.
#
"""WebSocket-based stream consumer for bacloud (Phase 2).

A stream-mode bacloud kickoff lands at a basn node, which injects a
``StreamWS`` into the response pointing at its own
``/streamcall/<call_id>`` WebSocket endpoint. We open that WS, print
``StreamOutput`` frames live as they arrive, and return the terminal
``StreamFinal`` so the caller can splice it back into bacloud's
existing response-handling flow.

On a non-terminal close (network blip, abnormal close, expired
token) we reconnect — refreshing the token via ``POST
/streamcall/<call_id>/refresh-token`` first if the close code says
the token is expired (4001). Reconnects use exponential backoff up
to a configurable wall-clock budget (default 60s, override via
``BACLOUD_RECONNECT_BUDGET_SECONDS``); past the budget we surface
``CleanError``. Token-bad / call-id-mismatch / no-token closes
(4002/4003/4004) are fatal — no retry.

v0 reconnect doesn't ask basn to replay the cursor: a reconnecting
client may miss frames that landed during the disconnect window. In
practice the stream still completes (the basn-side subscription
keeps polling regardless of WS attachments), and bacloud renders the
terminal ``StreamFinal`` correctly. Cursor-aware resume is
Phase 3 territory.

Test-only env vars:

- ``BACLOUD_TEST_FORCE_DROP_AFTER_SECONDS=N`` — force the WS closed
  N seconds after open; the reconnect path then runs as it would on
  a real drop.
- ``BACLOUD_TEST_BREAK_RECONNECT=1`` — point the reconnect URL at a
  guaranteed-unreachable host (``127.0.0.1:1``); reconnects fail
  until the budget expires.
"""

from __future__ import annotations

import asyncio
import os
import sys
import time
from typing import TYPE_CHECKING

from efro.error import CleanError
from efro.dataclassio import dataclass_from_json
from bacommon.bacloud import (
    BACLOUD_VERSION,
    ResponseData,
    StreamFinal,
    StreamFrame,
    StreamOutput,
)

if TYPE_CHECKING:
    import urllib.request

    from bacommon.bacloud import StreamWS


_DEFAULT_RECONNECT_BUDGET_SECONDS = 60.0
_RECONNECT_BACKOFF_MIN = 0.5
_RECONNECT_BACKOFF_MAX = 10.0
# A guaranteed-unreachable address used by the
# ``BACLOUD_TEST_BREAK_RECONNECT`` test hook.
_BROKEN_RECONNECT_HOST = '127.0.0.1:1'


[docs] def consume_via_ws( response: ResponseData, *, bearer: str | None, host: str ) -> ResponseData: """Drain a stream over WebSocket and return a terminal-only response. The returned ``ResponseData`` carries the terminal ``StreamFinal`` in ``stream_frames`` so bacloud's existing ``stream_frames`` loop falls through to the usual terminal handling (message/error/end_command). ``host`` is the bacloud client's resolved kickoff hostname (the basn the kickoff went to); used to construct the WS URL when the producer didn't pin one. Caller must check ``response.stream_ws is not None`` first. Raises :class:`~efro.error.CleanError` on unrecoverable WS failure (token-bad, reconnect-budget exhausted, etc.). """ assert response.stream_ws is not None terminal = asyncio.run( _consume_with_reconnect(response.stream_ws, bearer, host) ) return ResponseData(stream_frames=[terminal])
def _reconnect_budget_seconds() -> float: raw = os.environ.get('BACLOUD_RECONNECT_BUDGET_SECONDS') if raw is None: return _DEFAULT_RECONNECT_BUDGET_SECONDS try: return float(raw) except ValueError: return _DEFAULT_RECONNECT_BUDGET_SECONDS def _force_drop_after_seconds() -> float | None: raw = os.environ.get('BACLOUD_TEST_FORCE_DROP_AFTER_SECONDS') if raw is None: return None try: return float(raw) except ValueError: return None def _refresh_url_for(ws_url: str) -> str: """Compute the refresh-token endpoint URL from the WS URL.""" https = ws_url.replace('wss://', 'https://').replace('ws://', 'http://') return f'{https}/refresh-token' def _ws_url_for_reconnect(ws_url: str) -> str: """Apply the ``BACLOUD_TEST_BREAK_RECONNECT`` hook if set.""" if os.environ.get('BACLOUD_TEST_BREAK_RECONNECT') == '1': # Strip the host but preserve the path. The path includes # ``/streamcall/<call_id>``; we want websockets to connect # to a definitely-unreachable host on that path. from urllib.parse import urlparse, urlunparse parsed = urlparse(ws_url) return urlunparse(parsed._replace(netloc=_BROKEN_RECONNECT_HOST)) return ws_url def _resolve_ws_url(sw: 'StreamWS', host: str) -> str: """Determine the WS URL the client should connect to. When ``sw.basn_url`` is set the producer pinned the stream to a specific basn (Phase 3 case); we honor that. Otherwise we construct the URL from the bacloud client's own kickoff host, so the LB routes us to a healthy basn anywhere in the fleet. """ if sw.basn_url is not None: return sw.basn_url return f'wss://{host}/streamcall/{sw.call_id}' async def _consume_with_reconnect( sw: StreamWS, bearer: str | None, host: str ) -> StreamFinal: """Open the WS (with reconnect on transient failure).""" import websockets # The ws_token field is now a securedata.Archive nested in the # response. We pass it on the WS handshake as an HTTP header # value, which means we encode it as base64-of-canonical-JSON # — HTTP headers don't carry raw JSON cleanly, and basn does # the inverse decode on receipt. current_token = _encode_archive_for_header(sw.ws_token) base_ws_url = _resolve_ws_url(sw, host) deadline = time.monotonic() + _reconnect_budget_seconds() backoff = _RECONNECT_BACKOFF_MIN is_first_connection = True # Force-drop is meant to simulate a single mid-stream drop and # then let reconnect succeed naturally; firing it on every # reconnect would just stall the test forever. force_drop_seconds = _force_drop_after_seconds() while True: url = ( base_ws_url if is_first_connection else _ws_url_for_reconnect(base_ws_url) ) try: terminal = await _consume_once( url=url, token=current_token, bearer=bearer, websockets_module=websockets, force_drop_seconds=force_drop_seconds, ) except _NeedsTokenRefresh: try: current_token = await _refresh_token( base_ws_url, current_token, bearer ) except _RefreshFailed as exc: raise CleanError( f'Stream WS token refresh failed: {exc}' ) from exc print( '[bacloud] WS token refreshed; reconnecting...', file=sys.stderr, ) is_first_connection = False force_drop_seconds = None backoff = _RECONNECT_BACKOFF_MIN continue except _FatalAuth as exc: raise CleanError(f'Stream WS auth failed: {exc}') from exc except _Reconnectable as exc: if time.monotonic() >= deadline: raise CleanError( f'Stream WS reconnect budget exhausted: {exc}' ) from exc print( f'[bacloud] WS dropped ({exc}); ' f'reconnecting in {backoff:.1f}s...', file=sys.stderr, ) await asyncio.sleep(backoff) backoff = min(backoff * 2, _RECONNECT_BACKOFF_MAX) is_first_connection = False force_drop_seconds = None continue else: return terminal async def _consume_once( *, url: str, token: str, bearer: str | None, websockets_module: object, force_drop_seconds: float | None, ) -> StreamFinal: """One WS-open-to-close cycle. Raises classification exceptions.""" websockets = websockets_module # for readability from websockets.exceptions import ( ConnectionClosed, InvalidStatus, WebSocketException, ) headers: list[tuple[str, str]] = [('X-WS-Token', token)] if bearer is not None: headers.append(('Authorization', f'Bearer {bearer}')) headers.append(('User-Agent', f'bacloud/{BACLOUD_VERSION}')) drop_task: asyncio.Task[None] | None = None try: async with websockets.connect( # type: ignore[attr-defined] url, additional_headers=headers ) as ws: if force_drop_seconds is not None: drop_task = asyncio.create_task( _force_drop_at(ws, force_drop_seconds) ) async for raw in ws: if isinstance(raw, bytes): raw = raw.decode('utf-8') frame = dataclass_from_json(StreamFrame, raw) if isinstance(frame, StreamOutput): print(frame.text, end='', flush=True) elif isinstance(frame, StreamFinal): return frame # Loop ended without a StreamFinal — treat as # reconnectable; basn's subscription is still # alive server-side (or has cleanly ended without # us seeing the terminal frame). raise _Reconnectable('WS closed without terminal frame') except InvalidStatus as exc: # Handshake-time HTTP error — basn rejected the upgrade # before we got an app-level close code. Treat as fatal: # likely a versioning / routing problem. raise _FatalAuth(f'handshake rejected: {exc}') from exc except ConnectionClosed as exc: if exc.code == 4001: # token expired raise _NeedsTokenRefresh(str(exc)) from exc if exc.code in (4002, 4003, 4004): # token bad / mismatch / missing raise _FatalAuth(f'code={exc.code} reason={exc.reason!r}') from exc raise _Reconnectable( f'closed: code={exc.code} reason={exc.reason!r}' ) from exc except WebSocketException as exc: raise _Reconnectable(f'protocol error: {exc}') from exc except OSError as exc: raise _Reconnectable(f'connect failed: {exc}') from exc finally: if drop_task is not None: drop_task.cancel() async def _force_drop_at(ws: object, after_seconds: float) -> None: """Test-only: close the WS after ``after_seconds``.""" try: await asyncio.sleep(after_seconds) print( f'[bacloud] BACLOUD_TEST_FORCE_DROP_AFTER_SECONDS=:' f' force-closing WS after {after_seconds}s', file=sys.stderr, ) # Close the underlying transport, which surfaces in the # consumer loop as a ConnectionClosed (typically code 1006). await ws.close() # type: ignore[attr-defined] except asyncio.CancelledError: pass async def _refresh_token( basn_url: str, current_token: str, bearer: str | None ) -> str: """POST to refresh-token. Returns the new token string.""" import json import urllib.error import urllib.request url = _refresh_url_for(basn_url) req = urllib.request.Request(url, method='POST') req.add_header('X-WS-Token', current_token) if bearer is not None: req.add_header('Authorization', f'Bearer {bearer}') req.add_header('User-Agent', f'bacloud/{BACLOUD_VERSION}') # urllib is sync; run in a thread to avoid blocking the loop. loop = asyncio.get_running_loop() try: body = await loop.run_in_executor( None, lambda: _http_post(req).decode('utf-8') ) except urllib.error.HTTPError as exc: body_str = exc.read().decode(errors='replace') raise _RefreshFailed(f'HTTP {exc.code} from {url}: {body_str}') from exc except urllib.error.URLError as exc: raise _RefreshFailed(f'connect failed to {url}: {exc.reason}') from exc try: data = json.loads(body) return str(data['ws_token']) except (ValueError, KeyError) as exc: raise _RefreshFailed(f'unparseable response: {body!r}') from exc def _encode_archive_for_header(archive: object) -> str: """Encode a :class:`bacommon.securedata.Archive` for an HTTP header. Header value is base64-of-canonical-JSON. basn's :func:`_decode_token_header` is the inverse. """ import base64 from efro.dataclassio import dataclass_to_json return ( base64.urlsafe_b64encode(dataclass_to_json(archive).encode()) .rstrip(b'=') .decode('ascii') ) def _http_post(req: urllib.request.Request) -> bytes: """Sync HTTP POST helper for use under ``run_in_executor``.""" import urllib.request with urllib.request.urlopen(req, timeout=15) as resp: return resp.read() # type: ignore[no-any-return] class _Reconnectable(Exception): """Internal: WS dropped on a recoverable signal; retry with backoff.""" class _NeedsTokenRefresh(Exception): """Internal: WS closed with 4001 (expired); refresh & retry.""" class _FatalAuth(Exception): """Internal: WS closed with an unrecoverable auth code; give up.""" class _RefreshFailed(Exception): """Internal: refresh-token endpoint failed.""" # Docs-generation hack; import some stuff that we likely only forward-declared # in our actual source code so that docs tools can find it. from typing import (Coroutine, Any, Literal, Callable, Generator, Awaitable, Sequence, Self) import asyncio from concurrent.futures import Future from pathlib import Path from enum import Enum