async_case.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import asyncio
  2. import contextvars
  3. import inspect
  4. import warnings
  5. from .case import TestCase
  6. __unittest = True
  7. class IsolatedAsyncioTestCase(TestCase):
  8. # Names intentionally have a long prefix
  9. # to reduce a chance of clashing with user-defined attributes
  10. # from inherited test case
  11. #
  12. # The class doesn't call loop.run_until_complete(self.setUp()) and family
  13. # but uses a different approach:
  14. # 1. create a long-running task that reads self.setUp()
  15. # awaitable from queue along with a future
  16. # 2. await the awaitable object passing in and set the result
  17. # into the future object
  18. # 3. Outer code puts the awaitable and the future object into a queue
  19. # with waiting for the future
  20. # The trick is necessary because every run_until_complete() call
  21. # creates a new task with embedded ContextVar context.
  22. # To share contextvars between setUp(), test and tearDown() we need to execute
  23. # them inside the same task.
  24. # Note: the test case modifies event loop policy if the policy was not instantiated
  25. # yet.
  26. # asyncio.get_event_loop_policy() creates a default policy on demand but never
  27. # returns None
  28. # I believe this is not an issue in user level tests but python itself for testing
  29. # should reset a policy in every test module
  30. # by calling asyncio.set_event_loop_policy(None) in tearDownModule()
  31. def __init__(self, methodName='runTest'):
  32. super().__init__(methodName)
  33. self._asyncioRunner = None
  34. self._asyncioTestContext = contextvars.copy_context()
  35. async def asyncSetUp(self):
  36. pass
  37. async def asyncTearDown(self):
  38. pass
  39. def addAsyncCleanup(self, func, /, *args, **kwargs):
  40. # A trivial trampoline to addCleanup()
  41. # the function exists because it has a different semantics
  42. # and signature:
  43. # addCleanup() accepts regular functions
  44. # but addAsyncCleanup() accepts coroutines
  45. #
  46. # We intentionally don't add inspect.iscoroutinefunction() check
  47. # for func argument because there is no way
  48. # to check for async function reliably:
  49. # 1. It can be "async def func()" itself
  50. # 2. Class can implement "async def __call__()" method
  51. # 3. Regular "def func()" that returns awaitable object
  52. self.addCleanup(*(func, *args), **kwargs)
  53. async def enterAsyncContext(self, cm):
  54. """Enters the supplied asynchronous context manager.
  55. If successful, also adds its __aexit__ method as a cleanup
  56. function and returns the result of the __aenter__ method.
  57. """
  58. # We look up the special methods on the type to match the with
  59. # statement.
  60. cls = type(cm)
  61. try:
  62. enter = cls.__aenter__
  63. exit = cls.__aexit__
  64. except AttributeError:
  65. raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
  66. f"not support the asynchronous context manager protocol"
  67. ) from None
  68. result = await enter(cm)
  69. self.addAsyncCleanup(exit, cm, None, None, None)
  70. return result
  71. def _callSetUp(self):
  72. # Force loop to be initialized and set as the current loop
  73. # so that setUp functions can use get_event_loop() and get the
  74. # correct loop instance.
  75. self._asyncioRunner.get_loop()
  76. self._asyncioTestContext.run(self.setUp)
  77. self._callAsync(self.asyncSetUp)
  78. def _callTestMethod(self, method):
  79. if self._callMaybeAsync(method) is not None:
  80. warnings.warn(f'It is deprecated to return a value that is not None from a '
  81. f'test case ({method})', DeprecationWarning, stacklevel=4)
  82. def _callTearDown(self):
  83. self._callAsync(self.asyncTearDown)
  84. self._asyncioTestContext.run(self.tearDown)
  85. def _callCleanup(self, function, *args, **kwargs):
  86. self._callMaybeAsync(function, *args, **kwargs)
  87. def _callAsync(self, func, /, *args, **kwargs):
  88. assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
  89. assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
  90. return self._asyncioRunner.run(
  91. func(*args, **kwargs),
  92. context=self._asyncioTestContext
  93. )
  94. def _callMaybeAsync(self, func, /, *args, **kwargs):
  95. assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
  96. if inspect.iscoroutinefunction(func):
  97. return self._asyncioRunner.run(
  98. func(*args, **kwargs),
  99. context=self._asyncioTestContext,
  100. )
  101. else:
  102. return self._asyncioTestContext.run(func, *args, **kwargs)
  103. def _setupAsyncioRunner(self):
  104. assert self._asyncioRunner is None, 'asyncio runner is already initialized'
  105. runner = asyncio.Runner(debug=True)
  106. self._asyncioRunner = runner
  107. def _tearDownAsyncioRunner(self):
  108. runner = self._asyncioRunner
  109. runner.close()
  110. def run(self, result=None):
  111. self._setupAsyncioRunner()
  112. try:
  113. return super().run(result)
  114. finally:
  115. self._tearDownAsyncioRunner()
  116. def debug(self):
  117. self._setupAsyncioRunner()
  118. super().debug()
  119. self._tearDownAsyncioRunner()
  120. def __del__(self):
  121. if self._asyncioRunner is not None:
  122. self._tearDownAsyncioRunner()