# Released under the MIT License. See LICENSE for details.
#
"""Thread pool functionality."""
from __future__ import annotations
import time
import logging
import threading
from typing import TYPE_CHECKING, ParamSpec
from concurrent.futures import ThreadPoolExecutor
from efro.util import strip_exception_tracebacks
if TYPE_CHECKING:
from typing import Any, Callable
from concurrent.futures import Future
P = ParamSpec('P')
logger = logging.getLogger(__name__)
[docs]
class ThreadPoolExecutorEx(ThreadPoolExecutor):
"""A ThreadPoolExecutor with additional functionality added."""
def __init__(
self,
max_workers: int | None = None,
thread_name_prefix: str = '',
initializer: Callable[[], None] | None = None,
max_no_wait_count: int | None = None,
) -> None:
super().__init__(
max_workers=max_workers,
thread_name_prefix=thread_name_prefix,
initializer=initializer,
)
self.no_wait_count = 0
self._max_no_wait_count = (
max_no_wait_count
if max_no_wait_count is not None
else 50 if max_workers is None else max_workers * 2
)
self._last_no_wait_warn_time: float | None = None
self._no_wait_count_lock = threading.Lock()
[docs]
def submit_no_wait(
self, call: Callable[P, Any], *args: P.args, **keywds: P.kwargs
) -> None:
"""Submit work to the threadpool with no expectation of waiting.
Any exceptions raised by the callable are automatically caught
and logged via ``logger.exception()``, so callers do not need
their own error handling for fire-and-forget work. This call will
block and log a warning if the threadpool reaches its max queued
no-wait call count.
"""
# If we're too backlogged, issue a warning and block until we
# aren't. We don't bother with the lock here since this can be
# slightly inexact. In general we should aim to not hit this
# limit but it is good to have backpressure to avoid runaway
# queues in cases of network outages/etc.
if self.no_wait_count > self._max_no_wait_count:
now = time.monotonic()
if (
self._last_no_wait_warn_time is None
or now - self._last_no_wait_warn_time > 10.0
):
logger.warning(
'ThreadPoolExecutorEx hit max no-wait limit of %s;'
' blocking.',
self._max_no_wait_count,
)
self._last_no_wait_warn_time = now
while self.no_wait_count > self._max_no_wait_count:
time.sleep(0.01)
fut = self.submit(call, *args, **keywds)
with self._no_wait_count_lock:
self.no_wait_count += 1
fut.add_done_callback(self._no_wait_done)
def _no_wait_done(self, fut: Future) -> None:
with self._no_wait_count_lock:
self.no_wait_count -= 1
try:
fut.result()
except Exception as exc:
logger.exception('Error in work submitted via submit_no_wait().')
# We're done with this exception, so strip its traceback to
# avoid reference cycles.
strip_exception_tracebacks(exc)
# ---- Threadpool introspection ----
#
# These accessors let monitoring code sample a pool's live state.
# They reach into ``concurrent.futures.thread``'s private attributes
# (``_work_queue`` / ``_threads`` / ``_idle_semaphore``) because the
# public API doesn't expose queue depth or busy-count. Each access
# is wrapped in a try/except and falls back to ``None`` with a
# one-shot WARNING log if the internals change shape — that warning
# is the canary that this code needs an update for the current
# CPython version. Free functions (rather than methods on
# ``ThreadPoolExecutorEx``) so they work on any
# ``concurrent.futures.ThreadPoolExecutor``, including the asyncio
# loop's default executor (which is a plain ``ThreadPoolExecutor``).
#: Tracks per-(executor-id, attr) keys we've already warned about,
#: so a single broken introspection produces one log line per
#: process rather than spamming. ``id(executor)`` keys lets multiple
#: pools coexist with independent warning state.
_g_introspection_warned: set[tuple[int, str]] = set()
_g_introspection_warned_lock = threading.Lock()
def _warn_introspection_broken(executor: ThreadPoolExecutor, what: str) -> None:
key = (id(executor), what)
with _g_introspection_warned_lock:
if key in _g_introspection_warned:
return
_g_introspection_warned.add(key)
logger.warning(
'Threadpool introspection broken: %s on %r.'
' CPython internals may have changed shape;'
' efro/threadpool.py needs an update.',
what,
type(executor).__name__,
)
[docs]
def queue_depth(executor: ThreadPoolExecutor) -> int | None:
"""Return the current count of pending work items in ``executor``.
Best-effort: reads the underlying ``ThreadPoolExecutor``'s
``_work_queue`` (a :class:`queue.SimpleQueue`) and calls its
public ``qsize()``. Returns ``None`` and logs a one-shot warning
(per-executor, per-attr) if the shape doesn't match expectations.
"""
try:
wq = getattr(executor, '_work_queue', None)
if wq is None or not hasattr(wq, 'qsize'):
_warn_introspection_broken(executor, '_work_queue.qsize')
return None
return int(wq.qsize())
except Exception: # pylint: disable=broad-exception-caught
logger.exception('queue_depth() introspection failed.')
return None
[docs]
def live_thread_count(executor: ThreadPoolExecutor) -> int | None:
"""Return the count of worker threads currently alive in ``executor``.
``ThreadPoolExecutor`` spawns workers on demand and keeps them
alive until ``shutdown()``, so this is a high-water rather than
instantaneous-busy count. Pair with :func:`busy_workers` or
:func:`queue_depth` to interpret saturation. Returns ``None`` and
logs a one-shot warning on internals breakage.
"""
try:
threads = getattr(executor, '_threads', None)
if not isinstance(threads, set):
_warn_introspection_broken(executor, '_threads not a set')
return None
return len(threads)
except Exception: # pylint: disable=broad-exception-caught
logger.exception('live_thread_count() introspection failed.')
return None
[docs]
def busy_workers(executor: ThreadPoolExecutor) -> int | None:
"""Return the count of worker threads currently executing work.
Computed as ``live_thread_count - idle_workers``, where
``idle_workers`` reads ``_idle_semaphore._value`` (the
Semaphore's remaining permits — workers ``release()`` when they
go idle and ``acquire()`` when they pick up new work). Touching
``Semaphore._value`` is the most fragile of these accessors; on
any breakage we return ``None`` and log a one-shot warning
rather than guessing.
"""
live = live_thread_count(executor)
if live is None:
return None
try:
idle_sem = getattr(executor, '_idle_semaphore', None)
if idle_sem is None:
_warn_introspection_broken(executor, '_idle_semaphore missing')
return None
idle_count = getattr(idle_sem, '_value', None)
if not isinstance(idle_count, int):
_warn_introspection_broken(executor, '_idle_semaphore._value')
return None
# Clamp: under transient races between sampling and workers
# transitioning, idle could exceed live by an off-by-one.
# Floor at zero so we never report negative.
return max(0, live - idle_count)
except Exception: # pylint: disable=broad-exception-caught
logger.exception('busy_workers() introspection failed.')
return None
# Docs-generation hack; import some stuff that we likely only forward-declared
# in our actual source code so that docs tools can find it.
from typing import (Coroutine, Any, Literal, Callable,
Generator, Awaitable, Sequence, Self)
import asyncio
from concurrent.futures import Future
from pathlib import Path
from enum import Enum