|
@@ -7,6 +7,7 @@ import math
|
|
|
import socket
|
|
|
import sys
|
|
|
import threading
|
|
|
+import weakref
|
|
|
from asyncio import (
|
|
|
AbstractEventLoop,
|
|
|
CancelledError,
|
|
@@ -488,7 +489,7 @@ class CancelScope(BaseCancelScope):
|
|
|
if task is not current and (task is self._host_task or _task_started(task)):
|
|
|
waiter = task._fut_waiter # type: ignore[attr-defined]
|
|
|
if not isinstance(waiter, asyncio.Future) or not waiter.done():
|
|
|
- self._cancel_calls += 1
|
|
|
+ origin._cancel_calls += 1
|
|
|
if sys.version_info >= (3, 9):
|
|
|
task.cancel(f"Cancelled by cancel scope {id(origin):x}")
|
|
|
else:
|
|
@@ -596,14 +597,14 @@ class TaskState:
|
|
|
itself because there are no guarantees about its implementation.
|
|
|
"""
|
|
|
|
|
|
- __slots__ = "parent_id", "cancel_scope"
|
|
|
+ __slots__ = "parent_id", "cancel_scope", "__weakref__"
|
|
|
|
|
|
def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
|
|
|
self.parent_id = parent_id
|
|
|
self.cancel_scope = cancel_scope
|
|
|
|
|
|
|
|
|
-_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState]
|
|
|
+_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
|
|
|
|
|
|
|
|
|
#
|
|
@@ -620,9 +621,10 @@ class _AsyncioTaskStatus(abc.TaskStatus):
|
|
|
try:
|
|
|
self._future.set_result(value)
|
|
|
except asyncio.InvalidStateError:
|
|
|
- raise RuntimeError(
|
|
|
- "called 'started' twice on the same task status"
|
|
|
- ) from None
|
|
|
+ if not self._future.cancelled():
|
|
|
+ raise RuntimeError(
|
|
|
+ "called 'started' twice on the same task status"
|
|
|
+ ) from None
|
|
|
|
|
|
task = cast(asyncio.Task, current_task())
|
|
|
_task_states[task].parent_id = self._parent_id
|
|
@@ -713,6 +715,12 @@ class TaskGroup(abc.TaskGroup):
|
|
|
exc = e
|
|
|
|
|
|
if exc is not None:
|
|
|
+ # The future can only be in the cancelled state if the host task was
|
|
|
+ # cancelled, so return immediately instead of adding one more
|
|
|
+ # CancelledError to the exceptions list
|
|
|
+ if task_status_future is not None and task_status_future.cancelled():
|
|
|
+ return
|
|
|
+
|
|
|
if task_status_future is None or task_status_future.done():
|
|
|
if not isinstance(exc, CancelledError):
|
|
|
self._exceptions.append(exc)
|
|
@@ -1047,6 +1055,7 @@ class StreamProtocol(asyncio.Protocol):
|
|
|
read_event: asyncio.Event
|
|
|
write_event: asyncio.Event
|
|
|
exception: Exception | None = None
|
|
|
+ is_at_eof: bool = False
|
|
|
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
|
self.read_queue = deque()
|
|
@@ -1068,6 +1077,7 @@ class StreamProtocol(asyncio.Protocol):
|
|
|
self.read_event.set()
|
|
|
|
|
|
def eof_received(self) -> bool | None:
|
|
|
+ self.is_at_eof = True
|
|
|
self.read_event.set()
|
|
|
return True
|
|
|
|
|
@@ -1123,15 +1133,16 @@ class SocketStream(abc.SocketStream):
|
|
|
|
|
|
async def receive(self, max_bytes: int = 65536) -> bytes:
|
|
|
with self._receive_guard:
|
|
|
- await AsyncIOBackend.checkpoint()
|
|
|
-
|
|
|
if (
|
|
|
not self._protocol.read_event.is_set()
|
|
|
and not self._transport.is_closing()
|
|
|
+ and not self._protocol.is_at_eof
|
|
|
):
|
|
|
self._transport.resume_reading()
|
|
|
await self._protocol.read_event.wait()
|
|
|
self._transport.pause_reading()
|
|
|
+ else:
|
|
|
+ await AsyncIOBackend.checkpoint()
|
|
|
|
|
|
try:
|
|
|
chunk = self._protocol.read_queue.popleft()
|
|
@@ -1651,7 +1662,7 @@ class Event(BaseEvent):
|
|
|
await self._event.wait()
|
|
|
|
|
|
def statistics(self) -> EventStatistics:
|
|
|
- return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined]
|
|
|
+ return EventStatistics(len(self._event._waiters))
|
|
|
|
|
|
|
|
|
class CapacityLimiter(BaseCapacityLimiter):
|
|
@@ -1751,7 +1762,7 @@ class CapacityLimiter(BaseCapacityLimiter):
|
|
|
self._borrowers.remove(borrower)
|
|
|
except KeyError:
|
|
|
raise RuntimeError(
|
|
|
- "this borrower isn't holding any of this CapacityLimiter's " "tokens"
|
|
|
+ "this borrower isn't holding any of this CapacityLimiter's tokens"
|
|
|
) from None
|
|
|
|
|
|
# Notify the next task in line if this limiter has free capacity now
|
|
@@ -1823,14 +1834,36 @@ class _SignalReceiver:
|
|
|
#
|
|
|
|
|
|
|
|
|
-def _create_task_info(task: asyncio.Task) -> TaskInfo:
|
|
|
- task_state = _task_states.get(task)
|
|
|
- if task_state is None:
|
|
|
- parent_id = None
|
|
|
- else:
|
|
|
- parent_id = task_state.parent_id
|
|
|
+class AsyncIOTaskInfo(TaskInfo):
|
|
|
+ def __init__(self, task: asyncio.Task):
|
|
|
+ task_state = _task_states.get(task)
|
|
|
+ if task_state is None:
|
|
|
+ parent_id = None
|
|
|
+ else:
|
|
|
+ parent_id = task_state.parent_id
|
|
|
|
|
|
- return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
|
|
|
+ super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
|
|
|
+ self._task = weakref.ref(task)
|
|
|
+
|
|
|
+ def has_pending_cancellation(self) -> bool:
|
|
|
+ if not (task := self._task()):
|
|
|
+ # If the task isn't around anymore, it won't have a pending cancellation
|
|
|
+ return False
|
|
|
+
|
|
|
+ if sys.version_info >= (3, 11):
|
|
|
+ if task.cancelling():
|
|
|
+ return True
|
|
|
+ elif (
|
|
|
+ isinstance(task._fut_waiter, asyncio.Future)
|
|
|
+ and task._fut_waiter.cancelled()
|
|
|
+ ):
|
|
|
+ return True
|
|
|
+
|
|
|
+ if task_state := _task_states.get(task):
|
|
|
+ if cancel_scope := task_state.cancel_scope:
|
|
|
+ return cancel_scope.cancel_called or cancel_scope._parent_cancelled()
|
|
|
+
|
|
|
+ return False
|
|
|
|
|
|
|
|
|
class TestRunner(abc.TestRunner):
|
|
@@ -1887,13 +1920,13 @@ class TestRunner(abc.TestRunner):
|
|
|
"Multiple exceptions occurred in asynchronous callbacks", exceptions
|
|
|
)
|
|
|
|
|
|
- @staticmethod
|
|
|
async def _run_tests_and_fixtures(
|
|
|
+ self,
|
|
|
receive_stream: MemoryObjectReceiveStream[
|
|
|
tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
|
|
|
],
|
|
|
) -> None:
|
|
|
- with receive_stream:
|
|
|
+ with receive_stream, self._send_stream:
|
|
|
async for coro, future in receive_stream:
|
|
|
try:
|
|
|
retval = await coro
|
|
@@ -1990,7 +2023,7 @@ class AsyncIOBackend(AsyncBackend):
|
|
|
finally:
|
|
|
del _task_states[task]
|
|
|
|
|
|
- debug = options.get("debug", False)
|
|
|
+ debug = options.get("debug", None)
|
|
|
loop_factory = options.get("loop_factory", None)
|
|
|
if loop_factory is None and options.get("use_uvloop", False):
|
|
|
import uvloop
|
|
@@ -2448,11 +2481,11 @@ class AsyncIOBackend(AsyncBackend):
|
|
|
|
|
|
@classmethod
|
|
|
def get_current_task(cls) -> TaskInfo:
|
|
|
- return _create_task_info(current_task()) # type: ignore[arg-type]
|
|
|
+ return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type]
|
|
|
|
|
|
@classmethod
|
|
|
- def get_running_tasks(cls) -> list[TaskInfo]:
|
|
|
- return [_create_task_info(task) for task in all_tasks() if not task.done()]
|
|
|
+ def get_running_tasks(cls) -> Sequence[TaskInfo]:
|
|
|
+ return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
|
|
|
|
|
|
@classmethod
|
|
|
async def wait_all_tasks_blocked(cls) -> None:
|