rewrite.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from __future__ import absolute_import
  2. from __future__ import print_function
  3. import ast
  4. import py
  5. from _pytest.assertion import rewrite
  6. try:
  7. import importlib.util
  8. except ImportError:
  9. pass
  10. try:
  11. from pathlib import Path
  12. except ImportError:
  13. pass
  14. from __res import importer
  15. import sys
  16. import six
  17. def _get_state(config):
  18. if hasattr(config, '_assertstate'):
  19. return config._assertstate
  20. return config._store[rewrite.assertstate_key]
  21. class AssertionRewritingHook(rewrite.AssertionRewritingHook):
  22. def __init__(self, *args, **kwargs):
  23. self.modules = {}
  24. super(AssertionRewritingHook, self).__init__(*args, **kwargs)
  25. def find_module(self, name, path=None):
  26. co = self._find_module(name, path)
  27. if co is not None:
  28. return self
  29. def _find_module(self, name, path=None):
  30. state = _get_state(self.config)
  31. if not self._should_rewrite(name, None, state):
  32. return None
  33. state.trace("find_module called for: %s" % name)
  34. try:
  35. if self.is_package(name):
  36. return None
  37. except ImportError:
  38. return None
  39. if hasattr(self._rewritten_names, 'add'):
  40. self._rewritten_names.add(name)
  41. else:
  42. self._rewritten_names[name] = Path(path[0])
  43. state.trace("rewriting %s" % name)
  44. co = _rewrite_test(self.config, name)
  45. if co is None:
  46. # Probably a SyntaxError in the test.
  47. return None
  48. self.modules[name] = co, None
  49. return co
  50. def find_spec(self, name, path=None, target=None):
  51. co = self._find_module(name, path)
  52. if co is not None:
  53. return importlib.util.spec_from_file_location(
  54. name,
  55. co.co_filename,
  56. loader=self,
  57. )
  58. def _should_rewrite(self, name, fn, state):
  59. if name.startswith("__tests__.") or name.endswith(".conftest"):
  60. return True
  61. return self._is_marked_for_rewrite(name, state)
  62. def is_package(self, name):
  63. return importer.is_package(name)
  64. def get_source(self, name):
  65. return importer.get_source(name)
  66. if six.PY3:
  67. def load_module(self, module):
  68. co, _ = self.modules.pop(module.__name__)
  69. try:
  70. module.__file__ = co.co_filename
  71. module.__cached__ = None
  72. module.__loader__ = self
  73. module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self)
  74. exec(co, module.__dict__) # noqa
  75. except: # noqa
  76. if module.__name__ in sys.modules:
  77. del sys.modules[module.__name__]
  78. raise
  79. return sys.modules[module.__name__]
  80. def exec_module(self, module):
  81. if module.__name__ in self.modules:
  82. self.load_module(module)
  83. else:
  84. super(AssertionRewritingHook, self).exec_module(module)
  85. def _rewrite_test(config, name):
  86. """Try to read and rewrite *fn* and return the code object."""
  87. state = _get_state(config)
  88. source = importer.get_source(name)
  89. if source is None:
  90. return None
  91. path = importer.get_filename(name)
  92. try:
  93. tree = ast.parse(source, filename=path)
  94. except SyntaxError:
  95. # Let this pop up again in the real import.
  96. state.trace("failed to parse: %r" % (path,))
  97. return None
  98. rewrite.rewrite_asserts(tree, py.path.local(path), config)
  99. try:
  100. co = compile(tree, path, "exec", dont_inherit=True)
  101. except SyntaxError:
  102. # It's possible that this error is from some bug in the
  103. # assertion rewriting, but I don't know of a fast way to tell.
  104. state.trace("failed to compile: %r" % (path,))
  105. return None
  106. return co