1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072 |
- # -*- coding: utf-8 -*-
- """Rewrite assertion AST to produce nice error messages"""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import warnings
- warnings.filterwarnings("ignore", category=DeprecationWarning, module="_pytest.assertion.rewrite")
- import ast
- import errno
- import imp
- import itertools
- import marshal
- import os
- import re
- import string
- import struct
- import sys
- import types
- import atomicwrites
- import py
- import six
- from _pytest._io.saferepr import saferepr
- from _pytest.assertion import util
- from _pytest.assertion.util import ( # noqa: F401
- format_explanation as _format_explanation,
- )
- from _pytest.compat import spec_from_file_location
- from _pytest.pathlib import fnmatch_ex
- from _pytest.pathlib import PurePath
- # pytest caches rewritten pycs in __pycache__.
- if hasattr(imp, "get_tag"):
- PYTEST_TAG = imp.get_tag() + "-PYTEST"
- else:
- if hasattr(sys, "pypy_version_info"):
- impl = "pypy"
- elif sys.platform == "java":
- impl = "jython"
- else:
- impl = "cpython"
- ver = sys.version_info
- PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1])
- del ver, impl
- PYC_EXT = ".py" + (__debug__ and "c" or "o")
- PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
- ASCII_IS_DEFAULT_ENCODING = sys.version_info[0] < 3
- if sys.version_info >= (3, 5):
- ast_Call = ast.Call
- else:
- def ast_Call(a, b, c):
- return ast.Call(a, b, c, None, None)
- class AssertionRewritingHook(object):
- """PEP302 Import hook which rewrites asserts."""
- def __init__(self, config):
- self.config = config
- try:
- self.fnpats = config.getini("python_files")
- except ValueError:
- self.fnpats = ["test_*.py", "*_test.py"]
- self.session = None
- self.modules = {}
- self._rewritten_names = set()
- self._must_rewrite = set()
- # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
- # which might result in infinite recursion (#3506)
- self._writing_pyc = False
- self._basenames_to_check_rewrite = {"conftest"}
- self._marked_for_rewrite_cache = {}
- self._session_paths_checked = False
- def set_session(self, session):
- self.session = session
- self._session_paths_checked = False
- def _imp_find_module(self, name, path=None):
- """Indirection so we can mock calls to find_module originated from the hook during testing"""
- return imp.find_module(name, path)
- def find_module(self, name, path=None):
- if self._writing_pyc:
- return None
- state = self.config._assertstate
- if self._early_rewrite_bailout(name, state):
- return None
- state.trace("find_module called for: %s" % name)
- names = name.rsplit(".", 1)
- lastname = names[-1]
- pth = None
- if path is not None:
- # Starting with Python 3.3, path is a _NamespacePath(), which
- # causes problems if not converted to list.
- path = list(path)
- if len(path) == 1:
- pth = path[0]
- if pth is None:
- try:
- fd, fn, desc = self._imp_find_module(lastname, path)
- except ImportError:
- return None
- if fd is not None:
- fd.close()
- tp = desc[2]
- if tp == imp.PY_COMPILED:
- if hasattr(imp, "source_from_cache"):
- try:
- fn = imp.source_from_cache(fn)
- except ValueError:
- # Python 3 doesn't like orphaned but still-importable
- # .pyc files.
- fn = fn[:-1]
- else:
- fn = fn[:-1]
- elif tp != imp.PY_SOURCE:
- # Don't know what this is.
- return None
- else:
- fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
- fn_pypath = py.path.local(fn)
- if not self._should_rewrite(name, fn_pypath, state):
- return None
- self._rewritten_names.add(name)
- # The requested module looks like a test file, so rewrite it. This is
- # the most magical part of the process: load the source, rewrite the
- # asserts, and load the rewritten source. We also cache the rewritten
- # module code in a special pyc. We must be aware of the possibility of
- # concurrent pytest processes rewriting and loading pycs. To avoid
- # tricky race conditions, we maintain the following invariant: The
- # cached pyc is always a complete, valid pyc. Operations on it must be
- # atomic. POSIX's atomic rename comes in handy.
- write = not sys.dont_write_bytecode
- cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
- if write:
- try:
- os.mkdir(cache_dir)
- except OSError:
- e = sys.exc_info()[1].errno
- if e == errno.EEXIST:
- # Either the __pycache__ directory already exists (the
- # common case) or it's blocked by a non-dir node. In the
- # latter case, we'll ignore it in _write_pyc.
- pass
- elif e in [errno.ENOENT, errno.ENOTDIR]:
- # One of the path components was not a directory, likely
- # because we're in a zip file.
- write = False
- elif e in [errno.EACCES, errno.EROFS, errno.EPERM]:
- state.trace("read only directory: %r" % fn_pypath.dirname)
- write = False
- else:
- raise
- cache_name = fn_pypath.basename[:-3] + PYC_TAIL
- pyc = os.path.join(cache_dir, cache_name)
- # Notice that even if we're in a read-only directory, I'm going
- # to check for a cached pyc. This may not be optimal...
- co = _read_pyc(fn_pypath, pyc, state.trace)
- if co is None:
- state.trace("rewriting %r" % (fn,))
- source_stat, co = _rewrite_test(self.config, fn_pypath)
- if co is None:
- # Probably a SyntaxError in the test.
- return None
- if write:
- self._writing_pyc = True
- try:
- _write_pyc(state, co, source_stat, pyc)
- finally:
- self._writing_pyc = False
- else:
- state.trace("found cached rewritten pyc for %r" % (fn,))
- self.modules[name] = co, pyc
- return self
- def _early_rewrite_bailout(self, name, state):
- """
- This is a fast way to get out of rewriting modules. Profiling has
- shown that the call to imp.find_module (inside of the find_module
- from this class) is a major slowdown, so, this method tries to
- filter what we're sure won't be rewritten before getting to it.
- """
- if self.session is not None and not self._session_paths_checked:
- self._session_paths_checked = True
- for path in self.session._initialpaths:
- # Make something as c:/projects/my_project/path.py ->
- # ['c:', 'projects', 'my_project', 'path.py']
- parts = str(path).split(os.path.sep)
- # add 'path' to basenames to be checked.
- self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
- # Note: conftest already by default in _basenames_to_check_rewrite.
- parts = name.split(".")
- if parts[-1] in self._basenames_to_check_rewrite:
- return False
- # For matching the name it must be as if it was a filename.
- path = PurePath(os.path.sep.join(parts) + ".py")
- for pat in self.fnpats:
- # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
- # on the name alone because we need to match against the full path
- if os.path.dirname(pat):
- return False
- if fnmatch_ex(pat, path):
- return False
- if self._is_marked_for_rewrite(name, state):
- return False
- state.trace("early skip of rewriting module: %s" % (name,))
- return True
- def _should_rewrite(self, name, fn_pypath, state):
- # always rewrite conftest files
- fn = str(fn_pypath)
- if fn_pypath.basename == "conftest.py":
- state.trace("rewriting conftest file: %r" % (fn,))
- return True
- if self.session is not None:
- if self.session.isinitpath(fn):
- state.trace("matched test file (was specified on cmdline): %r" % (fn,))
- return True
- # modules not passed explicitly on the command line are only
- # rewritten if they match the naming convention for test files
- for pat in self.fnpats:
- if fn_pypath.fnmatch(pat):
- state.trace("matched test file %r" % (fn,))
- return True
- return self._is_marked_for_rewrite(name, state)
- def _is_marked_for_rewrite(self, name, state):
- try:
- return self._marked_for_rewrite_cache[name]
- except KeyError:
- for marked in self._must_rewrite:
- if name == marked or name.startswith(marked + "."):
- state.trace("matched marked file %r (from %r)" % (name, marked))
- self._marked_for_rewrite_cache[name] = True
- return True
- self._marked_for_rewrite_cache[name] = False
- return False
- def mark_rewrite(self, *names):
- """Mark import names as needing to be rewritten.
- The named module or package as well as any nested modules will
- be rewritten on import.
- """
- already_imported = (
- set(names).intersection(sys.modules).difference(self._rewritten_names)
- )
- for name in already_imported:
- if not AssertionRewriter.is_rewrite_disabled(
- sys.modules[name].__doc__ or ""
- ):
- self._warn_already_imported(name)
- self._must_rewrite.update(names)
- self._marked_for_rewrite_cache.clear()
- def _warn_already_imported(self, name):
- from _pytest.warning_types import PytestAssertRewriteWarning
- from _pytest.warnings import _issue_warning_captured
- _issue_warning_captured(
- PytestAssertRewriteWarning(
- "Module already imported so cannot be rewritten: %s" % name
- ),
- self.config.hook,
- stacklevel=5,
- )
- def load_module(self, name):
- co, pyc = self.modules.pop(name)
- if name in sys.modules:
- # If there is an existing module object named 'fullname' in
- # sys.modules, the loader must use that existing module. (Otherwise,
- # the reload() builtin will not work correctly.)
- mod = sys.modules[name]
- else:
- # I wish I could just call imp.load_compiled here, but __file__ has to
- # be set properly. In Python 3.2+, this all would be handled correctly
- # by load_compiled.
- mod = sys.modules[name] = imp.new_module(name)
- try:
- mod.__file__ = co.co_filename
- # Normally, this attribute is 3.2+.
- mod.__cached__ = pyc
- mod.__loader__ = self
- # Normally, this attribute is 3.4+
- mod.__spec__ = spec_from_file_location(name, co.co_filename, loader=self)
- exec(co, mod.__dict__)
- except: # noqa
- if name in sys.modules:
- del sys.modules[name]
- raise
- return sys.modules[name]
- def is_package(self, name):
- try:
- fd, fn, desc = self._imp_find_module(name)
- except ImportError:
- return False
- if fd is not None:
- fd.close()
- tp = desc[2]
- return tp == imp.PKG_DIRECTORY
- def get_data(self, pathname):
- """Optional PEP302 get_data API.
- """
- with open(pathname, "rb") as f:
- return f.read()
- def _write_pyc(state, co, source_stat, pyc):
- # Technically, we don't have to have the same pyc format as
- # (C)Python, since these "pycs" should never be seen by builtin
- # import. However, there's little reason deviate, and I hope
- # sometime to be able to use imp.load_compiled to load them. (See
- # the comment in load_module above.)
- try:
- with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp:
- fp.write(imp.get_magic())
- # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
- mtime = int(source_stat.mtime) & 0xFFFFFFFF
- size = source_stat.size & 0xFFFFFFFF
- # "<LL" stands for 2 unsigned longs, little-ending
- fp.write(struct.pack("<LL", mtime, size))
- fp.write(marshal.dumps(co))
- except EnvironmentError as e:
- state.trace("error writing pyc file at %s: errno=%s" % (pyc, e.errno))
- # we ignore any failure to write the cache file
- # there are many reasons, permission-denied, __pycache__ being a
- # file etc.
- return False
- return True
- RN = "\r\n".encode("utf-8")
- N = "\n".encode("utf-8")
- cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
- BOM_UTF8 = "\xef\xbb\xbf"
- def _rewrite_test(config, fn):
- """Try to read and rewrite *fn* and return the code object."""
- state = config._assertstate
- try:
- stat = fn.stat()
- source = fn.read("rb")
- except EnvironmentError:
- return None, None
- if ASCII_IS_DEFAULT_ENCODING:
- # ASCII is the default encoding in Python 2. Without a coding
- # declaration, Python 2 will complain about any bytes in the file
- # outside the ASCII range. Sadly, this behavior does not extend to
- # compile() or ast.parse(), which prefer to interpret the bytes as
- # latin-1. (At least they properly handle explicit coding cookies.) To
- # preserve this error behavior, we could force ast.parse() to use ASCII
- # as the encoding by inserting a coding cookie. Unfortunately, that
- # messes up line numbers. Thus, we have to check ourselves if anything
- # is outside the ASCII range in the case no encoding is explicitly
- # declared. For more context, see issue #269. Yay for Python 3 which
- # gets this right.
- end1 = source.find("\n")
- end2 = source.find("\n", end1 + 1)
- if (
- not source.startswith(BOM_UTF8)
- and cookie_re.match(source[0:end1]) is None
- and cookie_re.match(source[end1 + 1 : end2]) is None
- ):
- if hasattr(state, "_indecode"):
- # encodings imported us again, so don't rewrite.
- return None, None
- state._indecode = True
- try:
- try:
- source.decode("ascii")
- except UnicodeDecodeError:
- # Let it fail in real import.
- return None, None
- finally:
- del state._indecode
- try:
- tree = ast.parse(source, filename=fn.strpath)
- except SyntaxError:
- # Let this pop up again in the real import.
- state.trace("failed to parse: %r" % (fn,))
- return None, None
- rewrite_asserts(tree, fn, config)
- try:
- co = compile(tree, fn.strpath, "exec", dont_inherit=True)
- except SyntaxError:
- # It's possible that this error is from some bug in the
- # assertion rewriting, but I don't know of a fast way to tell.
- state.trace("failed to compile: %r" % (fn,))
- return None, None
- return stat, co
- def _read_pyc(source, pyc, trace=lambda x: None):
- """Possibly read a pytest pyc containing rewritten code.
- Return rewritten code if successful or None if not.
- """
- try:
- fp = open(pyc, "rb")
- except IOError:
- return None
- with fp:
- try:
- mtime = int(source.mtime())
- size = source.size()
- data = fp.read(12)
- except EnvironmentError as e:
- trace("_read_pyc(%s): EnvironmentError %s" % (source, e))
- return None
- # Check for invalid or out of date pyc file.
- if (
- len(data) != 12
- or data[:4] != imp.get_magic()
- or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
- ):
- trace("_read_pyc(%s): invalid or out of date pyc" % source)
- return None
- try:
- co = marshal.load(fp)
- except Exception as e:
- trace("_read_pyc(%s): marshal.load error %s" % (source, e))
- return None
- if not isinstance(co, types.CodeType):
- trace("_read_pyc(%s): not a code object" % source)
- return None
- return co
- def rewrite_asserts(mod, module_path=None, config=None):
- """Rewrite the assert statements in mod."""
- AssertionRewriter(module_path, config).run(mod)
- def _saferepr(obj):
- """Get a safe repr of an object for assertion error messages.
- The assertion formatting (util.format_explanation()) requires
- newlines to be escaped since they are a special character for it.
- Normally assertion.util.format_explanation() does this but for a
- custom repr it is possible to contain one of the special escape
- sequences, especially '\n{' and '\n}' are likely to be present in
- JSON reprs.
- """
- r = saferepr(obj)
- # only occurs in python2.x, repr must return text in python3+
- if isinstance(r, bytes):
- # Represent unprintable bytes as `\x##`
- r = u"".join(
- u"\\x{:x}".format(ord(c)) if c not in string.printable else c.decode()
- for c in r
- )
- return r.replace(u"\n", u"\\n")
- def _format_assertmsg(obj):
- """Format the custom assertion message given.
- For strings this simply replaces newlines with '\n~' so that
- util.format_explanation() will preserve them instead of escaping
- newlines. For other objects saferepr() is used first.
- """
- # reprlib appears to have a bug which means that if a string
- # contains a newline it gets escaped, however if an object has a
- # .__repr__() which contains newlines it does not get escaped.
- # However in either case we want to preserve the newline.
- replaces = [(u"\n", u"\n~"), (u"%", u"%%")]
- if not isinstance(obj, six.string_types):
- obj = saferepr(obj)
- replaces.append((u"\\n", u"\n~"))
- if isinstance(obj, bytes):
- replaces = [(r1.encode(), r2.encode()) for r1, r2 in replaces]
- for r1, r2 in replaces:
- obj = obj.replace(r1, r2)
- return obj
- def _should_repr_global_name(obj):
- if callable(obj):
- return False
- try:
- return not hasattr(obj, "__name__")
- except Exception:
- return True
- def _format_boolop(explanations, is_or):
- explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
- if isinstance(explanation, six.text_type):
- return explanation.replace(u"%", u"%%")
- else:
- return explanation.replace(b"%", b"%%")
- def _call_reprcompare(ops, results, expls, each_obj):
- for i, res, expl in zip(range(len(ops)), results, expls):
- try:
- done = not res
- except Exception:
- done = True
- if done:
- break
- if util._reprcompare is not None:
- custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
- if custom is not None:
- return custom
- return expl
- unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
- binop_map = {
- ast.BitOr: "|",
- ast.BitXor: "^",
- ast.BitAnd: "&",
- ast.LShift: "<<",
- ast.RShift: ">>",
- ast.Add: "+",
- ast.Sub: "-",
- ast.Mult: "*",
- ast.Div: "/",
- ast.FloorDiv: "//",
- ast.Mod: "%%", # escaped for string formatting
- ast.Eq: "==",
- ast.NotEq: "!=",
- ast.Lt: "<",
- ast.LtE: "<=",
- ast.Gt: ">",
- ast.GtE: ">=",
- ast.Pow: "**",
- ast.Is: "is",
- ast.IsNot: "is not",
- ast.In: "in",
- ast.NotIn: "not in",
- }
- # Python 3.5+ compatibility
- try:
- binop_map[ast.MatMult] = "@"
- except AttributeError:
- pass
- # Python 3.4+ compatibility
- if hasattr(ast, "NameConstant"):
- _NameConstant = ast.NameConstant
- else:
- def _NameConstant(c):
- return ast.Name(str(c), ast.Load())
- def set_location(node, lineno, col_offset):
- """Set node location information recursively."""
- def _fix(node, lineno, col_offset):
- if "lineno" in node._attributes:
- node.lineno = lineno
- if "col_offset" in node._attributes:
- node.col_offset = col_offset
- for child in ast.iter_child_nodes(node):
- _fix(child, lineno, col_offset)
- _fix(node, lineno, col_offset)
- return node
- class AssertionRewriter(ast.NodeVisitor):
- """Assertion rewriting implementation.
- The main entrypoint is to call .run() with an ast.Module instance,
- this will then find all the assert statements and rewrite them to
- provide intermediate values and a detailed assertion error. See
- http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
- for an overview of how this works.
- The entry point here is .run() which will iterate over all the
- statements in an ast.Module and for each ast.Assert statement it
- finds call .visit() with it. Then .visit_Assert() takes over and
- is responsible for creating new ast statements to replace the
- original assert statement: it rewrites the test of an assertion
- to provide intermediate values and replace it with an if statement
- which raises an assertion error with a detailed explanation in
- case the expression is false.
- For this .visit_Assert() uses the visitor pattern to visit all the
- AST nodes of the ast.Assert.test field, each visit call returning
- an AST node and the corresponding explanation string. During this
- state is kept in several instance attributes:
- :statements: All the AST statements which will replace the assert
- statement.
- :variables: This is populated by .variable() with each variable
- used by the statements so that they can all be set to None at
- the end of the statements.
- :variable_counter: Counter to create new unique variables needed
- by statements. Variables are created using .variable() and
- have the form of "@py_assert0".
- :on_failure: The AST statements which will be executed if the
- assertion test fails. This is the code which will construct
- the failure message and raises the AssertionError.
- :explanation_specifiers: A dict filled by .explanation_param()
- with %-formatting placeholders and their corresponding
- expressions to use in the building of an assertion message.
- This is used by .pop_format_context() to build a message.
- :stack: A stack of the explanation_specifiers dicts maintained by
- .push_format_context() and .pop_format_context() which allows
- to build another %-formatted string while already building one.
- This state is reset on every new assert statement visited and used
- by the other visitors.
- """
- def __init__(self, module_path, config):
- super(AssertionRewriter, self).__init__()
- self.module_path = module_path
- self.config = config
- def run(self, mod):
- """Find all assert statements in *mod* and rewrite them."""
- if not mod.body:
- # Nothing to do.
- return
- # Insert some special imports at the top of the module but after any
- # docstrings and __future__ imports.
- aliases = [
- ast.alias(six.moves.builtins.__name__, "@py_builtins"),
- ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
- ]
- doc = getattr(mod, "docstring", None)
- expect_docstring = doc is None
- if doc is not None and self.is_rewrite_disabled(doc):
- return
- pos = 0
- lineno = 1
- for item in mod.body:
- if (
- expect_docstring
- and isinstance(item, ast.Expr)
- and isinstance(item.value, ast.Str)
- ):
- doc = item.value.s
- if self.is_rewrite_disabled(doc):
- return
- expect_docstring = False
- elif (
- not isinstance(item, ast.ImportFrom)
- or item.level > 0
- or item.module != "__future__"
- ):
- lineno = item.lineno
- break
- pos += 1
- else:
- lineno = item.lineno
- imports = [
- ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
- ]
- mod.body[pos:pos] = imports
- # Collect asserts.
- nodes = [mod]
- while nodes:
- node = nodes.pop()
- for name, field in ast.iter_fields(node):
- if isinstance(field, list):
- new = []
- for i, child in enumerate(field):
- if isinstance(child, ast.Assert):
- # Transform assert.
- new.extend(self.visit(child))
- else:
- new.append(child)
- if isinstance(child, ast.AST):
- nodes.append(child)
- setattr(node, name, new)
- elif (
- isinstance(field, ast.AST)
- # Don't recurse into expressions as they can't contain
- # asserts.
- and not isinstance(field, ast.expr)
- ):
- nodes.append(field)
- @staticmethod
- def is_rewrite_disabled(docstring):
- return "PYTEST_DONT_REWRITE" in docstring
- def variable(self):
- """Get a new variable."""
- # Use a character invalid in python identifiers to avoid clashing.
- name = "@py_assert" + str(next(self.variable_counter))
- self.variables.append(name)
- return name
- def assign(self, expr):
- """Give *expr* a name."""
- name = self.variable()
- self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
- return ast.Name(name, ast.Load())
- def display(self, expr):
- """Call saferepr on the expression."""
- return self.helper("_saferepr", expr)
- def helper(self, name, *args):
- """Call a helper in this module."""
- py_name = ast.Name("@pytest_ar", ast.Load())
- attr = ast.Attribute(py_name, name, ast.Load())
- return ast_Call(attr, list(args), [])
- def builtin(self, name):
- """Return the builtin called *name*."""
- builtin_name = ast.Name("@py_builtins", ast.Load())
- return ast.Attribute(builtin_name, name, ast.Load())
- def explanation_param(self, expr):
- """Return a new named %-formatting placeholder for expr.
- This creates a %-formatting placeholder for expr in the
- current formatting context, e.g. ``%(py0)s``. The placeholder
- and expr are placed in the current format context so that it
- can be used on the next call to .pop_format_context().
- """
- specifier = "py" + str(next(self.variable_counter))
- self.explanation_specifiers[specifier] = expr
- return "%(" + specifier + ")s"
- def push_format_context(self):
- """Create a new formatting context.
- The format context is used for when an explanation wants to
- have a variable value formatted in the assertion message. In
- this case the value required can be added using
- .explanation_param(). Finally .pop_format_context() is used
- to format a string of %-formatted values as added by
- .explanation_param().
- """
- self.explanation_specifiers = {}
- self.stack.append(self.explanation_specifiers)
- def pop_format_context(self, expl_expr):
- """Format the %-formatted string with current format context.
- The expl_expr should be an ast.Str instance constructed from
- the %-placeholders created by .explanation_param(). This will
- add the required code to format said string to .on_failure and
- return the ast.Name instance of the formatted string.
- """
- current = self.stack.pop()
- if self.stack:
- self.explanation_specifiers = self.stack[-1]
- keys = [ast.Str(key) for key in current.keys()]
- format_dict = ast.Dict(keys, list(current.values()))
- form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
- name = "@py_format" + str(next(self.variable_counter))
- self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
- return ast.Name(name, ast.Load())
- def generic_visit(self, node):
- """Handle expressions we don't have custom code for."""
- assert isinstance(node, ast.expr)
- res = self.assign(node)
- return res, self.explanation_param(self.display(res))
- def visit_Assert(self, assert_):
- """Return the AST statements to replace the ast.Assert instance.
- This rewrites the test of an assertion to provide
- intermediate values and replace it with an if statement which
- raises an assertion error with a detailed explanation in case
- the expression is false.
- """
- if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
- from _pytest.warning_types import PytestAssertRewriteWarning
- import warnings
- warnings.warn_explicit(
- PytestAssertRewriteWarning(
- "assertion is always true, perhaps remove parentheses?"
- ),
- category=None,
- filename=str(self.module_path),
- lineno=assert_.lineno,
- )
- self.statements = []
- self.variables = []
- self.variable_counter = itertools.count()
- self.stack = []
- self.on_failure = []
- self.push_format_context()
- # Rewrite assert into a bunch of statements.
- top_condition, explanation = self.visit(assert_.test)
- # If in a test module, check if directly asserting None, in order to warn [Issue #3191]
- if self.module_path is not None:
- self.statements.append(
- self.warn_about_none_ast(
- top_condition, module_path=self.module_path, lineno=assert_.lineno
- )
- )
- # Create failure message.
- body = self.on_failure
- negation = ast.UnaryOp(ast.Not(), top_condition)
- self.statements.append(ast.If(negation, body, []))
- if assert_.msg:
- assertmsg = self.helper("_format_assertmsg", assert_.msg)
- explanation = "\n>assert " + explanation
- else:
- assertmsg = ast.Str("")
- explanation = "assert " + explanation
- template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
- msg = self.pop_format_context(template)
- fmt = self.helper("_format_explanation", msg)
- err_name = ast.Name("AssertionError", ast.Load())
- exc = ast_Call(err_name, [fmt], [])
- if sys.version_info[0] >= 3:
- raise_ = ast.Raise(exc, None)
- else:
- raise_ = ast.Raise(exc, None, None)
- body.append(raise_)
- # Clear temporary variables by setting them to None.
- if self.variables:
- variables = [ast.Name(name, ast.Store()) for name in self.variables]
- clear = ast.Assign(variables, _NameConstant(None))
- self.statements.append(clear)
- # Fix line numbers.
- for stmt in self.statements:
- set_location(stmt, assert_.lineno, assert_.col_offset)
- return self.statements
- def warn_about_none_ast(self, node, module_path, lineno):
- """
- Returns an AST issuing a warning if the value of node is `None`.
- This is used to warn the user when asserting a function that asserts
- internally already.
- See issue #3191 for more details.
- """
- # Using parse because it is different between py2 and py3.
- AST_NONE = ast.parse("None").body[0].value
- val_is_none = ast.Compare(node, [ast.Is()], [AST_NONE])
- send_warning = ast.parse(
- """
- from _pytest.warning_types import PytestAssertRewriteWarning
- from warnings import warn_explicit
- warn_explicit(
- PytestAssertRewriteWarning('asserting the value None, please use "assert is None"'),
- category=None,
- filename={filename!r},
- lineno={lineno},
- )
- """.format(
- filename=module_path.strpath, lineno=lineno
- )
- ).body
- return ast.If(val_is_none, send_warning, [])
- def visit_Name(self, name):
- # Display the repr of the name if it's a local variable or
- # _should_repr_global_name() thinks it's acceptable.
- locs = ast_Call(self.builtin("locals"), [], [])
- inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
- dorepr = self.helper("_should_repr_global_name", name)
- test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
- expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
- return name, self.explanation_param(expr)
- def visit_BoolOp(self, boolop):
- res_var = self.variable()
- expl_list = self.assign(ast.List([], ast.Load()))
- app = ast.Attribute(expl_list, "append", ast.Load())
- is_or = int(isinstance(boolop.op, ast.Or))
- body = save = self.statements
- fail_save = self.on_failure
- levels = len(boolop.values) - 1
- self.push_format_context()
- # Process each operand, short-circuting if needed.
- for i, v in enumerate(boolop.values):
- if i:
- fail_inner = []
- # cond is set in a prior loop iteration below
- self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
- self.on_failure = fail_inner
- self.push_format_context()
- res, expl = self.visit(v)
- body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
- expl_format = self.pop_format_context(ast.Str(expl))
- call = ast_Call(app, [expl_format], [])
- self.on_failure.append(ast.Expr(call))
- if i < levels:
- cond = res
- if is_or:
- cond = ast.UnaryOp(ast.Not(), cond)
- inner = []
- self.statements.append(ast.If(cond, inner, []))
- self.statements = body = inner
- self.statements = save
- self.on_failure = fail_save
- expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
- expl = self.pop_format_context(expl_template)
- return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
- def visit_UnaryOp(self, unary):
- pattern = unary_map[unary.op.__class__]
- operand_res, operand_expl = self.visit(unary.operand)
- res = self.assign(ast.UnaryOp(unary.op, operand_res))
- return res, pattern % (operand_expl,)
- def visit_BinOp(self, binop):
- symbol = binop_map[binop.op.__class__]
- left_expr, left_expl = self.visit(binop.left)
- right_expr, right_expl = self.visit(binop.right)
- explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
- res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
- return res, explanation
- def visit_Call_35(self, call):
- """
- visit `ast.Call` nodes on Python3.5 and after
- """
- new_func, func_expl = self.visit(call.func)
- arg_expls = []
- new_args = []
- new_kwargs = []
- for arg in call.args:
- res, expl = self.visit(arg)
- arg_expls.append(expl)
- new_args.append(res)
- for keyword in call.keywords:
- res, expl = self.visit(keyword.value)
- new_kwargs.append(ast.keyword(keyword.arg, res))
- if keyword.arg:
- arg_expls.append(keyword.arg + "=" + expl)
- else: # **args have `arg` keywords with an .arg of None
- arg_expls.append("**" + expl)
- expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
- new_call = ast.Call(new_func, new_args, new_kwargs)
- res = self.assign(new_call)
- res_expl = self.explanation_param(self.display(res))
- outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
- return res, outer_expl
- def visit_Starred(self, starred):
- # From Python 3.5, a Starred node can appear in a function call
- res, expl = self.visit(starred.value)
- new_starred = ast.Starred(res, starred.ctx)
- return new_starred, "*" + expl
- def visit_Call_legacy(self, call):
- """
- visit `ast.Call nodes on 3.4 and below`
- """
- new_func, func_expl = self.visit(call.func)
- arg_expls = []
- new_args = []
- new_kwargs = []
- new_star = new_kwarg = None
- for arg in call.args:
- res, expl = self.visit(arg)
- new_args.append(res)
- arg_expls.append(expl)
- for keyword in call.keywords:
- res, expl = self.visit(keyword.value)
- new_kwargs.append(ast.keyword(keyword.arg, res))
- arg_expls.append(keyword.arg + "=" + expl)
- if call.starargs:
- new_star, expl = self.visit(call.starargs)
- arg_expls.append("*" + expl)
- if call.kwargs:
- new_kwarg, expl = self.visit(call.kwargs)
- arg_expls.append("**" + expl)
- expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
- new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
- res = self.assign(new_call)
- res_expl = self.explanation_param(self.display(res))
- outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
- return res, outer_expl
- # ast.Call signature changed on 3.5,
- # conditionally change which methods is named
- # visit_Call depending on Python version
- if sys.version_info >= (3, 5):
- visit_Call = visit_Call_35
- else:
- visit_Call = visit_Call_legacy
- def visit_Attribute(self, attr):
- if not isinstance(attr.ctx, ast.Load):
- return self.generic_visit(attr)
- value, value_expl = self.visit(attr.value)
- res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
- res_expl = self.explanation_param(self.display(res))
- pat = "%s\n{%s = %s.%s\n}"
- expl = pat % (res_expl, res_expl, value_expl, attr.attr)
- return res, expl
- def visit_Compare(self, comp):
- self.push_format_context()
- left_res, left_expl = self.visit(comp.left)
- if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
- left_expl = "({})".format(left_expl)
- res_variables = [self.variable() for i in range(len(comp.ops))]
- load_names = [ast.Name(v, ast.Load()) for v in res_variables]
- store_names = [ast.Name(v, ast.Store()) for v in res_variables]
- it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
- expls = []
- syms = []
- results = [left_res]
- for i, op, next_operand in it:
- next_res, next_expl = self.visit(next_operand)
- if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
- next_expl = "({})".format(next_expl)
- results.append(next_res)
- sym = binop_map[op.__class__]
- syms.append(ast.Str(sym))
- expl = "%s %s %s" % (left_expl, sym, next_expl)
- expls.append(ast.Str(expl))
- res_expr = ast.Compare(left_res, [op], [next_res])
- self.statements.append(ast.Assign([store_names[i]], res_expr))
- left_res, left_expl = next_res, next_expl
- # Use pytest.assertion.util._reprcompare if that's available.
- expl_call = self.helper(
- "_call_reprcompare",
- ast.Tuple(syms, ast.Load()),
- ast.Tuple(load_names, ast.Load()),
- ast.Tuple(expls, ast.Load()),
- ast.Tuple(results, ast.Load()),
- )
- if len(comp.ops) > 1:
- res = ast.BoolOp(ast.And(), load_names)
- else:
- res = load_names[0]
- return res, self.explanation_param(self.pop_format_context(expl_call))
|