123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- # Adapted with permission from the EdgeDB project;
- # license: PSFL.
- __all__ = ("TaskGroup",)
- from . import events
- from . import exceptions
- from . import tasks
- class TaskGroup:
- """Asynchronous context manager for managing groups of tasks.
- Example use:
- async with asyncio.TaskGroup() as group:
- task1 = group.create_task(some_coroutine(...))
- task2 = group.create_task(other_coroutine(...))
- print("Both tasks have completed now.")
- All tasks are awaited when the context manager exits.
- Any exceptions other than `asyncio.CancelledError` raised within
- a task will cancel all remaining tasks and wait for them to exit.
- The exceptions are then combined and raised as an `ExceptionGroup`.
- """
- def __init__(self):
- self._entered = False
- self._exiting = False
- self._aborting = False
- self._loop = None
- self._parent_task = None
- self._parent_cancel_requested = False
- self._tasks = set()
- self._errors = []
- self._base_error = None
- self._on_completed_fut = None
- def __repr__(self):
- info = ['']
- if self._tasks:
- info.append(f'tasks={len(self._tasks)}')
- if self._errors:
- info.append(f'errors={len(self._errors)}')
- if self._aborting:
- info.append('cancelling')
- elif self._entered:
- info.append('entered')
- info_str = ' '.join(info)
- return f'<TaskGroup{info_str}>'
- async def __aenter__(self):
- if self._entered:
- raise RuntimeError(
- f"TaskGroup {self!r} has already been entered")
- if self._loop is None:
- self._loop = events.get_running_loop()
- self._parent_task = tasks.current_task(self._loop)
- if self._parent_task is None:
- raise RuntimeError(
- f'TaskGroup {self!r} cannot determine the parent task')
- self._entered = True
- return self
- async def __aexit__(self, et, exc, tb):
- self._exiting = True
- if (exc is not None and
- self._is_base_error(exc) and
- self._base_error is None):
- self._base_error = exc
- propagate_cancellation_error = \
- exc if et is exceptions.CancelledError else None
- if self._parent_cancel_requested:
- # If this flag is set we *must* call uncancel().
- if self._parent_task.uncancel() == 0:
- # If there are no pending cancellations left,
- # don't propagate CancelledError.
- propagate_cancellation_error = None
- if et is not None:
- if not self._aborting:
- # Our parent task is being cancelled:
- #
- # async with TaskGroup() as g:
- # g.create_task(...)
- # await ... # <- CancelledError
- #
- # or there's an exception in "async with":
- #
- # async with TaskGroup() as g:
- # g.create_task(...)
- # 1 / 0
- #
- self._abort()
- # We use while-loop here because "self._on_completed_fut"
- # can be cancelled multiple times if our parent task
- # is being cancelled repeatedly (or even once, when
- # our own cancellation is already in progress)
- while self._tasks:
- if self._on_completed_fut is None:
- self._on_completed_fut = self._loop.create_future()
- try:
- await self._on_completed_fut
- except exceptions.CancelledError as ex:
- if not self._aborting:
- # Our parent task is being cancelled:
- #
- # async def wrapper():
- # async with TaskGroup() as g:
- # g.create_task(foo)
- #
- # "wrapper" is being cancelled while "foo" is
- # still running.
- propagate_cancellation_error = ex
- self._abort()
- self._on_completed_fut = None
- assert not self._tasks
- if self._base_error is not None:
- raise self._base_error
- # Propagate CancelledError if there is one, except if there
- # are other errors -- those have priority.
- if propagate_cancellation_error and not self._errors:
- raise propagate_cancellation_error
- if et is not None and et is not exceptions.CancelledError:
- self._errors.append(exc)
- if self._errors:
- # Exceptions are heavy objects that can have object
- # cycles (bad for GC); let's not keep a reference to
- # a bunch of them.
- try:
- me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
- raise me from None
- finally:
- self._errors = None
- def create_task(self, coro, *, name=None, context=None):
- """Create a new task in this group and return it.
- Similar to `asyncio.create_task`.
- """
- if not self._entered:
- raise RuntimeError(f"TaskGroup {self!r} has not been entered")
- if self._exiting and not self._tasks:
- raise RuntimeError(f"TaskGroup {self!r} is finished")
- if self._aborting:
- raise RuntimeError(f"TaskGroup {self!r} is shutting down")
- if context is None:
- task = self._loop.create_task(coro)
- else:
- task = self._loop.create_task(coro, context=context)
- tasks._set_task_name(task, name)
- # optimization: Immediately call the done callback if the task is
- # already done (e.g. if the coro was able to complete eagerly),
- # and skip scheduling a done callback
- if task.done():
- self._on_task_done(task)
- else:
- self._tasks.add(task)
- task.add_done_callback(self._on_task_done)
- return task
- # Since Python 3.8 Tasks propagate all exceptions correctly,
- # except for KeyboardInterrupt and SystemExit which are
- # still considered special.
- def _is_base_error(self, exc: BaseException) -> bool:
- assert isinstance(exc, BaseException)
- return isinstance(exc, (SystemExit, KeyboardInterrupt))
- def _abort(self):
- self._aborting = True
- for t in self._tasks:
- if not t.done():
- t.cancel()
- def _on_task_done(self, task):
- self._tasks.discard(task)
- if self._on_completed_fut is not None and not self._tasks:
- if not self._on_completed_fut.done():
- self._on_completed_fut.set_result(True)
- if task.cancelled():
- return
- exc = task.exception()
- if exc is None:
- return
- self._errors.append(exc)
- if self._is_base_error(exc) and self._base_error is None:
- self._base_error = exc
- if self._parent_task.done():
- # Not sure if this case is possible, but we want to handle
- # it anyways.
- self._loop.call_exception_handler({
- 'message': f'Task {task!r} has errored out but its parent '
- f'task {self._parent_task} is already completed',
- 'exception': exc,
- 'task': task,
- })
- return
- if not self._aborting and not self._parent_cancel_requested:
- # If parent task *is not* being cancelled, it means that we want
- # to manually cancel it to abort whatever is being run right now
- # in the TaskGroup. But we want to mark parent task as
- # "not cancelled" later in __aexit__. Example situation that
- # we need to handle:
- #
- # async def foo():
- # try:
- # async with TaskGroup() as g:
- # g.create_task(crash_soon())
- # await something # <- this needs to be canceled
- # # by the TaskGroup, e.g.
- # # foo() needs to be cancelled
- # except Exception:
- # # Ignore any exceptions raised in the TaskGroup
- # pass
- # await something_else # this line has to be called
- # # after TaskGroup is finished.
- self._abort()
- self._parent_cancel_requested = True
- self._parent_task.cancel()
|