|
@@ -1,2807 +0,0 @@
|
|
|
-from __future__ import annotations
|
|
|
-
|
|
|
-import array
|
|
|
-import asyncio
|
|
|
-import concurrent.futures
|
|
|
-import math
|
|
|
-import os
|
|
|
-import socket
|
|
|
-import sys
|
|
|
-import threading
|
|
|
-import weakref
|
|
|
-from asyncio import (
|
|
|
- AbstractEventLoop,
|
|
|
- CancelledError,
|
|
|
- all_tasks,
|
|
|
- create_task,
|
|
|
- current_task,
|
|
|
- get_running_loop,
|
|
|
- sleep,
|
|
|
-)
|
|
|
-from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
|
|
|
-from collections import OrderedDict, deque
|
|
|
-from collections.abc import (
|
|
|
- AsyncGenerator,
|
|
|
- AsyncIterator,
|
|
|
- Awaitable,
|
|
|
- Callable,
|
|
|
- Collection,
|
|
|
- Coroutine,
|
|
|
- Iterable,
|
|
|
- Sequence,
|
|
|
-)
|
|
|
-from concurrent.futures import Future
|
|
|
-from contextlib import AbstractContextManager, suppress
|
|
|
-from contextvars import Context, copy_context
|
|
|
-from dataclasses import dataclass
|
|
|
-from functools import partial, wraps
|
|
|
-from inspect import (
|
|
|
- CORO_RUNNING,
|
|
|
- CORO_SUSPENDED,
|
|
|
- getcoroutinestate,
|
|
|
- iscoroutine,
|
|
|
-)
|
|
|
-from io import IOBase
|
|
|
-from os import PathLike
|
|
|
-from queue import Queue
|
|
|
-from signal import Signals
|
|
|
-from socket import AddressFamily, SocketKind
|
|
|
-from threading import Thread
|
|
|
-from types import CodeType, TracebackType
|
|
|
-from typing import (
|
|
|
- IO,
|
|
|
- TYPE_CHECKING,
|
|
|
- Any,
|
|
|
- Optional,
|
|
|
- TypeVar,
|
|
|
- cast,
|
|
|
-)
|
|
|
-from weakref import WeakKeyDictionary
|
|
|
-
|
|
|
-import sniffio
|
|
|
-
|
|
|
-from .. import (
|
|
|
- CapacityLimiterStatistics,
|
|
|
- EventStatistics,
|
|
|
- LockStatistics,
|
|
|
- TaskInfo,
|
|
|
- abc,
|
|
|
-)
|
|
|
-from .._core._eventloop import claim_worker_thread, threadlocals
|
|
|
-from .._core._exceptions import (
|
|
|
- BrokenResourceError,
|
|
|
- BusyResourceError,
|
|
|
- ClosedResourceError,
|
|
|
- EndOfStream,
|
|
|
- WouldBlock,
|
|
|
- iterate_exceptions,
|
|
|
-)
|
|
|
-from .._core._sockets import convert_ipv6_sockaddr
|
|
|
-from .._core._streams import create_memory_object_stream
|
|
|
-from .._core._synchronization import (
|
|
|
- CapacityLimiter as BaseCapacityLimiter,
|
|
|
-)
|
|
|
-from .._core._synchronization import Event as BaseEvent
|
|
|
-from .._core._synchronization import Lock as BaseLock
|
|
|
-from .._core._synchronization import (
|
|
|
- ResourceGuard,
|
|
|
- SemaphoreStatistics,
|
|
|
-)
|
|
|
-from .._core._synchronization import Semaphore as BaseSemaphore
|
|
|
-from .._core._tasks import CancelScope as BaseCancelScope
|
|
|
-from ..abc import (
|
|
|
- AsyncBackend,
|
|
|
- IPSockAddrType,
|
|
|
- SocketListener,
|
|
|
- UDPPacketType,
|
|
|
- UNIXDatagramPacketType,
|
|
|
-)
|
|
|
-from ..abc._eventloop import StrOrBytesPath
|
|
|
-from ..lowlevel import RunVar
|
|
|
-from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
|
-
|
|
|
-if TYPE_CHECKING:
|
|
|
- from _typeshed import FileDescriptorLike
|
|
|
-else:
|
|
|
- FileDescriptorLike = object
|
|
|
-
|
|
|
-if sys.version_info >= (3, 10):
|
|
|
- from typing import ParamSpec
|
|
|
-else:
|
|
|
- from typing_extensions import ParamSpec
|
|
|
-
|
|
|
-if sys.version_info >= (3, 11):
|
|
|
- from asyncio import Runner
|
|
|
- from typing import TypeVarTuple, Unpack
|
|
|
-else:
|
|
|
- import contextvars
|
|
|
- import enum
|
|
|
- import signal
|
|
|
- from asyncio import coroutines, events, exceptions, tasks
|
|
|
-
|
|
|
- from exceptiongroup import BaseExceptionGroup
|
|
|
- from typing_extensions import TypeVarTuple, Unpack
|
|
|
-
|
|
|
- class _State(enum.Enum):
|
|
|
- CREATED = "created"
|
|
|
- INITIALIZED = "initialized"
|
|
|
- CLOSED = "closed"
|
|
|
-
|
|
|
- class Runner:
|
|
|
- # Copied from CPython 3.11
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- *,
|
|
|
- debug: bool | None = None,
|
|
|
- loop_factory: Callable[[], AbstractEventLoop] | None = None,
|
|
|
- ):
|
|
|
- self._state = _State.CREATED
|
|
|
- self._debug = debug
|
|
|
- self._loop_factory = loop_factory
|
|
|
- self._loop: AbstractEventLoop | None = None
|
|
|
- self._context = None
|
|
|
- self._interrupt_count = 0
|
|
|
- self._set_event_loop = False
|
|
|
-
|
|
|
- def __enter__(self) -> Runner:
|
|
|
- self._lazy_init()
|
|
|
- return self
|
|
|
-
|
|
|
- def __exit__(
|
|
|
- self,
|
|
|
- exc_type: type[BaseException],
|
|
|
- exc_val: BaseException,
|
|
|
- exc_tb: TracebackType,
|
|
|
- ) -> None:
|
|
|
- self.close()
|
|
|
-
|
|
|
- def close(self) -> None:
|
|
|
- """Shutdown and close event loop."""
|
|
|
- if self._state is not _State.INITIALIZED:
|
|
|
- return
|
|
|
- try:
|
|
|
- loop = self._loop
|
|
|
- _cancel_all_tasks(loop)
|
|
|
- loop.run_until_complete(loop.shutdown_asyncgens())
|
|
|
- if hasattr(loop, "shutdown_default_executor"):
|
|
|
- loop.run_until_complete(loop.shutdown_default_executor())
|
|
|
- else:
|
|
|
- loop.run_until_complete(_shutdown_default_executor(loop))
|
|
|
- finally:
|
|
|
- if self._set_event_loop:
|
|
|
- events.set_event_loop(None)
|
|
|
- loop.close()
|
|
|
- self._loop = None
|
|
|
- self._state = _State.CLOSED
|
|
|
-
|
|
|
- def get_loop(self) -> AbstractEventLoop:
|
|
|
- """Return embedded event loop."""
|
|
|
- self._lazy_init()
|
|
|
- return self._loop
|
|
|
-
|
|
|
- def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
|
|
|
- """Run a coroutine inside the embedded event loop."""
|
|
|
- if not coroutines.iscoroutine(coro):
|
|
|
- raise ValueError(f"a coroutine was expected, got {coro!r}")
|
|
|
-
|
|
|
- if events._get_running_loop() is not None:
|
|
|
- # fail fast with short traceback
|
|
|
- raise RuntimeError(
|
|
|
- "Runner.run() cannot be called from a running event loop"
|
|
|
- )
|
|
|
-
|
|
|
- self._lazy_init()
|
|
|
-
|
|
|
- if context is None:
|
|
|
- context = self._context
|
|
|
- task = context.run(self._loop.create_task, coro)
|
|
|
-
|
|
|
- if (
|
|
|
- threading.current_thread() is threading.main_thread()
|
|
|
- and signal.getsignal(signal.SIGINT) is signal.default_int_handler
|
|
|
- ):
|
|
|
- sigint_handler = partial(self._on_sigint, main_task=task)
|
|
|
- try:
|
|
|
- signal.signal(signal.SIGINT, sigint_handler)
|
|
|
- except ValueError:
|
|
|
- # `signal.signal` may throw if `threading.main_thread` does
|
|
|
- # not support signals (e.g. embedded interpreter with signals
|
|
|
- # not registered - see gh-91880)
|
|
|
- sigint_handler = None
|
|
|
- else:
|
|
|
- sigint_handler = None
|
|
|
-
|
|
|
- self._interrupt_count = 0
|
|
|
- try:
|
|
|
- return self._loop.run_until_complete(task)
|
|
|
- except exceptions.CancelledError:
|
|
|
- if self._interrupt_count > 0:
|
|
|
- uncancel = getattr(task, "uncancel", None)
|
|
|
- if uncancel is not None and uncancel() == 0:
|
|
|
- raise KeyboardInterrupt()
|
|
|
- raise # CancelledError
|
|
|
- finally:
|
|
|
- if (
|
|
|
- sigint_handler is not None
|
|
|
- and signal.getsignal(signal.SIGINT) is sigint_handler
|
|
|
- ):
|
|
|
- signal.signal(signal.SIGINT, signal.default_int_handler)
|
|
|
-
|
|
|
- def _lazy_init(self) -> None:
|
|
|
- if self._state is _State.CLOSED:
|
|
|
- raise RuntimeError("Runner is closed")
|
|
|
- if self._state is _State.INITIALIZED:
|
|
|
- return
|
|
|
- if self._loop_factory is None:
|
|
|
- self._loop = events.new_event_loop()
|
|
|
- if not self._set_event_loop:
|
|
|
- # Call set_event_loop only once to avoid calling
|
|
|
- # attach_loop multiple times on child watchers
|
|
|
- events.set_event_loop(self._loop)
|
|
|
- self._set_event_loop = True
|
|
|
- else:
|
|
|
- self._loop = self._loop_factory()
|
|
|
- if self._debug is not None:
|
|
|
- self._loop.set_debug(self._debug)
|
|
|
- self._context = contextvars.copy_context()
|
|
|
- self._state = _State.INITIALIZED
|
|
|
-
|
|
|
- def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
|
|
|
- self._interrupt_count += 1
|
|
|
- if self._interrupt_count == 1 and not main_task.done():
|
|
|
- main_task.cancel()
|
|
|
- # wakeup loop if it is blocked by select() with long timeout
|
|
|
- self._loop.call_soon_threadsafe(lambda: None)
|
|
|
- return
|
|
|
- raise KeyboardInterrupt()
|
|
|
-
|
|
|
- def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
|
|
|
- to_cancel = tasks.all_tasks(loop)
|
|
|
- if not to_cancel:
|
|
|
- return
|
|
|
-
|
|
|
- for task in to_cancel:
|
|
|
- task.cancel()
|
|
|
-
|
|
|
- loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
|
|
|
-
|
|
|
- for task in to_cancel:
|
|
|
- if task.cancelled():
|
|
|
- continue
|
|
|
- if task.exception() is not None:
|
|
|
- loop.call_exception_handler(
|
|
|
- {
|
|
|
- "message": "unhandled exception during asyncio.run() shutdown",
|
|
|
- "exception": task.exception(),
|
|
|
- "task": task,
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
|
|
|
- """Schedule the shutdown of the default executor."""
|
|
|
-
|
|
|
- def _do_shutdown(future: asyncio.futures.Future) -> None:
|
|
|
- try:
|
|
|
- loop._default_executor.shutdown(wait=True) # type: ignore[attr-defined]
|
|
|
- loop.call_soon_threadsafe(future.set_result, None)
|
|
|
- except Exception as ex:
|
|
|
- loop.call_soon_threadsafe(future.set_exception, ex)
|
|
|
-
|
|
|
- loop._executor_shutdown_called = True
|
|
|
- if loop._default_executor is None:
|
|
|
- return
|
|
|
- future = loop.create_future()
|
|
|
- thread = threading.Thread(target=_do_shutdown, args=(future,))
|
|
|
- thread.start()
|
|
|
- try:
|
|
|
- await future
|
|
|
- finally:
|
|
|
- thread.join()
|
|
|
-
|
|
|
-
|
|
|
-T_Retval = TypeVar("T_Retval")
|
|
|
-T_contra = TypeVar("T_contra", contravariant=True)
|
|
|
-PosArgsT = TypeVarTuple("PosArgsT")
|
|
|
-P = ParamSpec("P")
|
|
|
-
|
|
|
-_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
|
|
|
-
|
|
|
-
|
|
|
-def find_root_task() -> asyncio.Task:
|
|
|
- root_task = _root_task.get(None)
|
|
|
- if root_task is not None and not root_task.done():
|
|
|
- return root_task
|
|
|
-
|
|
|
- # Look for a task that has been started via run_until_complete()
|
|
|
- for task in all_tasks():
|
|
|
- if task._callbacks and not task.done():
|
|
|
- callbacks = [cb for cb, context in task._callbacks]
|
|
|
- for cb in callbacks:
|
|
|
- if (
|
|
|
- cb is _run_until_complete_cb
|
|
|
- or getattr(cb, "__module__", None) == "uvloop.loop"
|
|
|
- ):
|
|
|
- _root_task.set(task)
|
|
|
- return task
|
|
|
-
|
|
|
- # Look up the topmost task in the AnyIO task tree, if possible
|
|
|
- task = cast(asyncio.Task, current_task())
|
|
|
- state = _task_states.get(task)
|
|
|
- if state:
|
|
|
- cancel_scope = state.cancel_scope
|
|
|
- while cancel_scope and cancel_scope._parent_scope is not None:
|
|
|
- cancel_scope = cancel_scope._parent_scope
|
|
|
-
|
|
|
- if cancel_scope is not None:
|
|
|
- return cast(asyncio.Task, cancel_scope._host_task)
|
|
|
-
|
|
|
- return task
|
|
|
-
|
|
|
-
|
|
|
-def get_callable_name(func: Callable) -> str:
|
|
|
- module = getattr(func, "__module__", None)
|
|
|
- qualname = getattr(func, "__qualname__", None)
|
|
|
- return ".".join([x for x in (module, qualname) if x])
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Event loop
|
|
|
-#
|
|
|
-
|
|
|
-_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
|
|
|
-
|
|
|
-
|
|
|
-def _task_started(task: asyncio.Task) -> bool:
|
|
|
- """Return ``True`` if the task has been started and has not finished."""
|
|
|
- # The task coro should never be None here, as we never add finished tasks to the
|
|
|
- # task list
|
|
|
- coro = task.get_coro()
|
|
|
- assert coro is not None
|
|
|
- try:
|
|
|
- return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
|
|
|
- except AttributeError:
|
|
|
- # task coro is async_genenerator_asend https://bugs.python.org/issue37771
|
|
|
- raise Exception(f"Cannot determine if task {task} has started or not") from None
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Timeouts and cancellation
|
|
|
-#
|
|
|
-
|
|
|
-
|
|
|
-def is_anyio_cancellation(exc: CancelledError) -> bool:
|
|
|
- # Sometimes third party frameworks catch a CancelledError and raise a new one, so as
|
|
|
- # a workaround we have to look at the previous ones in __context__ too for a
|
|
|
- # matching cancel message
|
|
|
- while True:
|
|
|
- if (
|
|
|
- exc.args
|
|
|
- and isinstance(exc.args[0], str)
|
|
|
- and exc.args[0].startswith("Cancelled by cancel scope ")
|
|
|
- ):
|
|
|
- return True
|
|
|
-
|
|
|
- if isinstance(exc.__context__, CancelledError):
|
|
|
- exc = exc.__context__
|
|
|
- continue
|
|
|
-
|
|
|
- return False
|
|
|
-
|
|
|
-
|
|
|
-class CancelScope(BaseCancelScope):
|
|
|
- def __new__(
|
|
|
- cls, *, deadline: float = math.inf, shield: bool = False
|
|
|
- ) -> CancelScope:
|
|
|
- return object.__new__(cls)
|
|
|
-
|
|
|
- def __init__(self, deadline: float = math.inf, shield: bool = False):
|
|
|
- self._deadline = deadline
|
|
|
- self._shield = shield
|
|
|
- self._parent_scope: CancelScope | None = None
|
|
|
- self._child_scopes: set[CancelScope] = set()
|
|
|
- self._cancel_called = False
|
|
|
- self._cancelled_caught = False
|
|
|
- self._active = False
|
|
|
- self._timeout_handle: asyncio.TimerHandle | None = None
|
|
|
- self._cancel_handle: asyncio.Handle | None = None
|
|
|
- self._tasks: set[asyncio.Task] = set()
|
|
|
- self._host_task: asyncio.Task | None = None
|
|
|
- if sys.version_info >= (3, 11):
|
|
|
- self._pending_uncancellations: int | None = 0
|
|
|
- else:
|
|
|
- self._pending_uncancellations = None
|
|
|
-
|
|
|
- def __enter__(self) -> CancelScope:
|
|
|
- if self._active:
|
|
|
- raise RuntimeError(
|
|
|
- "Each CancelScope may only be used for a single 'with' block"
|
|
|
- )
|
|
|
-
|
|
|
- self._host_task = host_task = cast(asyncio.Task, current_task())
|
|
|
- self._tasks.add(host_task)
|
|
|
- try:
|
|
|
- task_state = _task_states[host_task]
|
|
|
- except KeyError:
|
|
|
- task_state = TaskState(None, self)
|
|
|
- _task_states[host_task] = task_state
|
|
|
- else:
|
|
|
- self._parent_scope = task_state.cancel_scope
|
|
|
- task_state.cancel_scope = self
|
|
|
- if self._parent_scope is not None:
|
|
|
- # If using an eager task factory, the parent scope may not even contain
|
|
|
- # the host task
|
|
|
- self._parent_scope._child_scopes.add(self)
|
|
|
- self._parent_scope._tasks.discard(host_task)
|
|
|
-
|
|
|
- self._timeout()
|
|
|
- self._active = True
|
|
|
-
|
|
|
- # Start cancelling the host task if the scope was cancelled before entering
|
|
|
- if self._cancel_called:
|
|
|
- self._deliver_cancellation(self)
|
|
|
-
|
|
|
- return self
|
|
|
-
|
|
|
- def __exit__(
|
|
|
- self,
|
|
|
- exc_type: type[BaseException] | None,
|
|
|
- exc_val: BaseException | None,
|
|
|
- exc_tb: TracebackType | None,
|
|
|
- ) -> bool:
|
|
|
- del exc_tb
|
|
|
-
|
|
|
- if not self._active:
|
|
|
- raise RuntimeError("This cancel scope is not active")
|
|
|
- if current_task() is not self._host_task:
|
|
|
- raise RuntimeError(
|
|
|
- "Attempted to exit cancel scope in a different task than it was "
|
|
|
- "entered in"
|
|
|
- )
|
|
|
-
|
|
|
- assert self._host_task is not None
|
|
|
- host_task_state = _task_states.get(self._host_task)
|
|
|
- if host_task_state is None or host_task_state.cancel_scope is not self:
|
|
|
- raise RuntimeError(
|
|
|
- "Attempted to exit a cancel scope that isn't the current tasks's "
|
|
|
- "current cancel scope"
|
|
|
- )
|
|
|
-
|
|
|
- try:
|
|
|
- self._active = False
|
|
|
- if self._timeout_handle:
|
|
|
- self._timeout_handle.cancel()
|
|
|
- self._timeout_handle = None
|
|
|
-
|
|
|
- self._tasks.remove(self._host_task)
|
|
|
- if self._parent_scope is not None:
|
|
|
- self._parent_scope._child_scopes.remove(self)
|
|
|
- self._parent_scope._tasks.add(self._host_task)
|
|
|
-
|
|
|
- host_task_state.cancel_scope = self._parent_scope
|
|
|
-
|
|
|
- # Restart the cancellation effort in the closest visible, cancelled parent
|
|
|
- # scope if necessary
|
|
|
- self._restart_cancellation_in_parent()
|
|
|
-
|
|
|
- # We only swallow the exception iff it was an AnyIO CancelledError, either
|
|
|
- # directly as exc_val or inside an exception group and there are no cancelled
|
|
|
- # parent cancel scopes visible to us here
|
|
|
- if self._cancel_called and not self._parent_cancellation_is_visible_to_us:
|
|
|
- # For each level-cancel() call made on the host task, call uncancel()
|
|
|
- while self._pending_uncancellations:
|
|
|
- self._host_task.uncancel()
|
|
|
- self._pending_uncancellations -= 1
|
|
|
-
|
|
|
- # Update cancelled_caught and check for exceptions we must not swallow
|
|
|
- cannot_swallow_exc_val = False
|
|
|
- if exc_val is not None:
|
|
|
- for exc in iterate_exceptions(exc_val):
|
|
|
- if isinstance(exc, CancelledError) and is_anyio_cancellation(
|
|
|
- exc
|
|
|
- ):
|
|
|
- self._cancelled_caught = True
|
|
|
- else:
|
|
|
- cannot_swallow_exc_val = True
|
|
|
-
|
|
|
- return self._cancelled_caught and not cannot_swallow_exc_val
|
|
|
- else:
|
|
|
- if self._pending_uncancellations:
|
|
|
- assert self._parent_scope is not None
|
|
|
- assert self._parent_scope._pending_uncancellations is not None
|
|
|
- self._parent_scope._pending_uncancellations += (
|
|
|
- self._pending_uncancellations
|
|
|
- )
|
|
|
- self._pending_uncancellations = 0
|
|
|
-
|
|
|
- return False
|
|
|
- finally:
|
|
|
- self._host_task = None
|
|
|
- del exc_val
|
|
|
-
|
|
|
- @property
|
|
|
- def _effectively_cancelled(self) -> bool:
|
|
|
- cancel_scope: CancelScope | None = self
|
|
|
- while cancel_scope is not None:
|
|
|
- if cancel_scope._cancel_called:
|
|
|
- return True
|
|
|
-
|
|
|
- if cancel_scope.shield:
|
|
|
- return False
|
|
|
-
|
|
|
- cancel_scope = cancel_scope._parent_scope
|
|
|
-
|
|
|
- return False
|
|
|
-
|
|
|
- @property
|
|
|
- def _parent_cancellation_is_visible_to_us(self) -> bool:
|
|
|
- return (
|
|
|
- self._parent_scope is not None
|
|
|
- and not self.shield
|
|
|
- and self._parent_scope._effectively_cancelled
|
|
|
- )
|
|
|
-
|
|
|
- def _timeout(self) -> None:
|
|
|
- if self._deadline != math.inf:
|
|
|
- loop = get_running_loop()
|
|
|
- if loop.time() >= self._deadline:
|
|
|
- self.cancel()
|
|
|
- else:
|
|
|
- self._timeout_handle = loop.call_at(self._deadline, self._timeout)
|
|
|
-
|
|
|
- def _deliver_cancellation(self, origin: CancelScope) -> bool:
|
|
|
- """
|
|
|
- Deliver cancellation to directly contained tasks and nested cancel scopes.
|
|
|
-
|
|
|
- Schedule another run at the end if we still have tasks eligible for
|
|
|
- cancellation.
|
|
|
-
|
|
|
- :param origin: the cancel scope that originated the cancellation
|
|
|
- :return: ``True`` if the delivery needs to be retried on the next cycle
|
|
|
-
|
|
|
- """
|
|
|
- should_retry = False
|
|
|
- current = current_task()
|
|
|
- for task in self._tasks:
|
|
|
- should_retry = True
|
|
|
- if task._must_cancel: # type: ignore[attr-defined]
|
|
|
- continue
|
|
|
-
|
|
|
- # The task is eligible for cancellation if it has started
|
|
|
- if task is not current and (task is self._host_task or _task_started(task)):
|
|
|
- waiter = task._fut_waiter # type: ignore[attr-defined]
|
|
|
- if not isinstance(waiter, asyncio.Future) or not waiter.done():
|
|
|
- task.cancel(f"Cancelled by cancel scope {id(origin):x}")
|
|
|
- if (
|
|
|
- task is origin._host_task
|
|
|
- and origin._pending_uncancellations is not None
|
|
|
- ):
|
|
|
- origin._pending_uncancellations += 1
|
|
|
-
|
|
|
- # Deliver cancellation to child scopes that aren't shielded or running their own
|
|
|
- # cancellation callbacks
|
|
|
- for scope in self._child_scopes:
|
|
|
- if not scope._shield and not scope.cancel_called:
|
|
|
- should_retry = scope._deliver_cancellation(origin) or should_retry
|
|
|
-
|
|
|
- # Schedule another callback if there are still tasks left
|
|
|
- if origin is self:
|
|
|
- if should_retry:
|
|
|
- self._cancel_handle = get_running_loop().call_soon(
|
|
|
- self._deliver_cancellation, origin
|
|
|
- )
|
|
|
- else:
|
|
|
- self._cancel_handle = None
|
|
|
-
|
|
|
- return should_retry
|
|
|
-
|
|
|
- def _restart_cancellation_in_parent(self) -> None:
|
|
|
- """
|
|
|
- Restart the cancellation effort in the closest directly cancelled parent scope.
|
|
|
-
|
|
|
- """
|
|
|
- scope = self._parent_scope
|
|
|
- while scope is not None:
|
|
|
- if scope._cancel_called:
|
|
|
- if scope._cancel_handle is None:
|
|
|
- scope._deliver_cancellation(scope)
|
|
|
-
|
|
|
- break
|
|
|
-
|
|
|
- # No point in looking beyond any shielded scope
|
|
|
- if scope._shield:
|
|
|
- break
|
|
|
-
|
|
|
- scope = scope._parent_scope
|
|
|
-
|
|
|
- def cancel(self) -> None:
|
|
|
- if not self._cancel_called:
|
|
|
- if self._timeout_handle:
|
|
|
- self._timeout_handle.cancel()
|
|
|
- self._timeout_handle = None
|
|
|
-
|
|
|
- self._cancel_called = True
|
|
|
- if self._host_task is not None:
|
|
|
- self._deliver_cancellation(self)
|
|
|
-
|
|
|
- @property
|
|
|
- def deadline(self) -> float:
|
|
|
- return self._deadline
|
|
|
-
|
|
|
- @deadline.setter
|
|
|
- def deadline(self, value: float) -> None:
|
|
|
- self._deadline = float(value)
|
|
|
- if self._timeout_handle is not None:
|
|
|
- self._timeout_handle.cancel()
|
|
|
- self._timeout_handle = None
|
|
|
-
|
|
|
- if self._active and not self._cancel_called:
|
|
|
- self._timeout()
|
|
|
-
|
|
|
- @property
|
|
|
- def cancel_called(self) -> bool:
|
|
|
- return self._cancel_called
|
|
|
-
|
|
|
- @property
|
|
|
- def cancelled_caught(self) -> bool:
|
|
|
- return self._cancelled_caught
|
|
|
-
|
|
|
- @property
|
|
|
- def shield(self) -> bool:
|
|
|
- return self._shield
|
|
|
-
|
|
|
- @shield.setter
|
|
|
- def shield(self, value: bool) -> None:
|
|
|
- if self._shield != value:
|
|
|
- self._shield = value
|
|
|
- if not value:
|
|
|
- self._restart_cancellation_in_parent()
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Task states
|
|
|
-#
|
|
|
-
|
|
|
-
|
|
|
-class TaskState:
|
|
|
- """
|
|
|
- Encapsulates auxiliary task information that cannot be added to the Task instance
|
|
|
- itself because there are no guarantees about its implementation.
|
|
|
- """
|
|
|
-
|
|
|
- __slots__ = "parent_id", "cancel_scope", "__weakref__"
|
|
|
-
|
|
|
- def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
|
|
|
- self.parent_id = parent_id
|
|
|
- self.cancel_scope = cancel_scope
|
|
|
-
|
|
|
-
|
|
|
-_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Task groups
|
|
|
-#
|
|
|
-
|
|
|
-
|
|
|
-class _AsyncioTaskStatus(abc.TaskStatus):
|
|
|
- def __init__(self, future: asyncio.Future, parent_id: int):
|
|
|
- self._future = future
|
|
|
- self._parent_id = parent_id
|
|
|
-
|
|
|
- def started(self, value: T_contra | None = None) -> None:
|
|
|
- try:
|
|
|
- self._future.set_result(value)
|
|
|
- except asyncio.InvalidStateError:
|
|
|
- if not self._future.cancelled():
|
|
|
- raise RuntimeError(
|
|
|
- "called 'started' twice on the same task status"
|
|
|
- ) from None
|
|
|
-
|
|
|
- task = cast(asyncio.Task, current_task())
|
|
|
- _task_states[task].parent_id = self._parent_id
|
|
|
-
|
|
|
-
|
|
|
-if sys.version_info >= (3, 12):
|
|
|
- _eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__
|
|
|
-else:
|
|
|
- _eager_task_factory_code = None
|
|
|
-
|
|
|
-
|
|
|
-class TaskGroup(abc.TaskGroup):
|
|
|
- def __init__(self) -> None:
|
|
|
- self.cancel_scope: CancelScope = CancelScope()
|
|
|
- self._active = False
|
|
|
- self._exceptions: list[BaseException] = []
|
|
|
- self._tasks: set[asyncio.Task] = set()
|
|
|
- self._on_completed_fut: asyncio.Future[None] | None = None
|
|
|
-
|
|
|
- async def __aenter__(self) -> TaskGroup:
|
|
|
- self.cancel_scope.__enter__()
|
|
|
- self._active = True
|
|
|
- return self
|
|
|
-
|
|
|
- async def __aexit__(
|
|
|
- self,
|
|
|
- exc_type: type[BaseException] | None,
|
|
|
- exc_val: BaseException | None,
|
|
|
- exc_tb: TracebackType | None,
|
|
|
- ) -> bool | None:
|
|
|
- try:
|
|
|
- if exc_val is not None:
|
|
|
- self.cancel_scope.cancel()
|
|
|
- if not isinstance(exc_val, CancelledError):
|
|
|
- self._exceptions.append(exc_val)
|
|
|
-
|
|
|
- loop = get_running_loop()
|
|
|
- try:
|
|
|
- if self._tasks:
|
|
|
- with CancelScope() as wait_scope:
|
|
|
- while self._tasks:
|
|
|
- self._on_completed_fut = loop.create_future()
|
|
|
-
|
|
|
- try:
|
|
|
- await self._on_completed_fut
|
|
|
- except CancelledError as exc:
|
|
|
- # Shield the scope against further cancellation attempts,
|
|
|
- # as they're not productive (#695)
|
|
|
- wait_scope.shield = True
|
|
|
- self.cancel_scope.cancel()
|
|
|
-
|
|
|
- # Set exc_val from the cancellation exception if it was
|
|
|
- # previously unset. However, we should not replace a native
|
|
|
- # cancellation exception with one raise by a cancel scope.
|
|
|
- if exc_val is None or (
|
|
|
- isinstance(exc_val, CancelledError)
|
|
|
- and not is_anyio_cancellation(exc)
|
|
|
- ):
|
|
|
- exc_val = exc
|
|
|
-
|
|
|
- self._on_completed_fut = None
|
|
|
- else:
|
|
|
- # If there are no child tasks to wait on, run at least one checkpoint
|
|
|
- # anyway
|
|
|
- await AsyncIOBackend.cancel_shielded_checkpoint()
|
|
|
-
|
|
|
- self._active = False
|
|
|
- if self._exceptions:
|
|
|
- raise BaseExceptionGroup(
|
|
|
- "unhandled errors in a TaskGroup", self._exceptions
|
|
|
- )
|
|
|
- elif exc_val:
|
|
|
- raise exc_val
|
|
|
- except BaseException as exc:
|
|
|
- if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
|
|
|
- return True
|
|
|
-
|
|
|
- raise
|
|
|
-
|
|
|
- return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
|
|
- finally:
|
|
|
- del exc_val, exc_tb, self._exceptions
|
|
|
-
|
|
|
- def _spawn(
|
|
|
- self,
|
|
|
- func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
|
|
|
- args: tuple[Unpack[PosArgsT]],
|
|
|
- name: object,
|
|
|
- task_status_future: asyncio.Future | None = None,
|
|
|
- ) -> asyncio.Task:
|
|
|
- def task_done(_task: asyncio.Task) -> None:
|
|
|
- task_state = _task_states[_task]
|
|
|
- assert task_state.cancel_scope is not None
|
|
|
- assert _task in task_state.cancel_scope._tasks
|
|
|
- task_state.cancel_scope._tasks.remove(_task)
|
|
|
- self._tasks.remove(task)
|
|
|
- del _task_states[_task]
|
|
|
-
|
|
|
- if self._on_completed_fut is not None and not self._tasks:
|
|
|
- try:
|
|
|
- self._on_completed_fut.set_result(None)
|
|
|
- except asyncio.InvalidStateError:
|
|
|
- pass
|
|
|
-
|
|
|
- try:
|
|
|
- exc = _task.exception()
|
|
|
- except CancelledError as e:
|
|
|
- while isinstance(e.__context__, CancelledError):
|
|
|
- e = e.__context__
|
|
|
-
|
|
|
- exc = e
|
|
|
-
|
|
|
- if exc is not None:
|
|
|
- # The future can only be in the cancelled state if the host task was
|
|
|
- # cancelled, so return immediately instead of adding one more
|
|
|
- # CancelledError to the exceptions list
|
|
|
- if task_status_future is not None and task_status_future.cancelled():
|
|
|
- return
|
|
|
-
|
|
|
- if task_status_future is None or task_status_future.done():
|
|
|
- if not isinstance(exc, CancelledError):
|
|
|
- self._exceptions.append(exc)
|
|
|
-
|
|
|
- if not self.cancel_scope._effectively_cancelled:
|
|
|
- self.cancel_scope.cancel()
|
|
|
- else:
|
|
|
- task_status_future.set_exception(exc)
|
|
|
- elif task_status_future is not None and not task_status_future.done():
|
|
|
- task_status_future.set_exception(
|
|
|
- RuntimeError("Child exited without calling task_status.started()")
|
|
|
- )
|
|
|
-
|
|
|
- if not self._active:
|
|
|
- raise RuntimeError(
|
|
|
- "This task group is not active; no new tasks can be started."
|
|
|
- )
|
|
|
-
|
|
|
- kwargs = {}
|
|
|
- if task_status_future:
|
|
|
- parent_id = id(current_task())
|
|
|
- kwargs["task_status"] = _AsyncioTaskStatus(
|
|
|
- task_status_future, id(self.cancel_scope._host_task)
|
|
|
- )
|
|
|
- else:
|
|
|
- parent_id = id(self.cancel_scope._host_task)
|
|
|
-
|
|
|
- coro = func(*args, **kwargs)
|
|
|
- if not iscoroutine(coro):
|
|
|
- prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
|
|
|
- raise TypeError(
|
|
|
- f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
|
|
|
- f"the return value ({coro!r}) is not a coroutine object"
|
|
|
- )
|
|
|
-
|
|
|
- name = get_callable_name(func) if name is None else str(name)
|
|
|
- loop = asyncio.get_running_loop()
|
|
|
- if (
|
|
|
- (factory := loop.get_task_factory())
|
|
|
- and getattr(factory, "__code__", None) is _eager_task_factory_code
|
|
|
- and (closure := getattr(factory, "__closure__", None))
|
|
|
- ):
|
|
|
- custom_task_constructor = closure[0].cell_contents
|
|
|
- task = custom_task_constructor(coro, loop=loop, name=name)
|
|
|
- else:
|
|
|
- task = create_task(coro, name=name)
|
|
|
-
|
|
|
- # Make the spawned task inherit the task group's cancel scope
|
|
|
- _task_states[task] = TaskState(
|
|
|
- parent_id=parent_id, cancel_scope=self.cancel_scope
|
|
|
- )
|
|
|
- self.cancel_scope._tasks.add(task)
|
|
|
- self._tasks.add(task)
|
|
|
- task.add_done_callback(task_done)
|
|
|
- return task
|
|
|
-
|
|
|
- def start_soon(
|
|
|
- self,
|
|
|
- func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
|
|
|
- *args: Unpack[PosArgsT],
|
|
|
- name: object = None,
|
|
|
- ) -> None:
|
|
|
- self._spawn(func, args, name)
|
|
|
-
|
|
|
- async def start(
|
|
|
- self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
|
|
|
- ) -> Any:
|
|
|
- future: asyncio.Future = asyncio.Future()
|
|
|
- task = self._spawn(func, args, name, future)
|
|
|
-
|
|
|
- # If the task raises an exception after sending a start value without a switch
|
|
|
- # point between, the task group is cancelled and this method never proceeds to
|
|
|
- # process the completed future. That's why we have to have a shielded cancel
|
|
|
- # scope here.
|
|
|
- try:
|
|
|
- return await future
|
|
|
- except CancelledError:
|
|
|
- # Cancel the task and wait for it to exit before returning
|
|
|
- task.cancel()
|
|
|
- with CancelScope(shield=True), suppress(CancelledError):
|
|
|
- await task
|
|
|
-
|
|
|
- raise
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Threads
|
|
|
-#
|
|
|
-
|
|
|
-_Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]]
|
|
|
-
|
|
|
-
|
|
|
-class WorkerThread(Thread):
|
|
|
- MAX_IDLE_TIME = 10 # seconds
|
|
|
-
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- root_task: asyncio.Task,
|
|
|
- workers: set[WorkerThread],
|
|
|
- idle_workers: deque[WorkerThread],
|
|
|
- ):
|
|
|
- super().__init__(name="AnyIO worker thread")
|
|
|
- self.root_task = root_task
|
|
|
- self.workers = workers
|
|
|
- self.idle_workers = idle_workers
|
|
|
- self.loop = root_task._loop
|
|
|
- self.queue: Queue[
|
|
|
- tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
|
|
|
- ] = Queue(2)
|
|
|
- self.idle_since = AsyncIOBackend.current_time()
|
|
|
- self.stopping = False
|
|
|
-
|
|
|
- def _report_result(
|
|
|
- self, future: asyncio.Future, result: Any, exc: BaseException | None
|
|
|
- ) -> None:
|
|
|
- self.idle_since = AsyncIOBackend.current_time()
|
|
|
- if not self.stopping:
|
|
|
- self.idle_workers.append(self)
|
|
|
-
|
|
|
- if not future.cancelled():
|
|
|
- if exc is not None:
|
|
|
- if isinstance(exc, StopIteration):
|
|
|
- new_exc = RuntimeError("coroutine raised StopIteration")
|
|
|
- new_exc.__cause__ = exc
|
|
|
- exc = new_exc
|
|
|
-
|
|
|
- future.set_exception(exc)
|
|
|
- else:
|
|
|
- future.set_result(result)
|
|
|
-
|
|
|
- def run(self) -> None:
|
|
|
- with claim_worker_thread(AsyncIOBackend, self.loop):
|
|
|
- while True:
|
|
|
- item = self.queue.get()
|
|
|
- if item is None:
|
|
|
- # Shutdown command received
|
|
|
- return
|
|
|
-
|
|
|
- context, func, args, future, cancel_scope = item
|
|
|
- if not future.cancelled():
|
|
|
- result = None
|
|
|
- exception: BaseException | None = None
|
|
|
- threadlocals.current_cancel_scope = cancel_scope
|
|
|
- try:
|
|
|
- result = context.run(func, *args)
|
|
|
- except BaseException as exc:
|
|
|
- exception = exc
|
|
|
- finally:
|
|
|
- del threadlocals.current_cancel_scope
|
|
|
-
|
|
|
- if not self.loop.is_closed():
|
|
|
- self.loop.call_soon_threadsafe(
|
|
|
- self._report_result, future, result, exception
|
|
|
- )
|
|
|
-
|
|
|
- self.queue.task_done()
|
|
|
-
|
|
|
- def stop(self, f: asyncio.Task | None = None) -> None:
|
|
|
- self.stopping = True
|
|
|
- self.queue.put_nowait(None)
|
|
|
- self.workers.discard(self)
|
|
|
- try:
|
|
|
- self.idle_workers.remove(self)
|
|
|
- except ValueError:
|
|
|
- pass
|
|
|
-
|
|
|
-
|
|
|
-_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
|
|
|
- "_threadpool_idle_workers"
|
|
|
-)
|
|
|
-_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
|
|
|
-
|
|
|
-
|
|
|
-class BlockingPortal(abc.BlockingPortal):
|
|
|
- def __new__(cls) -> BlockingPortal:
|
|
|
- return object.__new__(cls)
|
|
|
-
|
|
|
- def __init__(self) -> None:
|
|
|
- super().__init__()
|
|
|
- self._loop = get_running_loop()
|
|
|
-
|
|
|
- def _spawn_task_from_thread(
|
|
|
- self,
|
|
|
- func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
|
|
|
- args: tuple[Unpack[PosArgsT]],
|
|
|
- kwargs: dict[str, Any],
|
|
|
- name: object,
|
|
|
- future: Future[T_Retval],
|
|
|
- ) -> None:
|
|
|
- AsyncIOBackend.run_sync_from_thread(
|
|
|
- partial(self._task_group.start_soon, name=name),
|
|
|
- (self._call_func, func, args, kwargs, future),
|
|
|
- self._loop,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Subprocesses
|
|
|
-#
|
|
|
-
|
|
|
-
|
|
|
-@dataclass(eq=False)
|
|
|
-class StreamReaderWrapper(abc.ByteReceiveStream):
|
|
|
- _stream: asyncio.StreamReader
|
|
|
-
|
|
|
- async def receive(self, max_bytes: int = 65536) -> bytes:
|
|
|
- data = await self._stream.read(max_bytes)
|
|
|
- if data:
|
|
|
- return data
|
|
|
- else:
|
|
|
- raise EndOfStream
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- self._stream.set_exception(ClosedResourceError())
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
-
|
|
|
-
|
|
|
-@dataclass(eq=False)
|
|
|
-class StreamWriterWrapper(abc.ByteSendStream):
|
|
|
- _stream: asyncio.StreamWriter
|
|
|
-
|
|
|
- async def send(self, item: bytes) -> None:
|
|
|
- self._stream.write(item)
|
|
|
- await self._stream.drain()
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- self._stream.close()
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
-
|
|
|
-
|
|
|
-@dataclass(eq=False)
|
|
|
-class Process(abc.Process):
|
|
|
- _process: asyncio.subprocess.Process
|
|
|
- _stdin: StreamWriterWrapper | None
|
|
|
- _stdout: StreamReaderWrapper | None
|
|
|
- _stderr: StreamReaderWrapper | None
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- with CancelScope(shield=True) as scope:
|
|
|
- if self._stdin:
|
|
|
- await self._stdin.aclose()
|
|
|
- if self._stdout:
|
|
|
- await self._stdout.aclose()
|
|
|
- if self._stderr:
|
|
|
- await self._stderr.aclose()
|
|
|
-
|
|
|
- scope.shield = False
|
|
|
- try:
|
|
|
- await self.wait()
|
|
|
- except BaseException:
|
|
|
- scope.shield = True
|
|
|
- self.kill()
|
|
|
- await self.wait()
|
|
|
- raise
|
|
|
-
|
|
|
- async def wait(self) -> int:
|
|
|
- return await self._process.wait()
|
|
|
-
|
|
|
- def terminate(self) -> None:
|
|
|
- self._process.terminate()
|
|
|
-
|
|
|
- def kill(self) -> None:
|
|
|
- self._process.kill()
|
|
|
-
|
|
|
- def send_signal(self, signal: int) -> None:
|
|
|
- self._process.send_signal(signal)
|
|
|
-
|
|
|
- @property
|
|
|
- def pid(self) -> int:
|
|
|
- return self._process.pid
|
|
|
-
|
|
|
- @property
|
|
|
- def returncode(self) -> int | None:
|
|
|
- return self._process.returncode
|
|
|
-
|
|
|
- @property
|
|
|
- def stdin(self) -> abc.ByteSendStream | None:
|
|
|
- return self._stdin
|
|
|
-
|
|
|
- @property
|
|
|
- def stdout(self) -> abc.ByteReceiveStream | None:
|
|
|
- return self._stdout
|
|
|
-
|
|
|
- @property
|
|
|
- def stderr(self) -> abc.ByteReceiveStream | None:
|
|
|
- return self._stderr
|
|
|
-
|
|
|
-
|
|
|
-def _forcibly_shutdown_process_pool_on_exit(
|
|
|
- workers: set[Process], _task: object
|
|
|
-) -> None:
|
|
|
- """
|
|
|
- Forcibly shuts down worker processes belonging to this event loop."""
|
|
|
- child_watcher: asyncio.AbstractChildWatcher | None = None
|
|
|
- if sys.version_info < (3, 12):
|
|
|
- try:
|
|
|
- child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
|
|
|
- except NotImplementedError:
|
|
|
- pass
|
|
|
-
|
|
|
- # Close as much as possible (w/o async/await) to avoid warnings
|
|
|
- for process in workers:
|
|
|
- if process.returncode is None:
|
|
|
- continue
|
|
|
-
|
|
|
- process._stdin._stream._transport.close() # type: ignore[union-attr]
|
|
|
- process._stdout._stream._transport.close() # type: ignore[union-attr]
|
|
|
- process._stderr._stream._transport.close() # type: ignore[union-attr]
|
|
|
- process.kill()
|
|
|
- if child_watcher:
|
|
|
- child_watcher.remove_child_handler(process.pid)
|
|
|
-
|
|
|
-
|
|
|
-async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
|
|
|
- """
|
|
|
- Shuts down worker processes belonging to this event loop.
|
|
|
-
|
|
|
- NOTE: this only works when the event loop was started using asyncio.run() or
|
|
|
- anyio.run().
|
|
|
-
|
|
|
- """
|
|
|
- process: abc.Process
|
|
|
- try:
|
|
|
- await sleep(math.inf)
|
|
|
- except asyncio.CancelledError:
|
|
|
- for process in workers:
|
|
|
- if process.returncode is None:
|
|
|
- process.kill()
|
|
|
-
|
|
|
- for process in workers:
|
|
|
- await process.aclose()
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Sockets and networking
|
|
|
-#
|
|
|
-
|
|
|
-
|
|
|
-class StreamProtocol(asyncio.Protocol):
|
|
|
- read_queue: deque[bytes]
|
|
|
- read_event: asyncio.Event
|
|
|
- write_event: asyncio.Event
|
|
|
- exception: Exception | None = None
|
|
|
- is_at_eof: bool = False
|
|
|
-
|
|
|
- def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
|
- self.read_queue = deque()
|
|
|
- self.read_event = asyncio.Event()
|
|
|
- self.write_event = asyncio.Event()
|
|
|
- self.write_event.set()
|
|
|
- cast(asyncio.Transport, transport).set_write_buffer_limits(0)
|
|
|
-
|
|
|
- def connection_lost(self, exc: Exception | None) -> None:
|
|
|
- if exc:
|
|
|
- self.exception = BrokenResourceError()
|
|
|
- self.exception.__cause__ = exc
|
|
|
-
|
|
|
- self.read_event.set()
|
|
|
- self.write_event.set()
|
|
|
-
|
|
|
- def data_received(self, data: bytes) -> None:
|
|
|
- # ProactorEventloop sometimes sends bytearray instead of bytes
|
|
|
- self.read_queue.append(bytes(data))
|
|
|
- self.read_event.set()
|
|
|
-
|
|
|
- def eof_received(self) -> bool | None:
|
|
|
- self.is_at_eof = True
|
|
|
- self.read_event.set()
|
|
|
- return True
|
|
|
-
|
|
|
- def pause_writing(self) -> None:
|
|
|
- self.write_event = asyncio.Event()
|
|
|
-
|
|
|
- def resume_writing(self) -> None:
|
|
|
- self.write_event.set()
|
|
|
-
|
|
|
-
|
|
|
-class DatagramProtocol(asyncio.DatagramProtocol):
|
|
|
- read_queue: deque[tuple[bytes, IPSockAddrType]]
|
|
|
- read_event: asyncio.Event
|
|
|
- write_event: asyncio.Event
|
|
|
- exception: Exception | None = None
|
|
|
-
|
|
|
- def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
|
- self.read_queue = deque(maxlen=100) # arbitrary value
|
|
|
- self.read_event = asyncio.Event()
|
|
|
- self.write_event = asyncio.Event()
|
|
|
- self.write_event.set()
|
|
|
-
|
|
|
- def connection_lost(self, exc: Exception | None) -> None:
|
|
|
- self.read_event.set()
|
|
|
- self.write_event.set()
|
|
|
-
|
|
|
- def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
|
|
|
- addr = convert_ipv6_sockaddr(addr)
|
|
|
- self.read_queue.append((data, addr))
|
|
|
- self.read_event.set()
|
|
|
-
|
|
|
- def error_received(self, exc: Exception) -> None:
|
|
|
- self.exception = exc
|
|
|
-
|
|
|
- def pause_writing(self) -> None:
|
|
|
- self.write_event.clear()
|
|
|
-
|
|
|
- def resume_writing(self) -> None:
|
|
|
- self.write_event.set()
|
|
|
-
|
|
|
-
|
|
|
-class SocketStream(abc.SocketStream):
|
|
|
- def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
|
|
|
- self._transport = transport
|
|
|
- self._protocol = protocol
|
|
|
- self._receive_guard = ResourceGuard("reading from")
|
|
|
- self._send_guard = ResourceGuard("writing to")
|
|
|
- self._closed = False
|
|
|
-
|
|
|
- @property
|
|
|
- def _raw_socket(self) -> socket.socket:
|
|
|
- return self._transport.get_extra_info("socket")
|
|
|
-
|
|
|
- async def receive(self, max_bytes: int = 65536) -> bytes:
|
|
|
- with self._receive_guard:
|
|
|
- if (
|
|
|
- not self._protocol.read_event.is_set()
|
|
|
- and not self._transport.is_closing()
|
|
|
- and not self._protocol.is_at_eof
|
|
|
- ):
|
|
|
- self._transport.resume_reading()
|
|
|
- await self._protocol.read_event.wait()
|
|
|
- self._transport.pause_reading()
|
|
|
- else:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
-
|
|
|
- try:
|
|
|
- chunk = self._protocol.read_queue.popleft()
|
|
|
- except IndexError:
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError from None
|
|
|
- elif self._protocol.exception:
|
|
|
- raise self._protocol.exception from None
|
|
|
- else:
|
|
|
- raise EndOfStream from None
|
|
|
-
|
|
|
- if len(chunk) > max_bytes:
|
|
|
- # Split the oversized chunk
|
|
|
- chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
|
|
|
- self._protocol.read_queue.appendleft(leftover)
|
|
|
-
|
|
|
- # If the read queue is empty, clear the flag so that the next call will
|
|
|
- # block until data is available
|
|
|
- if not self._protocol.read_queue:
|
|
|
- self._protocol.read_event.clear()
|
|
|
-
|
|
|
- return chunk
|
|
|
-
|
|
|
- async def send(self, item: bytes) -> None:
|
|
|
- with self._send_guard:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
-
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError
|
|
|
- elif self._protocol.exception is not None:
|
|
|
- raise self._protocol.exception
|
|
|
-
|
|
|
- try:
|
|
|
- self._transport.write(item)
|
|
|
- except RuntimeError as exc:
|
|
|
- if self._transport.is_closing():
|
|
|
- raise BrokenResourceError from exc
|
|
|
- else:
|
|
|
- raise
|
|
|
-
|
|
|
- await self._protocol.write_event.wait()
|
|
|
-
|
|
|
- async def send_eof(self) -> None:
|
|
|
- try:
|
|
|
- self._transport.write_eof()
|
|
|
- except OSError:
|
|
|
- pass
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- if not self._transport.is_closing():
|
|
|
- self._closed = True
|
|
|
- try:
|
|
|
- self._transport.write_eof()
|
|
|
- except OSError:
|
|
|
- pass
|
|
|
-
|
|
|
- self._transport.close()
|
|
|
- await sleep(0)
|
|
|
- self._transport.abort()
|
|
|
-
|
|
|
-
|
|
|
-class _RawSocketMixin:
|
|
|
- _receive_future: asyncio.Future | None = None
|
|
|
- _send_future: asyncio.Future | None = None
|
|
|
- _closing = False
|
|
|
-
|
|
|
- def __init__(self, raw_socket: socket.socket):
|
|
|
- self.__raw_socket = raw_socket
|
|
|
- self._receive_guard = ResourceGuard("reading from")
|
|
|
- self._send_guard = ResourceGuard("writing to")
|
|
|
-
|
|
|
- @property
|
|
|
- def _raw_socket(self) -> socket.socket:
|
|
|
- return self.__raw_socket
|
|
|
-
|
|
|
- def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
|
|
|
- def callback(f: object) -> None:
|
|
|
- del self._receive_future
|
|
|
- loop.remove_reader(self.__raw_socket)
|
|
|
-
|
|
|
- f = self._receive_future = asyncio.Future()
|
|
|
- loop.add_reader(self.__raw_socket, f.set_result, None)
|
|
|
- f.add_done_callback(callback)
|
|
|
- return f
|
|
|
-
|
|
|
- def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
|
|
|
- def callback(f: object) -> None:
|
|
|
- del self._send_future
|
|
|
- loop.remove_writer(self.__raw_socket)
|
|
|
-
|
|
|
- f = self._send_future = asyncio.Future()
|
|
|
- loop.add_writer(self.__raw_socket, f.set_result, None)
|
|
|
- f.add_done_callback(callback)
|
|
|
- return f
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- if not self._closing:
|
|
|
- self._closing = True
|
|
|
- if self.__raw_socket.fileno() != -1:
|
|
|
- self.__raw_socket.close()
|
|
|
-
|
|
|
- if self._receive_future:
|
|
|
- self._receive_future.set_result(None)
|
|
|
- if self._send_future:
|
|
|
- self._send_future.set_result(None)
|
|
|
-
|
|
|
-
|
|
|
-class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
|
|
|
- async def send_eof(self) -> None:
|
|
|
- with self._send_guard:
|
|
|
- self._raw_socket.shutdown(socket.SHUT_WR)
|
|
|
-
|
|
|
- async def receive(self, max_bytes: int = 65536) -> bytes:
|
|
|
- loop = get_running_loop()
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._receive_guard:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- data = self._raw_socket.recv(max_bytes)
|
|
|
- except BlockingIOError:
|
|
|
- await self._wait_until_readable(loop)
|
|
|
- except OSError as exc:
|
|
|
- if self._closing:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
- else:
|
|
|
- if not data:
|
|
|
- raise EndOfStream
|
|
|
-
|
|
|
- return data
|
|
|
-
|
|
|
- async def send(self, item: bytes) -> None:
|
|
|
- loop = get_running_loop()
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._send_guard:
|
|
|
- view = memoryview(item)
|
|
|
- while view:
|
|
|
- try:
|
|
|
- bytes_sent = self._raw_socket.send(view)
|
|
|
- except BlockingIOError:
|
|
|
- await self._wait_until_writable(loop)
|
|
|
- except OSError as exc:
|
|
|
- if self._closing:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
- else:
|
|
|
- view = view[bytes_sent:]
|
|
|
-
|
|
|
- async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
|
|
|
- if not isinstance(msglen, int) or msglen < 0:
|
|
|
- raise ValueError("msglen must be a non-negative integer")
|
|
|
- if not isinstance(maxfds, int) or maxfds < 1:
|
|
|
- raise ValueError("maxfds must be a positive integer")
|
|
|
-
|
|
|
- loop = get_running_loop()
|
|
|
- fds = array.array("i")
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._receive_guard:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- message, ancdata, flags, addr = self._raw_socket.recvmsg(
|
|
|
- msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
|
|
|
- )
|
|
|
- except BlockingIOError:
|
|
|
- await self._wait_until_readable(loop)
|
|
|
- except OSError as exc:
|
|
|
- if self._closing:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
- else:
|
|
|
- if not message and not ancdata:
|
|
|
- raise EndOfStream
|
|
|
-
|
|
|
- break
|
|
|
-
|
|
|
- for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
|
|
- if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
|
|
|
- raise RuntimeError(
|
|
|
- f"Received unexpected ancillary data; message = {message!r}, "
|
|
|
- f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
|
|
|
- )
|
|
|
-
|
|
|
- fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
|
|
|
-
|
|
|
- return message, list(fds)
|
|
|
-
|
|
|
- async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
|
|
|
- if not message:
|
|
|
- raise ValueError("message must not be empty")
|
|
|
- if not fds:
|
|
|
- raise ValueError("fds must not be empty")
|
|
|
-
|
|
|
- loop = get_running_loop()
|
|
|
- filenos: list[int] = []
|
|
|
- for fd in fds:
|
|
|
- if isinstance(fd, int):
|
|
|
- filenos.append(fd)
|
|
|
- elif isinstance(fd, IOBase):
|
|
|
- filenos.append(fd.fileno())
|
|
|
-
|
|
|
- fdarray = array.array("i", filenos)
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._send_guard:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- # The ignore can be removed after mypy picks up
|
|
|
- # https://github.com/python/typeshed/pull/5545
|
|
|
- self._raw_socket.sendmsg(
|
|
|
- [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
|
|
|
- )
|
|
|
- break
|
|
|
- except BlockingIOError:
|
|
|
- await self._wait_until_writable(loop)
|
|
|
- except OSError as exc:
|
|
|
- if self._closing:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
-
|
|
|
-
|
|
|
-class TCPSocketListener(abc.SocketListener):
|
|
|
- _accept_scope: CancelScope | None = None
|
|
|
- _closed = False
|
|
|
-
|
|
|
- def __init__(self, raw_socket: socket.socket):
|
|
|
- self.__raw_socket = raw_socket
|
|
|
- self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
|
|
|
- self._accept_guard = ResourceGuard("accepting connections from")
|
|
|
-
|
|
|
- @property
|
|
|
- def _raw_socket(self) -> socket.socket:
|
|
|
- return self.__raw_socket
|
|
|
-
|
|
|
- async def accept(self) -> abc.SocketStream:
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError
|
|
|
-
|
|
|
- with self._accept_guard:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with CancelScope() as self._accept_scope:
|
|
|
- try:
|
|
|
- client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
|
|
|
- except asyncio.CancelledError:
|
|
|
- # Workaround for https://bugs.python.org/issue41317
|
|
|
- try:
|
|
|
- self._loop.remove_reader(self._raw_socket)
|
|
|
- except (ValueError, NotImplementedError):
|
|
|
- pass
|
|
|
-
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError from None
|
|
|
-
|
|
|
- raise
|
|
|
- finally:
|
|
|
- self._accept_scope = None
|
|
|
-
|
|
|
- client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
|
- transport, protocol = await self._loop.connect_accepted_socket(
|
|
|
- StreamProtocol, client_sock
|
|
|
- )
|
|
|
- return SocketStream(transport, protocol)
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- if self._closed:
|
|
|
- return
|
|
|
-
|
|
|
- self._closed = True
|
|
|
- if self._accept_scope:
|
|
|
- # Workaround for https://bugs.python.org/issue41317
|
|
|
- try:
|
|
|
- self._loop.remove_reader(self._raw_socket)
|
|
|
- except (ValueError, NotImplementedError):
|
|
|
- pass
|
|
|
-
|
|
|
- self._accept_scope.cancel()
|
|
|
- await sleep(0)
|
|
|
-
|
|
|
- self._raw_socket.close()
|
|
|
-
|
|
|
-
|
|
|
-class UNIXSocketListener(abc.SocketListener):
|
|
|
- def __init__(self, raw_socket: socket.socket):
|
|
|
- self.__raw_socket = raw_socket
|
|
|
- self._loop = get_running_loop()
|
|
|
- self._accept_guard = ResourceGuard("accepting connections from")
|
|
|
- self._closed = False
|
|
|
-
|
|
|
- async def accept(self) -> abc.SocketStream:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._accept_guard:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- client_sock, _ = self.__raw_socket.accept()
|
|
|
- client_sock.setblocking(False)
|
|
|
- return UNIXSocketStream(client_sock)
|
|
|
- except BlockingIOError:
|
|
|
- f: asyncio.Future = asyncio.Future()
|
|
|
- self._loop.add_reader(self.__raw_socket, f.set_result, None)
|
|
|
- f.add_done_callback(
|
|
|
- lambda _: self._loop.remove_reader(self.__raw_socket)
|
|
|
- )
|
|
|
- await f
|
|
|
- except OSError as exc:
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- self._closed = True
|
|
|
- self.__raw_socket.close()
|
|
|
-
|
|
|
- @property
|
|
|
- def _raw_socket(self) -> socket.socket:
|
|
|
- return self.__raw_socket
|
|
|
-
|
|
|
-
|
|
|
-class UDPSocket(abc.UDPSocket):
|
|
|
- def __init__(
|
|
|
- self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
|
|
|
- ):
|
|
|
- self._transport = transport
|
|
|
- self._protocol = protocol
|
|
|
- self._receive_guard = ResourceGuard("reading from")
|
|
|
- self._send_guard = ResourceGuard("writing to")
|
|
|
- self._closed = False
|
|
|
-
|
|
|
- @property
|
|
|
- def _raw_socket(self) -> socket.socket:
|
|
|
- return self._transport.get_extra_info("socket")
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- if not self._transport.is_closing():
|
|
|
- self._closed = True
|
|
|
- self._transport.close()
|
|
|
-
|
|
|
- async def receive(self) -> tuple[bytes, IPSockAddrType]:
|
|
|
- with self._receive_guard:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
-
|
|
|
- # If the buffer is empty, ask for more data
|
|
|
- if not self._protocol.read_queue and not self._transport.is_closing():
|
|
|
- self._protocol.read_event.clear()
|
|
|
- await self._protocol.read_event.wait()
|
|
|
-
|
|
|
- try:
|
|
|
- return self._protocol.read_queue.popleft()
|
|
|
- except IndexError:
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from None
|
|
|
-
|
|
|
- async def send(self, item: UDPPacketType) -> None:
|
|
|
- with self._send_guard:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- await self._protocol.write_event.wait()
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError
|
|
|
- elif self._transport.is_closing():
|
|
|
- raise BrokenResourceError
|
|
|
- else:
|
|
|
- self._transport.sendto(*item)
|
|
|
-
|
|
|
-
|
|
|
-class ConnectedUDPSocket(abc.ConnectedUDPSocket):
|
|
|
- def __init__(
|
|
|
- self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
|
|
|
- ):
|
|
|
- self._transport = transport
|
|
|
- self._protocol = protocol
|
|
|
- self._receive_guard = ResourceGuard("reading from")
|
|
|
- self._send_guard = ResourceGuard("writing to")
|
|
|
- self._closed = False
|
|
|
-
|
|
|
- @property
|
|
|
- def _raw_socket(self) -> socket.socket:
|
|
|
- return self._transport.get_extra_info("socket")
|
|
|
-
|
|
|
- async def aclose(self) -> None:
|
|
|
- if not self._transport.is_closing():
|
|
|
- self._closed = True
|
|
|
- self._transport.close()
|
|
|
-
|
|
|
- async def receive(self) -> bytes:
|
|
|
- with self._receive_guard:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
-
|
|
|
- # If the buffer is empty, ask for more data
|
|
|
- if not self._protocol.read_queue and not self._transport.is_closing():
|
|
|
- self._protocol.read_event.clear()
|
|
|
- await self._protocol.read_event.wait()
|
|
|
-
|
|
|
- try:
|
|
|
- packet = self._protocol.read_queue.popleft()
|
|
|
- except IndexError:
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from None
|
|
|
-
|
|
|
- return packet[0]
|
|
|
-
|
|
|
- async def send(self, item: bytes) -> None:
|
|
|
- with self._send_guard:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- await self._protocol.write_event.wait()
|
|
|
- if self._closed:
|
|
|
- raise ClosedResourceError
|
|
|
- elif self._transport.is_closing():
|
|
|
- raise BrokenResourceError
|
|
|
- else:
|
|
|
- self._transport.sendto(item)
|
|
|
-
|
|
|
-
|
|
|
-class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
|
|
|
- async def receive(self) -> UNIXDatagramPacketType:
|
|
|
- loop = get_running_loop()
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._receive_guard:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- data = self._raw_socket.recvfrom(65536)
|
|
|
- except BlockingIOError:
|
|
|
- await self._wait_until_readable(loop)
|
|
|
- except OSError as exc:
|
|
|
- if self._closing:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
- else:
|
|
|
- return data
|
|
|
-
|
|
|
- async def send(self, item: UNIXDatagramPacketType) -> None:
|
|
|
- loop = get_running_loop()
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._send_guard:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- self._raw_socket.sendto(*item)
|
|
|
- except BlockingIOError:
|
|
|
- await self._wait_until_writable(loop)
|
|
|
- except OSError as exc:
|
|
|
- if self._closing:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
- else:
|
|
|
- return
|
|
|
-
|
|
|
-
|
|
|
-class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
|
|
|
- async def receive(self) -> bytes:
|
|
|
- loop = get_running_loop()
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._receive_guard:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- data = self._raw_socket.recv(65536)
|
|
|
- except BlockingIOError:
|
|
|
- await self._wait_until_readable(loop)
|
|
|
- except OSError as exc:
|
|
|
- if self._closing:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
- else:
|
|
|
- return data
|
|
|
-
|
|
|
- async def send(self, item: bytes) -> None:
|
|
|
- loop = get_running_loop()
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- with self._send_guard:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- self._raw_socket.send(item)
|
|
|
- except BlockingIOError:
|
|
|
- await self._wait_until_writable(loop)
|
|
|
- except OSError as exc:
|
|
|
- if self._closing:
|
|
|
- raise ClosedResourceError from None
|
|
|
- else:
|
|
|
- raise BrokenResourceError from exc
|
|
|
- else:
|
|
|
- return
|
|
|
-
|
|
|
-
|
|
|
-_read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events")
|
|
|
-_write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events")
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Synchronization
|
|
|
-#
|
|
|
-
|
|
|
-
|
|
|
-class Event(BaseEvent):
|
|
|
- def __new__(cls) -> Event:
|
|
|
- return object.__new__(cls)
|
|
|
-
|
|
|
- def __init__(self) -> None:
|
|
|
- self._event = asyncio.Event()
|
|
|
-
|
|
|
- def set(self) -> None:
|
|
|
- self._event.set()
|
|
|
-
|
|
|
- def is_set(self) -> bool:
|
|
|
- return self._event.is_set()
|
|
|
-
|
|
|
- async def wait(self) -> None:
|
|
|
- if self.is_set():
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- else:
|
|
|
- await self._event.wait()
|
|
|
-
|
|
|
- def statistics(self) -> EventStatistics:
|
|
|
- return EventStatistics(len(self._event._waiters))
|
|
|
-
|
|
|
-
|
|
|
-class Lock(BaseLock):
|
|
|
- def __new__(cls, *, fast_acquire: bool = False) -> Lock:
|
|
|
- return object.__new__(cls)
|
|
|
-
|
|
|
- def __init__(self, *, fast_acquire: bool = False) -> None:
|
|
|
- self._fast_acquire = fast_acquire
|
|
|
- self._owner_task: asyncio.Task | None = None
|
|
|
- self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
|
|
|
-
|
|
|
- async def acquire(self) -> None:
|
|
|
- task = cast(asyncio.Task, current_task())
|
|
|
- if self._owner_task is None and not self._waiters:
|
|
|
- await AsyncIOBackend.checkpoint_if_cancelled()
|
|
|
- self._owner_task = task
|
|
|
-
|
|
|
- # Unless on the "fast path", yield control of the event loop so that other
|
|
|
- # tasks can run too
|
|
|
- if not self._fast_acquire:
|
|
|
- try:
|
|
|
- await AsyncIOBackend.cancel_shielded_checkpoint()
|
|
|
- except CancelledError:
|
|
|
- self.release()
|
|
|
- raise
|
|
|
-
|
|
|
- return
|
|
|
-
|
|
|
- if self._owner_task == task:
|
|
|
- raise RuntimeError("Attempted to acquire an already held Lock")
|
|
|
-
|
|
|
- fut: asyncio.Future[None] = asyncio.Future()
|
|
|
- item = task, fut
|
|
|
- self._waiters.append(item)
|
|
|
- try:
|
|
|
- await fut
|
|
|
- except CancelledError:
|
|
|
- self._waiters.remove(item)
|
|
|
- if self._owner_task is task:
|
|
|
- self.release()
|
|
|
-
|
|
|
- raise
|
|
|
-
|
|
|
- self._waiters.remove(item)
|
|
|
-
|
|
|
- def acquire_nowait(self) -> None:
|
|
|
- task = cast(asyncio.Task, current_task())
|
|
|
- if self._owner_task is None and not self._waiters:
|
|
|
- self._owner_task = task
|
|
|
- return
|
|
|
-
|
|
|
- if self._owner_task is task:
|
|
|
- raise RuntimeError("Attempted to acquire an already held Lock")
|
|
|
-
|
|
|
- raise WouldBlock
|
|
|
-
|
|
|
- def locked(self) -> bool:
|
|
|
- return self._owner_task is not None
|
|
|
-
|
|
|
- def release(self) -> None:
|
|
|
- if self._owner_task != current_task():
|
|
|
- raise RuntimeError("The current task is not holding this lock")
|
|
|
-
|
|
|
- for task, fut in self._waiters:
|
|
|
- if not fut.cancelled():
|
|
|
- self._owner_task = task
|
|
|
- fut.set_result(None)
|
|
|
- return
|
|
|
-
|
|
|
- self._owner_task = None
|
|
|
-
|
|
|
- def statistics(self) -> LockStatistics:
|
|
|
- task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
|
|
|
- return LockStatistics(self.locked(), task_info, len(self._waiters))
|
|
|
-
|
|
|
-
|
|
|
-class Semaphore(BaseSemaphore):
|
|
|
- def __new__(
|
|
|
- cls,
|
|
|
- initial_value: int,
|
|
|
- *,
|
|
|
- max_value: int | None = None,
|
|
|
- fast_acquire: bool = False,
|
|
|
- ) -> Semaphore:
|
|
|
- return object.__new__(cls)
|
|
|
-
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- initial_value: int,
|
|
|
- *,
|
|
|
- max_value: int | None = None,
|
|
|
- fast_acquire: bool = False,
|
|
|
- ):
|
|
|
- super().__init__(initial_value, max_value=max_value)
|
|
|
- self._value = initial_value
|
|
|
- self._max_value = max_value
|
|
|
- self._fast_acquire = fast_acquire
|
|
|
- self._waiters: deque[asyncio.Future[None]] = deque()
|
|
|
-
|
|
|
- async def acquire(self) -> None:
|
|
|
- if self._value > 0 and not self._waiters:
|
|
|
- await AsyncIOBackend.checkpoint_if_cancelled()
|
|
|
- self._value -= 1
|
|
|
-
|
|
|
- # Unless on the "fast path", yield control of the event loop so that other
|
|
|
- # tasks can run too
|
|
|
- if not self._fast_acquire:
|
|
|
- try:
|
|
|
- await AsyncIOBackend.cancel_shielded_checkpoint()
|
|
|
- except CancelledError:
|
|
|
- self.release()
|
|
|
- raise
|
|
|
-
|
|
|
- return
|
|
|
-
|
|
|
- fut: asyncio.Future[None] = asyncio.Future()
|
|
|
- self._waiters.append(fut)
|
|
|
- try:
|
|
|
- await fut
|
|
|
- except CancelledError:
|
|
|
- try:
|
|
|
- self._waiters.remove(fut)
|
|
|
- except ValueError:
|
|
|
- self.release()
|
|
|
-
|
|
|
- raise
|
|
|
-
|
|
|
- def acquire_nowait(self) -> None:
|
|
|
- if self._value == 0:
|
|
|
- raise WouldBlock
|
|
|
-
|
|
|
- self._value -= 1
|
|
|
-
|
|
|
- def release(self) -> None:
|
|
|
- if self._max_value is not None and self._value == self._max_value:
|
|
|
- raise ValueError("semaphore released too many times")
|
|
|
-
|
|
|
- for fut in self._waiters:
|
|
|
- if not fut.cancelled():
|
|
|
- fut.set_result(None)
|
|
|
- self._waiters.remove(fut)
|
|
|
- return
|
|
|
-
|
|
|
- self._value += 1
|
|
|
-
|
|
|
- @property
|
|
|
- def value(self) -> int:
|
|
|
- return self._value
|
|
|
-
|
|
|
- @property
|
|
|
- def max_value(self) -> int | None:
|
|
|
- return self._max_value
|
|
|
-
|
|
|
- def statistics(self) -> SemaphoreStatistics:
|
|
|
- return SemaphoreStatistics(len(self._waiters))
|
|
|
-
|
|
|
-
|
|
|
-class CapacityLimiter(BaseCapacityLimiter):
|
|
|
- _total_tokens: float = 0
|
|
|
-
|
|
|
- def __new__(cls, total_tokens: float) -> CapacityLimiter:
|
|
|
- return object.__new__(cls)
|
|
|
-
|
|
|
- def __init__(self, total_tokens: float):
|
|
|
- self._borrowers: set[Any] = set()
|
|
|
- self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
|
|
|
- self.total_tokens = total_tokens
|
|
|
-
|
|
|
- async def __aenter__(self) -> None:
|
|
|
- await self.acquire()
|
|
|
-
|
|
|
- async def __aexit__(
|
|
|
- self,
|
|
|
- exc_type: type[BaseException] | None,
|
|
|
- exc_val: BaseException | None,
|
|
|
- exc_tb: TracebackType | None,
|
|
|
- ) -> None:
|
|
|
- self.release()
|
|
|
-
|
|
|
- @property
|
|
|
- def total_tokens(self) -> float:
|
|
|
- return self._total_tokens
|
|
|
-
|
|
|
- @total_tokens.setter
|
|
|
- def total_tokens(self, value: float) -> None:
|
|
|
- if not isinstance(value, int) and not math.isinf(value):
|
|
|
- raise TypeError("total_tokens must be an int or math.inf")
|
|
|
- if value < 1:
|
|
|
- raise ValueError("total_tokens must be >= 1")
|
|
|
-
|
|
|
- waiters_to_notify = max(value - self._total_tokens, 0)
|
|
|
- self._total_tokens = value
|
|
|
-
|
|
|
- # Notify waiting tasks that they have acquired the limiter
|
|
|
- while self._wait_queue and waiters_to_notify:
|
|
|
- event = self._wait_queue.popitem(last=False)[1]
|
|
|
- event.set()
|
|
|
- waiters_to_notify -= 1
|
|
|
-
|
|
|
- @property
|
|
|
- def borrowed_tokens(self) -> int:
|
|
|
- return len(self._borrowers)
|
|
|
-
|
|
|
- @property
|
|
|
- def available_tokens(self) -> float:
|
|
|
- return self._total_tokens - len(self._borrowers)
|
|
|
-
|
|
|
- def acquire_nowait(self) -> None:
|
|
|
- self.acquire_on_behalf_of_nowait(current_task())
|
|
|
-
|
|
|
- def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
|
|
|
- if borrower in self._borrowers:
|
|
|
- raise RuntimeError(
|
|
|
- "this borrower is already holding one of this CapacityLimiter's "
|
|
|
- "tokens"
|
|
|
- )
|
|
|
-
|
|
|
- if self._wait_queue or len(self._borrowers) >= self._total_tokens:
|
|
|
- raise WouldBlock
|
|
|
-
|
|
|
- self._borrowers.add(borrower)
|
|
|
-
|
|
|
- async def acquire(self) -> None:
|
|
|
- return await self.acquire_on_behalf_of(current_task())
|
|
|
-
|
|
|
- async def acquire_on_behalf_of(self, borrower: object) -> None:
|
|
|
- await AsyncIOBackend.checkpoint_if_cancelled()
|
|
|
- try:
|
|
|
- self.acquire_on_behalf_of_nowait(borrower)
|
|
|
- except WouldBlock:
|
|
|
- event = asyncio.Event()
|
|
|
- self._wait_queue[borrower] = event
|
|
|
- try:
|
|
|
- await event.wait()
|
|
|
- except BaseException:
|
|
|
- self._wait_queue.pop(borrower, None)
|
|
|
- raise
|
|
|
-
|
|
|
- self._borrowers.add(borrower)
|
|
|
- else:
|
|
|
- try:
|
|
|
- await AsyncIOBackend.cancel_shielded_checkpoint()
|
|
|
- except BaseException:
|
|
|
- self.release()
|
|
|
- raise
|
|
|
-
|
|
|
- def release(self) -> None:
|
|
|
- self.release_on_behalf_of(current_task())
|
|
|
-
|
|
|
- def release_on_behalf_of(self, borrower: object) -> None:
|
|
|
- try:
|
|
|
- self._borrowers.remove(borrower)
|
|
|
- except KeyError:
|
|
|
- raise RuntimeError(
|
|
|
- "this borrower isn't holding any of this CapacityLimiter's tokens"
|
|
|
- ) from None
|
|
|
-
|
|
|
- # Notify the next task in line if this limiter has free capacity now
|
|
|
- if self._wait_queue and len(self._borrowers) < self._total_tokens:
|
|
|
- event = self._wait_queue.popitem(last=False)[1]
|
|
|
- event.set()
|
|
|
-
|
|
|
- def statistics(self) -> CapacityLimiterStatistics:
|
|
|
- return CapacityLimiterStatistics(
|
|
|
- self.borrowed_tokens,
|
|
|
- self.total_tokens,
|
|
|
- tuple(self._borrowers),
|
|
|
- len(self._wait_queue),
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Operating system signals
|
|
|
-#
|
|
|
-
|
|
|
-
|
|
|
-class _SignalReceiver:
|
|
|
- def __init__(self, signals: tuple[Signals, ...]):
|
|
|
- self._signals = signals
|
|
|
- self._loop = get_running_loop()
|
|
|
- self._signal_queue: deque[Signals] = deque()
|
|
|
- self._future: asyncio.Future = asyncio.Future()
|
|
|
- self._handled_signals: set[Signals] = set()
|
|
|
-
|
|
|
- def _deliver(self, signum: Signals) -> None:
|
|
|
- self._signal_queue.append(signum)
|
|
|
- if not self._future.done():
|
|
|
- self._future.set_result(None)
|
|
|
-
|
|
|
- def __enter__(self) -> _SignalReceiver:
|
|
|
- for sig in set(self._signals):
|
|
|
- self._loop.add_signal_handler(sig, self._deliver, sig)
|
|
|
- self._handled_signals.add(sig)
|
|
|
-
|
|
|
- return self
|
|
|
-
|
|
|
- def __exit__(
|
|
|
- self,
|
|
|
- exc_type: type[BaseException] | None,
|
|
|
- exc_val: BaseException | None,
|
|
|
- exc_tb: TracebackType | None,
|
|
|
- ) -> None:
|
|
|
- for sig in self._handled_signals:
|
|
|
- self._loop.remove_signal_handler(sig)
|
|
|
-
|
|
|
- def __aiter__(self) -> _SignalReceiver:
|
|
|
- return self
|
|
|
-
|
|
|
- async def __anext__(self) -> Signals:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
- if not self._signal_queue:
|
|
|
- self._future = asyncio.Future()
|
|
|
- await self._future
|
|
|
-
|
|
|
- return self._signal_queue.popleft()
|
|
|
-
|
|
|
-
|
|
|
-#
|
|
|
-# Testing and debugging
|
|
|
-#
|
|
|
-
|
|
|
-
|
|
|
-class AsyncIOTaskInfo(TaskInfo):
|
|
|
- def __init__(self, task: asyncio.Task):
|
|
|
- task_state = _task_states.get(task)
|
|
|
- if task_state is None:
|
|
|
- parent_id = None
|
|
|
- else:
|
|
|
- parent_id = task_state.parent_id
|
|
|
-
|
|
|
- coro = task.get_coro()
|
|
|
- assert coro is not None, "created TaskInfo from a completed Task"
|
|
|
- super().__init__(id(task), parent_id, task.get_name(), coro)
|
|
|
- self._task = weakref.ref(task)
|
|
|
-
|
|
|
- def has_pending_cancellation(self) -> bool:
|
|
|
- if not (task := self._task()):
|
|
|
- # If the task isn't around anymore, it won't have a pending cancellation
|
|
|
- return False
|
|
|
-
|
|
|
- if task._must_cancel: # type: ignore[attr-defined]
|
|
|
- return True
|
|
|
- elif (
|
|
|
- isinstance(task._fut_waiter, asyncio.Future) # type: ignore[attr-defined]
|
|
|
- and task._fut_waiter.cancelled() # type: ignore[attr-defined]
|
|
|
- ):
|
|
|
- return True
|
|
|
-
|
|
|
- if task_state := _task_states.get(task):
|
|
|
- if cancel_scope := task_state.cancel_scope:
|
|
|
- return cancel_scope._effectively_cancelled
|
|
|
-
|
|
|
- return False
|
|
|
-
|
|
|
-
|
|
|
-class TestRunner(abc.TestRunner):
|
|
|
- _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
|
|
|
-
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- *,
|
|
|
- debug: bool | None = None,
|
|
|
- use_uvloop: bool = False,
|
|
|
- loop_factory: Callable[[], AbstractEventLoop] | None = None,
|
|
|
- ) -> None:
|
|
|
- if use_uvloop and loop_factory is None:
|
|
|
- import uvloop
|
|
|
-
|
|
|
- loop_factory = uvloop.new_event_loop
|
|
|
-
|
|
|
- self._runner = Runner(debug=debug, loop_factory=loop_factory)
|
|
|
- self._exceptions: list[BaseException] = []
|
|
|
- self._runner_task: asyncio.Task | None = None
|
|
|
-
|
|
|
- def __enter__(self) -> TestRunner:
|
|
|
- self._runner.__enter__()
|
|
|
- self.get_loop().set_exception_handler(self._exception_handler)
|
|
|
- return self
|
|
|
-
|
|
|
- def __exit__(
|
|
|
- self,
|
|
|
- exc_type: type[BaseException] | None,
|
|
|
- exc_val: BaseException | None,
|
|
|
- exc_tb: TracebackType | None,
|
|
|
- ) -> None:
|
|
|
- self._runner.__exit__(exc_type, exc_val, exc_tb)
|
|
|
-
|
|
|
- def get_loop(self) -> AbstractEventLoop:
|
|
|
- return self._runner.get_loop()
|
|
|
-
|
|
|
- def _exception_handler(
|
|
|
- self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
|
|
|
- ) -> None:
|
|
|
- if isinstance(context.get("exception"), Exception):
|
|
|
- self._exceptions.append(context["exception"])
|
|
|
- else:
|
|
|
- loop.default_exception_handler(context)
|
|
|
-
|
|
|
- def _raise_async_exceptions(self) -> None:
|
|
|
- # Re-raise any exceptions raised in asynchronous callbacks
|
|
|
- if self._exceptions:
|
|
|
- exceptions, self._exceptions = self._exceptions, []
|
|
|
- if len(exceptions) == 1:
|
|
|
- raise exceptions[0]
|
|
|
- elif exceptions:
|
|
|
- raise BaseExceptionGroup(
|
|
|
- "Multiple exceptions occurred in asynchronous callbacks", exceptions
|
|
|
- )
|
|
|
-
|
|
|
- async def _run_tests_and_fixtures(
|
|
|
- self,
|
|
|
- receive_stream: MemoryObjectReceiveStream[
|
|
|
- tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
|
|
|
- ],
|
|
|
- ) -> None:
|
|
|
- from _pytest.outcomes import OutcomeException
|
|
|
-
|
|
|
- with receive_stream, self._send_stream:
|
|
|
- async for coro, future in receive_stream:
|
|
|
- try:
|
|
|
- retval = await coro
|
|
|
- except CancelledError as exc:
|
|
|
- if not future.cancelled():
|
|
|
- future.cancel(*exc.args)
|
|
|
-
|
|
|
- raise
|
|
|
- except BaseException as exc:
|
|
|
- if not future.cancelled():
|
|
|
- future.set_exception(exc)
|
|
|
-
|
|
|
- if not isinstance(exc, (Exception, OutcomeException)):
|
|
|
- raise
|
|
|
- else:
|
|
|
- if not future.cancelled():
|
|
|
- future.set_result(retval)
|
|
|
-
|
|
|
- async def _call_in_runner_task(
|
|
|
- self,
|
|
|
- func: Callable[P, Awaitable[T_Retval]],
|
|
|
- *args: P.args,
|
|
|
- **kwargs: P.kwargs,
|
|
|
- ) -> T_Retval:
|
|
|
- if not self._runner_task:
|
|
|
- self._send_stream, receive_stream = create_memory_object_stream[
|
|
|
- tuple[Awaitable[Any], asyncio.Future]
|
|
|
- ](1)
|
|
|
- self._runner_task = self.get_loop().create_task(
|
|
|
- self._run_tests_and_fixtures(receive_stream)
|
|
|
- )
|
|
|
-
|
|
|
- coro = func(*args, **kwargs)
|
|
|
- future: asyncio.Future[T_Retval] = self.get_loop().create_future()
|
|
|
- self._send_stream.send_nowait((coro, future))
|
|
|
- return await future
|
|
|
-
|
|
|
- def run_asyncgen_fixture(
|
|
|
- self,
|
|
|
- fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
|
|
|
- kwargs: dict[str, Any],
|
|
|
- ) -> Iterable[T_Retval]:
|
|
|
- asyncgen = fixture_func(**kwargs)
|
|
|
- fixturevalue: T_Retval = self.get_loop().run_until_complete(
|
|
|
- self._call_in_runner_task(asyncgen.asend, None)
|
|
|
- )
|
|
|
- self._raise_async_exceptions()
|
|
|
-
|
|
|
- yield fixturevalue
|
|
|
-
|
|
|
- try:
|
|
|
- self.get_loop().run_until_complete(
|
|
|
- self._call_in_runner_task(asyncgen.asend, None)
|
|
|
- )
|
|
|
- except StopAsyncIteration:
|
|
|
- self._raise_async_exceptions()
|
|
|
- else:
|
|
|
- self.get_loop().run_until_complete(asyncgen.aclose())
|
|
|
- raise RuntimeError("Async generator fixture did not stop")
|
|
|
-
|
|
|
- def run_fixture(
|
|
|
- self,
|
|
|
- fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
|
|
|
- kwargs: dict[str, Any],
|
|
|
- ) -> T_Retval:
|
|
|
- retval = self.get_loop().run_until_complete(
|
|
|
- self._call_in_runner_task(fixture_func, **kwargs)
|
|
|
- )
|
|
|
- self._raise_async_exceptions()
|
|
|
- return retval
|
|
|
-
|
|
|
- def run_test(
|
|
|
- self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
|
|
|
- ) -> None:
|
|
|
- try:
|
|
|
- self.get_loop().run_until_complete(
|
|
|
- self._call_in_runner_task(test_func, **kwargs)
|
|
|
- )
|
|
|
- except Exception as exc:
|
|
|
- self._exceptions.append(exc)
|
|
|
-
|
|
|
- self._raise_async_exceptions()
|
|
|
-
|
|
|
-
|
|
|
-class AsyncIOBackend(AsyncBackend):
|
|
|
- @classmethod
|
|
|
- def run(
|
|
|
- cls,
|
|
|
- func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
|
|
|
- args: tuple[Unpack[PosArgsT]],
|
|
|
- kwargs: dict[str, Any],
|
|
|
- options: dict[str, Any],
|
|
|
- ) -> T_Retval:
|
|
|
- @wraps(func)
|
|
|
- async def wrapper() -> T_Retval:
|
|
|
- task = cast(asyncio.Task, current_task())
|
|
|
- task.set_name(get_callable_name(func))
|
|
|
- _task_states[task] = TaskState(None, None)
|
|
|
-
|
|
|
- try:
|
|
|
- return await func(*args)
|
|
|
- finally:
|
|
|
- del _task_states[task]
|
|
|
-
|
|
|
- debug = options.get("debug", None)
|
|
|
- loop_factory = options.get("loop_factory", None)
|
|
|
- if loop_factory is None and options.get("use_uvloop", False):
|
|
|
- import uvloop
|
|
|
-
|
|
|
- loop_factory = uvloop.new_event_loop
|
|
|
-
|
|
|
- with Runner(debug=debug, loop_factory=loop_factory) as runner:
|
|
|
- return runner.run(wrapper())
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def current_token(cls) -> object:
|
|
|
- return get_running_loop()
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def current_time(cls) -> float:
|
|
|
- return get_running_loop().time()
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def cancelled_exception_class(cls) -> type[BaseException]:
|
|
|
- return CancelledError
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def checkpoint(cls) -> None:
|
|
|
- await sleep(0)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def checkpoint_if_cancelled(cls) -> None:
|
|
|
- task = current_task()
|
|
|
- if task is None:
|
|
|
- return
|
|
|
-
|
|
|
- try:
|
|
|
- cancel_scope = _task_states[task].cancel_scope
|
|
|
- except KeyError:
|
|
|
- return
|
|
|
-
|
|
|
- while cancel_scope:
|
|
|
- if cancel_scope.cancel_called:
|
|
|
- await sleep(0)
|
|
|
- elif cancel_scope.shield:
|
|
|
- break
|
|
|
- else:
|
|
|
- cancel_scope = cancel_scope._parent_scope
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def cancel_shielded_checkpoint(cls) -> None:
|
|
|
- with CancelScope(shield=True):
|
|
|
- await sleep(0)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def sleep(cls, delay: float) -> None:
|
|
|
- await sleep(delay)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_cancel_scope(
|
|
|
- cls, *, deadline: float = math.inf, shield: bool = False
|
|
|
- ) -> CancelScope:
|
|
|
- return CancelScope(deadline=deadline, shield=shield)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def current_effective_deadline(cls) -> float:
|
|
|
- if (task := current_task()) is None:
|
|
|
- return math.inf
|
|
|
-
|
|
|
- try:
|
|
|
- cancel_scope = _task_states[task].cancel_scope
|
|
|
- except KeyError:
|
|
|
- return math.inf
|
|
|
-
|
|
|
- deadline = math.inf
|
|
|
- while cancel_scope:
|
|
|
- deadline = min(deadline, cancel_scope.deadline)
|
|
|
- if cancel_scope._cancel_called:
|
|
|
- deadline = -math.inf
|
|
|
- break
|
|
|
- elif cancel_scope.shield:
|
|
|
- break
|
|
|
- else:
|
|
|
- cancel_scope = cancel_scope._parent_scope
|
|
|
-
|
|
|
- return deadline
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_task_group(cls) -> abc.TaskGroup:
|
|
|
- return TaskGroup()
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_event(cls) -> abc.Event:
|
|
|
- return Event()
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
|
|
|
- return Lock(fast_acquire=fast_acquire)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_semaphore(
|
|
|
- cls,
|
|
|
- initial_value: int,
|
|
|
- *,
|
|
|
- max_value: int | None = None,
|
|
|
- fast_acquire: bool = False,
|
|
|
- ) -> abc.Semaphore:
|
|
|
- return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
|
|
|
- return CapacityLimiter(total_tokens)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def run_sync_in_worker_thread( # type: ignore[return]
|
|
|
- cls,
|
|
|
- func: Callable[[Unpack[PosArgsT]], T_Retval],
|
|
|
- args: tuple[Unpack[PosArgsT]],
|
|
|
- abandon_on_cancel: bool = False,
|
|
|
- limiter: abc.CapacityLimiter | None = None,
|
|
|
- ) -> T_Retval:
|
|
|
- await cls.checkpoint()
|
|
|
-
|
|
|
- # If this is the first run in this event loop thread, set up the necessary
|
|
|
- # variables
|
|
|
- try:
|
|
|
- idle_workers = _threadpool_idle_workers.get()
|
|
|
- workers = _threadpool_workers.get()
|
|
|
- except LookupError:
|
|
|
- idle_workers = deque()
|
|
|
- workers = set()
|
|
|
- _threadpool_idle_workers.set(idle_workers)
|
|
|
- _threadpool_workers.set(workers)
|
|
|
-
|
|
|
- async with limiter or cls.current_default_thread_limiter():
|
|
|
- with CancelScope(shield=not abandon_on_cancel) as scope:
|
|
|
- future = asyncio.Future[T_Retval]()
|
|
|
- root_task = find_root_task()
|
|
|
- if not idle_workers:
|
|
|
- worker = WorkerThread(root_task, workers, idle_workers)
|
|
|
- worker.start()
|
|
|
- workers.add(worker)
|
|
|
- root_task.add_done_callback(worker.stop)
|
|
|
- else:
|
|
|
- worker = idle_workers.pop()
|
|
|
-
|
|
|
- # Prune any other workers that have been idle for MAX_IDLE_TIME
|
|
|
- # seconds or longer
|
|
|
- now = cls.current_time()
|
|
|
- while idle_workers:
|
|
|
- if (
|
|
|
- now - idle_workers[0].idle_since
|
|
|
- < WorkerThread.MAX_IDLE_TIME
|
|
|
- ):
|
|
|
- break
|
|
|
-
|
|
|
- expired_worker = idle_workers.popleft()
|
|
|
- expired_worker.root_task.remove_done_callback(
|
|
|
- expired_worker.stop
|
|
|
- )
|
|
|
- expired_worker.stop()
|
|
|
-
|
|
|
- context = copy_context()
|
|
|
- context.run(sniffio.current_async_library_cvar.set, None)
|
|
|
- if abandon_on_cancel or scope._parent_scope is None:
|
|
|
- worker_scope = scope
|
|
|
- else:
|
|
|
- worker_scope = scope._parent_scope
|
|
|
-
|
|
|
- worker.queue.put_nowait((context, func, args, future, worker_scope))
|
|
|
- return await future
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def check_cancelled(cls) -> None:
|
|
|
- scope: CancelScope | None = threadlocals.current_cancel_scope
|
|
|
- while scope is not None:
|
|
|
- if scope.cancel_called:
|
|
|
- raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
|
|
|
-
|
|
|
- if scope.shield:
|
|
|
- return
|
|
|
-
|
|
|
- scope = scope._parent_scope
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def run_async_from_thread(
|
|
|
- cls,
|
|
|
- func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
|
|
|
- args: tuple[Unpack[PosArgsT]],
|
|
|
- token: object,
|
|
|
- ) -> T_Retval:
|
|
|
- async def task_wrapper(scope: CancelScope) -> T_Retval:
|
|
|
- __tracebackhide__ = True
|
|
|
- task = cast(asyncio.Task, current_task())
|
|
|
- _task_states[task] = TaskState(None, scope)
|
|
|
- scope._tasks.add(task)
|
|
|
- try:
|
|
|
- return await func(*args)
|
|
|
- except CancelledError as exc:
|
|
|
- raise concurrent.futures.CancelledError(str(exc)) from None
|
|
|
- finally:
|
|
|
- scope._tasks.discard(task)
|
|
|
-
|
|
|
- loop = cast(AbstractEventLoop, token)
|
|
|
- context = copy_context()
|
|
|
- context.run(sniffio.current_async_library_cvar.set, "asyncio")
|
|
|
- wrapper = task_wrapper(threadlocals.current_cancel_scope)
|
|
|
- f: concurrent.futures.Future[T_Retval] = context.run(
|
|
|
- asyncio.run_coroutine_threadsafe, wrapper, loop
|
|
|
- )
|
|
|
- return f.result()
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def run_sync_from_thread(
|
|
|
- cls,
|
|
|
- func: Callable[[Unpack[PosArgsT]], T_Retval],
|
|
|
- args: tuple[Unpack[PosArgsT]],
|
|
|
- token: object,
|
|
|
- ) -> T_Retval:
|
|
|
- @wraps(func)
|
|
|
- def wrapper() -> None:
|
|
|
- try:
|
|
|
- sniffio.current_async_library_cvar.set("asyncio")
|
|
|
- f.set_result(func(*args))
|
|
|
- except BaseException as exc:
|
|
|
- f.set_exception(exc)
|
|
|
- if not isinstance(exc, Exception):
|
|
|
- raise
|
|
|
-
|
|
|
- f: concurrent.futures.Future[T_Retval] = Future()
|
|
|
- loop = cast(AbstractEventLoop, token)
|
|
|
- loop.call_soon_threadsafe(wrapper)
|
|
|
- return f.result()
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_blocking_portal(cls) -> abc.BlockingPortal:
|
|
|
- return BlockingPortal()
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def open_process(
|
|
|
- cls,
|
|
|
- command: StrOrBytesPath | Sequence[StrOrBytesPath],
|
|
|
- *,
|
|
|
- stdin: int | IO[Any] | None,
|
|
|
- stdout: int | IO[Any] | None,
|
|
|
- stderr: int | IO[Any] | None,
|
|
|
- **kwargs: Any,
|
|
|
- ) -> Process:
|
|
|
- await cls.checkpoint()
|
|
|
- if isinstance(command, PathLike):
|
|
|
- command = os.fspath(command)
|
|
|
-
|
|
|
- if isinstance(command, (str, bytes)):
|
|
|
- process = await asyncio.create_subprocess_shell(
|
|
|
- command,
|
|
|
- stdin=stdin,
|
|
|
- stdout=stdout,
|
|
|
- stderr=stderr,
|
|
|
- **kwargs,
|
|
|
- )
|
|
|
- else:
|
|
|
- process = await asyncio.create_subprocess_exec(
|
|
|
- *command,
|
|
|
- stdin=stdin,
|
|
|
- stdout=stdout,
|
|
|
- stderr=stderr,
|
|
|
- **kwargs,
|
|
|
- )
|
|
|
-
|
|
|
- stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
|
|
|
- stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
|
|
|
- stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
|
|
|
- return Process(process, stdin_stream, stdout_stream, stderr_stream)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
|
|
|
- create_task(
|
|
|
- _shutdown_process_pool_on_exit(workers),
|
|
|
- name="AnyIO process pool shutdown task",
|
|
|
- )
|
|
|
- find_root_task().add_done_callback(
|
|
|
- partial(_forcibly_shutdown_process_pool_on_exit, workers) # type:ignore[arg-type]
|
|
|
- )
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def connect_tcp(
|
|
|
- cls, host: str, port: int, local_address: IPSockAddrType | None = None
|
|
|
- ) -> abc.SocketStream:
|
|
|
- transport, protocol = cast(
|
|
|
- tuple[asyncio.Transport, StreamProtocol],
|
|
|
- await get_running_loop().create_connection(
|
|
|
- StreamProtocol, host, port, local_addr=local_address
|
|
|
- ),
|
|
|
- )
|
|
|
- transport.pause_reading()
|
|
|
- return SocketStream(transport, protocol)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
|
|
|
- await cls.checkpoint()
|
|
|
- loop = get_running_loop()
|
|
|
- raw_socket = socket.socket(socket.AF_UNIX)
|
|
|
- raw_socket.setblocking(False)
|
|
|
- while True:
|
|
|
- try:
|
|
|
- raw_socket.connect(path)
|
|
|
- except BlockingIOError:
|
|
|
- f: asyncio.Future = asyncio.Future()
|
|
|
- loop.add_writer(raw_socket, f.set_result, None)
|
|
|
- f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
|
|
|
- await f
|
|
|
- except BaseException:
|
|
|
- raw_socket.close()
|
|
|
- raise
|
|
|
- else:
|
|
|
- return UNIXSocketStream(raw_socket)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
|
|
|
- return TCPSocketListener(sock)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
|
|
|
- return UNIXSocketListener(sock)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def create_udp_socket(
|
|
|
- cls,
|
|
|
- family: AddressFamily,
|
|
|
- local_address: IPSockAddrType | None,
|
|
|
- remote_address: IPSockAddrType | None,
|
|
|
- reuse_port: bool,
|
|
|
- ) -> UDPSocket | ConnectedUDPSocket:
|
|
|
- transport, protocol = await get_running_loop().create_datagram_endpoint(
|
|
|
- DatagramProtocol,
|
|
|
- local_addr=local_address,
|
|
|
- remote_addr=remote_address,
|
|
|
- family=family,
|
|
|
- reuse_port=reuse_port,
|
|
|
- )
|
|
|
- if protocol.exception:
|
|
|
- transport.close()
|
|
|
- raise protocol.exception
|
|
|
-
|
|
|
- if not remote_address:
|
|
|
- return UDPSocket(transport, protocol)
|
|
|
- else:
|
|
|
- return ConnectedUDPSocket(transport, protocol)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def create_unix_datagram_socket( # type: ignore[override]
|
|
|
- cls, raw_socket: socket.socket, remote_path: str | bytes | None
|
|
|
- ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
|
|
|
- await cls.checkpoint()
|
|
|
- loop = get_running_loop()
|
|
|
-
|
|
|
- if remote_path:
|
|
|
- while True:
|
|
|
- try:
|
|
|
- raw_socket.connect(remote_path)
|
|
|
- except BlockingIOError:
|
|
|
- f: asyncio.Future = asyncio.Future()
|
|
|
- loop.add_writer(raw_socket, f.set_result, None)
|
|
|
- f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
|
|
|
- await f
|
|
|
- except BaseException:
|
|
|
- raw_socket.close()
|
|
|
- raise
|
|
|
- else:
|
|
|
- return ConnectedUNIXDatagramSocket(raw_socket)
|
|
|
- else:
|
|
|
- return UNIXDatagramSocket(raw_socket)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def getaddrinfo(
|
|
|
- cls,
|
|
|
- host: bytes | str | None,
|
|
|
- port: str | int | None,
|
|
|
- *,
|
|
|
- family: int | AddressFamily = 0,
|
|
|
- type: int | SocketKind = 0,
|
|
|
- proto: int = 0,
|
|
|
- flags: int = 0,
|
|
|
- ) -> list[
|
|
|
- tuple[
|
|
|
- AddressFamily,
|
|
|
- SocketKind,
|
|
|
- int,
|
|
|
- str,
|
|
|
- tuple[str, int] | tuple[str, int, int, int],
|
|
|
- ]
|
|
|
- ]:
|
|
|
- return await get_running_loop().getaddrinfo(
|
|
|
- host, port, family=family, type=type, proto=proto, flags=flags
|
|
|
- )
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def getnameinfo(
|
|
|
- cls, sockaddr: IPSockAddrType, flags: int = 0
|
|
|
- ) -> tuple[str, str]:
|
|
|
- return await get_running_loop().getnameinfo(sockaddr, flags)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def wait_readable(cls, obj: FileDescriptorLike) -> None:
|
|
|
- await cls.checkpoint()
|
|
|
- try:
|
|
|
- read_events = _read_events.get()
|
|
|
- except LookupError:
|
|
|
- read_events = {}
|
|
|
- _read_events.set(read_events)
|
|
|
-
|
|
|
- if not isinstance(obj, int):
|
|
|
- obj = obj.fileno()
|
|
|
-
|
|
|
- if read_events.get(obj):
|
|
|
- raise BusyResourceError("reading from")
|
|
|
-
|
|
|
- loop = get_running_loop()
|
|
|
- event = asyncio.Event()
|
|
|
- try:
|
|
|
- loop.add_reader(obj, event.set)
|
|
|
- except NotImplementedError:
|
|
|
- from anyio._core._asyncio_selector_thread import get_selector
|
|
|
-
|
|
|
- selector = get_selector()
|
|
|
- selector.add_reader(obj, event.set)
|
|
|
- remove_reader = selector.remove_reader
|
|
|
- else:
|
|
|
- remove_reader = loop.remove_reader
|
|
|
-
|
|
|
- read_events[obj] = event
|
|
|
- try:
|
|
|
- await event.wait()
|
|
|
- finally:
|
|
|
- remove_reader(obj)
|
|
|
- del read_events[obj]
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def wait_writable(cls, obj: FileDescriptorLike) -> None:
|
|
|
- await cls.checkpoint()
|
|
|
- try:
|
|
|
- write_events = _write_events.get()
|
|
|
- except LookupError:
|
|
|
- write_events = {}
|
|
|
- _write_events.set(write_events)
|
|
|
-
|
|
|
- if not isinstance(obj, int):
|
|
|
- obj = obj.fileno()
|
|
|
-
|
|
|
- if write_events.get(obj):
|
|
|
- raise BusyResourceError("writing to")
|
|
|
-
|
|
|
- loop = get_running_loop()
|
|
|
- event = asyncio.Event()
|
|
|
- try:
|
|
|
- loop.add_writer(obj, event.set)
|
|
|
- except NotImplementedError:
|
|
|
- from anyio._core._asyncio_selector_thread import get_selector
|
|
|
-
|
|
|
- selector = get_selector()
|
|
|
- selector.add_writer(obj, event.set)
|
|
|
- remove_writer = selector.remove_writer
|
|
|
- else:
|
|
|
- remove_writer = loop.remove_writer
|
|
|
-
|
|
|
- write_events[obj] = event
|
|
|
- try:
|
|
|
- await event.wait()
|
|
|
- finally:
|
|
|
- del write_events[obj]
|
|
|
- remove_writer(obj)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def current_default_thread_limiter(cls) -> CapacityLimiter:
|
|
|
- try:
|
|
|
- return _default_thread_limiter.get()
|
|
|
- except LookupError:
|
|
|
- limiter = CapacityLimiter(40)
|
|
|
- _default_thread_limiter.set(limiter)
|
|
|
- return limiter
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def open_signal_receiver(
|
|
|
- cls, *signals: Signals
|
|
|
- ) -> AbstractContextManager[AsyncIterator[Signals]]:
|
|
|
- return _SignalReceiver(signals)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def get_current_task(cls) -> TaskInfo:
|
|
|
- return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type]
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def get_running_tasks(cls) -> Sequence[TaskInfo]:
|
|
|
- return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
|
|
|
-
|
|
|
- @classmethod
|
|
|
- async def wait_all_tasks_blocked(cls) -> None:
|
|
|
- await cls.checkpoint()
|
|
|
- this_task = current_task()
|
|
|
- while True:
|
|
|
- for task in all_tasks():
|
|
|
- if task is this_task:
|
|
|
- continue
|
|
|
-
|
|
|
- waiter = task._fut_waiter # type: ignore[attr-defined]
|
|
|
- if waiter is None or waiter.done():
|
|
|
- await sleep(0.1)
|
|
|
- break
|
|
|
- else:
|
|
|
- return
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
|
|
|
- return TestRunner(**options)
|
|
|
-
|
|
|
-
|
|
|
-backend_class = AsyncIOBackend
|