Source code for efro.threadpool

# Released under the MIT License. See LICENSE for details.
#
"""Thread pool functionality."""

from __future__ import annotations  # Docs-generation hack.

import time
import logging
import functools
import threading
from collections import Counter
from typing import TYPE_CHECKING, ParamSpec, TypeVar, override
from concurrent.futures import ThreadPoolExecutor

from efro.util import strip_exception_tracebacks

if TYPE_CHECKING:
    from typing import Any, Callable
    from concurrent.futures import Future

P = ParamSpec('P')
T = TypeVar('T')

logger = logging.getLogger(__name__)


[docs] class ThreadPoolExecutorEx(ThreadPoolExecutor): """A ThreadPoolExecutor with extra diagnostics. Intended for **efficiency**: parallelizing pieces of a single task so it finishes faster. It is **not** a queue for long-running or blocking work -- a worker tied up on a slow task can't run anything else, so long/blocking work starves the pool and delays everything queued behind it (and ``submit_no_wait`` callers can pile up a backlog). To make misuse easy to spot, submitted work is timed: a task that waits too long in the queue before starting, or runs too long, logs a (rate-limited) warning naming the callable. ``submit_no_wait`` also logs when its backlog exceeds a soft limit. None of these block. """ def __init__( self, max_workers: int | None = None, thread_name_prefix: str = '', initializer: Callable[[], None] | None = None, max_no_wait_count: int | None = None, *, allow_submit_no_wait: bool = True, queue_wait_warn_seconds: float = 10.0, run_duration_warn_seconds: float = 5.0, log_throttle_seconds: float = 10.0, ) -> None: super().__init__( max_workers=max_workers, thread_name_prefix=thread_name_prefix, initializer=initializer, ) self.no_wait_count = 0 #: Whether submit_no_wait() may be used on this pool. Pass #: False on hosts with no background CPU (e.g. Cloud Run #: request-based billing), where fire-and-forget work would #: stall between requests; submit_no_wait() then raises #: RuntimeError so offending call sites surface loudly. #: Callers with legitimate fire-and-forget needs can branch #: on this attr to do the work synchronously instead. self.allow_submit_no_wait = allow_submit_no_wait self._max_no_wait_count = ( max_no_wait_count if max_no_wait_count is not None else 50 if max_workers is None else max_workers * 2 ) #: Warn if a submitted task waits longer than this in the queue #: before a worker picks it up (pool saturation), or runs longer #: than the run threshold (likely misuse -- long work on an #: efficiency pool). self._queue_wait_warn_seconds = queue_wait_warn_seconds self._run_duration_warn_seconds = run_duration_warn_seconds #: Min seconds between repeats of any one throttled log line, so a #: bad spell can't saturate the logs. Keyed by a short kind string. self._log_throttle_seconds = log_throttle_seconds self._last_log_times: dict[str, float] = {} self._no_wait_count_lock = threading.Lock() #: Count of in-flight no-wait calls keyed by callable name, so an #: over-soft-limit log can name the spike's likely source. #: Guarded by ``_no_wait_count_lock``. self._no_wait_calls: Counter[str] = Counter()
[docs] def submit_no_wait( self, call: Callable[P, Any], *args: P.args, **keywds: P.kwargs ) -> None: """Submit fire-and-forget work to the threadpool. Any exceptions raised by the callable are automatically caught and logged via ``logger.exception()``, so callers do not need their own error handling for fire-and-forget work. This call **never blocks**. If the pool's queued no-wait backlog exceeds its soft limit we log an error (naming the most common in-flight callables, to help pinpoint the source of a spike) but still submit. A blocking backpressure was used here previously, but it could self-deadlock when ``submit_no_wait`` was called from one of this pool's own workers — the backlog only drains via those same workers, so blocking them stalled it forever. A loud, non-blocking warning gives the same visibility without that risk. Raises RuntimeError if the pool was created with ``allow_submit_no_wait=False`` (hosts with no background CPU). """ if not self.allow_submit_no_wait: raise RuntimeError( 'submit_no_wait() is disabled for this threadpool' ' (no background processing available on this host).' ' Do the work synchronously instead; see the' ' allow_submit_no_wait attr.' ) key = _callable_name(call) with self._no_wait_count_lock: self.no_wait_count += 1 self._no_wait_calls[key] += 1 count = self.no_wait_count # Over the soft limit: log (rate-limited) but never block. if count > self._max_no_wait_count and self._should_log('backlog'): logger.error( 'ThreadPoolExecutorEx no-wait backlog (%d) exceeds soft' ' limit (%d); not blocking. Top in-flight no-wait' ' callables: %s.', count, self._max_no_wait_count, self._top_no_wait_calls(), ) fut = self.submit(call, *args, **keywds) fut.add_done_callback(functools.partial(self._no_wait_done, key=key))
def _top_no_wait_calls(self, count: int = 5) -> str: """Compact ``name=N, ...`` of the top in-flight no-wait calls.""" with self._no_wait_count_lock: top = self._no_wait_calls.most_common(count) return ', '.join(f'{name}={num}' for name, num in top) or '(none)'
[docs] @override def submit( self, fn: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs ) -> Future[T]: """Submit work, timing its queue-wait and run duration. See the class docstring: this pool is for short parallel work, not long-running or blocking tasks. Submitted callables are timed and a slow wait-to-start or a slow run logs a rate-limited warning naming the callable, so misuse is easy to spot. """ return super().submit(self._wrap_timed(fn), *args, **kwargs)
def _wrap_timed(self, fn: Callable[P, T]) -> Callable[P, T]: """Wrap ``fn`` to warn on excessive queue-wait / run duration.""" enqueue_time = time.monotonic() name = _callable_name(fn) def _timed(*args: P.args, **kwargs: P.kwargs) -> T: start = time.monotonic() wait = start - enqueue_time if wait > self._queue_wait_warn_seconds and self._should_log( 'queue_wait' ): logger.warning( 'ThreadPoolExecutorEx: %s waited %.1fs in the queue' ' before starting (over %.0fs). This pool is for short' ' parallel work; long/blocking tasks or floods saturate' ' it and delay everything queued behind them.', name, wait, self._queue_wait_warn_seconds, ) try: return fn(*args, **kwargs) finally: duration = time.monotonic() - start if ( duration > self._run_duration_warn_seconds and self._should_log('run_duration') ): logger.warning( 'ThreadPoolExecutorEx: %s ran %.1fs (over %.0fs).' ' This pool is for short parallel work to speed a' ' task up, not long-running or blocking work -- that' ' ties up a worker and starves the pool. Move long' ' work elsewhere.', name, duration, self._run_duration_warn_seconds, ) return _timed def _should_log(self, kind: str) -> bool: """Return True at most once per throttle window for ``kind``. Lock-free and thus slightly inexact under races (an occasional extra line), which is fine for diagnostics. """ now = time.monotonic() last = self._last_log_times.get(kind) if last is None or now - last > self._log_throttle_seconds: self._last_log_times[kind] = now return True return False
[docs] def submit_no_wait_or_run( self, call: Callable[P, Any], *args: P.args, **keywds: P.kwargs ) -> None: """Fire-and-forget ``call`` off-thread, or run it inline. Equivalent to :meth:`submit_no_wait` on pools that allow it, but on pools that don't (``allow_submit_no_wait=False`` — hosts with no background CPU, e.g. Cloud Run request-based billing / BEEF) it runs ``call`` synchronously instead of raising. Either way the work is **best-effort**: exceptions are caught and logged, never propagated (the inline path mirrors ``submit_no_wait``'s off-thread handling). This is the one-call form of the common "``if pool.allow_submit_no_wait: submit_no_wait(...) else: <run inline>``" branch. Use it for cheap, latency-shaving side effects (e.g. cache writes) where the inline fallback's brief synchronous cost is acceptable; for heavier work, branch explicitly so you notice when you're blocking a request. """ if self.allow_submit_no_wait: self.submit_no_wait(call, *args, **keywds) return # No background CPU on this pool -- run inline, best-effort. try: call(*args, **keywds) except Exception as exc: logger.exception('Error in work run via submit_no_wait_or_run().') # Terminal consumer of this exception; strip to avoid cycles. strip_exception_tracebacks(exc)
def _no_wait_done(self, fut: Future, *, key: str) -> None: with self._no_wait_count_lock: self.no_wait_count -= 1 self._no_wait_calls[key] -= 1 # Keep the Counter from accumulating stale zero entries. if self._no_wait_calls[key] <= 0: del self._no_wait_calls[key] try: fut.result() except Exception as exc: logger.exception('Error in work submitted via submit_no_wait().') # We're done with this exception, so strip its traceback to # avoid reference cycles. strip_exception_tracebacks(exc)
def _callable_name(call: Callable[..., Any]) -> str: """Best-effort human-readable name for a submitted callable. Unwraps :class:`functools.partial` chains to the underlying function so diagnostics name the real target, not ``functools.partial``. """ target: Any = call while isinstance(target, functools.partial): target = target.func return ( getattr(target, '__qualname__', None) or getattr(target, '__name__', None) or type(target).__name__ ) # ---- Threadpool introspection ---- # # These accessors let monitoring code sample a pool's live state. # They reach into ``concurrent.futures.thread``'s private attributes # (``_work_queue`` / ``_threads`` / ``_idle_semaphore``) because the # public API doesn't expose queue depth or busy-count. Each access # is wrapped in a try/except and falls back to ``None`` with a # one-shot WARNING log if the internals change shape — that warning # is the canary that this code needs an update for the current # CPython version. Free functions (rather than methods on # ``ThreadPoolExecutorEx``) so they work on any # ``concurrent.futures.ThreadPoolExecutor``, including the asyncio # loop's default executor (which is a plain ``ThreadPoolExecutor``). #: Tracks per-(executor-id, attr) keys we've already warned about, #: so a single broken introspection produces one log line per #: process rather than spamming. ``id(executor)`` keys lets multiple #: pools coexist with independent warning state. _g_introspection_warned: set[tuple[int, str]] = set() _g_introspection_warned_lock = threading.Lock() def _warn_introspection_broken(executor: ThreadPoolExecutor, what: str) -> None: key = (id(executor), what) with _g_introspection_warned_lock: if key in _g_introspection_warned: return _g_introspection_warned.add(key) logger.warning( 'Threadpool introspection broken: %s on %r.' ' CPython internals may have changed shape;' ' efro/threadpool.py needs an update.', what, type(executor).__name__, )
[docs] def queue_depth(executor: ThreadPoolExecutor) -> int | None: """Return the current count of pending work items in ``executor``. Best-effort: reads the underlying ``ThreadPoolExecutor``'s ``_work_queue`` (a :class:`queue.SimpleQueue`) and calls its public ``qsize()``. Returns ``None`` and logs a one-shot warning (per-executor, per-attr) if the shape doesn't match expectations. """ try: wq = getattr(executor, '_work_queue', None) if wq is None or not hasattr(wq, 'qsize'): _warn_introspection_broken(executor, '_work_queue.qsize') return None return int(wq.qsize()) except Exception: # pylint: disable=broad-exception-caught logger.exception('queue_depth() introspection failed.') return None
[docs] def live_thread_count(executor: ThreadPoolExecutor) -> int | None: """Return the count of worker threads currently alive in ``executor``. ``ThreadPoolExecutor`` spawns workers on demand and keeps them alive until ``shutdown()``, so this is a high-water rather than instantaneous-busy count. Pair with :func:`busy_workers` or :func:`queue_depth` to interpret saturation. Returns ``None`` and logs a one-shot warning on internals breakage. """ try: threads = getattr(executor, '_threads', None) if not isinstance(threads, set): _warn_introspection_broken(executor, '_threads not a set') return None return len(threads) except Exception: # pylint: disable=broad-exception-caught logger.exception('live_thread_count() introspection failed.') return None
[docs] def busy_workers(executor: ThreadPoolExecutor) -> int | None: """Return the count of worker threads currently executing work. Computed as ``live_thread_count - idle_workers``, where ``idle_workers`` reads ``_idle_semaphore._value`` (the Semaphore's remaining permits — workers ``release()`` when they go idle and ``acquire()`` when they pick up new work). Touching ``Semaphore._value`` is the most fragile of these accessors; on any breakage we return ``None`` and log a one-shot warning rather than guessing. """ live = live_thread_count(executor) if live is None: return None try: idle_sem = getattr(executor, '_idle_semaphore', None) if idle_sem is None: _warn_introspection_broken(executor, '_idle_semaphore missing') return None idle_count = getattr(idle_sem, '_value', None) if not isinstance(idle_count, int): _warn_introspection_broken(executor, '_idle_semaphore._value') return None # Clamp: under transient races between sampling and workers # transitioning, idle could exceed live by an off-by-one. # Floor at zero so we never report negative. return max(0, live - idle_count) except Exception: # pylint: disable=broad-exception-caught logger.exception('busy_workers() introspection failed.') return None
# Docs-generation hack; import some stuff that we likely only forward-declared # in our actual source code so that docs tools can find it. from typing import (Coroutine, Any, Literal, Callable, Generator, Awaitable, Sequence, Self) import asyncio from concurrent.futures import Future from pathlib import Path from enum import Enum