tmpdir.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. """Support for providing temporary directories to test functions."""
  2. import dataclasses
  3. import os
  4. import re
  5. import tempfile
  6. from pathlib import Path
  7. from shutil import rmtree
  8. from typing import Any
  9. from typing import Dict
  10. from typing import Generator
  11. from typing import Optional
  12. from typing import TYPE_CHECKING
  13. from typing import Union
  14. from _pytest.nodes import Item
  15. from _pytest.reports import CollectReport
  16. from _pytest.stash import StashKey
  17. if TYPE_CHECKING:
  18. from typing_extensions import Literal
  19. RetentionType = Literal["all", "failed", "none"]
  20. from _pytest.config.argparsing import Parser
  21. from .pathlib import LOCK_TIMEOUT
  22. from .pathlib import make_numbered_dir
  23. from .pathlib import make_numbered_dir_with_cleanup
  24. from .pathlib import rm_rf
  25. from .pathlib import cleanup_dead_symlinks
  26. from _pytest.compat import final, get_user_id
  27. from _pytest.config import Config
  28. from _pytest.config import ExitCode
  29. from _pytest.config import hookimpl
  30. from _pytest.deprecated import check_ispytest
  31. from _pytest.fixtures import fixture
  32. from _pytest.fixtures import FixtureRequest
  33. from _pytest.monkeypatch import MonkeyPatch
  34. tmppath_result_key = StashKey[Dict[str, bool]]()
  35. @final
  36. @dataclasses.dataclass
  37. class TempPathFactory:
  38. """Factory for temporary directories under the common base temp directory.
  39. The base directory can be configured using the ``--basetemp`` option.
  40. """
  41. _given_basetemp: Optional[Path]
  42. # pluggy TagTracerSub, not currently exposed, so Any.
  43. _trace: Any
  44. _basetemp: Optional[Path]
  45. _retention_count: int
  46. _retention_policy: "RetentionType"
  47. def __init__(
  48. self,
  49. given_basetemp: Optional[Path],
  50. retention_count: int,
  51. retention_policy: "RetentionType",
  52. trace,
  53. basetemp: Optional[Path] = None,
  54. *,
  55. _ispytest: bool = False,
  56. ) -> None:
  57. check_ispytest(_ispytest)
  58. if given_basetemp is None:
  59. self._given_basetemp = None
  60. else:
  61. # Use os.path.abspath() to get absolute path instead of resolve() as it
  62. # does not work the same in all platforms (see #4427).
  63. # Path.absolute() exists, but it is not public (see https://bugs.python.org/issue25012).
  64. self._given_basetemp = Path(os.path.abspath(str(given_basetemp)))
  65. self._trace = trace
  66. self._retention_count = retention_count
  67. self._retention_policy = retention_policy
  68. self._basetemp = basetemp
  69. @classmethod
  70. def from_config(
  71. cls,
  72. config: Config,
  73. *,
  74. _ispytest: bool = False,
  75. ) -> "TempPathFactory":
  76. """Create a factory according to pytest configuration.
  77. :meta private:
  78. """
  79. check_ispytest(_ispytest)
  80. count = int(config.getini("tmp_path_retention_count"))
  81. if count < 0:
  82. raise ValueError(
  83. f"tmp_path_retention_count must be >= 0. Current input: {count}."
  84. )
  85. policy = config.getini("tmp_path_retention_policy")
  86. if policy not in ("all", "failed", "none"):
  87. raise ValueError(
  88. f"tmp_path_retention_policy must be either all, failed, none. Current input: {policy}."
  89. )
  90. return cls(
  91. given_basetemp=config.option.basetemp,
  92. trace=config.trace.get("tmpdir"),
  93. retention_count=count,
  94. retention_policy=policy,
  95. _ispytest=True,
  96. )
  97. def _ensure_relative_to_basetemp(self, basename: str) -> str:
  98. basename = os.path.normpath(basename)
  99. if (self.getbasetemp() / basename).resolve().parent != self.getbasetemp():
  100. raise ValueError(f"{basename} is not a normalized and relative path")
  101. return basename
  102. def mktemp(self, basename: str, numbered: bool = True) -> Path:
  103. """Create a new temporary directory managed by the factory.
  104. :param basename:
  105. Directory base name, must be a relative path.
  106. :param numbered:
  107. If ``True``, ensure the directory is unique by adding a numbered
  108. suffix greater than any existing one: ``basename="foo-"`` and ``numbered=True``
  109. means that this function will create directories named ``"foo-0"``,
  110. ``"foo-1"``, ``"foo-2"`` and so on.
  111. :returns:
  112. The path to the new directory.
  113. """
  114. basename = self._ensure_relative_to_basetemp(basename)
  115. if not numbered:
  116. p = self.getbasetemp().joinpath(basename)
  117. p.mkdir(mode=0o700)
  118. else:
  119. p = make_numbered_dir(root=self.getbasetemp(), prefix=basename, mode=0o700)
  120. self._trace("mktemp", p)
  121. return p
  122. def getbasetemp(self) -> Path:
  123. """Return the base temporary directory, creating it if needed.
  124. :returns:
  125. The base temporary directory.
  126. """
  127. if self._basetemp is not None:
  128. return self._basetemp
  129. if self._given_basetemp is not None:
  130. basetemp = self._given_basetemp
  131. if basetemp.exists():
  132. rm_rf(basetemp)
  133. basetemp.mkdir(mode=0o700)
  134. basetemp = basetemp.resolve()
  135. else:
  136. from_env = os.environ.get("PYTEST_DEBUG_TEMPROOT")
  137. temproot = Path(from_env or tempfile.gettempdir()).resolve()
  138. user = get_user() or "unknown"
  139. # use a sub-directory in the temproot to speed-up
  140. # make_numbered_dir() call
  141. rootdir = temproot.joinpath(f"pytest-of-{user}")
  142. try:
  143. rootdir.mkdir(mode=0o700, exist_ok=True)
  144. except OSError:
  145. # getuser() likely returned illegal characters for the platform, use unknown back off mechanism
  146. rootdir = temproot.joinpath("pytest-of-unknown")
  147. rootdir.mkdir(mode=0o700, exist_ok=True)
  148. # Because we use exist_ok=True with a predictable name, make sure
  149. # we are the owners, to prevent any funny business (on unix, where
  150. # temproot is usually shared).
  151. # Also, to keep things private, fixup any world-readable temp
  152. # rootdir's permissions. Historically 0o755 was used, so we can't
  153. # just error out on this, at least for a while.
  154. uid = get_user_id()
  155. if uid is not None:
  156. rootdir_stat = rootdir.stat()
  157. if rootdir_stat.st_uid != uid:
  158. raise OSError(
  159. f"The temporary directory {rootdir} is not owned by the current user. "
  160. "Fix this and try again."
  161. )
  162. if (rootdir_stat.st_mode & 0o077) != 0:
  163. os.chmod(rootdir, rootdir_stat.st_mode & ~0o077)
  164. keep = self._retention_count
  165. if self._retention_policy == "none":
  166. keep = 0
  167. basetemp = make_numbered_dir_with_cleanup(
  168. prefix="pytest-",
  169. root=rootdir,
  170. keep=keep,
  171. lock_timeout=LOCK_TIMEOUT,
  172. mode=0o700,
  173. )
  174. assert basetemp is not None, basetemp
  175. self._basetemp = basetemp
  176. self._trace("new basetemp", basetemp)
  177. return basetemp
  178. def get_user() -> Optional[str]:
  179. """Return the current user name, or None if getuser() does not work
  180. in the current environment (see #1010)."""
  181. try:
  182. # In some exotic environments, getpass may not be importable.
  183. import getpass
  184. return getpass.getuser()
  185. except (ImportError, KeyError):
  186. return None
  187. def pytest_configure(config: Config) -> None:
  188. """Create a TempPathFactory and attach it to the config object.
  189. This is to comply with existing plugins which expect the handler to be
  190. available at pytest_configure time, but ideally should be moved entirely
  191. to the tmp_path_factory session fixture.
  192. """
  193. mp = MonkeyPatch()
  194. config.add_cleanup(mp.undo)
  195. _tmp_path_factory = TempPathFactory.from_config(config, _ispytest=True)
  196. mp.setattr(config, "_tmp_path_factory", _tmp_path_factory, raising=False)
  197. def pytest_addoption(parser: Parser) -> None:
  198. parser.addini(
  199. "tmp_path_retention_count",
  200. help="How many sessions should we keep the `tmp_path` directories, according to `tmp_path_retention_policy`.",
  201. default=3,
  202. )
  203. parser.addini(
  204. "tmp_path_retention_policy",
  205. help="Controls which directories created by the `tmp_path` fixture are kept around, based on test outcome. "
  206. "(all/failed/none)",
  207. default="all",
  208. )
  209. @fixture(scope="session")
  210. def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
  211. """Return a :class:`pytest.TempPathFactory` instance for the test session."""
  212. # Set dynamically by pytest_configure() above.
  213. return request.config._tmp_path_factory # type: ignore
  214. def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
  215. name = request.node.name
  216. name = re.sub(r"[\W]", "_", name)
  217. MAXVAL = 30
  218. name = name[:MAXVAL]
  219. return factory.mktemp(name, numbered=True)
  220. @fixture
  221. def tmp_path(
  222. request: FixtureRequest, tmp_path_factory: TempPathFactory
  223. ) -> Generator[Path, None, None]:
  224. """Return a temporary directory path object which is unique to each test
  225. function invocation, created as a sub directory of the base temporary
  226. directory.
  227. By default, a new base temporary directory is created each test session,
  228. and old bases are removed after 3 sessions, to aid in debugging.
  229. This behavior can be configured with :confval:`tmp_path_retention_count` and
  230. :confval:`tmp_path_retention_policy`.
  231. If ``--basetemp`` is used then it is cleared each session. See :ref:`base
  232. temporary directory`.
  233. The returned object is a :class:`pathlib.Path` object.
  234. """
  235. path = _mk_tmp(request, tmp_path_factory)
  236. yield path
  237. # Remove the tmpdir if the policy is "failed" and the test passed.
  238. tmp_path_factory: TempPathFactory = request.session.config._tmp_path_factory # type: ignore
  239. policy = tmp_path_factory._retention_policy
  240. result_dict = request.node.stash[tmppath_result_key]
  241. if policy == "failed" and result_dict.get("call", True):
  242. # We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
  243. # permissions, etc, in which case we ignore it.
  244. rmtree(path, ignore_errors=True)
  245. del request.node.stash[tmppath_result_key]
  246. def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]):
  247. """After each session, remove base directory if all the tests passed,
  248. the policy is "failed", and the basetemp is not specified by a user.
  249. """
  250. tmp_path_factory: TempPathFactory = session.config._tmp_path_factory
  251. basetemp = tmp_path_factory._basetemp
  252. if basetemp is None:
  253. return
  254. policy = tmp_path_factory._retention_policy
  255. if (
  256. exitstatus == 0
  257. and policy == "failed"
  258. and tmp_path_factory._given_basetemp is None
  259. ):
  260. if basetemp.is_dir():
  261. # We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
  262. # permissions, etc, in which case we ignore it.
  263. rmtree(basetemp, ignore_errors=True)
  264. # Remove dead symlinks.
  265. if basetemp.is_dir():
  266. cleanup_dead_symlinks(basetemp)
  267. @hookimpl(tryfirst=True, hookwrapper=True)
  268. def pytest_runtest_makereport(item: Item, call):
  269. outcome = yield
  270. result: CollectReport = outcome.get_result()
  271. empty: Dict[str, bool] = {}
  272. item.stash.setdefault(tmppath_result_key, empty)[result.when] = result.passed