12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import threading
- import traceback
- import warnings
- from types import TracebackType
- from typing import Any
- from typing import Callable
- from typing import Generator
- from typing import Optional
- from typing import Type
- import pytest
- # Copied from cpython/Lib/test/support/threading_helper.py, with modifications.
- class catch_threading_exception:
- """Context manager catching threading.Thread exception using
- threading.excepthook.
- Storing exc_value using a custom hook can create a reference cycle. The
- reference cycle is broken explicitly when the context manager exits.
- Storing thread using a custom hook can resurrect it if it is set to an
- object which is being finalized. Exiting the context manager clears the
- stored object.
- Usage:
- with threading_helper.catch_threading_exception() as cm:
- # code spawning a thread which raises an exception
- ...
- # check the thread exception: use cm.args
- ...
- # cm.args attribute no longer exists at this point
- # (to break a reference cycle)
- """
- def __init__(self) -> None:
- self.args: Optional["threading.ExceptHookArgs"] = None
- self._old_hook: Optional[Callable[["threading.ExceptHookArgs"], Any]] = None
- def _hook(self, args: "threading.ExceptHookArgs") -> None:
- self.args = args
- def __enter__(self) -> "catch_threading_exception":
- self._old_hook = threading.excepthook
- threading.excepthook = self._hook
- return self
- def __exit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType],
- ) -> None:
- assert self._old_hook is not None
- threading.excepthook = self._old_hook
- self._old_hook = None
- del self.args
- def thread_exception_runtest_hook() -> Generator[None, None, None]:
- with catch_threading_exception() as cm:
- yield
- if cm.args:
- thread_name = "<unknown>" if cm.args.thread is None else cm.args.thread.name
- msg = f"Exception in thread {thread_name}\n\n"
- msg += "".join(
- traceback.format_exception(
- cm.args.exc_type,
- cm.args.exc_value,
- cm.args.exc_traceback,
- )
- )
- warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))
- @pytest.hookimpl(hookwrapper=True, trylast=True)
- def pytest_runtest_setup() -> Generator[None, None, None]:
- yield from thread_exception_runtest_hook()
- @pytest.hookimpl(hookwrapper=True, tryfirst=True)
- def pytest_runtest_call() -> Generator[None, None, None]:
- yield from thread_exception_runtest_hook()
- @pytest.hookimpl(hookwrapper=True, tryfirst=True)
- def pytest_runtest_teardown() -> Generator[None, None, None]:
- yield from thread_exception_runtest_hook()
|