_warnings.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # From scikit-image: https://github.com/scikit-image/scikit-image/blob/c2f8c4ab123ebe5f7b827bc495625a32bb225c10/skimage/_shared/_warnings.py
  2. # Licensed under modified BSD license
  3. from __future__ import annotations
  4. __all__ = ["all_warnings", "expected_warnings"]
  5. import inspect
  6. import os
  7. import re
  8. import sys
  9. import warnings
  10. from contextlib import contextmanager
  11. from unittest import mock
  12. @contextmanager
  13. def all_warnings():
  14. """
  15. Context for use in testing to ensure that all warnings are raised.
  16. Examples
  17. --------
  18. >>> import warnings
  19. >>> def foo():
  20. ... warnings.warn(RuntimeWarning("bar"))
  21. We raise the warning once, while the warning filter is set to "once".
  22. Hereafter, the warning is invisible, even with custom filters:
  23. >>> with warnings.catch_warnings():
  24. ... warnings.simplefilter('once')
  25. ... foo()
  26. We can now run ``foo()`` without a warning being raised:
  27. >>> from numpy.testing import assert_warns # doctest: +SKIP
  28. >>> foo() # doctest: +SKIP
  29. To catch the warning, we call in the help of ``all_warnings``:
  30. >>> with all_warnings(): # doctest: +SKIP
  31. ... assert_warns(RuntimeWarning, foo)
  32. """
  33. # Whenever a warning is triggered, Python adds a __warningregistry__
  34. # member to the *calling* module. The exercise here is to find
  35. # and eradicate all those breadcrumbs that were left lying around.
  36. #
  37. # We proceed by first searching all parent calling frames and explicitly
  38. # clearing their warning registries (necessary for the doctests above to
  39. # pass). Then, we search for all submodules of skimage and clear theirs
  40. # as well (necessary for the skimage test suite to pass).
  41. frame = inspect.currentframe()
  42. if frame:
  43. for f in inspect.getouterframes(frame):
  44. f[0].f_locals["__warningregistry__"] = {}
  45. del frame
  46. for _, mod in list(sys.modules.items()):
  47. try:
  48. mod.__warningregistry__.clear()
  49. except AttributeError:
  50. pass
  51. with warnings.catch_warnings(record=True) as w, mock.patch.dict(
  52. os.environ, {"TRAITLETS_ALL_DEPRECATIONS": "1"}
  53. ):
  54. warnings.simplefilter("always")
  55. yield w
  56. @contextmanager
  57. def expected_warnings(matching):
  58. r"""Context for use in testing to catch known warnings matching regexes
  59. Parameters
  60. ----------
  61. matching : list of strings or compiled regexes
  62. Regexes for the desired warning to catch
  63. Examples
  64. --------
  65. >>> from skimage import data, img_as_ubyte, img_as_float # doctest: +SKIP
  66. >>> with expected_warnings(["precision loss"]): # doctest: +SKIP
  67. ... d = img_as_ubyte(img_as_float(data.coins())) # doctest: +SKIP
  68. Notes
  69. -----
  70. Uses `all_warnings` to ensure all warnings are raised.
  71. Upon exiting, it checks the recorded warnings for the desired matching
  72. pattern(s).
  73. Raises a ValueError if any match was not found or an unexpected
  74. warning was raised.
  75. Allows for three types of behaviors: "and", "or", and "optional" matches.
  76. This is done to accommodate different build environments or loop conditions
  77. that may produce different warnings. The behaviors can be combined.
  78. If you pass multiple patterns, you get an orderless "and", where all of the
  79. warnings must be raised.
  80. If you use the "|" operator in a pattern, you can catch one of several warnings.
  81. Finally, you can use "|\A\Z" in a pattern to signify it as optional.
  82. """
  83. with all_warnings() as w:
  84. # enter context
  85. yield w
  86. # exited user context, check the recorded warnings
  87. remaining = [m for m in matching if r"\A\Z" not in m.split("|")]
  88. for warn in w:
  89. found = False
  90. for match in matching:
  91. if re.search(match, str(warn.message)) is not None:
  92. found = True
  93. if match in remaining:
  94. remaining.remove(match)
  95. if not found:
  96. raise ValueError("Unexpected warning: %s" % str(warn.message))
  97. if len(remaining) > 0:
  98. msg = "No warning raised matching:\n%s" % "\n".join(remaining)
  99. raise ValueError(msg)