taskgroups.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Adapted with permission from the EdgeDB project;
  2. # license: PSFL.
  3. __all__ = ("TaskGroup",)
  4. from . import events
  5. from . import exceptions
  6. from . import tasks
  7. class TaskGroup:
  8. """Asynchronous context manager for managing groups of tasks.
  9. Example use:
  10. async with asyncio.TaskGroup() as group:
  11. task1 = group.create_task(some_coroutine(...))
  12. task2 = group.create_task(other_coroutine(...))
  13. print("Both tasks have completed now.")
  14. All tasks are awaited when the context manager exits.
  15. Any exceptions other than `asyncio.CancelledError` raised within
  16. a task will cancel all remaining tasks and wait for them to exit.
  17. The exceptions are then combined and raised as an `ExceptionGroup`.
  18. """
  19. def __init__(self):
  20. self._entered = False
  21. self._exiting = False
  22. self._aborting = False
  23. self._loop = None
  24. self._parent_task = None
  25. self._parent_cancel_requested = False
  26. self._tasks = set()
  27. self._errors = []
  28. self._base_error = None
  29. self._on_completed_fut = None
  30. def __repr__(self):
  31. info = ['']
  32. if self._tasks:
  33. info.append(f'tasks={len(self._tasks)}')
  34. if self._errors:
  35. info.append(f'errors={len(self._errors)}')
  36. if self._aborting:
  37. info.append('cancelling')
  38. elif self._entered:
  39. info.append('entered')
  40. info_str = ' '.join(info)
  41. return f'<TaskGroup{info_str}>'
  42. async def __aenter__(self):
  43. if self._entered:
  44. raise RuntimeError(
  45. f"TaskGroup {self!r} has already been entered")
  46. if self._loop is None:
  47. self._loop = events.get_running_loop()
  48. self._parent_task = tasks.current_task(self._loop)
  49. if self._parent_task is None:
  50. raise RuntimeError(
  51. f'TaskGroup {self!r} cannot determine the parent task')
  52. self._entered = True
  53. return self
  54. async def __aexit__(self, et, exc, tb):
  55. self._exiting = True
  56. if (exc is not None and
  57. self._is_base_error(exc) and
  58. self._base_error is None):
  59. self._base_error = exc
  60. propagate_cancellation_error = \
  61. exc if et is exceptions.CancelledError else None
  62. if self._parent_cancel_requested:
  63. # If this flag is set we *must* call uncancel().
  64. if self._parent_task.uncancel() == 0:
  65. # If there are no pending cancellations left,
  66. # don't propagate CancelledError.
  67. propagate_cancellation_error = None
  68. if et is not None:
  69. if not self._aborting:
  70. # Our parent task is being cancelled:
  71. #
  72. # async with TaskGroup() as g:
  73. # g.create_task(...)
  74. # await ... # <- CancelledError
  75. #
  76. # or there's an exception in "async with":
  77. #
  78. # async with TaskGroup() as g:
  79. # g.create_task(...)
  80. # 1 / 0
  81. #
  82. self._abort()
  83. # We use while-loop here because "self._on_completed_fut"
  84. # can be cancelled multiple times if our parent task
  85. # is being cancelled repeatedly (or even once, when
  86. # our own cancellation is already in progress)
  87. while self._tasks:
  88. if self._on_completed_fut is None:
  89. self._on_completed_fut = self._loop.create_future()
  90. try:
  91. await self._on_completed_fut
  92. except exceptions.CancelledError as ex:
  93. if not self._aborting:
  94. # Our parent task is being cancelled:
  95. #
  96. # async def wrapper():
  97. # async with TaskGroup() as g:
  98. # g.create_task(foo)
  99. #
  100. # "wrapper" is being cancelled while "foo" is
  101. # still running.
  102. propagate_cancellation_error = ex
  103. self._abort()
  104. self._on_completed_fut = None
  105. assert not self._tasks
  106. if self._base_error is not None:
  107. raise self._base_error
  108. # Propagate CancelledError if there is one, except if there
  109. # are other errors -- those have priority.
  110. if propagate_cancellation_error and not self._errors:
  111. raise propagate_cancellation_error
  112. if et is not None and et is not exceptions.CancelledError:
  113. self._errors.append(exc)
  114. if self._errors:
  115. # Exceptions are heavy objects that can have object
  116. # cycles (bad for GC); let's not keep a reference to
  117. # a bunch of them.
  118. try:
  119. me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
  120. raise me from None
  121. finally:
  122. self._errors = None
  123. def create_task(self, coro, *, name=None, context=None):
  124. """Create a new task in this group and return it.
  125. Similar to `asyncio.create_task`.
  126. """
  127. if not self._entered:
  128. raise RuntimeError(f"TaskGroup {self!r} has not been entered")
  129. if self._exiting and not self._tasks:
  130. raise RuntimeError(f"TaskGroup {self!r} is finished")
  131. if self._aborting:
  132. raise RuntimeError(f"TaskGroup {self!r} is shutting down")
  133. if context is None:
  134. task = self._loop.create_task(coro)
  135. else:
  136. task = self._loop.create_task(coro, context=context)
  137. tasks._set_task_name(task, name)
  138. # optimization: Immediately call the done callback if the task is
  139. # already done (e.g. if the coro was able to complete eagerly),
  140. # and skip scheduling a done callback
  141. if task.done():
  142. self._on_task_done(task)
  143. else:
  144. self._tasks.add(task)
  145. task.add_done_callback(self._on_task_done)
  146. return task
  147. # Since Python 3.8 Tasks propagate all exceptions correctly,
  148. # except for KeyboardInterrupt and SystemExit which are
  149. # still considered special.
  150. def _is_base_error(self, exc: BaseException) -> bool:
  151. assert isinstance(exc, BaseException)
  152. return isinstance(exc, (SystemExit, KeyboardInterrupt))
  153. def _abort(self):
  154. self._aborting = True
  155. for t in self._tasks:
  156. if not t.done():
  157. t.cancel()
  158. def _on_task_done(self, task):
  159. self._tasks.discard(task)
  160. if self._on_completed_fut is not None and not self._tasks:
  161. if not self._on_completed_fut.done():
  162. self._on_completed_fut.set_result(True)
  163. if task.cancelled():
  164. return
  165. exc = task.exception()
  166. if exc is None:
  167. return
  168. self._errors.append(exc)
  169. if self._is_base_error(exc) and self._base_error is None:
  170. self._base_error = exc
  171. if self._parent_task.done():
  172. # Not sure if this case is possible, but we want to handle
  173. # it anyways.
  174. self._loop.call_exception_handler({
  175. 'message': f'Task {task!r} has errored out but its parent '
  176. f'task {self._parent_task} is already completed',
  177. 'exception': exc,
  178. 'task': task,
  179. })
  180. return
  181. if not self._aborting and not self._parent_cancel_requested:
  182. # If parent task *is not* being cancelled, it means that we want
  183. # to manually cancel it to abort whatever is being run right now
  184. # in the TaskGroup. But we want to mark parent task as
  185. # "not cancelled" later in __aexit__. Example situation that
  186. # we need to handle:
  187. #
  188. # async def foo():
  189. # try:
  190. # async with TaskGroup() as g:
  191. # g.create_task(crash_soon())
  192. # await something # <- this needs to be canceled
  193. # # by the TaskGroup, e.g.
  194. # # foo() needs to be cancelled
  195. # except Exception:
  196. # # Ignore any exceptions raised in the TaskGroup
  197. # pass
  198. # await something_else # this line has to be called
  199. # # after TaskGroup is finished.
  200. self._abort()
  201. self._parent_cancel_requested = True
  202. self._parent_task.cancel()