disttrial.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # -*- test-case-name: twisted.trial._dist.test.test_disttrial -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. This module contains the trial distributed runner, the management class
  6. responsible for coordinating all of trial's behavior at the highest level.
  7. @since: 12.3
  8. """
  9. import os
  10. import sys
  11. from functools import partial
  12. from os.path import isabs
  13. from typing import (
  14. Any,
  15. Awaitable,
  16. Callable,
  17. Iterable,
  18. List,
  19. Optional,
  20. Sequence,
  21. TextIO,
  22. Union,
  23. cast,
  24. )
  25. from unittest import TestCase, TestSuite
  26. from attrs import define, field, frozen
  27. from attrs.converters import default_if_none
  28. from twisted.internet.defer import Deferred, DeferredList, gatherResults
  29. from twisted.internet.interfaces import IReactorCore, IReactorProcess
  30. from twisted.logger import Logger
  31. from twisted.python.failure import Failure
  32. from twisted.python.filepath import FilePath
  33. from twisted.python.lockfile import FilesystemLock
  34. from twisted.python.modules import theSystemPath
  35. from .._asyncrunner import _iterateTests
  36. from ..itrial import IReporter, ITestCase
  37. from ..reporter import UncleanWarningsReporterWrapper
  38. from ..runner import TestHolder
  39. from ..util import _unusedTestDirectory, openTestLog
  40. from . import _WORKER_AMP_STDIN, _WORKER_AMP_STDOUT
  41. from .distreporter import DistReporter
  42. from .functional import countingCalls, discardResult, iterateWhile, takeWhile
  43. from .worker import LocalWorker, LocalWorkerAMP, WorkerAction
  44. class IDistTrialReactor(IReactorCore, IReactorProcess):
  45. """
  46. The reactor interfaces required by disttrial.
  47. """
  48. def _defaultReactor() -> IDistTrialReactor:
  49. """
  50. Get the default reactor, ensuring it is suitable for use with disttrial.
  51. """
  52. import twisted.internet.reactor as defaultReactor
  53. if all(
  54. [
  55. IReactorCore.providedBy(defaultReactor),
  56. IReactorProcess.providedBy(defaultReactor),
  57. ]
  58. ):
  59. # If it provides each of the interfaces then it provides the
  60. # intersection interface. cast it to make it easier to talk about
  61. # later on.
  62. return cast(IDistTrialReactor, defaultReactor)
  63. raise TypeError("Reactor does not provide the right interfaces")
  64. @frozen
  65. class WorkerPoolConfig:
  66. """
  67. Configuration parameters for a pool of test-running workers.
  68. @ivar numWorkers: The number of workers in the pool.
  69. @ivar workingDirectory: A directory in which working directories for each
  70. of the workers will be created.
  71. @ivar workerArguments: Extra arguments to pass the worker process in its
  72. argv.
  73. @ivar logFile: The basename of the overall test log file.
  74. """
  75. numWorkers: int
  76. workingDirectory: FilePath[Any]
  77. workerArguments: Sequence[str]
  78. logFile: str
  79. @define
  80. class StartedWorkerPool:
  81. """
  82. A pool of workers which have already been started.
  83. @ivar workingDirectory: A directory holding the working directories for
  84. each of the workers.
  85. @ivar testDirLock: An object representing the cooperative lock this pool
  86. holds on its working directory.
  87. @ivar testLog: The open overall test log file.
  88. @ivar workers: Objects corresponding to the worker child processes and
  89. adapting between process-related interfaces and C{IProtocol}.
  90. @ivar ampWorkers: AMP protocol instances corresponding to the worker child
  91. processes.
  92. """
  93. workingDirectory: FilePath[Any]
  94. testDirLock: FilesystemLock
  95. testLog: TextIO
  96. workers: List[LocalWorker]
  97. ampWorkers: List[LocalWorkerAMP]
  98. _logger = Logger()
  99. async def run(self, workerAction: WorkerAction[Any]) -> None:
  100. """
  101. Run an action on all of the workers in the pool.
  102. """
  103. await gatherResults(
  104. discardResult(workerAction(worker)) for worker in self.ampWorkers
  105. )
  106. return None
  107. async def join(self) -> None:
  108. """
  109. Shut down all of the workers in the pool.
  110. The pool is unusable after this method is called.
  111. """
  112. results = await DeferredList(
  113. [Deferred.fromCoroutine(worker.exit()) for worker in self.workers],
  114. consumeErrors=True,
  115. )
  116. for n, (succeeded, failure) in enumerate(results):
  117. if not succeeded:
  118. self._logger.failure(f"joining disttrial worker #{n} failed", failure)
  119. del self.workers[:]
  120. del self.ampWorkers[:]
  121. self.testLog.close()
  122. self.testDirLock.unlock()
  123. @frozen
  124. class WorkerPool:
  125. """
  126. Manage a fixed-size collection of child processes which can run tests.
  127. @ivar _config: Configuration for the precise way in which the pool is run.
  128. """
  129. _config: WorkerPoolConfig
  130. def _createLocalWorkers(
  131. self,
  132. protocols: Iterable[LocalWorkerAMP],
  133. workingDirectory: FilePath[Any],
  134. logFile: TextIO,
  135. ) -> List[LocalWorker]:
  136. """
  137. Create local worker protocol instances and return them.
  138. @param protocols: The process/protocol adapters to use for the created
  139. workers.
  140. @param workingDirectory: The base path in which we should run the
  141. workers.
  142. @param logFile: The test log, for workers to write to.
  143. @return: A list of C{quantity} C{LocalWorker} instances.
  144. """
  145. return [
  146. LocalWorker(protocol, workingDirectory.child(str(x)), logFile)
  147. for x, protocol in enumerate(protocols)
  148. ]
  149. def _launchWorkerProcesses(self, spawner, protocols, arguments):
  150. """
  151. Spawn processes from a list of process protocols.
  152. @param spawner: A C{IReactorProcess.spawnProcess} implementation.
  153. @param protocols: An iterable of C{ProcessProtocol} instances.
  154. @param arguments: Extra arguments passed to the processes.
  155. """
  156. workertrialPath = theSystemPath["twisted.trial._dist.workertrial"].filePath.path
  157. childFDs = {
  158. 0: "w",
  159. 1: "r",
  160. 2: "r",
  161. _WORKER_AMP_STDIN: "w",
  162. _WORKER_AMP_STDOUT: "r",
  163. }
  164. environ = os.environ.copy()
  165. # Add an environment variable containing the raw sys.path, to be used
  166. # by subprocesses to try to make it identical to the parent's.
  167. environ["PYTHONPATH"] = os.pathsep.join(sys.path)
  168. for worker in protocols:
  169. args = [sys.executable, workertrialPath]
  170. args.extend(arguments)
  171. spawner(worker, sys.executable, args=args, childFDs=childFDs, env=environ)
  172. async def start(self, reactor: IReactorProcess) -> StartedWorkerPool:
  173. """
  174. Launch all of the workers for this pool.
  175. @return: A started pool object that can run jobs using the workers.
  176. """
  177. testDir, testDirLock = _unusedTestDirectory(
  178. self._config.workingDirectory,
  179. )
  180. if isabs(self._config.logFile):
  181. # Open a log file wherever the user asked.
  182. testLogPath = FilePath(self._config.logFile)
  183. else:
  184. # Open a log file in the chosen working directory (not necessarily
  185. # the same as our configured working directory, if that path was
  186. # in use).
  187. testLogPath = testDir.preauthChild(self._config.logFile)
  188. testLog = openTestLog(testLogPath)
  189. ampWorkers = [LocalWorkerAMP() for x in range(self._config.numWorkers)]
  190. workers = self._createLocalWorkers(
  191. ampWorkers,
  192. testDir,
  193. testLog,
  194. )
  195. self._launchWorkerProcesses(
  196. reactor.spawnProcess,
  197. workers,
  198. self._config.workerArguments,
  199. )
  200. return StartedWorkerPool(
  201. testDir,
  202. testDirLock,
  203. testLog,
  204. workers,
  205. ampWorkers,
  206. )
  207. def shouldContinue(untilFailure: bool, result: IReporter) -> bool:
  208. """
  209. Determine whether the test suite should be iterated again.
  210. @param untilFailure: C{True} if the suite is supposed to run until
  211. failure.
  212. @param result: The test result of the test suite iteration which just
  213. completed.
  214. """
  215. return untilFailure and result.wasSuccessful()
  216. async def runTests(
  217. pool: StartedWorkerPool,
  218. testCases: Iterable[ITestCase],
  219. result: DistReporter,
  220. driveWorker: Callable[
  221. [DistReporter, Sequence[ITestCase], LocalWorkerAMP], Awaitable[None]
  222. ],
  223. ) -> None:
  224. try:
  225. # Run the tests using the worker pool.
  226. await pool.run(partial(driveWorker, result, testCases))
  227. except Exception:
  228. # Exceptions from test code are handled somewhere else. An
  229. # exception here is a bug in the runner itself. The only
  230. # convenient place to put it is in the result, though.
  231. result.original.addError(TestHolder("<runTests>"), Failure())
  232. @define
  233. class DistTrialRunner:
  234. """
  235. A specialized runner for distributed trial. The runner launches a number of
  236. local worker processes which will run tests.
  237. @ivar _maxWorkers: the number of workers to be spawned.
  238. @ivar _exitFirst: ``True`` to stop the run as soon as a test case fails.
  239. ``False`` to run through the whole suite and report all of the results
  240. at the end.
  241. @ivar stream: stream which the reporter will use.
  242. @ivar _reporterFactory: the reporter class to be used.
  243. """
  244. _distReporterFactory = DistReporter
  245. _logger = Logger()
  246. # accepts a `realtime` keyword argument which we can't annotate, so punt
  247. # on the argument annotation
  248. _reporterFactory: Callable[..., IReporter]
  249. _maxWorkers: int
  250. _workerArguments: List[str]
  251. _exitFirst: bool = False
  252. _reactor: IDistTrialReactor = field(
  253. # mypy doesn't understand the converter
  254. default=None,
  255. converter=default_if_none(factory=_defaultReactor), # type: ignore [misc]
  256. )
  257. # mypy doesn't understand the converter
  258. stream: TextIO = field(default=None, converter=default_if_none(sys.stdout)) # type: ignore [misc]
  259. _tracebackFormat: str = "default"
  260. _realTimeErrors: bool = False
  261. _uncleanWarnings: bool = False
  262. _logfile: str = "test.log"
  263. _workingDirectory: str = "_trial_temp"
  264. _workerPoolFactory: Callable[[WorkerPoolConfig], WorkerPool] = WorkerPool
  265. def _makeResult(self) -> DistReporter:
  266. """
  267. Make reporter factory, and wrap it with a L{DistReporter}.
  268. """
  269. reporter = self._reporterFactory(
  270. self.stream, self._tracebackFormat, realtime=self._realTimeErrors
  271. )
  272. if self._uncleanWarnings:
  273. reporter = UncleanWarningsReporterWrapper(reporter)
  274. return self._distReporterFactory(reporter)
  275. def writeResults(self, result):
  276. """
  277. Write test run final outcome to result.
  278. @param result: A C{TestResult} which will print errors and the summary.
  279. """
  280. result.done()
  281. async def _driveWorker(
  282. self,
  283. result: DistReporter,
  284. testCases: Sequence[ITestCase],
  285. worker: LocalWorkerAMP,
  286. ) -> None:
  287. """
  288. Drive a L{LocalWorkerAMP} instance, iterating the tests and calling
  289. C{run} for every one of them.
  290. @param worker: The L{LocalWorkerAMP} to drive.
  291. @param result: The global L{DistReporter} instance.
  292. @param testCases: The global list of tests to iterate.
  293. @return: A coroutine that completes after all of the tests have
  294. completed.
  295. """
  296. async def task(case):
  297. try:
  298. await worker.run(case, result)
  299. except Exception:
  300. result.original.addError(case, Failure())
  301. for case in testCases:
  302. await task(case)
  303. async def runAsync(
  304. self,
  305. suite: Union[TestCase, TestSuite],
  306. untilFailure: bool = False,
  307. ) -> DistReporter:
  308. """
  309. Spawn local worker processes and load tests. After that, run them.
  310. @param suite: A test or suite to be run.
  311. @param untilFailure: If C{True}, continue to run the tests until they
  312. fail.
  313. @return: A coroutine that completes with the test result.
  314. """
  315. # Realize a concrete set of tests to run.
  316. testCases = list(_iterateTests(suite))
  317. # Create a worker pool to use to execute them.
  318. poolStarter = self._workerPoolFactory(
  319. WorkerPoolConfig(
  320. # Don't make it larger than is useful or allowed.
  321. min(len(testCases), self._maxWorkers),
  322. FilePath(self._workingDirectory),
  323. self._workerArguments,
  324. self._logfile,
  325. ),
  326. )
  327. # Announce that we're beginning. countTestCases result is preferred
  328. # (over len(testCases)) because testCases may contain synthetic cases
  329. # for error reporting purposes.
  330. self.stream.write(f"Running {suite.countTestCases()} tests.\n")
  331. # Start the worker pool.
  332. startedPool = await poolStarter.start(self._reactor)
  333. # The condition that will determine whether the test run repeats.
  334. condition = partial(shouldContinue, untilFailure)
  335. # A function that will run the whole suite once.
  336. @countingCalls
  337. async def runAndReport(n: int) -> DistReporter:
  338. if untilFailure:
  339. # If and only if we're running the suite more than once,
  340. # provide a report about which run this is.
  341. self.stream.write(f"Test Pass {n + 1}\n")
  342. result = self._makeResult()
  343. if self._exitFirst:
  344. # Keep giving out tests as long as the result object has only
  345. # seen success.
  346. casesCondition = lambda _: result.original.wasSuccessful()
  347. else:
  348. casesCondition = lambda _: True
  349. await runTests(
  350. startedPool,
  351. takeWhile(casesCondition, testCases),
  352. result,
  353. self._driveWorker,
  354. )
  355. self.writeResults(result)
  356. return result
  357. try:
  358. # Start submitting tests to workers in the pool. Perhaps repeat
  359. # the whole test suite more than once, if appropriate for our
  360. # configuration.
  361. return await iterateWhile(condition, runAndReport)
  362. finally:
  363. # Shut down the worker pool.
  364. await startedPool.join()
  365. def _run(self, test: Union[TestCase, TestSuite], untilFailure: bool) -> IReporter:
  366. result: Union[Failure, DistReporter, None] = None
  367. reactorStopping: bool = False
  368. testsInProgress: Deferred[object]
  369. def capture(r: Union[Failure, DistReporter]) -> None:
  370. nonlocal result
  371. result = r
  372. def maybeStopTests() -> Optional[Deferred[object]]:
  373. nonlocal reactorStopping
  374. reactorStopping = True
  375. if result is None:
  376. testsInProgress.cancel()
  377. return testsInProgress
  378. return None
  379. def maybeStopReactor(result: object) -> object:
  380. if not reactorStopping:
  381. self._reactor.stop()
  382. return result
  383. self._reactor.addSystemEventTrigger("before", "shutdown", maybeStopTests)
  384. testsInProgress = (
  385. Deferred.fromCoroutine(self.runAsync(test, untilFailure))
  386. .addBoth(capture)
  387. .addBoth(maybeStopReactor)
  388. )
  389. self._reactor.run()
  390. if isinstance(result, Failure):
  391. result.raiseException()
  392. # mypy can't see that raiseException raises an exception so we can
  393. # only get here if result is not a Failure, so tell mypy result is
  394. # certainly a DistReporter at this point.
  395. assert isinstance(result, DistReporter), f"{result} is not DistReporter"
  396. # Unwrap the DistReporter to give the caller some regular IReporter
  397. # object. DistReporter isn't type annotated correctly so fix it here.
  398. return cast(IReporter, result.original)
  399. def run(self, test: Union[TestCase, TestSuite]) -> IReporter:
  400. """
  401. Run a reactor and a test suite.
  402. @param test: The test or suite to run.
  403. """
  404. return self._run(test, untilFailure=False)
  405. def runUntilFailure(self, test: Union[TestCase, TestSuite]) -> IReporter:
  406. """
  407. Run the tests with local worker processes until they fail.
  408. @param test: The test or suite to run.
  409. """
  410. return self._run(test, untilFailure=True)