rewrite.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072
  1. # -*- coding: utf-8 -*-
  2. """Rewrite assertion AST to produce nice error messages"""
  3. from __future__ import absolute_import
  4. from __future__ import division
  5. from __future__ import print_function
  6. import warnings
  7. warnings.filterwarnings("ignore", category=DeprecationWarning, module="_pytest.assertion.rewrite")
  8. import ast
  9. import errno
  10. import imp
  11. import itertools
  12. import marshal
  13. import os
  14. import re
  15. import string
  16. import struct
  17. import sys
  18. import types
  19. import atomicwrites
  20. import py
  21. import six
  22. from _pytest._io.saferepr import saferepr
  23. from _pytest.assertion import util
  24. from _pytest.assertion.util import ( # noqa: F401
  25. format_explanation as _format_explanation,
  26. )
  27. from _pytest.compat import spec_from_file_location
  28. from _pytest.pathlib import fnmatch_ex
  29. from _pytest.pathlib import PurePath
  30. # pytest caches rewritten pycs in __pycache__.
  31. if hasattr(imp, "get_tag"):
  32. PYTEST_TAG = imp.get_tag() + "-PYTEST"
  33. else:
  34. if hasattr(sys, "pypy_version_info"):
  35. impl = "pypy"
  36. elif sys.platform == "java":
  37. impl = "jython"
  38. else:
  39. impl = "cpython"
  40. ver = sys.version_info
  41. PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1])
  42. del ver, impl
  43. PYC_EXT = ".py" + (__debug__ and "c" or "o")
  44. PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
  45. ASCII_IS_DEFAULT_ENCODING = sys.version_info[0] < 3
  46. if sys.version_info >= (3, 5):
  47. ast_Call = ast.Call
  48. else:
  49. def ast_Call(a, b, c):
  50. return ast.Call(a, b, c, None, None)
  51. class AssertionRewritingHook(object):
  52. """PEP302 Import hook which rewrites asserts."""
  53. def __init__(self, config):
  54. self.config = config
  55. try:
  56. self.fnpats = config.getini("python_files")
  57. except ValueError:
  58. self.fnpats = ["test_*.py", "*_test.py"]
  59. self.session = None
  60. self.modules = {}
  61. self._rewritten_names = set()
  62. self._must_rewrite = set()
  63. # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
  64. # which might result in infinite recursion (#3506)
  65. self._writing_pyc = False
  66. self._basenames_to_check_rewrite = {"conftest"}
  67. self._marked_for_rewrite_cache = {}
  68. self._session_paths_checked = False
  69. def set_session(self, session):
  70. self.session = session
  71. self._session_paths_checked = False
  72. def _imp_find_module(self, name, path=None):
  73. """Indirection so we can mock calls to find_module originated from the hook during testing"""
  74. return imp.find_module(name, path)
  75. def find_module(self, name, path=None):
  76. if self._writing_pyc:
  77. return None
  78. state = self.config._assertstate
  79. if self._early_rewrite_bailout(name, state):
  80. return None
  81. state.trace("find_module called for: %s" % name)
  82. names = name.rsplit(".", 1)
  83. lastname = names[-1]
  84. pth = None
  85. if path is not None:
  86. # Starting with Python 3.3, path is a _NamespacePath(), which
  87. # causes problems if not converted to list.
  88. path = list(path)
  89. if len(path) == 1:
  90. pth = path[0]
  91. if pth is None:
  92. try:
  93. fd, fn, desc = self._imp_find_module(lastname, path)
  94. except ImportError:
  95. return None
  96. if fd is not None:
  97. fd.close()
  98. tp = desc[2]
  99. if tp == imp.PY_COMPILED:
  100. if hasattr(imp, "source_from_cache"):
  101. try:
  102. fn = imp.source_from_cache(fn)
  103. except ValueError:
  104. # Python 3 doesn't like orphaned but still-importable
  105. # .pyc files.
  106. fn = fn[:-1]
  107. else:
  108. fn = fn[:-1]
  109. elif tp != imp.PY_SOURCE:
  110. # Don't know what this is.
  111. return None
  112. else:
  113. fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
  114. fn_pypath = py.path.local(fn)
  115. if not self._should_rewrite(name, fn_pypath, state):
  116. return None
  117. self._rewritten_names.add(name)
  118. # The requested module looks like a test file, so rewrite it. This is
  119. # the most magical part of the process: load the source, rewrite the
  120. # asserts, and load the rewritten source. We also cache the rewritten
  121. # module code in a special pyc. We must be aware of the possibility of
  122. # concurrent pytest processes rewriting and loading pycs. To avoid
  123. # tricky race conditions, we maintain the following invariant: The
  124. # cached pyc is always a complete, valid pyc. Operations on it must be
  125. # atomic. POSIX's atomic rename comes in handy.
  126. write = not sys.dont_write_bytecode
  127. cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
  128. if write:
  129. try:
  130. os.mkdir(cache_dir)
  131. except OSError:
  132. e = sys.exc_info()[1].errno
  133. if e == errno.EEXIST:
  134. # Either the __pycache__ directory already exists (the
  135. # common case) or it's blocked by a non-dir node. In the
  136. # latter case, we'll ignore it in _write_pyc.
  137. pass
  138. elif e in [errno.ENOENT, errno.ENOTDIR]:
  139. # One of the path components was not a directory, likely
  140. # because we're in a zip file.
  141. write = False
  142. elif e in [errno.EACCES, errno.EROFS, errno.EPERM]:
  143. state.trace("read only directory: %r" % fn_pypath.dirname)
  144. write = False
  145. else:
  146. raise
  147. cache_name = fn_pypath.basename[:-3] + PYC_TAIL
  148. pyc = os.path.join(cache_dir, cache_name)
  149. # Notice that even if we're in a read-only directory, I'm going
  150. # to check for a cached pyc. This may not be optimal...
  151. co = _read_pyc(fn_pypath, pyc, state.trace)
  152. if co is None:
  153. state.trace("rewriting %r" % (fn,))
  154. source_stat, co = _rewrite_test(self.config, fn_pypath)
  155. if co is None:
  156. # Probably a SyntaxError in the test.
  157. return None
  158. if write:
  159. self._writing_pyc = True
  160. try:
  161. _write_pyc(state, co, source_stat, pyc)
  162. finally:
  163. self._writing_pyc = False
  164. else:
  165. state.trace("found cached rewritten pyc for %r" % (fn,))
  166. self.modules[name] = co, pyc
  167. return self
  168. def _early_rewrite_bailout(self, name, state):
  169. """
  170. This is a fast way to get out of rewriting modules. Profiling has
  171. shown that the call to imp.find_module (inside of the find_module
  172. from this class) is a major slowdown, so, this method tries to
  173. filter what we're sure won't be rewritten before getting to it.
  174. """
  175. if self.session is not None and not self._session_paths_checked:
  176. self._session_paths_checked = True
  177. for path in self.session._initialpaths:
  178. # Make something as c:/projects/my_project/path.py ->
  179. # ['c:', 'projects', 'my_project', 'path.py']
  180. parts = str(path).split(os.path.sep)
  181. # add 'path' to basenames to be checked.
  182. self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
  183. # Note: conftest already by default in _basenames_to_check_rewrite.
  184. parts = name.split(".")
  185. if parts[-1] in self._basenames_to_check_rewrite:
  186. return False
  187. # For matching the name it must be as if it was a filename.
  188. path = PurePath(os.path.sep.join(parts) + ".py")
  189. for pat in self.fnpats:
  190. # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
  191. # on the name alone because we need to match against the full path
  192. if os.path.dirname(pat):
  193. return False
  194. if fnmatch_ex(pat, path):
  195. return False
  196. if self._is_marked_for_rewrite(name, state):
  197. return False
  198. state.trace("early skip of rewriting module: %s" % (name,))
  199. return True
  200. def _should_rewrite(self, name, fn_pypath, state):
  201. # always rewrite conftest files
  202. fn = str(fn_pypath)
  203. if fn_pypath.basename == "conftest.py":
  204. state.trace("rewriting conftest file: %r" % (fn,))
  205. return True
  206. if self.session is not None:
  207. if self.session.isinitpath(fn):
  208. state.trace("matched test file (was specified on cmdline): %r" % (fn,))
  209. return True
  210. # modules not passed explicitly on the command line are only
  211. # rewritten if they match the naming convention for test files
  212. for pat in self.fnpats:
  213. if fn_pypath.fnmatch(pat):
  214. state.trace("matched test file %r" % (fn,))
  215. return True
  216. return self._is_marked_for_rewrite(name, state)
  217. def _is_marked_for_rewrite(self, name, state):
  218. try:
  219. return self._marked_for_rewrite_cache[name]
  220. except KeyError:
  221. for marked in self._must_rewrite:
  222. if name == marked or name.startswith(marked + "."):
  223. state.trace("matched marked file %r (from %r)" % (name, marked))
  224. self._marked_for_rewrite_cache[name] = True
  225. return True
  226. self._marked_for_rewrite_cache[name] = False
  227. return False
  228. def mark_rewrite(self, *names):
  229. """Mark import names as needing to be rewritten.
  230. The named module or package as well as any nested modules will
  231. be rewritten on import.
  232. """
  233. already_imported = (
  234. set(names).intersection(sys.modules).difference(self._rewritten_names)
  235. )
  236. for name in already_imported:
  237. if not AssertionRewriter.is_rewrite_disabled(
  238. sys.modules[name].__doc__ or ""
  239. ):
  240. self._warn_already_imported(name)
  241. self._must_rewrite.update(names)
  242. self._marked_for_rewrite_cache.clear()
  243. def _warn_already_imported(self, name):
  244. from _pytest.warning_types import PytestAssertRewriteWarning
  245. from _pytest.warnings import _issue_warning_captured
  246. _issue_warning_captured(
  247. PytestAssertRewriteWarning(
  248. "Module already imported so cannot be rewritten: %s" % name
  249. ),
  250. self.config.hook,
  251. stacklevel=5,
  252. )
  253. def load_module(self, name):
  254. co, pyc = self.modules.pop(name)
  255. if name in sys.modules:
  256. # If there is an existing module object named 'fullname' in
  257. # sys.modules, the loader must use that existing module. (Otherwise,
  258. # the reload() builtin will not work correctly.)
  259. mod = sys.modules[name]
  260. else:
  261. # I wish I could just call imp.load_compiled here, but __file__ has to
  262. # be set properly. In Python 3.2+, this all would be handled correctly
  263. # by load_compiled.
  264. mod = sys.modules[name] = imp.new_module(name)
  265. try:
  266. mod.__file__ = co.co_filename
  267. # Normally, this attribute is 3.2+.
  268. mod.__cached__ = pyc
  269. mod.__loader__ = self
  270. # Normally, this attribute is 3.4+
  271. mod.__spec__ = spec_from_file_location(name, co.co_filename, loader=self)
  272. exec(co, mod.__dict__)
  273. except: # noqa
  274. if name in sys.modules:
  275. del sys.modules[name]
  276. raise
  277. return sys.modules[name]
  278. def is_package(self, name):
  279. try:
  280. fd, fn, desc = self._imp_find_module(name)
  281. except ImportError:
  282. return False
  283. if fd is not None:
  284. fd.close()
  285. tp = desc[2]
  286. return tp == imp.PKG_DIRECTORY
  287. def get_data(self, pathname):
  288. """Optional PEP302 get_data API.
  289. """
  290. with open(pathname, "rb") as f:
  291. return f.read()
  292. def _write_pyc(state, co, source_stat, pyc):
  293. # Technically, we don't have to have the same pyc format as
  294. # (C)Python, since these "pycs" should never be seen by builtin
  295. # import. However, there's little reason deviate, and I hope
  296. # sometime to be able to use imp.load_compiled to load them. (See
  297. # the comment in load_module above.)
  298. try:
  299. with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp:
  300. fp.write(imp.get_magic())
  301. # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
  302. mtime = int(source_stat.mtime) & 0xFFFFFFFF
  303. size = source_stat.size & 0xFFFFFFFF
  304. # "<LL" stands for 2 unsigned longs, little-ending
  305. fp.write(struct.pack("<LL", mtime, size))
  306. fp.write(marshal.dumps(co))
  307. except EnvironmentError as e:
  308. state.trace("error writing pyc file at %s: errno=%s" % (pyc, e.errno))
  309. # we ignore any failure to write the cache file
  310. # there are many reasons, permission-denied, __pycache__ being a
  311. # file etc.
  312. return False
  313. return True
  314. RN = "\r\n".encode("utf-8")
  315. N = "\n".encode("utf-8")
  316. cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
  317. BOM_UTF8 = "\xef\xbb\xbf"
  318. def _rewrite_test(config, fn):
  319. """Try to read and rewrite *fn* and return the code object."""
  320. state = config._assertstate
  321. try:
  322. stat = fn.stat()
  323. source = fn.read("rb")
  324. except EnvironmentError:
  325. return None, None
  326. if ASCII_IS_DEFAULT_ENCODING:
  327. # ASCII is the default encoding in Python 2. Without a coding
  328. # declaration, Python 2 will complain about any bytes in the file
  329. # outside the ASCII range. Sadly, this behavior does not extend to
  330. # compile() or ast.parse(), which prefer to interpret the bytes as
  331. # latin-1. (At least they properly handle explicit coding cookies.) To
  332. # preserve this error behavior, we could force ast.parse() to use ASCII
  333. # as the encoding by inserting a coding cookie. Unfortunately, that
  334. # messes up line numbers. Thus, we have to check ourselves if anything
  335. # is outside the ASCII range in the case no encoding is explicitly
  336. # declared. For more context, see issue #269. Yay for Python 3 which
  337. # gets this right.
  338. end1 = source.find("\n")
  339. end2 = source.find("\n", end1 + 1)
  340. if (
  341. not source.startswith(BOM_UTF8)
  342. and cookie_re.match(source[0:end1]) is None
  343. and cookie_re.match(source[end1 + 1 : end2]) is None
  344. ):
  345. if hasattr(state, "_indecode"):
  346. # encodings imported us again, so don't rewrite.
  347. return None, None
  348. state._indecode = True
  349. try:
  350. try:
  351. source.decode("ascii")
  352. except UnicodeDecodeError:
  353. # Let it fail in real import.
  354. return None, None
  355. finally:
  356. del state._indecode
  357. try:
  358. tree = ast.parse(source, filename=fn.strpath)
  359. except SyntaxError:
  360. # Let this pop up again in the real import.
  361. state.trace("failed to parse: %r" % (fn,))
  362. return None, None
  363. rewrite_asserts(tree, fn, config)
  364. try:
  365. co = compile(tree, fn.strpath, "exec", dont_inherit=True)
  366. except SyntaxError:
  367. # It's possible that this error is from some bug in the
  368. # assertion rewriting, but I don't know of a fast way to tell.
  369. state.trace("failed to compile: %r" % (fn,))
  370. return None, None
  371. return stat, co
  372. def _read_pyc(source, pyc, trace=lambda x: None):
  373. """Possibly read a pytest pyc containing rewritten code.
  374. Return rewritten code if successful or None if not.
  375. """
  376. try:
  377. fp = open(pyc, "rb")
  378. except IOError:
  379. return None
  380. with fp:
  381. try:
  382. mtime = int(source.mtime())
  383. size = source.size()
  384. data = fp.read(12)
  385. except EnvironmentError as e:
  386. trace("_read_pyc(%s): EnvironmentError %s" % (source, e))
  387. return None
  388. # Check for invalid or out of date pyc file.
  389. if (
  390. len(data) != 12
  391. or data[:4] != imp.get_magic()
  392. or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
  393. ):
  394. trace("_read_pyc(%s): invalid or out of date pyc" % source)
  395. return None
  396. try:
  397. co = marshal.load(fp)
  398. except Exception as e:
  399. trace("_read_pyc(%s): marshal.load error %s" % (source, e))
  400. return None
  401. if not isinstance(co, types.CodeType):
  402. trace("_read_pyc(%s): not a code object" % source)
  403. return None
  404. return co
  405. def rewrite_asserts(mod, module_path=None, config=None):
  406. """Rewrite the assert statements in mod."""
  407. AssertionRewriter(module_path, config).run(mod)
  408. def _saferepr(obj):
  409. """Get a safe repr of an object for assertion error messages.
  410. The assertion formatting (util.format_explanation()) requires
  411. newlines to be escaped since they are a special character for it.
  412. Normally assertion.util.format_explanation() does this but for a
  413. custom repr it is possible to contain one of the special escape
  414. sequences, especially '\n{' and '\n}' are likely to be present in
  415. JSON reprs.
  416. """
  417. r = saferepr(obj)
  418. # only occurs in python2.x, repr must return text in python3+
  419. if isinstance(r, bytes):
  420. # Represent unprintable bytes as `\x##`
  421. r = u"".join(
  422. u"\\x{:x}".format(ord(c)) if c not in string.printable else c.decode()
  423. for c in r
  424. )
  425. return r.replace(u"\n", u"\\n")
  426. def _format_assertmsg(obj):
  427. """Format the custom assertion message given.
  428. For strings this simply replaces newlines with '\n~' so that
  429. util.format_explanation() will preserve them instead of escaping
  430. newlines. For other objects saferepr() is used first.
  431. """
  432. # reprlib appears to have a bug which means that if a string
  433. # contains a newline it gets escaped, however if an object has a
  434. # .__repr__() which contains newlines it does not get escaped.
  435. # However in either case we want to preserve the newline.
  436. replaces = [(u"\n", u"\n~"), (u"%", u"%%")]
  437. if not isinstance(obj, six.string_types):
  438. obj = saferepr(obj)
  439. replaces.append((u"\\n", u"\n~"))
  440. if isinstance(obj, bytes):
  441. replaces = [(r1.encode(), r2.encode()) for r1, r2 in replaces]
  442. for r1, r2 in replaces:
  443. obj = obj.replace(r1, r2)
  444. return obj
  445. def _should_repr_global_name(obj):
  446. if callable(obj):
  447. return False
  448. try:
  449. return not hasattr(obj, "__name__")
  450. except Exception:
  451. return True
  452. def _format_boolop(explanations, is_or):
  453. explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
  454. if isinstance(explanation, six.text_type):
  455. return explanation.replace(u"%", u"%%")
  456. else:
  457. return explanation.replace(b"%", b"%%")
  458. def _call_reprcompare(ops, results, expls, each_obj):
  459. for i, res, expl in zip(range(len(ops)), results, expls):
  460. try:
  461. done = not res
  462. except Exception:
  463. done = True
  464. if done:
  465. break
  466. if util._reprcompare is not None:
  467. custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
  468. if custom is not None:
  469. return custom
  470. return expl
  471. unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
  472. binop_map = {
  473. ast.BitOr: "|",
  474. ast.BitXor: "^",
  475. ast.BitAnd: "&",
  476. ast.LShift: "<<",
  477. ast.RShift: ">>",
  478. ast.Add: "+",
  479. ast.Sub: "-",
  480. ast.Mult: "*",
  481. ast.Div: "/",
  482. ast.FloorDiv: "//",
  483. ast.Mod: "%%", # escaped for string formatting
  484. ast.Eq: "==",
  485. ast.NotEq: "!=",
  486. ast.Lt: "<",
  487. ast.LtE: "<=",
  488. ast.Gt: ">",
  489. ast.GtE: ">=",
  490. ast.Pow: "**",
  491. ast.Is: "is",
  492. ast.IsNot: "is not",
  493. ast.In: "in",
  494. ast.NotIn: "not in",
  495. }
  496. # Python 3.5+ compatibility
  497. try:
  498. binop_map[ast.MatMult] = "@"
  499. except AttributeError:
  500. pass
  501. # Python 3.4+ compatibility
  502. if hasattr(ast, "NameConstant"):
  503. _NameConstant = ast.NameConstant
  504. else:
  505. def _NameConstant(c):
  506. return ast.Name(str(c), ast.Load())
  507. def set_location(node, lineno, col_offset):
  508. """Set node location information recursively."""
  509. def _fix(node, lineno, col_offset):
  510. if "lineno" in node._attributes:
  511. node.lineno = lineno
  512. if "col_offset" in node._attributes:
  513. node.col_offset = col_offset
  514. for child in ast.iter_child_nodes(node):
  515. _fix(child, lineno, col_offset)
  516. _fix(node, lineno, col_offset)
  517. return node
  518. class AssertionRewriter(ast.NodeVisitor):
  519. """Assertion rewriting implementation.
  520. The main entrypoint is to call .run() with an ast.Module instance,
  521. this will then find all the assert statements and rewrite them to
  522. provide intermediate values and a detailed assertion error. See
  523. http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
  524. for an overview of how this works.
  525. The entry point here is .run() which will iterate over all the
  526. statements in an ast.Module and for each ast.Assert statement it
  527. finds call .visit() with it. Then .visit_Assert() takes over and
  528. is responsible for creating new ast statements to replace the
  529. original assert statement: it rewrites the test of an assertion
  530. to provide intermediate values and replace it with an if statement
  531. which raises an assertion error with a detailed explanation in
  532. case the expression is false.
  533. For this .visit_Assert() uses the visitor pattern to visit all the
  534. AST nodes of the ast.Assert.test field, each visit call returning
  535. an AST node and the corresponding explanation string. During this
  536. state is kept in several instance attributes:
  537. :statements: All the AST statements which will replace the assert
  538. statement.
  539. :variables: This is populated by .variable() with each variable
  540. used by the statements so that they can all be set to None at
  541. the end of the statements.
  542. :variable_counter: Counter to create new unique variables needed
  543. by statements. Variables are created using .variable() and
  544. have the form of "@py_assert0".
  545. :on_failure: The AST statements which will be executed if the
  546. assertion test fails. This is the code which will construct
  547. the failure message and raises the AssertionError.
  548. :explanation_specifiers: A dict filled by .explanation_param()
  549. with %-formatting placeholders and their corresponding
  550. expressions to use in the building of an assertion message.
  551. This is used by .pop_format_context() to build a message.
  552. :stack: A stack of the explanation_specifiers dicts maintained by
  553. .push_format_context() and .pop_format_context() which allows
  554. to build another %-formatted string while already building one.
  555. This state is reset on every new assert statement visited and used
  556. by the other visitors.
  557. """
  558. def __init__(self, module_path, config):
  559. super(AssertionRewriter, self).__init__()
  560. self.module_path = module_path
  561. self.config = config
  562. def run(self, mod):
  563. """Find all assert statements in *mod* and rewrite them."""
  564. if not mod.body:
  565. # Nothing to do.
  566. return
  567. # Insert some special imports at the top of the module but after any
  568. # docstrings and __future__ imports.
  569. aliases = [
  570. ast.alias(six.moves.builtins.__name__, "@py_builtins"),
  571. ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
  572. ]
  573. doc = getattr(mod, "docstring", None)
  574. expect_docstring = doc is None
  575. if doc is not None and self.is_rewrite_disabled(doc):
  576. return
  577. pos = 0
  578. lineno = 1
  579. for item in mod.body:
  580. if (
  581. expect_docstring
  582. and isinstance(item, ast.Expr)
  583. and isinstance(item.value, ast.Str)
  584. ):
  585. doc = item.value.s
  586. if self.is_rewrite_disabled(doc):
  587. return
  588. expect_docstring = False
  589. elif (
  590. not isinstance(item, ast.ImportFrom)
  591. or item.level > 0
  592. or item.module != "__future__"
  593. ):
  594. lineno = item.lineno
  595. break
  596. pos += 1
  597. else:
  598. lineno = item.lineno
  599. imports = [
  600. ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
  601. ]
  602. mod.body[pos:pos] = imports
  603. # Collect asserts.
  604. nodes = [mod]
  605. while nodes:
  606. node = nodes.pop()
  607. for name, field in ast.iter_fields(node):
  608. if isinstance(field, list):
  609. new = []
  610. for i, child in enumerate(field):
  611. if isinstance(child, ast.Assert):
  612. # Transform assert.
  613. new.extend(self.visit(child))
  614. else:
  615. new.append(child)
  616. if isinstance(child, ast.AST):
  617. nodes.append(child)
  618. setattr(node, name, new)
  619. elif (
  620. isinstance(field, ast.AST)
  621. # Don't recurse into expressions as they can't contain
  622. # asserts.
  623. and not isinstance(field, ast.expr)
  624. ):
  625. nodes.append(field)
  626. @staticmethod
  627. def is_rewrite_disabled(docstring):
  628. return "PYTEST_DONT_REWRITE" in docstring
  629. def variable(self):
  630. """Get a new variable."""
  631. # Use a character invalid in python identifiers to avoid clashing.
  632. name = "@py_assert" + str(next(self.variable_counter))
  633. self.variables.append(name)
  634. return name
  635. def assign(self, expr):
  636. """Give *expr* a name."""
  637. name = self.variable()
  638. self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
  639. return ast.Name(name, ast.Load())
  640. def display(self, expr):
  641. """Call saferepr on the expression."""
  642. return self.helper("_saferepr", expr)
  643. def helper(self, name, *args):
  644. """Call a helper in this module."""
  645. py_name = ast.Name("@pytest_ar", ast.Load())
  646. attr = ast.Attribute(py_name, name, ast.Load())
  647. return ast_Call(attr, list(args), [])
  648. def builtin(self, name):
  649. """Return the builtin called *name*."""
  650. builtin_name = ast.Name("@py_builtins", ast.Load())
  651. return ast.Attribute(builtin_name, name, ast.Load())
  652. def explanation_param(self, expr):
  653. """Return a new named %-formatting placeholder for expr.
  654. This creates a %-formatting placeholder for expr in the
  655. current formatting context, e.g. ``%(py0)s``. The placeholder
  656. and expr are placed in the current format context so that it
  657. can be used on the next call to .pop_format_context().
  658. """
  659. specifier = "py" + str(next(self.variable_counter))
  660. self.explanation_specifiers[specifier] = expr
  661. return "%(" + specifier + ")s"
  662. def push_format_context(self):
  663. """Create a new formatting context.
  664. The format context is used for when an explanation wants to
  665. have a variable value formatted in the assertion message. In
  666. this case the value required can be added using
  667. .explanation_param(). Finally .pop_format_context() is used
  668. to format a string of %-formatted values as added by
  669. .explanation_param().
  670. """
  671. self.explanation_specifiers = {}
  672. self.stack.append(self.explanation_specifiers)
  673. def pop_format_context(self, expl_expr):
  674. """Format the %-formatted string with current format context.
  675. The expl_expr should be an ast.Str instance constructed from
  676. the %-placeholders created by .explanation_param(). This will
  677. add the required code to format said string to .on_failure and
  678. return the ast.Name instance of the formatted string.
  679. """
  680. current = self.stack.pop()
  681. if self.stack:
  682. self.explanation_specifiers = self.stack[-1]
  683. keys = [ast.Str(key) for key in current.keys()]
  684. format_dict = ast.Dict(keys, list(current.values()))
  685. form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
  686. name = "@py_format" + str(next(self.variable_counter))
  687. self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
  688. return ast.Name(name, ast.Load())
  689. def generic_visit(self, node):
  690. """Handle expressions we don't have custom code for."""
  691. assert isinstance(node, ast.expr)
  692. res = self.assign(node)
  693. return res, self.explanation_param(self.display(res))
  694. def visit_Assert(self, assert_):
  695. """Return the AST statements to replace the ast.Assert instance.
  696. This rewrites the test of an assertion to provide
  697. intermediate values and replace it with an if statement which
  698. raises an assertion error with a detailed explanation in case
  699. the expression is false.
  700. """
  701. if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
  702. from _pytest.warning_types import PytestAssertRewriteWarning
  703. import warnings
  704. warnings.warn_explicit(
  705. PytestAssertRewriteWarning(
  706. "assertion is always true, perhaps remove parentheses?"
  707. ),
  708. category=None,
  709. filename=str(self.module_path),
  710. lineno=assert_.lineno,
  711. )
  712. self.statements = []
  713. self.variables = []
  714. self.variable_counter = itertools.count()
  715. self.stack = []
  716. self.on_failure = []
  717. self.push_format_context()
  718. # Rewrite assert into a bunch of statements.
  719. top_condition, explanation = self.visit(assert_.test)
  720. # If in a test module, check if directly asserting None, in order to warn [Issue #3191]
  721. if self.module_path is not None:
  722. self.statements.append(
  723. self.warn_about_none_ast(
  724. top_condition, module_path=self.module_path, lineno=assert_.lineno
  725. )
  726. )
  727. # Create failure message.
  728. body = self.on_failure
  729. negation = ast.UnaryOp(ast.Not(), top_condition)
  730. self.statements.append(ast.If(negation, body, []))
  731. if assert_.msg:
  732. assertmsg = self.helper("_format_assertmsg", assert_.msg)
  733. explanation = "\n>assert " + explanation
  734. else:
  735. assertmsg = ast.Str("")
  736. explanation = "assert " + explanation
  737. template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
  738. msg = self.pop_format_context(template)
  739. fmt = self.helper("_format_explanation", msg)
  740. err_name = ast.Name("AssertionError", ast.Load())
  741. exc = ast_Call(err_name, [fmt], [])
  742. if sys.version_info[0] >= 3:
  743. raise_ = ast.Raise(exc, None)
  744. else:
  745. raise_ = ast.Raise(exc, None, None)
  746. body.append(raise_)
  747. # Clear temporary variables by setting them to None.
  748. if self.variables:
  749. variables = [ast.Name(name, ast.Store()) for name in self.variables]
  750. clear = ast.Assign(variables, _NameConstant(None))
  751. self.statements.append(clear)
  752. # Fix line numbers.
  753. for stmt in self.statements:
  754. set_location(stmt, assert_.lineno, assert_.col_offset)
  755. return self.statements
  756. def warn_about_none_ast(self, node, module_path, lineno):
  757. """
  758. Returns an AST issuing a warning if the value of node is `None`.
  759. This is used to warn the user when asserting a function that asserts
  760. internally already.
  761. See issue #3191 for more details.
  762. """
  763. # Using parse because it is different between py2 and py3.
  764. AST_NONE = ast.parse("None").body[0].value
  765. val_is_none = ast.Compare(node, [ast.Is()], [AST_NONE])
  766. send_warning = ast.parse(
  767. """
  768. from _pytest.warning_types import PytestAssertRewriteWarning
  769. from warnings import warn_explicit
  770. warn_explicit(
  771. PytestAssertRewriteWarning('asserting the value None, please use "assert is None"'),
  772. category=None,
  773. filename={filename!r},
  774. lineno={lineno},
  775. )
  776. """.format(
  777. filename=module_path.strpath, lineno=lineno
  778. )
  779. ).body
  780. return ast.If(val_is_none, send_warning, [])
  781. def visit_Name(self, name):
  782. # Display the repr of the name if it's a local variable or
  783. # _should_repr_global_name() thinks it's acceptable.
  784. locs = ast_Call(self.builtin("locals"), [], [])
  785. inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
  786. dorepr = self.helper("_should_repr_global_name", name)
  787. test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
  788. expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
  789. return name, self.explanation_param(expr)
  790. def visit_BoolOp(self, boolop):
  791. res_var = self.variable()
  792. expl_list = self.assign(ast.List([], ast.Load()))
  793. app = ast.Attribute(expl_list, "append", ast.Load())
  794. is_or = int(isinstance(boolop.op, ast.Or))
  795. body = save = self.statements
  796. fail_save = self.on_failure
  797. levels = len(boolop.values) - 1
  798. self.push_format_context()
  799. # Process each operand, short-circuting if needed.
  800. for i, v in enumerate(boolop.values):
  801. if i:
  802. fail_inner = []
  803. # cond is set in a prior loop iteration below
  804. self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
  805. self.on_failure = fail_inner
  806. self.push_format_context()
  807. res, expl = self.visit(v)
  808. body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
  809. expl_format = self.pop_format_context(ast.Str(expl))
  810. call = ast_Call(app, [expl_format], [])
  811. self.on_failure.append(ast.Expr(call))
  812. if i < levels:
  813. cond = res
  814. if is_or:
  815. cond = ast.UnaryOp(ast.Not(), cond)
  816. inner = []
  817. self.statements.append(ast.If(cond, inner, []))
  818. self.statements = body = inner
  819. self.statements = save
  820. self.on_failure = fail_save
  821. expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
  822. expl = self.pop_format_context(expl_template)
  823. return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
  824. def visit_UnaryOp(self, unary):
  825. pattern = unary_map[unary.op.__class__]
  826. operand_res, operand_expl = self.visit(unary.operand)
  827. res = self.assign(ast.UnaryOp(unary.op, operand_res))
  828. return res, pattern % (operand_expl,)
  829. def visit_BinOp(self, binop):
  830. symbol = binop_map[binop.op.__class__]
  831. left_expr, left_expl = self.visit(binop.left)
  832. right_expr, right_expl = self.visit(binop.right)
  833. explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
  834. res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
  835. return res, explanation
  836. def visit_Call_35(self, call):
  837. """
  838. visit `ast.Call` nodes on Python3.5 and after
  839. """
  840. new_func, func_expl = self.visit(call.func)
  841. arg_expls = []
  842. new_args = []
  843. new_kwargs = []
  844. for arg in call.args:
  845. res, expl = self.visit(arg)
  846. arg_expls.append(expl)
  847. new_args.append(res)
  848. for keyword in call.keywords:
  849. res, expl = self.visit(keyword.value)
  850. new_kwargs.append(ast.keyword(keyword.arg, res))
  851. if keyword.arg:
  852. arg_expls.append(keyword.arg + "=" + expl)
  853. else: # **args have `arg` keywords with an .arg of None
  854. arg_expls.append("**" + expl)
  855. expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
  856. new_call = ast.Call(new_func, new_args, new_kwargs)
  857. res = self.assign(new_call)
  858. res_expl = self.explanation_param(self.display(res))
  859. outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
  860. return res, outer_expl
  861. def visit_Starred(self, starred):
  862. # From Python 3.5, a Starred node can appear in a function call
  863. res, expl = self.visit(starred.value)
  864. new_starred = ast.Starred(res, starred.ctx)
  865. return new_starred, "*" + expl
  866. def visit_Call_legacy(self, call):
  867. """
  868. visit `ast.Call nodes on 3.4 and below`
  869. """
  870. new_func, func_expl = self.visit(call.func)
  871. arg_expls = []
  872. new_args = []
  873. new_kwargs = []
  874. new_star = new_kwarg = None
  875. for arg in call.args:
  876. res, expl = self.visit(arg)
  877. new_args.append(res)
  878. arg_expls.append(expl)
  879. for keyword in call.keywords:
  880. res, expl = self.visit(keyword.value)
  881. new_kwargs.append(ast.keyword(keyword.arg, res))
  882. arg_expls.append(keyword.arg + "=" + expl)
  883. if call.starargs:
  884. new_star, expl = self.visit(call.starargs)
  885. arg_expls.append("*" + expl)
  886. if call.kwargs:
  887. new_kwarg, expl = self.visit(call.kwargs)
  888. arg_expls.append("**" + expl)
  889. expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
  890. new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
  891. res = self.assign(new_call)
  892. res_expl = self.explanation_param(self.display(res))
  893. outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
  894. return res, outer_expl
  895. # ast.Call signature changed on 3.5,
  896. # conditionally change which methods is named
  897. # visit_Call depending on Python version
  898. if sys.version_info >= (3, 5):
  899. visit_Call = visit_Call_35
  900. else:
  901. visit_Call = visit_Call_legacy
  902. def visit_Attribute(self, attr):
  903. if not isinstance(attr.ctx, ast.Load):
  904. return self.generic_visit(attr)
  905. value, value_expl = self.visit(attr.value)
  906. res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
  907. res_expl = self.explanation_param(self.display(res))
  908. pat = "%s\n{%s = %s.%s\n}"
  909. expl = pat % (res_expl, res_expl, value_expl, attr.attr)
  910. return res, expl
  911. def visit_Compare(self, comp):
  912. self.push_format_context()
  913. left_res, left_expl = self.visit(comp.left)
  914. if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
  915. left_expl = "({})".format(left_expl)
  916. res_variables = [self.variable() for i in range(len(comp.ops))]
  917. load_names = [ast.Name(v, ast.Load()) for v in res_variables]
  918. store_names = [ast.Name(v, ast.Store()) for v in res_variables]
  919. it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
  920. expls = []
  921. syms = []
  922. results = [left_res]
  923. for i, op, next_operand in it:
  924. next_res, next_expl = self.visit(next_operand)
  925. if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
  926. next_expl = "({})".format(next_expl)
  927. results.append(next_res)
  928. sym = binop_map[op.__class__]
  929. syms.append(ast.Str(sym))
  930. expl = "%s %s %s" % (left_expl, sym, next_expl)
  931. expls.append(ast.Str(expl))
  932. res_expr = ast.Compare(left_res, [op], [next_res])
  933. self.statements.append(ast.Assign([store_names[i]], res_expr))
  934. left_res, left_expl = next_res, next_expl
  935. # Use pytest.assertion.util._reprcompare if that's available.
  936. expl_call = self.helper(
  937. "_call_reprcompare",
  938. ast.Tuple(syms, ast.Load()),
  939. ast.Tuple(load_names, ast.Load()),
  940. ast.Tuple(expls, ast.Load()),
  941. ast.Tuple(results, ast.Load()),
  942. )
  943. if len(comp.ops) > 1:
  944. res = ast.BoolOp(ast.And(), load_names)
  945. else:
  946. res = load_names[0]
  947. return res, self.explanation_param(self.pop_format_context(expl_call))