taskgroups.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Adapted with permission from the EdgeDB project;
  2. # license: PSFL.
  3. __all__ = ("TaskGroup",)
  4. from . import events
  5. from . import exceptions
  6. from . import tasks
  7. class TaskGroup:
  8. """Asynchronous context manager for managing groups of tasks.
  9. Example use:
  10. async with asyncio.TaskGroup() as group:
  11. task1 = group.create_task(some_coroutine(...))
  12. task2 = group.create_task(other_coroutine(...))
  13. print("Both tasks have completed now.")
  14. All tasks are awaited when the context manager exits.
  15. Any exceptions other than `asyncio.CancelledError` raised within
  16. a task will cancel all remaining tasks and wait for them to exit.
  17. The exceptions are then combined and raised as an `ExceptionGroup`.
  18. """
  19. def __init__(self):
  20. self._entered = False
  21. self._exiting = False
  22. self._aborting = False
  23. self._loop = None
  24. self._parent_task = None
  25. self._parent_cancel_requested = False
  26. self._tasks = set()
  27. self._errors = []
  28. self._base_error = None
  29. self._on_completed_fut = None
  30. def __repr__(self):
  31. info = ['']
  32. if self._tasks:
  33. info.append(f'tasks={len(self._tasks)}')
  34. if self._errors:
  35. info.append(f'errors={len(self._errors)}')
  36. if self._aborting:
  37. info.append('cancelling')
  38. elif self._entered:
  39. info.append('entered')
  40. info_str = ' '.join(info)
  41. return f'<TaskGroup{info_str}>'
  42. async def __aenter__(self):
  43. if self._entered:
  44. raise RuntimeError(
  45. f"TaskGroup {self!r} has already been entered")
  46. if self._loop is None:
  47. self._loop = events.get_running_loop()
  48. self._parent_task = tasks.current_task(self._loop)
  49. if self._parent_task is None:
  50. raise RuntimeError(
  51. f'TaskGroup {self!r} cannot determine the parent task')
  52. self._entered = True
  53. return self
  54. async def __aexit__(self, et, exc, tb):
  55. tb = None
  56. try:
  57. return await self._aexit(et, exc)
  58. finally:
  59. # Exceptions are heavy objects that can have object
  60. # cycles (bad for GC); let's not keep a reference to
  61. # a bunch of them. It would be nicer to use a try/finally
  62. # in __aexit__ directly but that introduced some diff noise
  63. self._parent_task = None
  64. self._errors = None
  65. self._base_error = None
  66. exc = None
  67. async def _aexit(self, et, exc):
  68. self._exiting = True
  69. if (exc is not None and
  70. self._is_base_error(exc) and
  71. self._base_error is None):
  72. self._base_error = exc
  73. propagate_cancellation_error = \
  74. exc if et is exceptions.CancelledError else None
  75. if self._parent_cancel_requested:
  76. # If this flag is set we *must* call uncancel().
  77. if self._parent_task.uncancel() == 0:
  78. # If there are no pending cancellations left,
  79. # don't propagate CancelledError.
  80. propagate_cancellation_error = None
  81. if et is not None:
  82. if not self._aborting:
  83. # Our parent task is being cancelled:
  84. #
  85. # async with TaskGroup() as g:
  86. # g.create_task(...)
  87. # await ... # <- CancelledError
  88. #
  89. # or there's an exception in "async with":
  90. #
  91. # async with TaskGroup() as g:
  92. # g.create_task(...)
  93. # 1 / 0
  94. #
  95. self._abort()
  96. # We use while-loop here because "self._on_completed_fut"
  97. # can be cancelled multiple times if our parent task
  98. # is being cancelled repeatedly (or even once, when
  99. # our own cancellation is already in progress)
  100. while self._tasks:
  101. if self._on_completed_fut is None:
  102. self._on_completed_fut = self._loop.create_future()
  103. try:
  104. await self._on_completed_fut
  105. except exceptions.CancelledError as ex:
  106. if not self._aborting:
  107. # Our parent task is being cancelled:
  108. #
  109. # async def wrapper():
  110. # async with TaskGroup() as g:
  111. # g.create_task(foo)
  112. #
  113. # "wrapper" is being cancelled while "foo" is
  114. # still running.
  115. propagate_cancellation_error = ex
  116. self._abort()
  117. self._on_completed_fut = None
  118. assert not self._tasks
  119. if self._base_error is not None:
  120. try:
  121. raise self._base_error
  122. finally:
  123. exc = None
  124. # Propagate CancelledError if there is one, except if there
  125. # are other errors -- those have priority.
  126. try:
  127. if propagate_cancellation_error and not self._errors:
  128. try:
  129. raise propagate_cancellation_error
  130. finally:
  131. exc = None
  132. finally:
  133. propagate_cancellation_error = None
  134. if et is not None and et is not exceptions.CancelledError:
  135. self._errors.append(exc)
  136. if self._errors:
  137. try:
  138. raise BaseExceptionGroup(
  139. 'unhandled errors in a TaskGroup',
  140. self._errors,
  141. ) from None
  142. finally:
  143. exc = None
  144. def create_task(self, coro, *, name=None, context=None):
  145. """Create a new task in this group and return it.
  146. Similar to `asyncio.create_task`.
  147. """
  148. if not self._entered:
  149. raise RuntimeError(f"TaskGroup {self!r} has not been entered")
  150. if self._exiting and not self._tasks:
  151. raise RuntimeError(f"TaskGroup {self!r} is finished")
  152. if self._aborting:
  153. raise RuntimeError(f"TaskGroup {self!r} is shutting down")
  154. if context is None:
  155. task = self._loop.create_task(coro)
  156. else:
  157. task = self._loop.create_task(coro, context=context)
  158. tasks._set_task_name(task, name)
  159. # optimization: Immediately call the done callback if the task is
  160. # already done (e.g. if the coro was able to complete eagerly),
  161. # and skip scheduling a done callback
  162. if task.done():
  163. self._on_task_done(task)
  164. else:
  165. self._tasks.add(task)
  166. task.add_done_callback(self._on_task_done)
  167. return task
  168. # Since Python 3.8 Tasks propagate all exceptions correctly,
  169. # except for KeyboardInterrupt and SystemExit which are
  170. # still considered special.
  171. def _is_base_error(self, exc: BaseException) -> bool:
  172. assert isinstance(exc, BaseException)
  173. return isinstance(exc, (SystemExit, KeyboardInterrupt))
  174. def _abort(self):
  175. self._aborting = True
  176. for t in self._tasks:
  177. if not t.done():
  178. t.cancel()
  179. def _on_task_done(self, task):
  180. self._tasks.discard(task)
  181. if self._on_completed_fut is not None and not self._tasks:
  182. if not self._on_completed_fut.done():
  183. self._on_completed_fut.set_result(True)
  184. if task.cancelled():
  185. return
  186. exc = task.exception()
  187. if exc is None:
  188. return
  189. self._errors.append(exc)
  190. if self._is_base_error(exc) and self._base_error is None:
  191. self._base_error = exc
  192. if self._parent_task.done():
  193. # Not sure if this case is possible, but we want to handle
  194. # it anyways.
  195. self._loop.call_exception_handler({
  196. 'message': f'Task {task!r} has errored out but its parent '
  197. f'task {self._parent_task} is already completed',
  198. 'exception': exc,
  199. 'task': task,
  200. })
  201. return
  202. if not self._aborting and not self._parent_cancel_requested:
  203. # If parent task *is not* being cancelled, it means that we want
  204. # to manually cancel it to abort whatever is being run right now
  205. # in the TaskGroup. But we want to mark parent task as
  206. # "not cancelled" later in __aexit__. Example situation that
  207. # we need to handle:
  208. #
  209. # async def foo():
  210. # try:
  211. # async with TaskGroup() as g:
  212. # g.create_task(crash_soon())
  213. # await something # <- this needs to be canceled
  214. # # by the TaskGroup, e.g.
  215. # # foo() needs to be cancelled
  216. # except Exception:
  217. # # Ignore any exceptions raised in the TaskGroup
  218. # pass
  219. # await something_else # this line has to be called
  220. # # after TaskGroup is finished.
  221. self._abort()
  222. self._parent_cancel_requested = True
  223. self._parent_task.cancel()