12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217 |
- """Rewrite assertion AST to produce nice error messages."""
- import ast
- import errno
- import functools
- import importlib.abc
- import importlib.machinery
- import importlib.util
- import io
- import itertools
- import marshal
- import os
- import struct
- import sys
- import tokenize
- import types
- from collections import defaultdict
- from pathlib import Path
- from pathlib import PurePath
- from typing import Callable
- from typing import Dict
- from typing import IO
- from typing import Iterable
- from typing import Iterator
- from typing import List
- from typing import Optional
- from typing import Sequence
- from typing import Set
- from typing import Tuple
- from typing import TYPE_CHECKING
- from typing import Union
- from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
- from _pytest._io.saferepr import saferepr
- from _pytest._version import version
- from _pytest.assertion import util
- from _pytest.assertion.util import ( # noqa: F401
- format_explanation as _format_explanation,
- )
- from _pytest.config import Config
- from _pytest.main import Session
- from _pytest.pathlib import absolutepath
- from _pytest.pathlib import fnmatch_ex
- from _pytest.stash import StashKey
- if TYPE_CHECKING:
- from _pytest.assertion import AssertionState
- if sys.version_info >= (3, 8):
- namedExpr = ast.NamedExpr
- astNameConstant = ast.Constant
- astStr = ast.Constant
- astNum = ast.Constant
- else:
- namedExpr = ast.Expr
- astNameConstant = ast.NameConstant
- astStr = ast.Str
- astNum = ast.Num
- class Sentinel:
- pass
- assertstate_key = StashKey["AssertionState"]()
- # pytest caches rewritten pycs in pycache dirs
- PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
- PYC_EXT = ".py" + (__debug__ and "c" or "o")
- PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
- # Special marker that denotes we have just left a scope definition
- _SCOPE_END_MARKER = Sentinel()
- class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
- """PEP302/PEP451 import hook which rewrites asserts."""
- def __init__(self, config: Config) -> None:
- self.config = config
- try:
- self.fnpats = config.getini("python_files")
- except ValueError:
- self.fnpats = ["test_*.py", "*_test.py"]
- self.session: Optional[Session] = None
- self._rewritten_names: Dict[str, Path] = {}
- self._must_rewrite: Set[str] = set()
- # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
- # which might result in infinite recursion (#3506)
- self._writing_pyc = False
- self._basenames_to_check_rewrite = {"conftest"}
- self._marked_for_rewrite_cache: Dict[str, bool] = {}
- self._session_paths_checked = False
- def set_session(self, session: Optional[Session]) -> None:
- self.session = session
- self._session_paths_checked = False
- # Indirection so we can mock calls to find_spec originated from the hook during testing
- _find_spec = importlib.machinery.PathFinder.find_spec
- def find_spec(
- self,
- name: str,
- path: Optional[Sequence[Union[str, bytes]]] = None,
- target: Optional[types.ModuleType] = None,
- ) -> Optional[importlib.machinery.ModuleSpec]:
- if self._writing_pyc:
- return None
- state = self.config.stash[assertstate_key]
- if self._early_rewrite_bailout(name, state):
- return None
- state.trace("find_module called for: %s" % name)
- # Type ignored because mypy is confused about the `self` binding here.
- spec = self._find_spec(name, path) # type: ignore
- if (
- # the import machinery could not find a file to import
- spec is None
- # this is a namespace package (without `__init__.py`)
- # there's nothing to rewrite there
- or spec.origin is None
- # we can only rewrite source files
- or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
- # if the file doesn't exist, we can't rewrite it
- or not os.path.exists(spec.origin)
- ):
- return None
- else:
- fn = spec.origin
- if not self._should_rewrite(name, fn, state):
- return None
- return importlib.util.spec_from_file_location(
- name,
- fn,
- loader=self,
- submodule_search_locations=spec.submodule_search_locations,
- )
- def create_module(
- self, spec: importlib.machinery.ModuleSpec
- ) -> Optional[types.ModuleType]:
- return None # default behaviour is fine
- def exec_module(self, module: types.ModuleType) -> None:
- assert module.__spec__ is not None
- assert module.__spec__.origin is not None
- fn = Path(module.__spec__.origin)
- state = self.config.stash[assertstate_key]
- self._rewritten_names[module.__name__] = fn
- # The requested module looks like a test file, so rewrite it. This is
- # the most magical part of the process: load the source, rewrite the
- # asserts, and load the rewritten source. We also cache the rewritten
- # module code in a special pyc. We must be aware of the possibility of
- # concurrent pytest processes rewriting and loading pycs. To avoid
- # tricky race conditions, we maintain the following invariant: The
- # cached pyc is always a complete, valid pyc. Operations on it must be
- # atomic. POSIX's atomic rename comes in handy.
- write = not sys.dont_write_bytecode
- cache_dir = get_cache_dir(fn)
- if write:
- ok = try_makedirs(cache_dir)
- if not ok:
- write = False
- state.trace(f"read only directory: {cache_dir}")
- cache_name = fn.name[:-3] + PYC_TAIL
- pyc = cache_dir / cache_name
- # Notice that even if we're in a read-only directory, I'm going
- # to check for a cached pyc. This may not be optimal...
- co = _read_pyc(fn, pyc, state.trace)
- if co is None:
- state.trace(f"rewriting {fn!r}")
- source_stat, co = _rewrite_test(fn, self.config)
- if write:
- self._writing_pyc = True
- try:
- _write_pyc(state, co, source_stat, pyc)
- finally:
- self._writing_pyc = False
- else:
- state.trace(f"found cached rewritten pyc for {fn}")
- exec(co, module.__dict__)
- def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
- """A fast way to get out of rewriting modules.
- Profiling has shown that the call to PathFinder.find_spec (inside of
- the find_spec from this class) is a major slowdown, so, this method
- tries to filter what we're sure won't be rewritten before getting to
- it.
- """
- if self.session is not None and not self._session_paths_checked:
- self._session_paths_checked = True
- for initial_path in self.session._initialpaths:
- # Make something as c:/projects/my_project/path.py ->
- # ['c:', 'projects', 'my_project', 'path.py']
- parts = str(initial_path).split(os.sep)
- # add 'path' to basenames to be checked.
- self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
- # Note: conftest already by default in _basenames_to_check_rewrite.
- parts = name.split(".")
- if parts[-1] in self._basenames_to_check_rewrite:
- return False
- # For matching the name it must be as if it was a filename.
- path = PurePath(*parts).with_suffix(".py")
- for pat in self.fnpats:
- # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
- # on the name alone because we need to match against the full path
- if os.path.dirname(pat):
- return False
- if fnmatch_ex(pat, path):
- return False
- if self._is_marked_for_rewrite(name, state):
- return False
- state.trace(f"early skip of rewriting module: {name}")
- return True
- def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
- # always rewrite conftest files
- if os.path.basename(fn) == "conftest.py":
- state.trace(f"rewriting conftest file: {fn!r}")
- return True
- if self.session is not None:
- if self.session.isinitpath(absolutepath(fn)):
- state.trace(f"matched test file (was specified on cmdline): {fn!r}")
- return True
- # modules not passed explicitly on the command line are only
- # rewritten if they match the naming convention for test files
- fn_path = PurePath(fn)
- for pat in self.fnpats:
- if fnmatch_ex(pat, fn_path):
- state.trace(f"matched test file {fn!r}")
- return True
- return self._is_marked_for_rewrite(name, state)
- def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
- try:
- return self._marked_for_rewrite_cache[name]
- except KeyError:
- for marked in self._must_rewrite:
- if name == marked or name.startswith(marked + "."):
- state.trace(f"matched marked file {name!r} (from {marked!r})")
- self._marked_for_rewrite_cache[name] = True
- return True
- self._marked_for_rewrite_cache[name] = False
- return False
- def mark_rewrite(self, *names: str) -> None:
- """Mark import names as needing to be rewritten.
- The named module or package as well as any nested modules will
- be rewritten on import.
- """
- already_imported = (
- set(names).intersection(sys.modules).difference(self._rewritten_names)
- )
- for name in already_imported:
- mod = sys.modules[name]
- if not AssertionRewriter.is_rewrite_disabled(
- mod.__doc__ or ""
- ) and not isinstance(mod.__loader__, type(self)):
- self._warn_already_imported(name)
- self._must_rewrite.update(names)
- self._marked_for_rewrite_cache.clear()
- def _warn_already_imported(self, name: str) -> None:
- from _pytest.warning_types import PytestAssertRewriteWarning
- self.config.issue_config_time_warning(
- PytestAssertRewriteWarning(
- "Module already imported so cannot be rewritten: %s" % name
- ),
- stacklevel=5,
- )
- def get_data(self, pathname: Union[str, bytes]) -> bytes:
- """Optional PEP302 get_data API."""
- with open(pathname, "rb") as f:
- return f.read()
- if sys.version_info >= (3, 10):
- if sys.version_info >= (3, 12):
- from importlib.resources.abc import TraversableResources
- else:
- from importlib.abc import TraversableResources
- def get_resource_reader(self, name: str) -> TraversableResources: # type: ignore
- if sys.version_info < (3, 11):
- from importlib.readers import FileReader
- else:
- from importlib.resources.readers import FileReader
- return FileReader( # type:ignore[no-any-return]
- types.SimpleNamespace(path=self._rewritten_names[name])
- )
- def _write_pyc_fp(
- fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
- ) -> None:
- # Technically, we don't have to have the same pyc format as
- # (C)Python, since these "pycs" should never be seen by builtin
- # import. However, there's little reason to deviate.
- fp.write(importlib.util.MAGIC_NUMBER)
- # https://www.python.org/dev/peps/pep-0552/
- flags = b"\x00\x00\x00\x00"
- fp.write(flags)
- # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
- mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
- size = source_stat.st_size & 0xFFFFFFFF
- # "<LL" stands for 2 unsigned longs, little-endian.
- fp.write(struct.pack("<LL", mtime, size))
- fp.write(marshal.dumps(co))
- def _write_pyc(
- state: "AssertionState",
- co: types.CodeType,
- source_stat: os.stat_result,
- pyc: Path,
- ) -> bool:
- proc_pyc = f"{pyc}.{os.getpid()}"
- try:
- with open(proc_pyc, "wb") as fp:
- _write_pyc_fp(fp, source_stat, co)
- except OSError as e:
- state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
- return False
- try:
- os.replace(proc_pyc, pyc)
- except OSError as e:
- state.trace(f"error writing pyc file at {pyc}: {e}")
- # we ignore any failure to write the cache file
- # there are many reasons, permission-denied, pycache dir being a
- # file etc.
- return False
- return True
- def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
- """Read and rewrite *fn* and return the code object."""
- stat = os.stat(fn)
- source = fn.read_bytes()
- strfn = str(fn)
- tree = ast.parse(source, filename=strfn)
- rewrite_asserts(tree, source, strfn, config)
- co = compile(tree, strfn, "exec", dont_inherit=True)
- return stat, co
- def _read_pyc(
- source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
- ) -> Optional[types.CodeType]:
- """Possibly read a pytest pyc containing rewritten code.
- Return rewritten code if successful or None if not.
- """
- try:
- fp = open(pyc, "rb")
- except OSError:
- return None
- with fp:
- try:
- stat_result = os.stat(source)
- mtime = int(stat_result.st_mtime)
- size = stat_result.st_size
- data = fp.read(16)
- except OSError as e:
- trace(f"_read_pyc({source}): OSError {e}")
- return None
- # Check for invalid or out of date pyc file.
- if len(data) != (16):
- trace("_read_pyc(%s): invalid pyc (too short)" % source)
- return None
- if data[:4] != importlib.util.MAGIC_NUMBER:
- trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
- return None
- if data[4:8] != b"\x00\x00\x00\x00":
- trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
- return None
- mtime_data = data[8:12]
- if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
- trace("_read_pyc(%s): out of date" % source)
- return None
- size_data = data[12:16]
- if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
- trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
- return None
- try:
- co = marshal.load(fp)
- except Exception as e:
- trace(f"_read_pyc({source}): marshal.load error {e}")
- return None
- if not isinstance(co, types.CodeType):
- trace("_read_pyc(%s): not a code object" % source)
- return None
- return co
- def rewrite_asserts(
- mod: ast.Module,
- source: bytes,
- module_path: Optional[str] = None,
- config: Optional[Config] = None,
- ) -> None:
- """Rewrite the assert statements in mod."""
- AssertionRewriter(module_path, config, source).run(mod)
- def _saferepr(obj: object) -> str:
- r"""Get a safe repr of an object for assertion error messages.
- The assertion formatting (util.format_explanation()) requires
- newlines to be escaped since they are a special character for it.
- Normally assertion.util.format_explanation() does this but for a
- custom repr it is possible to contain one of the special escape
- sequences, especially '\n{' and '\n}' are likely to be present in
- JSON reprs.
- """
- maxsize = _get_maxsize_for_saferepr(util._config)
- return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
- def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
- """Get `maxsize` configuration for saferepr based on the given config object."""
- verbosity = config.getoption("verbose") if config is not None else 0
- if verbosity >= 2:
- return None
- if verbosity >= 1:
- return DEFAULT_REPR_MAX_SIZE * 10
- return DEFAULT_REPR_MAX_SIZE
- def _format_assertmsg(obj: object) -> str:
- r"""Format the custom assertion message given.
- For strings this simply replaces newlines with '\n~' so that
- util.format_explanation() will preserve them instead of escaping
- newlines. For other objects saferepr() is used first.
- """
- # reprlib appears to have a bug which means that if a string
- # contains a newline it gets escaped, however if an object has a
- # .__repr__() which contains newlines it does not get escaped.
- # However in either case we want to preserve the newline.
- replaces = [("\n", "\n~"), ("%", "%%")]
- if not isinstance(obj, str):
- obj = saferepr(obj)
- replaces.append(("\\n", "\n~"))
- for r1, r2 in replaces:
- obj = obj.replace(r1, r2)
- return obj
- def _should_repr_global_name(obj: object) -> bool:
- if callable(obj):
- return False
- try:
- return not hasattr(obj, "__name__")
- except Exception:
- return True
- def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
- explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
- return explanation.replace("%", "%%")
- def _call_reprcompare(
- ops: Sequence[str],
- results: Sequence[bool],
- expls: Sequence[str],
- each_obj: Sequence[object],
- ) -> str:
- for i, res, expl in zip(range(len(ops)), results, expls):
- try:
- done = not res
- except Exception:
- done = True
- if done:
- break
- if util._reprcompare is not None:
- custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
- if custom is not None:
- return custom
- return expl
- def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
- if util._assertion_pass is not None:
- util._assertion_pass(lineno, orig, expl)
- def _check_if_assertion_pass_impl() -> bool:
- """Check if any plugins implement the pytest_assertion_pass hook
- in order not to generate explanation unnecessarily (might be expensive)."""
- return True if util._assertion_pass else False
- UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
- BINOP_MAP = {
- ast.BitOr: "|",
- ast.BitXor: "^",
- ast.BitAnd: "&",
- ast.LShift: "<<",
- ast.RShift: ">>",
- ast.Add: "+",
- ast.Sub: "-",
- ast.Mult: "*",
- ast.Div: "/",
- ast.FloorDiv: "//",
- ast.Mod: "%%", # escaped for string formatting
- ast.Eq: "==",
- ast.NotEq: "!=",
- ast.Lt: "<",
- ast.LtE: "<=",
- ast.Gt: ">",
- ast.GtE: ">=",
- ast.Pow: "**",
- ast.Is: "is",
- ast.IsNot: "is not",
- ast.In: "in",
- ast.NotIn: "not in",
- ast.MatMult: "@",
- }
- def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
- """Recursively yield node and all its children in depth-first order."""
- yield node
- for child in ast.iter_child_nodes(node):
- yield from traverse_node(child)
- @functools.lru_cache(maxsize=1)
- def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
- """Return a mapping from {lineno: "assertion test expression"}."""
- ret: Dict[int, str] = {}
- depth = 0
- lines: List[str] = []
- assert_lineno: Optional[int] = None
- seen_lines: Set[int] = set()
- def _write_and_reset() -> None:
- nonlocal depth, lines, assert_lineno, seen_lines
- assert assert_lineno is not None
- ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
- depth = 0
- lines = []
- assert_lineno = None
- seen_lines = set()
- tokens = tokenize.tokenize(io.BytesIO(src).readline)
- for tp, source, (lineno, offset), _, line in tokens:
- if tp == tokenize.NAME and source == "assert":
- assert_lineno = lineno
- elif assert_lineno is not None:
- # keep track of depth for the assert-message `,` lookup
- if tp == tokenize.OP and source in "([{":
- depth += 1
- elif tp == tokenize.OP and source in ")]}":
- depth -= 1
- if not lines:
- lines.append(line[offset:])
- seen_lines.add(lineno)
- # a non-nested comma separates the expression from the message
- elif depth == 0 and tp == tokenize.OP and source == ",":
- # one line assert with message
- if lineno in seen_lines and len(lines) == 1:
- offset_in_trimmed = offset + len(lines[-1]) - len(line)
- lines[-1] = lines[-1][:offset_in_trimmed]
- # multi-line assert with message
- elif lineno in seen_lines:
- lines[-1] = lines[-1][:offset]
- # multi line assert with escapd newline before message
- else:
- lines.append(line[:offset])
- _write_and_reset()
- elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
- _write_and_reset()
- elif lines and lineno not in seen_lines:
- lines.append(line)
- seen_lines.add(lineno)
- return ret
- def _get_ast_constant_value(value: astStr) -> object:
- if sys.version_info >= (3, 8):
- return value.value
- else:
- return value.s
- class AssertionRewriter(ast.NodeVisitor):
- """Assertion rewriting implementation.
- The main entrypoint is to call .run() with an ast.Module instance,
- this will then find all the assert statements and rewrite them to
- provide intermediate values and a detailed assertion error. See
- http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
- for an overview of how this works.
- The entry point here is .run() which will iterate over all the
- statements in an ast.Module and for each ast.Assert statement it
- finds call .visit() with it. Then .visit_Assert() takes over and
- is responsible for creating new ast statements to replace the
- original assert statement: it rewrites the test of an assertion
- to provide intermediate values and replace it with an if statement
- which raises an assertion error with a detailed explanation in
- case the expression is false and calls pytest_assertion_pass hook
- if expression is true.
- For this .visit_Assert() uses the visitor pattern to visit all the
- AST nodes of the ast.Assert.test field, each visit call returning
- an AST node and the corresponding explanation string. During this
- state is kept in several instance attributes:
- :statements: All the AST statements which will replace the assert
- statement.
- :variables: This is populated by .variable() with each variable
- used by the statements so that they can all be set to None at
- the end of the statements.
- :variable_counter: Counter to create new unique variables needed
- by statements. Variables are created using .variable() and
- have the form of "@py_assert0".
- :expl_stmts: The AST statements which will be executed to get
- data from the assertion. This is the code which will construct
- the detailed assertion message that is used in the AssertionError
- or for the pytest_assertion_pass hook.
- :explanation_specifiers: A dict filled by .explanation_param()
- with %-formatting placeholders and their corresponding
- expressions to use in the building of an assertion message.
- This is used by .pop_format_context() to build a message.
- :stack: A stack of the explanation_specifiers dicts maintained by
- .push_format_context() and .pop_format_context() which allows
- to build another %-formatted string while already building one.
- :scope: A tuple containing the current scope used for variables_overwrite.
- :variables_overwrite: A dict filled with references to variables
- that change value within an assert. This happens when a variable is
- reassigned with the walrus operator
- This state, except the variables_overwrite, is reset on every new assert
- statement visited and used by the other visitors.
- """
- def __init__(
- self, module_path: Optional[str], config: Optional[Config], source: bytes
- ) -> None:
- super().__init__()
- self.module_path = module_path
- self.config = config
- if config is not None:
- self.enable_assertion_pass_hook = config.getini(
- "enable_assertion_pass_hook"
- )
- else:
- self.enable_assertion_pass_hook = False
- self.source = source
- self.scope: tuple[ast.AST, ...] = ()
- self.variables_overwrite: defaultdict[
- tuple[ast.AST, ...], Dict[str, str]
- ] = defaultdict(dict)
- def run(self, mod: ast.Module) -> None:
- """Find all assert statements in *mod* and rewrite them."""
- if not mod.body:
- # Nothing to do.
- return
- # We'll insert some special imports at the top of the module, but after any
- # docstrings and __future__ imports, so first figure out where that is.
- doc = getattr(mod, "docstring", None)
- expect_docstring = doc is None
- if doc is not None and self.is_rewrite_disabled(doc):
- return
- pos = 0
- item = None
- for item in mod.body:
- if (
- expect_docstring
- and isinstance(item, ast.Expr)
- and isinstance(item.value, astStr)
- and isinstance(_get_ast_constant_value(item.value), str)
- ):
- doc = _get_ast_constant_value(item.value)
- assert isinstance(doc, str)
- if self.is_rewrite_disabled(doc):
- return
- expect_docstring = False
- elif (
- isinstance(item, ast.ImportFrom)
- and item.level == 0
- and item.module == "__future__"
- ):
- pass
- else:
- break
- pos += 1
- # Special case: for a decorated function, set the lineno to that of the
- # first decorator, not the `def`. Issue #4984.
- if isinstance(item, ast.FunctionDef) and item.decorator_list:
- lineno = item.decorator_list[0].lineno
- else:
- lineno = item.lineno
- # Now actually insert the special imports.
- if sys.version_info >= (3, 10):
- aliases = [
- ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
- ast.alias(
- "_pytest.assertion.rewrite",
- "@pytest_ar",
- lineno=lineno,
- col_offset=0,
- ),
- ]
- else:
- aliases = [
- ast.alias("builtins", "@py_builtins"),
- ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
- ]
- imports = [
- ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
- ]
- mod.body[pos:pos] = imports
- # Collect asserts.
- self.scope = (mod,)
- nodes: List[Union[ast.AST, Sentinel]] = [mod]
- while nodes:
- node = nodes.pop()
- if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
- self.scope = tuple((*self.scope, node))
- nodes.append(_SCOPE_END_MARKER)
- if node == _SCOPE_END_MARKER:
- self.scope = self.scope[:-1]
- continue
- assert isinstance(node, ast.AST)
- for name, field in ast.iter_fields(node):
- if isinstance(field, list):
- new: List[ast.AST] = []
- for i, child in enumerate(field):
- if isinstance(child, ast.Assert):
- # Transform assert.
- new.extend(self.visit(child))
- else:
- new.append(child)
- if isinstance(child, ast.AST):
- nodes.append(child)
- setattr(node, name, new)
- elif (
- isinstance(field, ast.AST)
- # Don't recurse into expressions as they can't contain
- # asserts.
- and not isinstance(field, ast.expr)
- ):
- nodes.append(field)
- @staticmethod
- def is_rewrite_disabled(docstring: str) -> bool:
- return "PYTEST_DONT_REWRITE" in docstring
- def variable(self) -> str:
- """Get a new variable."""
- # Use a character invalid in python identifiers to avoid clashing.
- name = "@py_assert" + str(next(self.variable_counter))
- self.variables.append(name)
- return name
- def assign(self, expr: ast.expr) -> ast.Name:
- """Give *expr* a name."""
- name = self.variable()
- self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
- return ast.Name(name, ast.Load())
- def display(self, expr: ast.expr) -> ast.expr:
- """Call saferepr on the expression."""
- return self.helper("_saferepr", expr)
- def helper(self, name: str, *args: ast.expr) -> ast.expr:
- """Call a helper in this module."""
- py_name = ast.Name("@pytest_ar", ast.Load())
- attr = ast.Attribute(py_name, name, ast.Load())
- return ast.Call(attr, list(args), [])
- def builtin(self, name: str) -> ast.Attribute:
- """Return the builtin called *name*."""
- builtin_name = ast.Name("@py_builtins", ast.Load())
- return ast.Attribute(builtin_name, name, ast.Load())
- def explanation_param(self, expr: ast.expr) -> str:
- """Return a new named %-formatting placeholder for expr.
- This creates a %-formatting placeholder for expr in the
- current formatting context, e.g. ``%(py0)s``. The placeholder
- and expr are placed in the current format context so that it
- can be used on the next call to .pop_format_context().
- """
- specifier = "py" + str(next(self.variable_counter))
- self.explanation_specifiers[specifier] = expr
- return "%(" + specifier + ")s"
- def push_format_context(self) -> None:
- """Create a new formatting context.
- The format context is used for when an explanation wants to
- have a variable value formatted in the assertion message. In
- this case the value required can be added using
- .explanation_param(). Finally .pop_format_context() is used
- to format a string of %-formatted values as added by
- .explanation_param().
- """
- self.explanation_specifiers: Dict[str, ast.expr] = {}
- self.stack.append(self.explanation_specifiers)
- def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
- """Format the %-formatted string with current format context.
- The expl_expr should be an str ast.expr instance constructed from
- the %-placeholders created by .explanation_param(). This will
- add the required code to format said string to .expl_stmts and
- return the ast.Name instance of the formatted string.
- """
- current = self.stack.pop()
- if self.stack:
- self.explanation_specifiers = self.stack[-1]
- keys = [astStr(key) for key in current.keys()]
- format_dict = ast.Dict(keys, list(current.values()))
- form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
- name = "@py_format" + str(next(self.variable_counter))
- if self.enable_assertion_pass_hook:
- self.format_variables.append(name)
- self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
- return ast.Name(name, ast.Load())
- def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
- """Handle expressions we don't have custom code for."""
- assert isinstance(node, ast.expr)
- res = self.assign(node)
- return res, self.explanation_param(self.display(res))
- def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
- """Return the AST statements to replace the ast.Assert instance.
- This rewrites the test of an assertion to provide
- intermediate values and replace it with an if statement which
- raises an assertion error with a detailed explanation in case
- the expression is false.
- """
- if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
- from _pytest.warning_types import PytestAssertRewriteWarning
- import warnings
- # TODO: This assert should not be needed.
- assert self.module_path is not None
- warnings.warn_explicit(
- PytestAssertRewriteWarning(
- "assertion is always true, perhaps remove parentheses?"
- ),
- category=None,
- filename=self.module_path,
- lineno=assert_.lineno,
- )
- self.statements: List[ast.stmt] = []
- self.variables: List[str] = []
- self.variable_counter = itertools.count()
- if self.enable_assertion_pass_hook:
- self.format_variables: List[str] = []
- self.stack: List[Dict[str, ast.expr]] = []
- self.expl_stmts: List[ast.stmt] = []
- self.push_format_context()
- # Rewrite assert into a bunch of statements.
- top_condition, explanation = self.visit(assert_.test)
- negation = ast.UnaryOp(ast.Not(), top_condition)
- if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
- msg = self.pop_format_context(astStr(explanation))
- # Failed
- if assert_.msg:
- assertmsg = self.helper("_format_assertmsg", assert_.msg)
- gluestr = "\n>assert "
- else:
- assertmsg = astStr("")
- gluestr = "assert "
- err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg)
- err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
- err_name = ast.Name("AssertionError", ast.Load())
- fmt = self.helper("_format_explanation", err_msg)
- exc = ast.Call(err_name, [fmt], [])
- raise_ = ast.Raise(exc, None)
- statements_fail = []
- statements_fail.extend(self.expl_stmts)
- statements_fail.append(raise_)
- # Passed
- fmt_pass = self.helper("_format_explanation", msg)
- orig = _get_assertion_exprs(self.source)[assert_.lineno]
- hook_call_pass = ast.Expr(
- self.helper(
- "_call_assertion_pass",
- astNum(assert_.lineno),
- astStr(orig),
- fmt_pass,
- )
- )
- # If any hooks implement assert_pass hook
- hook_impl_test = ast.If(
- self.helper("_check_if_assertion_pass_impl"),
- self.expl_stmts + [hook_call_pass],
- [],
- )
- statements_pass = [hook_impl_test]
- # Test for assertion condition
- main_test = ast.If(negation, statements_fail, statements_pass)
- self.statements.append(main_test)
- if self.format_variables:
- variables = [
- ast.Name(name, ast.Store()) for name in self.format_variables
- ]
- clear_format = ast.Assign(variables, astNameConstant(None))
- self.statements.append(clear_format)
- else: # Original assertion rewriting
- # Create failure message.
- body = self.expl_stmts
- self.statements.append(ast.If(negation, body, []))
- if assert_.msg:
- assertmsg = self.helper("_format_assertmsg", assert_.msg)
- explanation = "\n>assert " + explanation
- else:
- assertmsg = astStr("")
- explanation = "assert " + explanation
- template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation))
- msg = self.pop_format_context(template)
- fmt = self.helper("_format_explanation", msg)
- err_name = ast.Name("AssertionError", ast.Load())
- exc = ast.Call(err_name, [fmt], [])
- raise_ = ast.Raise(exc, None)
- body.append(raise_)
- # Clear temporary variables by setting them to None.
- if self.variables:
- variables = [ast.Name(name, ast.Store()) for name in self.variables]
- clear = ast.Assign(variables, astNameConstant(None))
- self.statements.append(clear)
- # Fix locations (line numbers/column offsets).
- for stmt in self.statements:
- for node in traverse_node(stmt):
- ast.copy_location(node, assert_)
- return self.statements
- def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
- # This method handles the 'walrus operator' repr of the target
- # name if it's a local variable or _should_repr_global_name()
- # thinks it's acceptable.
- locs = ast.Call(self.builtin("locals"), [], [])
- target_id = name.target.id # type: ignore[attr-defined]
- inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs])
- dorepr = self.helper("_should_repr_global_name", name)
- test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
- expr = ast.IfExp(test, self.display(name), astStr(target_id))
- return name, self.explanation_param(expr)
- def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
- # Display the repr of the name if it's a local variable or
- # _should_repr_global_name() thinks it's acceptable.
- locs = ast.Call(self.builtin("locals"), [], [])
- inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs])
- dorepr = self.helper("_should_repr_global_name", name)
- test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
- expr = ast.IfExp(test, self.display(name), astStr(name.id))
- return name, self.explanation_param(expr)
- def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
- res_var = self.variable()
- expl_list = self.assign(ast.List([], ast.Load()))
- app = ast.Attribute(expl_list, "append", ast.Load())
- is_or = int(isinstance(boolop.op, ast.Or))
- body = save = self.statements
- fail_save = self.expl_stmts
- levels = len(boolop.values) - 1
- self.push_format_context()
- # Process each operand, short-circuiting if needed.
- for i, v in enumerate(boolop.values):
- if i:
- fail_inner: List[ast.stmt] = []
- # cond is set in a prior loop iteration below
- self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
- self.expl_stmts = fail_inner
- # Check if the left operand is a namedExpr and the value has already been visited
- if (
- isinstance(v, ast.Compare)
- and isinstance(v.left, namedExpr)
- and v.left.target.id
- in [
- ast_expr.id
- for ast_expr in boolop.values[:i]
- if hasattr(ast_expr, "id")
- ]
- ):
- pytest_temp = self.variable()
- self.variables_overwrite[self.scope][
- v.left.target.id
- ] = v.left # type:ignore[assignment]
- v.left.target.id = pytest_temp
- self.push_format_context()
- res, expl = self.visit(v)
- body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
- expl_format = self.pop_format_context(astStr(expl))
- call = ast.Call(app, [expl_format], [])
- self.expl_stmts.append(ast.Expr(call))
- if i < levels:
- cond: ast.expr = res
- if is_or:
- cond = ast.UnaryOp(ast.Not(), cond)
- inner: List[ast.stmt] = []
- self.statements.append(ast.If(cond, inner, []))
- self.statements = body = inner
- self.statements = save
- self.expl_stmts = fail_save
- expl_template = self.helper("_format_boolop", expl_list, astNum(is_or))
- expl = self.pop_format_context(expl_template)
- return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
- def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
- pattern = UNARY_MAP[unary.op.__class__]
- operand_res, operand_expl = self.visit(unary.operand)
- res = self.assign(ast.UnaryOp(unary.op, operand_res))
- return res, pattern % (operand_expl,)
- def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
- symbol = BINOP_MAP[binop.op.__class__]
- left_expr, left_expl = self.visit(binop.left)
- right_expr, right_expl = self.visit(binop.right)
- explanation = f"({left_expl} {symbol} {right_expl})"
- res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
- return res, explanation
- def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
- new_func, func_expl = self.visit(call.func)
- arg_expls = []
- new_args = []
- new_kwargs = []
- for arg in call.args:
- if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
- self.scope, {}
- ):
- arg = self.variables_overwrite[self.scope][
- arg.id
- ] # type:ignore[assignment]
- res, expl = self.visit(arg)
- arg_expls.append(expl)
- new_args.append(res)
- for keyword in call.keywords:
- if isinstance(
- keyword.value, ast.Name
- ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
- keyword.value = self.variables_overwrite[self.scope][
- keyword.value.id
- ] # type:ignore[assignment]
- res, expl = self.visit(keyword.value)
- new_kwargs.append(ast.keyword(keyword.arg, res))
- if keyword.arg:
- arg_expls.append(keyword.arg + "=" + expl)
- else: # **args have `arg` keywords with an .arg of None
- arg_expls.append("**" + expl)
- expl = "{}({})".format(func_expl, ", ".join(arg_expls))
- new_call = ast.Call(new_func, new_args, new_kwargs)
- res = self.assign(new_call)
- res_expl = self.explanation_param(self.display(res))
- outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
- return res, outer_expl
- def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
- # A Starred node can appear in a function call.
- res, expl = self.visit(starred.value)
- new_starred = ast.Starred(res, starred.ctx)
- return new_starred, "*" + expl
- def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
- if not isinstance(attr.ctx, ast.Load):
- return self.generic_visit(attr)
- value, value_expl = self.visit(attr.value)
- res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
- res_expl = self.explanation_param(self.display(res))
- pat = "%s\n{%s = %s.%s\n}"
- expl = pat % (res_expl, res_expl, value_expl, attr.attr)
- return res, expl
- def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
- self.push_format_context()
- # We first check if we have overwritten a variable in the previous assert
- if isinstance(
- comp.left, ast.Name
- ) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
- comp.left = self.variables_overwrite[self.scope][
- comp.left.id
- ] # type:ignore[assignment]
- if isinstance(comp.left, namedExpr):
- self.variables_overwrite[self.scope][
- comp.left.target.id
- ] = comp.left # type:ignore[assignment]
- left_res, left_expl = self.visit(comp.left)
- if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
- left_expl = f"({left_expl})"
- res_variables = [self.variable() for i in range(len(comp.ops))]
- load_names = [ast.Name(v, ast.Load()) for v in res_variables]
- store_names = [ast.Name(v, ast.Store()) for v in res_variables]
- it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
- expls = []
- syms = []
- results = [left_res]
- for i, op, next_operand in it:
- if (
- isinstance(next_operand, namedExpr)
- and isinstance(left_res, ast.Name)
- and next_operand.target.id == left_res.id
- ):
- next_operand.target.id = self.variable()
- self.variables_overwrite[self.scope][
- left_res.id
- ] = next_operand # type:ignore[assignment]
- next_res, next_expl = self.visit(next_operand)
- if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
- next_expl = f"({next_expl})"
- results.append(next_res)
- sym = BINOP_MAP[op.__class__]
- syms.append(astStr(sym))
- expl = f"{left_expl} {sym} {next_expl}"
- expls.append(astStr(expl))
- res_expr = ast.Compare(left_res, [op], [next_res])
- self.statements.append(ast.Assign([store_names[i]], res_expr))
- left_res, left_expl = next_res, next_expl
- # Use pytest.assertion.util._reprcompare if that's available.
- expl_call = self.helper(
- "_call_reprcompare",
- ast.Tuple(syms, ast.Load()),
- ast.Tuple(load_names, ast.Load()),
- ast.Tuple(expls, ast.Load()),
- ast.Tuple(results, ast.Load()),
- )
- if len(comp.ops) > 1:
- res: ast.expr = ast.BoolOp(ast.And(), load_names)
- else:
- res = load_names[0]
- return res, self.explanation_param(self.pop_format_context(expl_call))
- def try_makedirs(cache_dir: Path) -> bool:
- """Attempt to create the given directory and sub-directories exist.
- Returns True if successful or if it already exists.
- """
- try:
- os.makedirs(cache_dir, exist_ok=True)
- except (FileNotFoundError, NotADirectoryError, FileExistsError):
- # One of the path components was not a directory:
- # - we're in a zip file
- # - it is a file
- return False
- except PermissionError:
- return False
- except OSError as e:
- # as of now, EROFS doesn't have an equivalent OSError-subclass
- if e.errno == errno.EROFS:
- return False
- raise
- return True
- def get_cache_dir(file_path: Path) -> Path:
- """Return the cache directory to write .pyc files for the given .py file path."""
- if sys.version_info >= (3, 8) and sys.pycache_prefix:
- # given:
- # prefix = '/tmp/pycs'
- # path = '/home/user/proj/test_app.py'
- # we want:
- # '/tmp/pycs/home/user/proj'
- return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
- else:
- # classic pycache directory
- return file_path.parent / "__pycache__"
|