Browse Source

Intermediate changes

robot-piglet 9 months ago
parent
commit
4106eea04f

+ 1 - 1
contrib/python/anyio/.dist-info/METADATA

@@ -1,6 +1,6 @@
 Metadata-Version: 2.1
 Name: anyio
-Version: 4.3.0
+Version: 4.4.0
 Summary: High level compatibility layer for multiple asynchronous event loop implementations
 Author-email: Alex Grönholm <alex.gronholm@nextday.fi>
 License: MIT

+ 56 - 23
contrib/python/anyio/anyio/_backends/_asyncio.py

@@ -7,6 +7,7 @@ import math
 import socket
 import sys
 import threading
+import weakref
 from asyncio import (
     AbstractEventLoop,
     CancelledError,
@@ -488,7 +489,7 @@ class CancelScope(BaseCancelScope):
             if task is not current and (task is self._host_task or _task_started(task)):
                 waiter = task._fut_waiter  # type: ignore[attr-defined]
                 if not isinstance(waiter, asyncio.Future) or not waiter.done():
-                    self._cancel_calls += 1
+                    origin._cancel_calls += 1
                     if sys.version_info >= (3, 9):
                         task.cancel(f"Cancelled by cancel scope {id(origin):x}")
                     else:
@@ -596,14 +597,14 @@ class TaskState:
     itself because there are no guarantees about its implementation.
     """
 
-    __slots__ = "parent_id", "cancel_scope"
+    __slots__ = "parent_id", "cancel_scope", "__weakref__"
 
     def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
         self.parent_id = parent_id
         self.cancel_scope = cancel_scope
 
 
-_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
+_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
 
 
 #
@@ -620,9 +621,10 @@ class _AsyncioTaskStatus(abc.TaskStatus):
         try:
             self._future.set_result(value)
         except asyncio.InvalidStateError:
-            raise RuntimeError(
-                "called 'started' twice on the same task status"
-            ) from None
+            if not self._future.cancelled():
+                raise RuntimeError(
+                    "called 'started' twice on the same task status"
+                ) from None
 
         task = cast(asyncio.Task, current_task())
         _task_states[task].parent_id = self._parent_id
@@ -713,6 +715,12 @@ class TaskGroup(abc.TaskGroup):
                 exc = e
 
             if exc is not None:
+                # The future can only be in the cancelled state if the host task was
+                # cancelled, so return immediately instead of adding one more
+                # CancelledError to the exceptions list
+                if task_status_future is not None and task_status_future.cancelled():
+                    return
+
                 if task_status_future is None or task_status_future.done():
                     if not isinstance(exc, CancelledError):
                         self._exceptions.append(exc)
@@ -1047,6 +1055,7 @@ class StreamProtocol(asyncio.Protocol):
     read_event: asyncio.Event
     write_event: asyncio.Event
     exception: Exception | None = None
+    is_at_eof: bool = False
 
     def connection_made(self, transport: asyncio.BaseTransport) -> None:
         self.read_queue = deque()
@@ -1068,6 +1077,7 @@ class StreamProtocol(asyncio.Protocol):
         self.read_event.set()
 
     def eof_received(self) -> bool | None:
+        self.is_at_eof = True
         self.read_event.set()
         return True
 
@@ -1123,15 +1133,16 @@ class SocketStream(abc.SocketStream):
 
     async def receive(self, max_bytes: int = 65536) -> bytes:
         with self._receive_guard:
-            await AsyncIOBackend.checkpoint()
-
             if (
                 not self._protocol.read_event.is_set()
                 and not self._transport.is_closing()
+                and not self._protocol.is_at_eof
             ):
                 self._transport.resume_reading()
                 await self._protocol.read_event.wait()
                 self._transport.pause_reading()
+            else:
+                await AsyncIOBackend.checkpoint()
 
             try:
                 chunk = self._protocol.read_queue.popleft()
@@ -1651,7 +1662,7 @@ class Event(BaseEvent):
             await self._event.wait()
 
     def statistics(self) -> EventStatistics:
-        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
+        return EventStatistics(len(self._event._waiters))
 
 
 class CapacityLimiter(BaseCapacityLimiter):
@@ -1751,7 +1762,7 @@ class CapacityLimiter(BaseCapacityLimiter):
             self._borrowers.remove(borrower)
         except KeyError:
             raise RuntimeError(
-                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
+                "this borrower isn't holding any of this CapacityLimiter's tokens"
             ) from None
 
         # Notify the next task in line if this limiter has free capacity now
@@ -1823,14 +1834,36 @@ class _SignalReceiver:
 #
 
 
-def _create_task_info(task: asyncio.Task) -> TaskInfo:
-    task_state = _task_states.get(task)
-    if task_state is None:
-        parent_id = None
-    else:
-        parent_id = task_state.parent_id
+class AsyncIOTaskInfo(TaskInfo):
+    def __init__(self, task: asyncio.Task):
+        task_state = _task_states.get(task)
+        if task_state is None:
+            parent_id = None
+        else:
+            parent_id = task_state.parent_id
 
-    return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
+        super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
+        self._task = weakref.ref(task)
+
+    def has_pending_cancellation(self) -> bool:
+        if not (task := self._task()):
+            # If the task isn't around anymore, it won't have a pending cancellation
+            return False
+
+        if sys.version_info >= (3, 11):
+            if task.cancelling():
+                return True
+        elif (
+            isinstance(task._fut_waiter, asyncio.Future)
+            and task._fut_waiter.cancelled()
+        ):
+            return True
+
+        if task_state := _task_states.get(task):
+            if cancel_scope := task_state.cancel_scope:
+                return cancel_scope.cancel_called or cancel_scope._parent_cancelled()
+
+        return False
 
 
 class TestRunner(abc.TestRunner):
@@ -1887,13 +1920,13 @@ class TestRunner(abc.TestRunner):
                     "Multiple exceptions occurred in asynchronous callbacks", exceptions
                 )
 
-    @staticmethod
     async def _run_tests_and_fixtures(
+        self,
         receive_stream: MemoryObjectReceiveStream[
             tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
         ],
     ) -> None:
-        with receive_stream:
+        with receive_stream, self._send_stream:
             async for coro, future in receive_stream:
                 try:
                     retval = await coro
@@ -1990,7 +2023,7 @@ class AsyncIOBackend(AsyncBackend):
             finally:
                 del _task_states[task]
 
-        debug = options.get("debug", False)
+        debug = options.get("debug", None)
         loop_factory = options.get("loop_factory", None)
         if loop_factory is None and options.get("use_uvloop", False):
             import uvloop
@@ -2448,11 +2481,11 @@ class AsyncIOBackend(AsyncBackend):
 
     @classmethod
     def get_current_task(cls) -> TaskInfo:
-        return _create_task_info(current_task())  # type: ignore[arg-type]
+        return AsyncIOTaskInfo(current_task())  # type: ignore[arg-type]
 
     @classmethod
-    def get_running_tasks(cls) -> list[TaskInfo]:
-        return [_create_task_info(task) for task in all_tasks() if not task.done()]
+    def get_running_tasks(cls) -> Sequence[TaskInfo]:
+        return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
 
     @classmethod
     async def wait_all_tasks_blocked(cls) -> None:

+ 25 - 17
contrib/python/anyio/anyio/_backends/_trio.py

@@ -5,6 +5,7 @@ import math
 import socket
 import sys
 import types
+import weakref
 from collections.abc import AsyncIterator, Iterable
 from concurrent.futures import Future
 from dataclasses import dataclass
@@ -839,6 +840,24 @@ class TestRunner(abc.TestRunner):
         self._call_in_runner_task(test_func, **kwargs)
 
 
+class TrioTaskInfo(TaskInfo):
+    def __init__(self, task: trio.lowlevel.Task):
+        parent_id = None
+        if task.parent_nursery and task.parent_nursery.parent_task:
+            parent_id = id(task.parent_nursery.parent_task)
+
+        super().__init__(id(task), parent_id, task.name, task.coro)
+        self._task = weakref.proxy(task)
+
+    def has_pending_cancellation(self) -> bool:
+        try:
+            return self._task._cancel_status.effectively_cancelled
+        except ReferenceError:
+            # If the task is no longer around, it surely doesn't have a cancellation
+            # pending
+            return False
+
+
 class TrioBackend(AsyncBackend):
     @classmethod
     def run(
@@ -1040,15 +1059,13 @@ class TrioBackend(AsyncBackend):
     @overload
     async def create_unix_datagram_socket(
         cls, raw_socket: socket.socket, remote_path: None
-    ) -> abc.UNIXDatagramSocket:
-        ...
+    ) -> abc.UNIXDatagramSocket: ...
 
     @classmethod
     @overload
     async def create_unix_datagram_socket(
         cls, raw_socket: socket.socket, remote_path: str | bytes
-    ) -> abc.ConnectedUNIXDatagramSocket:
-        ...
+    ) -> abc.ConnectedUNIXDatagramSocket: ...
 
     @classmethod
     async def create_unix_datagram_socket(
@@ -1127,28 +1144,19 @@ class TrioBackend(AsyncBackend):
     @classmethod
     def get_current_task(cls) -> TaskInfo:
         task = current_task()
-
-        parent_id = None
-        if task.parent_nursery and task.parent_nursery.parent_task:
-            parent_id = id(task.parent_nursery.parent_task)
-
-        return TaskInfo(id(task), parent_id, task.name, task.coro)
+        return TrioTaskInfo(task)
 
     @classmethod
-    def get_running_tasks(cls) -> list[TaskInfo]:
+    def get_running_tasks(cls) -> Sequence[TaskInfo]:
         root_task = current_root_task()
         assert root_task
-        task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)]
+        task_infos = [TrioTaskInfo(root_task)]
         nurseries = root_task.child_nurseries
         while nurseries:
             new_nurseries: list[trio.Nursery] = []
             for nursery in nurseries:
                 for task in nursery.child_tasks:
-                    task_infos.append(
-                        TaskInfo(
-                            id(task), id(nursery.parent_task), task.name, task.coro
-                        )
-                    )
+                    task_infos.append(TrioTaskInfo(task))
                     new_nurseries.extend(task.child_nurseries)
 
             nurseries = new_nurseries

+ 9 - 6
contrib/python/anyio/anyio/_core/_eventloop.py

@@ -25,6 +25,7 @@ T_Retval = TypeVar("T_Retval")
 PosArgsT = TypeVarTuple("PosArgsT")
 
 threadlocals = threading.local()
+loaded_backends: dict[str, type[AsyncBackend]] = {}
 
 
 def run(
@@ -150,14 +151,16 @@ def claim_worker_thread(
         del threadlocals.current_token
 
 
-def get_async_backend(asynclib_name: str | None = None) -> AsyncBackend:
+def get_async_backend(asynclib_name: str | None = None) -> type[AsyncBackend]:
     if asynclib_name is None:
         asynclib_name = sniffio.current_async_library()
 
-    modulename = "anyio._backends._" + asynclib_name
+    # We use our own dict instead of sys.modules to get the already imported back-end
+    # class because the appropriate modules in sys.modules could potentially be only
+    # partially initialized
     try:
-        module = sys.modules[modulename]
+        return loaded_backends[asynclib_name]
     except KeyError:
-        module = import_module(modulename)
-
-    return getattr(module, "backend_class")
+        module = import_module(f"anyio._backends._{asynclib_name}")
+        loaded_backends[asynclib_name] = module.backend_class
+        return module.backend_class

+ 8 - 16
contrib/python/anyio/anyio/_core/_fileio.py

@@ -100,12 +100,10 @@ class AsyncFile(AsyncResource, Generic[AnyStr]):
         return await to_thread.run_sync(self._fp.readinto1, b)
 
     @overload
-    async def write(self: AsyncFile[bytes], b: ReadableBuffer) -> int:
-        ...
+    async def write(self: AsyncFile[bytes], b: ReadableBuffer) -> int: ...
 
     @overload
-    async def write(self: AsyncFile[str], b: str) -> int:
-        ...
+    async def write(self: AsyncFile[str], b: str) -> int: ...
 
     async def write(self, b: ReadableBuffer | str) -> int:
         return await to_thread.run_sync(self._fp.write, b)
@@ -113,12 +111,10 @@ class AsyncFile(AsyncResource, Generic[AnyStr]):
     @overload
     async def writelines(
         self: AsyncFile[bytes], lines: Iterable[ReadableBuffer]
-    ) -> None:
-        ...
+    ) -> None: ...
 
     @overload
-    async def writelines(self: AsyncFile[str], lines: Iterable[str]) -> None:
-        ...
+    async def writelines(self: AsyncFile[str], lines: Iterable[str]) -> None: ...
 
     async def writelines(self, lines: Iterable[ReadableBuffer] | Iterable[str]) -> None:
         return await to_thread.run_sync(self._fp.writelines, lines)
@@ -146,8 +142,7 @@ async def open_file(
     newline: str | None = ...,
     closefd: bool = ...,
     opener: Callable[[str, int], int] | None = ...,
-) -> AsyncFile[bytes]:
-    ...
+) -> AsyncFile[bytes]: ...
 
 
 @overload
@@ -160,8 +155,7 @@ async def open_file(
     newline: str | None = ...,
     closefd: bool = ...,
     opener: Callable[[str, int], int] | None = ...,
-) -> AsyncFile[str]:
-    ...
+) -> AsyncFile[str]: ...
 
 
 async def open_file(
@@ -476,8 +470,7 @@ class Path:
         encoding: str | None = ...,
         errors: str | None = ...,
         newline: str | None = ...,
-    ) -> AsyncFile[bytes]:
-        ...
+    ) -> AsyncFile[bytes]: ...
 
     @overload
     async def open(
@@ -487,8 +480,7 @@ class Path:
         encoding: str | None = ...,
         errors: str | None = ...,
         newline: str | None = ...,
-    ) -> AsyncFile[str]:
-        ...
+    ) -> AsyncFile[str]: ...
 
     async def open(
         self,

+ 5 - 10
contrib/python/anyio/anyio/_core/_sockets.py

@@ -53,8 +53,7 @@ async def connect_tcp(
     tls_standard_compatible: bool = ...,
     tls_hostname: str,
     happy_eyeballs_delay: float = ...,
-) -> TLSStream:
-    ...
+) -> TLSStream: ...
 
 
 # ssl_context given
@@ -68,8 +67,7 @@ async def connect_tcp(
     tls_standard_compatible: bool = ...,
     tls_hostname: str | None = ...,
     happy_eyeballs_delay: float = ...,
-) -> TLSStream:
-    ...
+) -> TLSStream: ...
 
 
 # tls=True
@@ -84,8 +82,7 @@ async def connect_tcp(
     tls_standard_compatible: bool = ...,
     tls_hostname: str | None = ...,
     happy_eyeballs_delay: float = ...,
-) -> TLSStream:
-    ...
+) -> TLSStream: ...
 
 
 # tls=False
@@ -100,8 +97,7 @@ async def connect_tcp(
     tls_standard_compatible: bool = ...,
     tls_hostname: str | None = ...,
     happy_eyeballs_delay: float = ...,
-) -> SocketStream:
-    ...
+) -> SocketStream: ...
 
 
 # No TLS arguments
@@ -112,8 +108,7 @@ async def connect_tcp(
     *,
     local_host: IPAddressType | None = ...,
     happy_eyeballs_delay: float = ...,
-) -> SocketStream:
-    ...
+) -> SocketStream: ...
 
 
 async def connect_tcp(

+ 8 - 4
contrib/python/anyio/anyio/_core/_testing.py

@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from collections.abc import Awaitable, Generator
-from typing import Any
+from typing import Any, cast
 
 from ._eventloop import get_async_backend
 
@@ -45,8 +45,12 @@ class TaskInfo:
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})"
 
-    def _unwrap(self) -> TaskInfo:
-        return self
+    def has_pending_cancellation(self) -> bool:
+        """
+        Return ``True`` if the task has a cancellation pending, ``False`` otherwise.
+
+        """
+        return False
 
 
 def get_current_task() -> TaskInfo:
@@ -66,7 +70,7 @@ def get_running_tasks() -> list[TaskInfo]:
     :return: a list of task info objects
 
     """
-    return get_async_backend().get_running_tasks()
+    return cast("list[TaskInfo]", get_async_backend().get_running_tasks())
 
 
 async def wait_all_tasks_blocked() -> None:

+ 5 - 5
contrib/python/anyio/anyio/_core/_typedattr.py

@@ -50,12 +50,10 @@ class TypedAttributeProvider:
         return {}
 
     @overload
-    def extra(self, attribute: T_Attr) -> T_Attr:
-        ...
+    def extra(self, attribute: T_Attr) -> T_Attr: ...
 
     @overload
-    def extra(self, attribute: T_Attr, default: T_Default) -> T_Attr | T_Default:
-        ...
+    def extra(self, attribute: T_Attr, default: T_Default) -> T_Attr | T_Default: ...
 
     @final
     def extra(self, attribute: Any, default: object = undefined) -> object:
@@ -73,9 +71,11 @@ class TypedAttributeProvider:
 
         """
         try:
-            return self.extra_attributes[attribute]()
+            getter = self.extra_attributes[attribute]
         except KeyError:
             if default is undefined:
                 raise TypedAttributeLookupError("Attribute not found") from None
             else:
                 return default
+
+        return getter()

+ 3 - 5
contrib/python/anyio/anyio/abc/_eventloop.py

@@ -303,15 +303,13 @@ class AsyncBackend(metaclass=ABCMeta):
     @overload
     async def create_unix_datagram_socket(
         cls, raw_socket: socket, remote_path: None
-    ) -> UNIXDatagramSocket:
-        ...
+    ) -> UNIXDatagramSocket: ...
 
     @classmethod
     @overload
     async def create_unix_datagram_socket(
         cls, raw_socket: socket, remote_path: str | bytes
-    ) -> ConnectedUNIXDatagramSocket:
-        ...
+    ) -> ConnectedUNIXDatagramSocket: ...
 
     @classmethod
     @abstractmethod
@@ -378,7 +376,7 @@ class AsyncBackend(metaclass=ABCMeta):
 
     @classmethod
     @abstractmethod
-    def get_running_tasks(cls) -> list[TaskInfo]:
+    def get_running_tasks(cls) -> Sequence[TaskInfo]:
         pass
 
     @classmethod

+ 2 - 0
contrib/python/anyio/anyio/abc/_resources.py

@@ -15,6 +15,8 @@ class AsyncResource(metaclass=ABCMeta):
     and calls :meth:`aclose` on exit.
     """
 
+    __slots__ = ()
+
     async def __aenter__(self: T) -> T:
         return self
 

Some files were not shown because too many files changed in this diff