123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- from __future__ import absolute_import
- from __future__ import print_function
- import ast
- import py
- from _pytest.assertion import rewrite
- try:
- import importlib.util
- except ImportError:
- pass
- try:
- from pathlib import Path
- except ImportError:
- pass
- from __res import importer
- import sys
- import six
- def _get_state(config):
- if hasattr(config, '_assertstate'):
- return config._assertstate
- return config._store[rewrite.assertstate_key]
- class AssertionRewritingHook(rewrite.AssertionRewritingHook):
- def __init__(self, *args, **kwargs):
- self.modules = {}
- super(AssertionRewritingHook, self).__init__(*args, **kwargs)
- def find_module(self, name, path=None):
- co = self._find_module(name, path)
- if co is not None:
- return self
- def _find_module(self, name, path=None):
- state = _get_state(self.config)
- if not self._should_rewrite(name, None, state):
- return None
- state.trace("find_module called for: %s" % name)
- try:
- if self.is_package(name):
- return None
- except ImportError:
- return None
- if hasattr(self._rewritten_names, 'add'):
- self._rewritten_names.add(name)
- else:
- self._rewritten_names[name] = Path(path[0])
- state.trace("rewriting %s" % name)
- co = _rewrite_test(self.config, name)
- if co is None:
- # Probably a SyntaxError in the test.
- return None
- self.modules[name] = co, None
- return co
- def find_spec(self, name, path=None, target=None):
- co = self._find_module(name, path)
- if co is not None:
- return importlib.util.spec_from_file_location(
- name,
- co.co_filename,
- loader=self,
- )
- def _should_rewrite(self, name, fn, state):
- if name.startswith("__tests__.") or name.endswith(".conftest"):
- return True
- return self._is_marked_for_rewrite(name, state)
- def is_package(self, name):
- return importer.is_package(name)
- def get_source(self, name):
- return importer.get_source(name)
- if six.PY3:
- def load_module(self, module):
- co, _ = self.modules.pop(module.__name__)
- try:
- module.__file__ = co.co_filename
- module.__cached__ = None
- module.__loader__ = self
- module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self)
- exec(co, module.__dict__) # noqa
- except: # noqa
- if module.__name__ in sys.modules:
- del sys.modules[module.__name__]
- raise
- return sys.modules[module.__name__]
- def exec_module(self, module):
- if module.__name__ in self.modules:
- self.load_module(module)
- else:
- super(AssertionRewritingHook, self).exec_module(module)
- def _rewrite_test(config, name):
- """Try to read and rewrite *fn* and return the code object."""
- state = _get_state(config)
- source = importer.get_source(name)
- if source is None:
- return None
- path = importer.get_filename(name)
- try:
- tree = ast.parse(source, filename=path)
- except SyntaxError:
- # Let this pop up again in the real import.
- state.trace("failed to parse: %r" % (path,))
- return None
- rewrite.rewrite_asserts(tree, py.path.local(path), config)
- try:
- co = compile(tree, path, "exec", dont_inherit=True)
- except SyntaxError:
- # It's possible that this error is from some bug in the
- # assertion rewriting, but I don't know of a fast way to tell.
- state.trace("failed to compile: %r" % (path,))
- return None
- return co
|