rewrite.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217
  1. """Rewrite assertion AST to produce nice error messages."""
  2. import ast
  3. import errno
  4. import functools
  5. import importlib.abc
  6. import importlib.machinery
  7. import importlib.util
  8. import io
  9. import itertools
  10. import marshal
  11. import os
  12. import struct
  13. import sys
  14. import tokenize
  15. import types
  16. from collections import defaultdict
  17. from pathlib import Path
  18. from pathlib import PurePath
  19. from typing import Callable
  20. from typing import Dict
  21. from typing import IO
  22. from typing import Iterable
  23. from typing import Iterator
  24. from typing import List
  25. from typing import Optional
  26. from typing import Sequence
  27. from typing import Set
  28. from typing import Tuple
  29. from typing import TYPE_CHECKING
  30. from typing import Union
  31. from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
  32. from _pytest._io.saferepr import saferepr
  33. from _pytest._version import version
  34. from _pytest.assertion import util
  35. from _pytest.assertion.util import ( # noqa: F401
  36. format_explanation as _format_explanation,
  37. )
  38. from _pytest.config import Config
  39. from _pytest.main import Session
  40. from _pytest.pathlib import absolutepath
  41. from _pytest.pathlib import fnmatch_ex
  42. from _pytest.stash import StashKey
  43. if TYPE_CHECKING:
  44. from _pytest.assertion import AssertionState
  45. if sys.version_info >= (3, 8):
  46. namedExpr = ast.NamedExpr
  47. astNameConstant = ast.Constant
  48. astStr = ast.Constant
  49. astNum = ast.Constant
  50. else:
  51. namedExpr = ast.Expr
  52. astNameConstant = ast.NameConstant
  53. astStr = ast.Str
  54. astNum = ast.Num
  55. class Sentinel:
  56. pass
  57. assertstate_key = StashKey["AssertionState"]()
  58. # pytest caches rewritten pycs in pycache dirs
  59. PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
  60. PYC_EXT = ".py" + (__debug__ and "c" or "o")
  61. PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
  62. # Special marker that denotes we have just left a scope definition
  63. _SCOPE_END_MARKER = Sentinel()
  64. class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
  65. """PEP302/PEP451 import hook which rewrites asserts."""
  66. def __init__(self, config: Config) -> None:
  67. self.config = config
  68. try:
  69. self.fnpats = config.getini("python_files")
  70. except ValueError:
  71. self.fnpats = ["test_*.py", "*_test.py"]
  72. self.session: Optional[Session] = None
  73. self._rewritten_names: Dict[str, Path] = {}
  74. self._must_rewrite: Set[str] = set()
  75. # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
  76. # which might result in infinite recursion (#3506)
  77. self._writing_pyc = False
  78. self._basenames_to_check_rewrite = {"conftest"}
  79. self._marked_for_rewrite_cache: Dict[str, bool] = {}
  80. self._session_paths_checked = False
  81. def set_session(self, session: Optional[Session]) -> None:
  82. self.session = session
  83. self._session_paths_checked = False
  84. # Indirection so we can mock calls to find_spec originated from the hook during testing
  85. _find_spec = importlib.machinery.PathFinder.find_spec
  86. def find_spec(
  87. self,
  88. name: str,
  89. path: Optional[Sequence[Union[str, bytes]]] = None,
  90. target: Optional[types.ModuleType] = None,
  91. ) -> Optional[importlib.machinery.ModuleSpec]:
  92. if self._writing_pyc:
  93. return None
  94. state = self.config.stash[assertstate_key]
  95. if self._early_rewrite_bailout(name, state):
  96. return None
  97. state.trace("find_module called for: %s" % name)
  98. # Type ignored because mypy is confused about the `self` binding here.
  99. spec = self._find_spec(name, path) # type: ignore
  100. if (
  101. # the import machinery could not find a file to import
  102. spec is None
  103. # this is a namespace package (without `__init__.py`)
  104. # there's nothing to rewrite there
  105. or spec.origin is None
  106. # we can only rewrite source files
  107. or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
  108. # if the file doesn't exist, we can't rewrite it
  109. or not os.path.exists(spec.origin)
  110. ):
  111. return None
  112. else:
  113. fn = spec.origin
  114. if not self._should_rewrite(name, fn, state):
  115. return None
  116. return importlib.util.spec_from_file_location(
  117. name,
  118. fn,
  119. loader=self,
  120. submodule_search_locations=spec.submodule_search_locations,
  121. )
  122. def create_module(
  123. self, spec: importlib.machinery.ModuleSpec
  124. ) -> Optional[types.ModuleType]:
  125. return None # default behaviour is fine
  126. def exec_module(self, module: types.ModuleType) -> None:
  127. assert module.__spec__ is not None
  128. assert module.__spec__.origin is not None
  129. fn = Path(module.__spec__.origin)
  130. state = self.config.stash[assertstate_key]
  131. self._rewritten_names[module.__name__] = fn
  132. # The requested module looks like a test file, so rewrite it. This is
  133. # the most magical part of the process: load the source, rewrite the
  134. # asserts, and load the rewritten source. We also cache the rewritten
  135. # module code in a special pyc. We must be aware of the possibility of
  136. # concurrent pytest processes rewriting and loading pycs. To avoid
  137. # tricky race conditions, we maintain the following invariant: The
  138. # cached pyc is always a complete, valid pyc. Operations on it must be
  139. # atomic. POSIX's atomic rename comes in handy.
  140. write = not sys.dont_write_bytecode
  141. cache_dir = get_cache_dir(fn)
  142. if write:
  143. ok = try_makedirs(cache_dir)
  144. if not ok:
  145. write = False
  146. state.trace(f"read only directory: {cache_dir}")
  147. cache_name = fn.name[:-3] + PYC_TAIL
  148. pyc = cache_dir / cache_name
  149. # Notice that even if we're in a read-only directory, I'm going
  150. # to check for a cached pyc. This may not be optimal...
  151. co = _read_pyc(fn, pyc, state.trace)
  152. if co is None:
  153. state.trace(f"rewriting {fn!r}")
  154. source_stat, co = _rewrite_test(fn, self.config)
  155. if write:
  156. self._writing_pyc = True
  157. try:
  158. _write_pyc(state, co, source_stat, pyc)
  159. finally:
  160. self._writing_pyc = False
  161. else:
  162. state.trace(f"found cached rewritten pyc for {fn}")
  163. exec(co, module.__dict__)
  164. def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
  165. """A fast way to get out of rewriting modules.
  166. Profiling has shown that the call to PathFinder.find_spec (inside of
  167. the find_spec from this class) is a major slowdown, so, this method
  168. tries to filter what we're sure won't be rewritten before getting to
  169. it.
  170. """
  171. if self.session is not None and not self._session_paths_checked:
  172. self._session_paths_checked = True
  173. for initial_path in self.session._initialpaths:
  174. # Make something as c:/projects/my_project/path.py ->
  175. # ['c:', 'projects', 'my_project', 'path.py']
  176. parts = str(initial_path).split(os.sep)
  177. # add 'path' to basenames to be checked.
  178. self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
  179. # Note: conftest already by default in _basenames_to_check_rewrite.
  180. parts = name.split(".")
  181. if parts[-1] in self._basenames_to_check_rewrite:
  182. return False
  183. # For matching the name it must be as if it was a filename.
  184. path = PurePath(*parts).with_suffix(".py")
  185. for pat in self.fnpats:
  186. # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
  187. # on the name alone because we need to match against the full path
  188. if os.path.dirname(pat):
  189. return False
  190. if fnmatch_ex(pat, path):
  191. return False
  192. if self._is_marked_for_rewrite(name, state):
  193. return False
  194. state.trace(f"early skip of rewriting module: {name}")
  195. return True
  196. def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
  197. # always rewrite conftest files
  198. if os.path.basename(fn) == "conftest.py":
  199. state.trace(f"rewriting conftest file: {fn!r}")
  200. return True
  201. if self.session is not None:
  202. if self.session.isinitpath(absolutepath(fn)):
  203. state.trace(f"matched test file (was specified on cmdline): {fn!r}")
  204. return True
  205. # modules not passed explicitly on the command line are only
  206. # rewritten if they match the naming convention for test files
  207. fn_path = PurePath(fn)
  208. for pat in self.fnpats:
  209. if fnmatch_ex(pat, fn_path):
  210. state.trace(f"matched test file {fn!r}")
  211. return True
  212. return self._is_marked_for_rewrite(name, state)
  213. def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
  214. try:
  215. return self._marked_for_rewrite_cache[name]
  216. except KeyError:
  217. for marked in self._must_rewrite:
  218. if name == marked or name.startswith(marked + "."):
  219. state.trace(f"matched marked file {name!r} (from {marked!r})")
  220. self._marked_for_rewrite_cache[name] = True
  221. return True
  222. self._marked_for_rewrite_cache[name] = False
  223. return False
  224. def mark_rewrite(self, *names: str) -> None:
  225. """Mark import names as needing to be rewritten.
  226. The named module or package as well as any nested modules will
  227. be rewritten on import.
  228. """
  229. already_imported = (
  230. set(names).intersection(sys.modules).difference(self._rewritten_names)
  231. )
  232. for name in already_imported:
  233. mod = sys.modules[name]
  234. if not AssertionRewriter.is_rewrite_disabled(
  235. mod.__doc__ or ""
  236. ) and not isinstance(mod.__loader__, type(self)):
  237. self._warn_already_imported(name)
  238. self._must_rewrite.update(names)
  239. self._marked_for_rewrite_cache.clear()
  240. def _warn_already_imported(self, name: str) -> None:
  241. from _pytest.warning_types import PytestAssertRewriteWarning
  242. self.config.issue_config_time_warning(
  243. PytestAssertRewriteWarning(
  244. "Module already imported so cannot be rewritten: %s" % name
  245. ),
  246. stacklevel=5,
  247. )
  248. def get_data(self, pathname: Union[str, bytes]) -> bytes:
  249. """Optional PEP302 get_data API."""
  250. with open(pathname, "rb") as f:
  251. return f.read()
  252. if sys.version_info >= (3, 10):
  253. if sys.version_info >= (3, 12):
  254. from importlib.resources.abc import TraversableResources
  255. else:
  256. from importlib.abc import TraversableResources
  257. def get_resource_reader(self, name: str) -> TraversableResources: # type: ignore
  258. if sys.version_info < (3, 11):
  259. from importlib.readers import FileReader
  260. else:
  261. from importlib.resources.readers import FileReader
  262. return FileReader( # type:ignore[no-any-return]
  263. types.SimpleNamespace(path=self._rewritten_names[name])
  264. )
  265. def _write_pyc_fp(
  266. fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
  267. ) -> None:
  268. # Technically, we don't have to have the same pyc format as
  269. # (C)Python, since these "pycs" should never be seen by builtin
  270. # import. However, there's little reason to deviate.
  271. fp.write(importlib.util.MAGIC_NUMBER)
  272. # https://www.python.org/dev/peps/pep-0552/
  273. flags = b"\x00\x00\x00\x00"
  274. fp.write(flags)
  275. # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
  276. mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
  277. size = source_stat.st_size & 0xFFFFFFFF
  278. # "<LL" stands for 2 unsigned longs, little-endian.
  279. fp.write(struct.pack("<LL", mtime, size))
  280. fp.write(marshal.dumps(co))
  281. def _write_pyc(
  282. state: "AssertionState",
  283. co: types.CodeType,
  284. source_stat: os.stat_result,
  285. pyc: Path,
  286. ) -> bool:
  287. proc_pyc = f"{pyc}.{os.getpid()}"
  288. try:
  289. with open(proc_pyc, "wb") as fp:
  290. _write_pyc_fp(fp, source_stat, co)
  291. except OSError as e:
  292. state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
  293. return False
  294. try:
  295. os.replace(proc_pyc, pyc)
  296. except OSError as e:
  297. state.trace(f"error writing pyc file at {pyc}: {e}")
  298. # we ignore any failure to write the cache file
  299. # there are many reasons, permission-denied, pycache dir being a
  300. # file etc.
  301. return False
  302. return True
  303. def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
  304. """Read and rewrite *fn* and return the code object."""
  305. stat = os.stat(fn)
  306. source = fn.read_bytes()
  307. strfn = str(fn)
  308. tree = ast.parse(source, filename=strfn)
  309. rewrite_asserts(tree, source, strfn, config)
  310. co = compile(tree, strfn, "exec", dont_inherit=True)
  311. return stat, co
  312. def _read_pyc(
  313. source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
  314. ) -> Optional[types.CodeType]:
  315. """Possibly read a pytest pyc containing rewritten code.
  316. Return rewritten code if successful or None if not.
  317. """
  318. try:
  319. fp = open(pyc, "rb")
  320. except OSError:
  321. return None
  322. with fp:
  323. try:
  324. stat_result = os.stat(source)
  325. mtime = int(stat_result.st_mtime)
  326. size = stat_result.st_size
  327. data = fp.read(16)
  328. except OSError as e:
  329. trace(f"_read_pyc({source}): OSError {e}")
  330. return None
  331. # Check for invalid or out of date pyc file.
  332. if len(data) != (16):
  333. trace("_read_pyc(%s): invalid pyc (too short)" % source)
  334. return None
  335. if data[:4] != importlib.util.MAGIC_NUMBER:
  336. trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
  337. return None
  338. if data[4:8] != b"\x00\x00\x00\x00":
  339. trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
  340. return None
  341. mtime_data = data[8:12]
  342. if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
  343. trace("_read_pyc(%s): out of date" % source)
  344. return None
  345. size_data = data[12:16]
  346. if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
  347. trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
  348. return None
  349. try:
  350. co = marshal.load(fp)
  351. except Exception as e:
  352. trace(f"_read_pyc({source}): marshal.load error {e}")
  353. return None
  354. if not isinstance(co, types.CodeType):
  355. trace("_read_pyc(%s): not a code object" % source)
  356. return None
  357. return co
  358. def rewrite_asserts(
  359. mod: ast.Module,
  360. source: bytes,
  361. module_path: Optional[str] = None,
  362. config: Optional[Config] = None,
  363. ) -> None:
  364. """Rewrite the assert statements in mod."""
  365. AssertionRewriter(module_path, config, source).run(mod)
  366. def _saferepr(obj: object) -> str:
  367. r"""Get a safe repr of an object for assertion error messages.
  368. The assertion formatting (util.format_explanation()) requires
  369. newlines to be escaped since they are a special character for it.
  370. Normally assertion.util.format_explanation() does this but for a
  371. custom repr it is possible to contain one of the special escape
  372. sequences, especially '\n{' and '\n}' are likely to be present in
  373. JSON reprs.
  374. """
  375. maxsize = _get_maxsize_for_saferepr(util._config)
  376. return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
  377. def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
  378. """Get `maxsize` configuration for saferepr based on the given config object."""
  379. verbosity = config.getoption("verbose") if config is not None else 0
  380. if verbosity >= 2:
  381. return None
  382. if verbosity >= 1:
  383. return DEFAULT_REPR_MAX_SIZE * 10
  384. return DEFAULT_REPR_MAX_SIZE
  385. def _format_assertmsg(obj: object) -> str:
  386. r"""Format the custom assertion message given.
  387. For strings this simply replaces newlines with '\n~' so that
  388. util.format_explanation() will preserve them instead of escaping
  389. newlines. For other objects saferepr() is used first.
  390. """
  391. # reprlib appears to have a bug which means that if a string
  392. # contains a newline it gets escaped, however if an object has a
  393. # .__repr__() which contains newlines it does not get escaped.
  394. # However in either case we want to preserve the newline.
  395. replaces = [("\n", "\n~"), ("%", "%%")]
  396. if not isinstance(obj, str):
  397. obj = saferepr(obj)
  398. replaces.append(("\\n", "\n~"))
  399. for r1, r2 in replaces:
  400. obj = obj.replace(r1, r2)
  401. return obj
  402. def _should_repr_global_name(obj: object) -> bool:
  403. if callable(obj):
  404. return False
  405. try:
  406. return not hasattr(obj, "__name__")
  407. except Exception:
  408. return True
  409. def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
  410. explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
  411. return explanation.replace("%", "%%")
  412. def _call_reprcompare(
  413. ops: Sequence[str],
  414. results: Sequence[bool],
  415. expls: Sequence[str],
  416. each_obj: Sequence[object],
  417. ) -> str:
  418. for i, res, expl in zip(range(len(ops)), results, expls):
  419. try:
  420. done = not res
  421. except Exception:
  422. done = True
  423. if done:
  424. break
  425. if util._reprcompare is not None:
  426. custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
  427. if custom is not None:
  428. return custom
  429. return expl
  430. def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
  431. if util._assertion_pass is not None:
  432. util._assertion_pass(lineno, orig, expl)
  433. def _check_if_assertion_pass_impl() -> bool:
  434. """Check if any plugins implement the pytest_assertion_pass hook
  435. in order not to generate explanation unnecessarily (might be expensive)."""
  436. return True if util._assertion_pass else False
  437. UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
  438. BINOP_MAP = {
  439. ast.BitOr: "|",
  440. ast.BitXor: "^",
  441. ast.BitAnd: "&",
  442. ast.LShift: "<<",
  443. ast.RShift: ">>",
  444. ast.Add: "+",
  445. ast.Sub: "-",
  446. ast.Mult: "*",
  447. ast.Div: "/",
  448. ast.FloorDiv: "//",
  449. ast.Mod: "%%", # escaped for string formatting
  450. ast.Eq: "==",
  451. ast.NotEq: "!=",
  452. ast.Lt: "<",
  453. ast.LtE: "<=",
  454. ast.Gt: ">",
  455. ast.GtE: ">=",
  456. ast.Pow: "**",
  457. ast.Is: "is",
  458. ast.IsNot: "is not",
  459. ast.In: "in",
  460. ast.NotIn: "not in",
  461. ast.MatMult: "@",
  462. }
  463. def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
  464. """Recursively yield node and all its children in depth-first order."""
  465. yield node
  466. for child in ast.iter_child_nodes(node):
  467. yield from traverse_node(child)
  468. @functools.lru_cache(maxsize=1)
  469. def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
  470. """Return a mapping from {lineno: "assertion test expression"}."""
  471. ret: Dict[int, str] = {}
  472. depth = 0
  473. lines: List[str] = []
  474. assert_lineno: Optional[int] = None
  475. seen_lines: Set[int] = set()
  476. def _write_and_reset() -> None:
  477. nonlocal depth, lines, assert_lineno, seen_lines
  478. assert assert_lineno is not None
  479. ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
  480. depth = 0
  481. lines = []
  482. assert_lineno = None
  483. seen_lines = set()
  484. tokens = tokenize.tokenize(io.BytesIO(src).readline)
  485. for tp, source, (lineno, offset), _, line in tokens:
  486. if tp == tokenize.NAME and source == "assert":
  487. assert_lineno = lineno
  488. elif assert_lineno is not None:
  489. # keep track of depth for the assert-message `,` lookup
  490. if tp == tokenize.OP and source in "([{":
  491. depth += 1
  492. elif tp == tokenize.OP and source in ")]}":
  493. depth -= 1
  494. if not lines:
  495. lines.append(line[offset:])
  496. seen_lines.add(lineno)
  497. # a non-nested comma separates the expression from the message
  498. elif depth == 0 and tp == tokenize.OP and source == ",":
  499. # one line assert with message
  500. if lineno in seen_lines and len(lines) == 1:
  501. offset_in_trimmed = offset + len(lines[-1]) - len(line)
  502. lines[-1] = lines[-1][:offset_in_trimmed]
  503. # multi-line assert with message
  504. elif lineno in seen_lines:
  505. lines[-1] = lines[-1][:offset]
  506. # multi line assert with escapd newline before message
  507. else:
  508. lines.append(line[:offset])
  509. _write_and_reset()
  510. elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
  511. _write_and_reset()
  512. elif lines and lineno not in seen_lines:
  513. lines.append(line)
  514. seen_lines.add(lineno)
  515. return ret
  516. def _get_ast_constant_value(value: astStr) -> object:
  517. if sys.version_info >= (3, 8):
  518. return value.value
  519. else:
  520. return value.s
  521. class AssertionRewriter(ast.NodeVisitor):
  522. """Assertion rewriting implementation.
  523. The main entrypoint is to call .run() with an ast.Module instance,
  524. this will then find all the assert statements and rewrite them to
  525. provide intermediate values and a detailed assertion error. See
  526. http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
  527. for an overview of how this works.
  528. The entry point here is .run() which will iterate over all the
  529. statements in an ast.Module and for each ast.Assert statement it
  530. finds call .visit() with it. Then .visit_Assert() takes over and
  531. is responsible for creating new ast statements to replace the
  532. original assert statement: it rewrites the test of an assertion
  533. to provide intermediate values and replace it with an if statement
  534. which raises an assertion error with a detailed explanation in
  535. case the expression is false and calls pytest_assertion_pass hook
  536. if expression is true.
  537. For this .visit_Assert() uses the visitor pattern to visit all the
  538. AST nodes of the ast.Assert.test field, each visit call returning
  539. an AST node and the corresponding explanation string. During this
  540. state is kept in several instance attributes:
  541. :statements: All the AST statements which will replace the assert
  542. statement.
  543. :variables: This is populated by .variable() with each variable
  544. used by the statements so that they can all be set to None at
  545. the end of the statements.
  546. :variable_counter: Counter to create new unique variables needed
  547. by statements. Variables are created using .variable() and
  548. have the form of "@py_assert0".
  549. :expl_stmts: The AST statements which will be executed to get
  550. data from the assertion. This is the code which will construct
  551. the detailed assertion message that is used in the AssertionError
  552. or for the pytest_assertion_pass hook.
  553. :explanation_specifiers: A dict filled by .explanation_param()
  554. with %-formatting placeholders and their corresponding
  555. expressions to use in the building of an assertion message.
  556. This is used by .pop_format_context() to build a message.
  557. :stack: A stack of the explanation_specifiers dicts maintained by
  558. .push_format_context() and .pop_format_context() which allows
  559. to build another %-formatted string while already building one.
  560. :scope: A tuple containing the current scope used for variables_overwrite.
  561. :variables_overwrite: A dict filled with references to variables
  562. that change value within an assert. This happens when a variable is
  563. reassigned with the walrus operator
  564. This state, except the variables_overwrite, is reset on every new assert
  565. statement visited and used by the other visitors.
  566. """
  567. def __init__(
  568. self, module_path: Optional[str], config: Optional[Config], source: bytes
  569. ) -> None:
  570. super().__init__()
  571. self.module_path = module_path
  572. self.config = config
  573. if config is not None:
  574. self.enable_assertion_pass_hook = config.getini(
  575. "enable_assertion_pass_hook"
  576. )
  577. else:
  578. self.enable_assertion_pass_hook = False
  579. self.source = source
  580. self.scope: tuple[ast.AST, ...] = ()
  581. self.variables_overwrite: defaultdict[
  582. tuple[ast.AST, ...], Dict[str, str]
  583. ] = defaultdict(dict)
  584. def run(self, mod: ast.Module) -> None:
  585. """Find all assert statements in *mod* and rewrite them."""
  586. if not mod.body:
  587. # Nothing to do.
  588. return
  589. # We'll insert some special imports at the top of the module, but after any
  590. # docstrings and __future__ imports, so first figure out where that is.
  591. doc = getattr(mod, "docstring", None)
  592. expect_docstring = doc is None
  593. if doc is not None and self.is_rewrite_disabled(doc):
  594. return
  595. pos = 0
  596. item = None
  597. for item in mod.body:
  598. if (
  599. expect_docstring
  600. and isinstance(item, ast.Expr)
  601. and isinstance(item.value, astStr)
  602. and isinstance(_get_ast_constant_value(item.value), str)
  603. ):
  604. doc = _get_ast_constant_value(item.value)
  605. assert isinstance(doc, str)
  606. if self.is_rewrite_disabled(doc):
  607. return
  608. expect_docstring = False
  609. elif (
  610. isinstance(item, ast.ImportFrom)
  611. and item.level == 0
  612. and item.module == "__future__"
  613. ):
  614. pass
  615. else:
  616. break
  617. pos += 1
  618. # Special case: for a decorated function, set the lineno to that of the
  619. # first decorator, not the `def`. Issue #4984.
  620. if isinstance(item, ast.FunctionDef) and item.decorator_list:
  621. lineno = item.decorator_list[0].lineno
  622. else:
  623. lineno = item.lineno
  624. # Now actually insert the special imports.
  625. if sys.version_info >= (3, 10):
  626. aliases = [
  627. ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
  628. ast.alias(
  629. "_pytest.assertion.rewrite",
  630. "@pytest_ar",
  631. lineno=lineno,
  632. col_offset=0,
  633. ),
  634. ]
  635. else:
  636. aliases = [
  637. ast.alias("builtins", "@py_builtins"),
  638. ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
  639. ]
  640. imports = [
  641. ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
  642. ]
  643. mod.body[pos:pos] = imports
  644. # Collect asserts.
  645. self.scope = (mod,)
  646. nodes: List[Union[ast.AST, Sentinel]] = [mod]
  647. while nodes:
  648. node = nodes.pop()
  649. if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
  650. self.scope = tuple((*self.scope, node))
  651. nodes.append(_SCOPE_END_MARKER)
  652. if node == _SCOPE_END_MARKER:
  653. self.scope = self.scope[:-1]
  654. continue
  655. assert isinstance(node, ast.AST)
  656. for name, field in ast.iter_fields(node):
  657. if isinstance(field, list):
  658. new: List[ast.AST] = []
  659. for i, child in enumerate(field):
  660. if isinstance(child, ast.Assert):
  661. # Transform assert.
  662. new.extend(self.visit(child))
  663. else:
  664. new.append(child)
  665. if isinstance(child, ast.AST):
  666. nodes.append(child)
  667. setattr(node, name, new)
  668. elif (
  669. isinstance(field, ast.AST)
  670. # Don't recurse into expressions as they can't contain
  671. # asserts.
  672. and not isinstance(field, ast.expr)
  673. ):
  674. nodes.append(field)
  675. @staticmethod
  676. def is_rewrite_disabled(docstring: str) -> bool:
  677. return "PYTEST_DONT_REWRITE" in docstring
  678. def variable(self) -> str:
  679. """Get a new variable."""
  680. # Use a character invalid in python identifiers to avoid clashing.
  681. name = "@py_assert" + str(next(self.variable_counter))
  682. self.variables.append(name)
  683. return name
  684. def assign(self, expr: ast.expr) -> ast.Name:
  685. """Give *expr* a name."""
  686. name = self.variable()
  687. self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
  688. return ast.Name(name, ast.Load())
  689. def display(self, expr: ast.expr) -> ast.expr:
  690. """Call saferepr on the expression."""
  691. return self.helper("_saferepr", expr)
  692. def helper(self, name: str, *args: ast.expr) -> ast.expr:
  693. """Call a helper in this module."""
  694. py_name = ast.Name("@pytest_ar", ast.Load())
  695. attr = ast.Attribute(py_name, name, ast.Load())
  696. return ast.Call(attr, list(args), [])
  697. def builtin(self, name: str) -> ast.Attribute:
  698. """Return the builtin called *name*."""
  699. builtin_name = ast.Name("@py_builtins", ast.Load())
  700. return ast.Attribute(builtin_name, name, ast.Load())
  701. def explanation_param(self, expr: ast.expr) -> str:
  702. """Return a new named %-formatting placeholder for expr.
  703. This creates a %-formatting placeholder for expr in the
  704. current formatting context, e.g. ``%(py0)s``. The placeholder
  705. and expr are placed in the current format context so that it
  706. can be used on the next call to .pop_format_context().
  707. """
  708. specifier = "py" + str(next(self.variable_counter))
  709. self.explanation_specifiers[specifier] = expr
  710. return "%(" + specifier + ")s"
  711. def push_format_context(self) -> None:
  712. """Create a new formatting context.
  713. The format context is used for when an explanation wants to
  714. have a variable value formatted in the assertion message. In
  715. this case the value required can be added using
  716. .explanation_param(). Finally .pop_format_context() is used
  717. to format a string of %-formatted values as added by
  718. .explanation_param().
  719. """
  720. self.explanation_specifiers: Dict[str, ast.expr] = {}
  721. self.stack.append(self.explanation_specifiers)
  722. def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
  723. """Format the %-formatted string with current format context.
  724. The expl_expr should be an str ast.expr instance constructed from
  725. the %-placeholders created by .explanation_param(). This will
  726. add the required code to format said string to .expl_stmts and
  727. return the ast.Name instance of the formatted string.
  728. """
  729. current = self.stack.pop()
  730. if self.stack:
  731. self.explanation_specifiers = self.stack[-1]
  732. keys = [astStr(key) for key in current.keys()]
  733. format_dict = ast.Dict(keys, list(current.values()))
  734. form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
  735. name = "@py_format" + str(next(self.variable_counter))
  736. if self.enable_assertion_pass_hook:
  737. self.format_variables.append(name)
  738. self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
  739. return ast.Name(name, ast.Load())
  740. def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
  741. """Handle expressions we don't have custom code for."""
  742. assert isinstance(node, ast.expr)
  743. res = self.assign(node)
  744. return res, self.explanation_param(self.display(res))
  745. def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
  746. """Return the AST statements to replace the ast.Assert instance.
  747. This rewrites the test of an assertion to provide
  748. intermediate values and replace it with an if statement which
  749. raises an assertion error with a detailed explanation in case
  750. the expression is false.
  751. """
  752. if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
  753. from _pytest.warning_types import PytestAssertRewriteWarning
  754. import warnings
  755. # TODO: This assert should not be needed.
  756. assert self.module_path is not None
  757. warnings.warn_explicit(
  758. PytestAssertRewriteWarning(
  759. "assertion is always true, perhaps remove parentheses?"
  760. ),
  761. category=None,
  762. filename=self.module_path,
  763. lineno=assert_.lineno,
  764. )
  765. self.statements: List[ast.stmt] = []
  766. self.variables: List[str] = []
  767. self.variable_counter = itertools.count()
  768. if self.enable_assertion_pass_hook:
  769. self.format_variables: List[str] = []
  770. self.stack: List[Dict[str, ast.expr]] = []
  771. self.expl_stmts: List[ast.stmt] = []
  772. self.push_format_context()
  773. # Rewrite assert into a bunch of statements.
  774. top_condition, explanation = self.visit(assert_.test)
  775. negation = ast.UnaryOp(ast.Not(), top_condition)
  776. if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
  777. msg = self.pop_format_context(astStr(explanation))
  778. # Failed
  779. if assert_.msg:
  780. assertmsg = self.helper("_format_assertmsg", assert_.msg)
  781. gluestr = "\n>assert "
  782. else:
  783. assertmsg = astStr("")
  784. gluestr = "assert "
  785. err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg)
  786. err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
  787. err_name = ast.Name("AssertionError", ast.Load())
  788. fmt = self.helper("_format_explanation", err_msg)
  789. exc = ast.Call(err_name, [fmt], [])
  790. raise_ = ast.Raise(exc, None)
  791. statements_fail = []
  792. statements_fail.extend(self.expl_stmts)
  793. statements_fail.append(raise_)
  794. # Passed
  795. fmt_pass = self.helper("_format_explanation", msg)
  796. orig = _get_assertion_exprs(self.source)[assert_.lineno]
  797. hook_call_pass = ast.Expr(
  798. self.helper(
  799. "_call_assertion_pass",
  800. astNum(assert_.lineno),
  801. astStr(orig),
  802. fmt_pass,
  803. )
  804. )
  805. # If any hooks implement assert_pass hook
  806. hook_impl_test = ast.If(
  807. self.helper("_check_if_assertion_pass_impl"),
  808. self.expl_stmts + [hook_call_pass],
  809. [],
  810. )
  811. statements_pass = [hook_impl_test]
  812. # Test for assertion condition
  813. main_test = ast.If(negation, statements_fail, statements_pass)
  814. self.statements.append(main_test)
  815. if self.format_variables:
  816. variables = [
  817. ast.Name(name, ast.Store()) for name in self.format_variables
  818. ]
  819. clear_format = ast.Assign(variables, astNameConstant(None))
  820. self.statements.append(clear_format)
  821. else: # Original assertion rewriting
  822. # Create failure message.
  823. body = self.expl_stmts
  824. self.statements.append(ast.If(negation, body, []))
  825. if assert_.msg:
  826. assertmsg = self.helper("_format_assertmsg", assert_.msg)
  827. explanation = "\n>assert " + explanation
  828. else:
  829. assertmsg = astStr("")
  830. explanation = "assert " + explanation
  831. template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation))
  832. msg = self.pop_format_context(template)
  833. fmt = self.helper("_format_explanation", msg)
  834. err_name = ast.Name("AssertionError", ast.Load())
  835. exc = ast.Call(err_name, [fmt], [])
  836. raise_ = ast.Raise(exc, None)
  837. body.append(raise_)
  838. # Clear temporary variables by setting them to None.
  839. if self.variables:
  840. variables = [ast.Name(name, ast.Store()) for name in self.variables]
  841. clear = ast.Assign(variables, astNameConstant(None))
  842. self.statements.append(clear)
  843. # Fix locations (line numbers/column offsets).
  844. for stmt in self.statements:
  845. for node in traverse_node(stmt):
  846. ast.copy_location(node, assert_)
  847. return self.statements
  848. def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
  849. # This method handles the 'walrus operator' repr of the target
  850. # name if it's a local variable or _should_repr_global_name()
  851. # thinks it's acceptable.
  852. locs = ast.Call(self.builtin("locals"), [], [])
  853. target_id = name.target.id # type: ignore[attr-defined]
  854. inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs])
  855. dorepr = self.helper("_should_repr_global_name", name)
  856. test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
  857. expr = ast.IfExp(test, self.display(name), astStr(target_id))
  858. return name, self.explanation_param(expr)
  859. def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
  860. # Display the repr of the name if it's a local variable or
  861. # _should_repr_global_name() thinks it's acceptable.
  862. locs = ast.Call(self.builtin("locals"), [], [])
  863. inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs])
  864. dorepr = self.helper("_should_repr_global_name", name)
  865. test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
  866. expr = ast.IfExp(test, self.display(name), astStr(name.id))
  867. return name, self.explanation_param(expr)
  868. def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
  869. res_var = self.variable()
  870. expl_list = self.assign(ast.List([], ast.Load()))
  871. app = ast.Attribute(expl_list, "append", ast.Load())
  872. is_or = int(isinstance(boolop.op, ast.Or))
  873. body = save = self.statements
  874. fail_save = self.expl_stmts
  875. levels = len(boolop.values) - 1
  876. self.push_format_context()
  877. # Process each operand, short-circuiting if needed.
  878. for i, v in enumerate(boolop.values):
  879. if i:
  880. fail_inner: List[ast.stmt] = []
  881. # cond is set in a prior loop iteration below
  882. self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
  883. self.expl_stmts = fail_inner
  884. # Check if the left operand is a namedExpr and the value has already been visited
  885. if (
  886. isinstance(v, ast.Compare)
  887. and isinstance(v.left, namedExpr)
  888. and v.left.target.id
  889. in [
  890. ast_expr.id
  891. for ast_expr in boolop.values[:i]
  892. if hasattr(ast_expr, "id")
  893. ]
  894. ):
  895. pytest_temp = self.variable()
  896. self.variables_overwrite[self.scope][
  897. v.left.target.id
  898. ] = v.left # type:ignore[assignment]
  899. v.left.target.id = pytest_temp
  900. self.push_format_context()
  901. res, expl = self.visit(v)
  902. body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
  903. expl_format = self.pop_format_context(astStr(expl))
  904. call = ast.Call(app, [expl_format], [])
  905. self.expl_stmts.append(ast.Expr(call))
  906. if i < levels:
  907. cond: ast.expr = res
  908. if is_or:
  909. cond = ast.UnaryOp(ast.Not(), cond)
  910. inner: List[ast.stmt] = []
  911. self.statements.append(ast.If(cond, inner, []))
  912. self.statements = body = inner
  913. self.statements = save
  914. self.expl_stmts = fail_save
  915. expl_template = self.helper("_format_boolop", expl_list, astNum(is_or))
  916. expl = self.pop_format_context(expl_template)
  917. return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
  918. def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
  919. pattern = UNARY_MAP[unary.op.__class__]
  920. operand_res, operand_expl = self.visit(unary.operand)
  921. res = self.assign(ast.UnaryOp(unary.op, operand_res))
  922. return res, pattern % (operand_expl,)
  923. def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
  924. symbol = BINOP_MAP[binop.op.__class__]
  925. left_expr, left_expl = self.visit(binop.left)
  926. right_expr, right_expl = self.visit(binop.right)
  927. explanation = f"({left_expl} {symbol} {right_expl})"
  928. res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
  929. return res, explanation
  930. def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
  931. new_func, func_expl = self.visit(call.func)
  932. arg_expls = []
  933. new_args = []
  934. new_kwargs = []
  935. for arg in call.args:
  936. if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
  937. self.scope, {}
  938. ):
  939. arg = self.variables_overwrite[self.scope][
  940. arg.id
  941. ] # type:ignore[assignment]
  942. res, expl = self.visit(arg)
  943. arg_expls.append(expl)
  944. new_args.append(res)
  945. for keyword in call.keywords:
  946. if isinstance(
  947. keyword.value, ast.Name
  948. ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
  949. keyword.value = self.variables_overwrite[self.scope][
  950. keyword.value.id
  951. ] # type:ignore[assignment]
  952. res, expl = self.visit(keyword.value)
  953. new_kwargs.append(ast.keyword(keyword.arg, res))
  954. if keyword.arg:
  955. arg_expls.append(keyword.arg + "=" + expl)
  956. else: # **args have `arg` keywords with an .arg of None
  957. arg_expls.append("**" + expl)
  958. expl = "{}({})".format(func_expl, ", ".join(arg_expls))
  959. new_call = ast.Call(new_func, new_args, new_kwargs)
  960. res = self.assign(new_call)
  961. res_expl = self.explanation_param(self.display(res))
  962. outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
  963. return res, outer_expl
  964. def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
  965. # A Starred node can appear in a function call.
  966. res, expl = self.visit(starred.value)
  967. new_starred = ast.Starred(res, starred.ctx)
  968. return new_starred, "*" + expl
  969. def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
  970. if not isinstance(attr.ctx, ast.Load):
  971. return self.generic_visit(attr)
  972. value, value_expl = self.visit(attr.value)
  973. res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
  974. res_expl = self.explanation_param(self.display(res))
  975. pat = "%s\n{%s = %s.%s\n}"
  976. expl = pat % (res_expl, res_expl, value_expl, attr.attr)
  977. return res, expl
  978. def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
  979. self.push_format_context()
  980. # We first check if we have overwritten a variable in the previous assert
  981. if isinstance(
  982. comp.left, ast.Name
  983. ) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
  984. comp.left = self.variables_overwrite[self.scope][
  985. comp.left.id
  986. ] # type:ignore[assignment]
  987. if isinstance(comp.left, namedExpr):
  988. self.variables_overwrite[self.scope][
  989. comp.left.target.id
  990. ] = comp.left # type:ignore[assignment]
  991. left_res, left_expl = self.visit(comp.left)
  992. if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
  993. left_expl = f"({left_expl})"
  994. res_variables = [self.variable() for i in range(len(comp.ops))]
  995. load_names = [ast.Name(v, ast.Load()) for v in res_variables]
  996. store_names = [ast.Name(v, ast.Store()) for v in res_variables]
  997. it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
  998. expls = []
  999. syms = []
  1000. results = [left_res]
  1001. for i, op, next_operand in it:
  1002. if (
  1003. isinstance(next_operand, namedExpr)
  1004. and isinstance(left_res, ast.Name)
  1005. and next_operand.target.id == left_res.id
  1006. ):
  1007. next_operand.target.id = self.variable()
  1008. self.variables_overwrite[self.scope][
  1009. left_res.id
  1010. ] = next_operand # type:ignore[assignment]
  1011. next_res, next_expl = self.visit(next_operand)
  1012. if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
  1013. next_expl = f"({next_expl})"
  1014. results.append(next_res)
  1015. sym = BINOP_MAP[op.__class__]
  1016. syms.append(astStr(sym))
  1017. expl = f"{left_expl} {sym} {next_expl}"
  1018. expls.append(astStr(expl))
  1019. res_expr = ast.Compare(left_res, [op], [next_res])
  1020. self.statements.append(ast.Assign([store_names[i]], res_expr))
  1021. left_res, left_expl = next_res, next_expl
  1022. # Use pytest.assertion.util._reprcompare if that's available.
  1023. expl_call = self.helper(
  1024. "_call_reprcompare",
  1025. ast.Tuple(syms, ast.Load()),
  1026. ast.Tuple(load_names, ast.Load()),
  1027. ast.Tuple(expls, ast.Load()),
  1028. ast.Tuple(results, ast.Load()),
  1029. )
  1030. if len(comp.ops) > 1:
  1031. res: ast.expr = ast.BoolOp(ast.And(), load_names)
  1032. else:
  1033. res = load_names[0]
  1034. return res, self.explanation_param(self.pop_format_context(expl_call))
  1035. def try_makedirs(cache_dir: Path) -> bool:
  1036. """Attempt to create the given directory and sub-directories exist.
  1037. Returns True if successful or if it already exists.
  1038. """
  1039. try:
  1040. os.makedirs(cache_dir, exist_ok=True)
  1041. except (FileNotFoundError, NotADirectoryError, FileExistsError):
  1042. # One of the path components was not a directory:
  1043. # - we're in a zip file
  1044. # - it is a file
  1045. return False
  1046. except PermissionError:
  1047. return False
  1048. except OSError as e:
  1049. # as of now, EROFS doesn't have an equivalent OSError-subclass
  1050. if e.errno == errno.EROFS:
  1051. return False
  1052. raise
  1053. return True
  1054. def get_cache_dir(file_path: Path) -> Path:
  1055. """Return the cache directory to write .pyc files for the given .py file path."""
  1056. if sys.version_info >= (3, 8) and sys.pycache_prefix:
  1057. # given:
  1058. # prefix = '/tmp/pycs'
  1059. # path = '/home/user/proj/test_app.py'
  1060. # we want:
  1061. # '/tmp/pycs/home/user/proj'
  1062. return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
  1063. else:
  1064. # classic pycache directory
  1065. return file_path.parent / "__pycache__"