testing.py 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072
  1. # -*- test-case-name: twisted.internet.test.test_testing -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Assorted functionality which is commonly useful when writing unit tests.
  6. """
  7. from __future__ import annotations
  8. from io import BytesIO
  9. from socket import AF_INET, AF_INET6
  10. from time import time
  11. from typing import (
  12. Any,
  13. Callable,
  14. Coroutine,
  15. Generator,
  16. Iterator,
  17. Sequence,
  18. TypeVar,
  19. Union,
  20. overload,
  21. )
  22. from zope.interface import implementedBy, implementer
  23. from zope.interface.verify import verifyClass
  24. from typing_extensions import ParamSpec, Self
  25. from twisted.internet import address, error, protocol, task
  26. from twisted.internet.abstract import _dataMustBeBytes, isIPv6Address
  27. from twisted.internet.address import IPv4Address, IPv6Address, UNIXAddress
  28. from twisted.internet.defer import Deferred, ensureDeferred, succeed
  29. from twisted.internet.error import UnsupportedAddressFamily
  30. from twisted.internet.interfaces import (
  31. IConnector,
  32. IConsumer,
  33. IListeningPort,
  34. IProtocol,
  35. IPushProducer,
  36. IReactorCore,
  37. IReactorFDSet,
  38. IReactorSocket,
  39. IReactorSSL,
  40. IReactorTCP,
  41. IReactorUNIX,
  42. ITransport,
  43. )
  44. from twisted.internet.task import Clock
  45. from twisted.logger import ILogObserver, LogEvent, LogPublisher
  46. from twisted.protocols import basic
  47. from twisted.python import failure
  48. from twisted.trial.unittest import TestCase
  49. __all__ = [
  50. "AccumulatingProtocol",
  51. "LineSendingProtocol",
  52. "FakeDatagramTransport",
  53. "StringTransport",
  54. "StringTransportWithDisconnection",
  55. "StringIOWithoutClosing",
  56. "_FakeConnector",
  57. "_FakePort",
  58. "MemoryReactor",
  59. "MemoryReactorClock",
  60. "RaisingMemoryReactor",
  61. "NonStreamingProducer",
  62. "waitUntilAllDisconnected",
  63. "EventLoggingObserver",
  64. ]
  65. _P = ParamSpec("_P")
  66. class AccumulatingProtocol(protocol.Protocol):
  67. """
  68. L{AccumulatingProtocol} is an L{IProtocol} implementation which collects
  69. the data delivered to it and can fire a Deferred when it is connected or
  70. disconnected.
  71. @ivar made: A flag indicating whether C{connectionMade} has been called.
  72. @ivar data: Bytes giving all the data passed to C{dataReceived}.
  73. @ivar closed: A flag indicated whether C{connectionLost} has been called.
  74. @ivar closedReason: The value of the I{reason} parameter passed to
  75. C{connectionLost}.
  76. @ivar closedDeferred: If set to a L{Deferred}, this will be fired when
  77. C{connectionLost} is called.
  78. """
  79. made = closed = 0
  80. closedReason = None
  81. closedDeferred = None
  82. data = b""
  83. factory = None
  84. def connectionMade(self):
  85. self.made = 1
  86. if self.factory is not None and self.factory.protocolConnectionMade is not None:
  87. d = self.factory.protocolConnectionMade
  88. self.factory.protocolConnectionMade = None
  89. d.callback(self)
  90. def dataReceived(self, data):
  91. self.data += data
  92. def connectionLost(self, reason):
  93. self.closed = 1
  94. self.closedReason = reason
  95. if self.closedDeferred is not None:
  96. d, self.closedDeferred = self.closedDeferred, None
  97. d.callback(None)
  98. class LineSendingProtocol(basic.LineReceiver):
  99. lostConn = False
  100. def __init__(self, lines, start=True):
  101. self.lines = lines[:]
  102. self.response = []
  103. self.start = start
  104. def connectionMade(self):
  105. if self.start:
  106. for line in self.lines:
  107. self.sendLine(line)
  108. def lineReceived(self, line):
  109. if not self.start:
  110. for line in self.lines:
  111. self.sendLine(line)
  112. self.lines = []
  113. self.response.append(line)
  114. def connectionLost(self, reason):
  115. self.lostConn = True
  116. class FakeDatagramTransport:
  117. noAddr = object()
  118. def __init__(self):
  119. self.written = []
  120. def write(self, packet, addr=noAddr):
  121. self.written.append((packet, addr))
  122. @implementer(ITransport, IConsumer, IPushProducer)
  123. class StringTransport:
  124. """
  125. A transport implementation which buffers data in memory and keeps track of
  126. its other state without providing any behavior.
  127. L{StringTransport} has a number of attributes which are not part of any of
  128. the interfaces it claims to implement. These attributes are provided for
  129. testing purposes. Implementation code should not use any of these
  130. attributes; they are not provided by other transports.
  131. @ivar disconnecting: A C{bool} which is C{False} until L{loseConnection} is
  132. called, then C{True}.
  133. @ivar disconnected: A C{bool} which is C{False} until L{abortConnection} is
  134. called, then C{True}.
  135. @ivar producer: If a producer is currently registered, C{producer} is a
  136. reference to it. Otherwise, L{None}.
  137. @ivar streaming: If a producer is currently registered, C{streaming} refers
  138. to the value of the second parameter passed to C{registerProducer}.
  139. @ivar hostAddr: L{None} or an object which will be returned as the host
  140. address of this transport. If L{None}, a nasty tuple will be returned
  141. instead.
  142. @ivar peerAddr: L{None} or an object which will be returned as the peer
  143. address of this transport. If L{None}, a nasty tuple will be returned
  144. instead.
  145. @ivar producerState: The state of this L{StringTransport} in its capacity
  146. as an L{IPushProducer}. One of C{'producing'}, C{'paused'}, or
  147. C{'stopped'}.
  148. @ivar io: A L{io.BytesIO} which holds the data which has been written to
  149. this transport since the last call to L{clear}. Use L{value} instead
  150. of accessing this directly.
  151. @ivar _lenient: By default L{StringTransport} enforces that
  152. L{resumeProducing} is not called after the connection is lost. This is
  153. to ensure that any code that does call L{resumeProducing} after the
  154. connection is lost is not blindly expecting L{resumeProducing} to have
  155. any impact.
  156. However, if your test case is calling L{resumeProducing} after
  157. connection close on purpose, and you know it won't block expecting
  158. further data to show up, this flag may safely be set to L{True}.
  159. Defaults to L{False}.
  160. @type lenient: L{bool}
  161. """
  162. disconnecting = False
  163. disconnected = False
  164. producer = None
  165. streaming = None
  166. hostAddr = None
  167. peerAddr = None
  168. producerState = "producing"
  169. def __init__(self, hostAddress=None, peerAddress=None, lenient=False):
  170. self.clear()
  171. if hostAddress is not None:
  172. self.hostAddr = hostAddress
  173. if peerAddress is not None:
  174. self.peerAddr = peerAddress
  175. self.connected = True
  176. self._lenient = lenient
  177. def clear(self):
  178. """
  179. Discard all data written to this transport so far.
  180. This is not a transport method. It is intended for tests. Do not use
  181. it in implementation code.
  182. """
  183. self.io = BytesIO()
  184. def value(self):
  185. """
  186. Retrieve all data which has been buffered by this transport.
  187. This is not a transport method. It is intended for tests. Do not use
  188. it in implementation code.
  189. @return: A C{bytes} giving all data written to this transport since the
  190. last call to L{clear}.
  191. @rtype: C{bytes}
  192. """
  193. return self.io.getvalue()
  194. # ITransport
  195. def write(self, data):
  196. _dataMustBeBytes(data)
  197. self.io.write(data)
  198. def writeSequence(self, data):
  199. self.io.write(b"".join(data))
  200. def loseConnection(self):
  201. """
  202. Close the connection. Does nothing besides toggle the C{disconnecting}
  203. instance variable to C{True}.
  204. """
  205. self.disconnecting = True
  206. def abortConnection(self):
  207. """
  208. Abort the connection. Same as C{loseConnection}, but also toggles the
  209. C{aborted} instance variable to C{True}.
  210. """
  211. self.disconnected = True
  212. self.loseConnection()
  213. def getPeer(self):
  214. if self.peerAddr is None:
  215. return address.IPv4Address("TCP", "192.168.1.1", 54321)
  216. return self.peerAddr
  217. def getHost(self):
  218. if self.hostAddr is None:
  219. return address.IPv4Address("TCP", "10.0.0.1", 12345)
  220. return self.hostAddr
  221. # IConsumer
  222. def registerProducer(self, producer, streaming):
  223. if self.producer is not None:
  224. raise RuntimeError("Cannot register two producers")
  225. self.producer = producer
  226. self.streaming = streaming
  227. def unregisterProducer(self):
  228. if self.producer is None:
  229. raise RuntimeError("Cannot unregister a producer unless one is registered")
  230. self.producer = None
  231. self.streaming = None
  232. # IPushProducer
  233. def _checkState(self):
  234. if self.disconnecting and not self._lenient:
  235. raise RuntimeError("Cannot resume producing after loseConnection")
  236. if self.producerState == "stopped":
  237. raise RuntimeError("Cannot resume a stopped producer")
  238. def pauseProducing(self):
  239. self._checkState()
  240. self.producerState = "paused"
  241. def stopProducing(self):
  242. self.producerState = "stopped"
  243. def resumeProducing(self):
  244. self._checkState()
  245. self.producerState = "producing"
  246. class StringTransportWithDisconnection(StringTransport):
  247. """
  248. A L{StringTransport} which on disconnection will trigger the connection
  249. lost on the attached protocol.
  250. """
  251. protocol: IProtocol
  252. def loseConnection(self):
  253. if self.connected:
  254. self.connected = False
  255. self.protocol.connectionLost(failure.Failure(error.ConnectionDone("Bye.")))
  256. class StringIOWithoutClosing(BytesIO):
  257. """
  258. A BytesIO that can't be closed.
  259. """
  260. def close(self):
  261. """
  262. Do nothing.
  263. """
  264. @implementer(IListeningPort)
  265. class _FakePort:
  266. """
  267. A fake L{IListeningPort} to be used in tests.
  268. @ivar _hostAddress: The L{IAddress} this L{IListeningPort} is pretending
  269. to be listening on.
  270. """
  271. def __init__(self, hostAddress):
  272. """
  273. @param hostAddress: An L{IAddress} this L{IListeningPort} should
  274. pretend to be listening on.
  275. """
  276. self._hostAddress = hostAddress
  277. def startListening(self):
  278. """
  279. Fake L{IListeningPort.startListening} that doesn't do anything.
  280. """
  281. def stopListening(self):
  282. """
  283. Fake L{IListeningPort.stopListening} that doesn't do anything.
  284. """
  285. def getHost(self):
  286. """
  287. Fake L{IListeningPort.getHost} that returns our L{IAddress}.
  288. """
  289. return self._hostAddress
  290. @implementer(IConnector)
  291. class _FakeConnector:
  292. """
  293. A fake L{IConnector} that allows us to inspect if it has been told to stop
  294. connecting.
  295. @ivar stoppedConnecting: has this connector's
  296. L{_FakeConnector.stopConnecting} method been invoked yet?
  297. @ivar _address: An L{IAddress} provider that represents our destination.
  298. """
  299. _disconnected = False
  300. stoppedConnecting = False
  301. def __init__(self, address):
  302. """
  303. @param address: An L{IAddress} provider that represents this
  304. connector's destination.
  305. """
  306. self._address = address
  307. def stopConnecting(self):
  308. """
  309. Implement L{IConnector.stopConnecting} and set
  310. L{_FakeConnector.stoppedConnecting} to C{True}
  311. """
  312. self.stoppedConnecting = True
  313. def disconnect(self):
  314. """
  315. Implement L{IConnector.disconnect} as a no-op.
  316. """
  317. self._disconnected = True
  318. def connect(self):
  319. """
  320. Implement L{IConnector.connect} as a no-op.
  321. """
  322. def getDestination(self):
  323. """
  324. Implement L{IConnector.getDestination} to return the C{address} passed
  325. to C{__init__}.
  326. """
  327. return self._address
  328. @implementer(
  329. IReactorCore, IReactorTCP, IReactorSSL, IReactorUNIX, IReactorSocket, IReactorFDSet
  330. )
  331. class MemoryReactor:
  332. """
  333. A fake reactor to be used in tests. This reactor doesn't actually do
  334. much that's useful yet. It accepts TCP connection setup attempts, but
  335. they will never succeed.
  336. @ivar hasInstalled: Keeps track of whether this reactor has been installed.
  337. @type hasInstalled: L{bool}
  338. @ivar running: Keeps track of whether this reactor is running.
  339. @type running: L{bool}
  340. @ivar hasStopped: Keeps track of whether this reactor has been stopped.
  341. @type hasStopped: L{bool}
  342. @ivar hasCrashed: Keeps track of whether this reactor has crashed.
  343. @type hasCrashed: L{bool}
  344. @ivar whenRunningHooks: Keeps track of hooks registered with
  345. C{callWhenRunning}.
  346. @type whenRunningHooks: L{list}
  347. @ivar triggers: Keeps track of hooks registered with
  348. C{addSystemEventTrigger}.
  349. @type triggers: L{dict}
  350. @ivar tcpClients: Keeps track of connection attempts (ie, calls to
  351. C{connectTCP}).
  352. @type tcpClients: L{list}
  353. @ivar tcpServers: Keeps track of server listen attempts (ie, calls to
  354. C{listenTCP}).
  355. @type tcpServers: L{list}
  356. @ivar sslClients: Keeps track of connection attempts (ie, calls to
  357. C{connectSSL}).
  358. @type sslClients: L{list}
  359. @ivar sslServers: Keeps track of server listen attempts (ie, calls to
  360. C{listenSSL}).
  361. @type sslServers: L{list}
  362. @ivar unixClients: Keeps track of connection attempts (ie, calls to
  363. C{connectUNIX}).
  364. @type unixClients: L{list}
  365. @ivar unixServers: Keeps track of server listen attempts (ie, calls to
  366. C{listenUNIX}).
  367. @type unixServers: L{list}
  368. @ivar adoptedPorts: Keeps track of server listen attempts (ie, calls to
  369. C{adoptStreamPort}).
  370. @ivar adoptedStreamConnections: Keeps track of stream-oriented
  371. connections added using C{adoptStreamConnection}.
  372. """
  373. def __init__(self):
  374. """
  375. Initialize the tracking lists.
  376. """
  377. self.hasInstalled = False
  378. self.running = False
  379. self.hasRun = True
  380. self.hasStopped = True
  381. self.hasCrashed = True
  382. self.whenRunningHooks = []
  383. self.triggers = {}
  384. self.tcpClients = []
  385. self.tcpServers = []
  386. self.sslClients = []
  387. self.sslServers = []
  388. self.unixClients = []
  389. self.unixServers = []
  390. self.adoptedPorts = []
  391. self.adoptedStreamConnections = []
  392. self.connectors = []
  393. self.readers = set()
  394. self.writers = set()
  395. def install(self):
  396. """
  397. Fake install callable to emulate reactor module installation.
  398. """
  399. self.hasInstalled = True
  400. def resolve(self, name, timeout=10):
  401. """
  402. Not implemented; raises L{NotImplementedError}.
  403. """
  404. raise NotImplementedError()
  405. def run(self):
  406. """
  407. Fake L{IReactorCore.run}.
  408. Sets C{self.running} to L{True}, runs all of the hooks passed to
  409. C{self.callWhenRunning}, then calls C{self.stop} to simulate a request
  410. to stop the reactor.
  411. Sets C{self.hasRun} to L{True}.
  412. """
  413. assert self.running is False
  414. self.running = True
  415. self.hasRun = True
  416. for f, args, kwargs in self.whenRunningHooks:
  417. f(*args, **kwargs)
  418. self.stop()
  419. # That we stopped means we can return, phew.
  420. def stop(self):
  421. """
  422. Fake L{IReactorCore.run}.
  423. Sets C{self.running} to L{False}.
  424. Sets C{self.hasStopped} to L{True}.
  425. """
  426. self.running = False
  427. self.hasStopped = True
  428. def crash(self):
  429. """
  430. Fake L{IReactorCore.crash}.
  431. Sets C{self.running} to L{None}, because that feels crashy.
  432. Sets C{self.hasCrashed} to L{True}.
  433. """
  434. self.running = None
  435. self.hasCrashed = True
  436. def iterate(self, delay=0):
  437. """
  438. Not implemented; raises L{NotImplementedError}.
  439. """
  440. raise NotImplementedError()
  441. def fireSystemEvent(self, eventType):
  442. """
  443. Not implemented; raises L{NotImplementedError}.
  444. """
  445. raise NotImplementedError()
  446. def addSystemEventTrigger(
  447. self,
  448. phase: str,
  449. eventType: str,
  450. callable: Callable[_P, object],
  451. *args: _P.args,
  452. **kw: _P.kwargs,
  453. ) -> None:
  454. """
  455. Fake L{IReactorCore.run}.
  456. Keep track of trigger by appending it to
  457. self.triggers[phase][eventType].
  458. """
  459. phaseTriggers = self.triggers.setdefault(phase, {})
  460. eventTypeTriggers = phaseTriggers.setdefault(eventType, [])
  461. eventTypeTriggers.append((callable, args, kw))
  462. def removeSystemEventTrigger(self, triggerID):
  463. """
  464. Not implemented; raises L{NotImplementedError}.
  465. """
  466. raise NotImplementedError()
  467. def callWhenRunning(
  468. self, callable: Callable[_P, object], *args: _P.args, **kw: _P.kwargs
  469. ) -> None:
  470. """
  471. Fake L{IReactorCore.callWhenRunning}.
  472. Keeps a list of invocations to make in C{self.whenRunningHooks}.
  473. """
  474. self.whenRunningHooks.append((callable, args, kw))
  475. def adoptStreamPort(self, fileno, addressFamily, factory):
  476. """
  477. Fake L{IReactorSocket.adoptStreamPort}, that logs the call and returns
  478. an L{IListeningPort}.
  479. """
  480. if addressFamily == AF_INET:
  481. addr = IPv4Address("TCP", "0.0.0.0", 1234)
  482. elif addressFamily == AF_INET6:
  483. addr = IPv6Address("TCP", "::", 1234)
  484. else:
  485. raise UnsupportedAddressFamily()
  486. self.adoptedPorts.append((fileno, addressFamily, factory))
  487. return _FakePort(addr)
  488. def adoptStreamConnection(self, fileDescriptor, addressFamily, factory):
  489. """
  490. Record the given stream connection in C{adoptedStreamConnections}.
  491. @see:
  492. L{twisted.internet.interfaces.IReactorSocket.adoptStreamConnection}
  493. """
  494. self.adoptedStreamConnections.append((fileDescriptor, addressFamily, factory))
  495. def adoptDatagramPort(self, fileno, addressFamily, protocol, maxPacketSize=8192):
  496. """
  497. Fake L{IReactorSocket.adoptDatagramPort}, that logs the call and
  498. returns a fake L{IListeningPort}.
  499. @see: L{twisted.internet.interfaces.IReactorSocket.adoptDatagramPort}
  500. """
  501. if addressFamily == AF_INET:
  502. addr = IPv4Address("UDP", "0.0.0.0", 1234)
  503. elif addressFamily == AF_INET6:
  504. addr = IPv6Address("UDP", "::", 1234)
  505. else:
  506. raise UnsupportedAddressFamily()
  507. self.adoptedPorts.append((fileno, addressFamily, protocol, maxPacketSize))
  508. return _FakePort(addr)
  509. def listenTCP(self, port, factory, backlog=50, interface=""):
  510. """
  511. Fake L{IReactorTCP.listenTCP}, that logs the call and
  512. returns an L{IListeningPort}.
  513. """
  514. self.tcpServers.append((port, factory, backlog, interface))
  515. if isIPv6Address(interface):
  516. address = IPv6Address("TCP", interface, port)
  517. else:
  518. address = IPv4Address("TCP", "0.0.0.0", port)
  519. return _FakePort(address)
  520. def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
  521. """
  522. Fake L{IReactorTCP.connectTCP}, that logs the call and
  523. returns an L{IConnector}.
  524. """
  525. self.tcpClients.append((host, port, factory, timeout, bindAddress))
  526. if isIPv6Address(host):
  527. conn = _FakeConnector(IPv6Address("TCP", host, port))
  528. else:
  529. conn = _FakeConnector(IPv4Address("TCP", host, port))
  530. factory.startedConnecting(conn)
  531. self.connectors.append(conn)
  532. return conn
  533. def listenSSL(self, port, factory, contextFactory, backlog=50, interface=""):
  534. """
  535. Fake L{IReactorSSL.listenSSL}, that logs the call and
  536. returns an L{IListeningPort}.
  537. """
  538. self.sslServers.append((port, factory, contextFactory, backlog, interface))
  539. return _FakePort(IPv4Address("TCP", "0.0.0.0", port))
  540. def connectSSL(
  541. self, host, port, factory, contextFactory, timeout=30, bindAddress=None
  542. ):
  543. """
  544. Fake L{IReactorSSL.connectSSL}, that logs the call and returns an
  545. L{IConnector}.
  546. """
  547. self.sslClients.append(
  548. (host, port, factory, contextFactory, timeout, bindAddress)
  549. )
  550. conn = _FakeConnector(IPv4Address("TCP", host, port))
  551. factory.startedConnecting(conn)
  552. self.connectors.append(conn)
  553. return conn
  554. def listenUNIX(self, address, factory, backlog=50, mode=0o666, wantPID=0):
  555. """
  556. Fake L{IReactorUNIX.listenUNIX}, that logs the call and returns an
  557. L{IListeningPort}.
  558. """
  559. self.unixServers.append((address, factory, backlog, mode, wantPID))
  560. return _FakePort(UNIXAddress(address))
  561. def connectUNIX(self, address, factory, timeout=30, checkPID=0):
  562. """
  563. Fake L{IReactorUNIX.connectUNIX}, that logs the call and returns an
  564. L{IConnector}.
  565. """
  566. self.unixClients.append((address, factory, timeout, checkPID))
  567. conn = _FakeConnector(UNIXAddress(address))
  568. factory.startedConnecting(conn)
  569. self.connectors.append(conn)
  570. return conn
  571. def addReader(self, reader):
  572. """
  573. Fake L{IReactorFDSet.addReader} which adds the reader to a local set.
  574. """
  575. self.readers.add(reader)
  576. def removeReader(self, reader):
  577. """
  578. Fake L{IReactorFDSet.removeReader} which removes the reader from a
  579. local set.
  580. """
  581. self.readers.discard(reader)
  582. def addWriter(self, writer):
  583. """
  584. Fake L{IReactorFDSet.addWriter} which adds the writer to a local set.
  585. """
  586. self.writers.add(writer)
  587. def removeWriter(self, writer):
  588. """
  589. Fake L{IReactorFDSet.removeWriter} which removes the writer from a
  590. local set.
  591. """
  592. self.writers.discard(writer)
  593. def getReaders(self):
  594. """
  595. Fake L{IReactorFDSet.getReaders} which returns a list of readers from
  596. the local set.
  597. """
  598. return list(self.readers)
  599. def getWriters(self):
  600. """
  601. Fake L{IReactorFDSet.getWriters} which returns a list of writers from
  602. the local set.
  603. """
  604. return list(self.writers)
  605. def removeAll(self):
  606. """
  607. Fake L{IReactorFDSet.removeAll} which removed all readers and writers
  608. from the local sets.
  609. """
  610. self.readers.clear()
  611. self.writers.clear()
  612. for iface in implementedBy(MemoryReactor):
  613. verifyClass(iface, MemoryReactor)
  614. class MemoryReactorClock(MemoryReactor, Clock):
  615. def __init__(self):
  616. MemoryReactor.__init__(self)
  617. Clock.__init__(self)
  618. @implementer(IReactorTCP, IReactorSSL, IReactorUNIX, IReactorSocket)
  619. class RaisingMemoryReactor:
  620. """
  621. A fake reactor to be used in tests. It accepts TCP connection setup
  622. attempts, but they will fail.
  623. @ivar _listenException: An instance of an L{Exception}
  624. @ivar _connectException: An instance of an L{Exception}
  625. """
  626. def __init__(self, listenException=None, connectException=None):
  627. """
  628. @param listenException: An instance of an L{Exception} to raise
  629. when any C{listen} method is called.
  630. @param connectException: An instance of an L{Exception} to raise
  631. when any C{connect} method is called.
  632. """
  633. self._listenException = listenException
  634. self._connectException = connectException
  635. def adoptStreamPort(self, fileno, addressFamily, factory):
  636. """
  637. Fake L{IReactorSocket.adoptStreamPort}, that raises
  638. L{_listenException}.
  639. """
  640. raise self._listenException
  641. def listenTCP(self, port, factory, backlog=50, interface=""):
  642. """
  643. Fake L{IReactorTCP.listenTCP}, that raises L{_listenException}.
  644. """
  645. raise self._listenException
  646. def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
  647. """
  648. Fake L{IReactorTCP.connectTCP}, that raises L{_connectException}.
  649. """
  650. raise self._connectException
  651. def listenSSL(self, port, factory, contextFactory, backlog=50, interface=""):
  652. """
  653. Fake L{IReactorSSL.listenSSL}, that raises L{_listenException}.
  654. """
  655. raise self._listenException
  656. def connectSSL(
  657. self, host, port, factory, contextFactory, timeout=30, bindAddress=None
  658. ):
  659. """
  660. Fake L{IReactorSSL.connectSSL}, that raises L{_connectException}.
  661. """
  662. raise self._connectException
  663. def listenUNIX(self, address, factory, backlog=50, mode=0o666, wantPID=0):
  664. """
  665. Fake L{IReactorUNIX.listenUNIX}, that raises L{_listenException}.
  666. """
  667. raise self._listenException
  668. def connectUNIX(self, address, factory, timeout=30, checkPID=0):
  669. """
  670. Fake L{IReactorUNIX.connectUNIX}, that raises L{_connectException}.
  671. """
  672. raise self._connectException
  673. def adoptDatagramPort(self, fileDescriptor, addressFamily, protocol, maxPacketSize):
  674. """
  675. Fake L{IReactorSocket.adoptDatagramPort}, that raises
  676. L{_connectException}.
  677. """
  678. raise self._connectException
  679. def adoptStreamConnection(self, fileDescriptor, addressFamily, factory):
  680. """
  681. Fake L{IReactorSocket.adoptStreamConnection}, that raises
  682. L{_connectException}.
  683. """
  684. raise self._connectException
  685. class NonStreamingProducer:
  686. """
  687. A pull producer which writes 10 times only.
  688. """
  689. counter = 0
  690. stopped = False
  691. def __init__(self, consumer):
  692. self.consumer = consumer
  693. self.result = Deferred()
  694. def resumeProducing(self):
  695. """
  696. Write the counter value once.
  697. """
  698. if self.consumer is None or self.counter >= 10:
  699. raise RuntimeError("BUG: resume after unregister/stop.")
  700. else:
  701. self.consumer.write(b"%d" % (self.counter,))
  702. self.counter += 1
  703. if self.counter == 10:
  704. self.consumer.unregisterProducer()
  705. self._done()
  706. def pauseProducing(self):
  707. """
  708. An implementation of C{IPushProducer.pauseProducing}. This should never
  709. be called on a pull producer, so this just raises an error.
  710. """
  711. raise RuntimeError("BUG: pause should never be called.")
  712. def _done(self):
  713. """
  714. Fire a L{Deferred} so that users can wait for this to complete.
  715. """
  716. self.consumer = None
  717. d = self.result
  718. del self.result
  719. d.callback(None)
  720. def stopProducing(self):
  721. """
  722. Stop all production.
  723. """
  724. self.stopped = True
  725. self._done()
  726. def waitUntilAllDisconnected(reactor, protocols):
  727. """
  728. Take a list of disconnecting protocols, callback a L{Deferred} when they're
  729. all done.
  730. This is a hack to make some older tests less flaky, as
  731. L{ITransport.loseConnection} is not atomic on all reactors (for example,
  732. the CoreFoundation, which sometimes takes a reactor turn for CFSocket to
  733. realise). New tests should either not use real sockets in testing, or take
  734. the advice in
  735. I{https://jml.io/pages/how-to-disconnect-in-twisted-really.html} to heart.
  736. @param reactor: The reactor to schedule the checks on.
  737. @type reactor: L{IReactorTime}
  738. @param protocols: The protocols to wait for disconnecting.
  739. @type protocols: A L{list} of L{IProtocol}s.
  740. """
  741. lc = None
  742. def _check():
  743. if True not in [x.transport.connected for x in protocols]:
  744. lc.stop()
  745. lc = task.LoopingCall(_check)
  746. lc.clock = reactor
  747. return lc.start(0.01, now=True)
  748. @implementer(ILogObserver)
  749. class EventLoggingObserver(Sequence[LogEvent]):
  750. """
  751. L{ILogObserver} That stores its events in a list for later inspection.
  752. This class is similar to L{LimitedHistoryLogObserver} save that the
  753. internal buffer is public and intended for external inspection. The
  754. observer implements the sequence protocol to ease iteration of the events.
  755. @ivar _events: The events captured by this observer
  756. @type _events: L{list}
  757. """
  758. def __init__(self) -> None:
  759. self._events: list[LogEvent] = []
  760. def __len__(self) -> int:
  761. return len(self._events)
  762. @overload
  763. def __getitem__(self, index: int) -> LogEvent:
  764. ...
  765. @overload
  766. def __getitem__(self, index: slice) -> Sequence[LogEvent]:
  767. ...
  768. def __getitem__(self, index: int | slice) -> LogEvent | Sequence[LogEvent]:
  769. return self._events[index]
  770. def __iter__(self) -> Iterator[LogEvent]:
  771. return iter(self._events)
  772. def __call__(self, event: LogEvent) -> None:
  773. """
  774. @see: L{ILogObserver}
  775. """
  776. self._events.append(event)
  777. @classmethod
  778. def createWithCleanup(cls, testInstance: TestCase, publisher: LogPublisher) -> Self:
  779. """
  780. Create an L{EventLoggingObserver} instance that observes the provided
  781. publisher and will be cleaned up with addCleanup().
  782. @param testInstance: Test instance in which this logger is used.
  783. @type testInstance: L{twisted.trial.unittest.TestCase}
  784. @param publisher: Log publisher to observe.
  785. @type publisher: twisted.logger.LogPublisher
  786. @return: An EventLoggingObserver configured to observe the provided
  787. publisher.
  788. @rtype: L{twisted.test.proto_helpers.EventLoggingObserver}
  789. """
  790. obs = cls()
  791. publisher.addObserver(obs)
  792. testInstance.addCleanup(lambda: publisher.removeObserver(obs))
  793. return obs
  794. _T = TypeVar("_T")
  795. def _benchmarkWithReactor(
  796. test_target: Callable[
  797. [],
  798. Union[
  799. Coroutine[Deferred[Any], Any, _T],
  800. Generator[Deferred[Any], Any, _T],
  801. Deferred[_T],
  802. ],
  803. ]
  804. ) -> Callable[[Any], None]: # pragma: no cover
  805. """
  806. Decorator for running a benchmark tests that loops the reactor.
  807. This is designed to decorate test method executed using pytest and
  808. pytest-benchmark.
  809. """
  810. def deferredWrapper():
  811. return ensureDeferred(test_target())
  812. def benchmark_test(benchmark: Any) -> None:
  813. # Spinning up and spinning down the reactor adds quite a lot of
  814. # overhead to the benchmarked function. So, make sure that the overhead
  815. # isn't making the benchmark meaningless before we bother with any real
  816. # benchmarking.
  817. start = time()
  818. _runReactor(lambda: succeed(None))
  819. justReactorElapsed = time() - start
  820. start = time()
  821. _runReactor(deferredWrapper)
  822. benchmarkElapsed = time() - start
  823. if benchmarkElapsed / justReactorElapsed < 5:
  824. raise RuntimeError( # pragma: no cover
  825. "The function you are benchmarking is fast enough that its "
  826. "run time is being swamped by the startup/shutdown of the "
  827. "reactor. Consider adding a for loop to the benchmark "
  828. "function so it does the work a number of times."
  829. )
  830. benchmark(_runReactor, deferredWrapper)
  831. return benchmark_test
  832. def _runReactor(callback: Callable[[], Deferred[_T]]) -> None: # pragma: no cover
  833. """
  834. (re)Start a reactor that might have been previously started.
  835. """
  836. # Delay to import to prevent side-effect in normal tests that are
  837. # expecting to import twisted.internet.testing while no reactor is
  838. # installed.
  839. from twisted.internet import reactor
  840. errors: list[failure.Failure] = []
  841. deferred = callback()
  842. deferred.addErrback(errors.append)
  843. deferred.addBoth(lambda _: reactor.callLater(0, _stopReactor, reactor)) # type: ignore[attr-defined]
  844. reactor.run(installSignalHandlers=False) # type: ignore[attr-defined]
  845. if errors: # pragma: no cover
  846. # Make sure the test fails in a visible way:
  847. errors[0].raiseException()
  848. def _stopReactor(reactor): # pragma: no cover
  849. """
  850. Stop the reactor and allow it to be re-started later.
  851. """
  852. reactor.stop()
  853. # Allow for on shutdown hooks to execute.
  854. reactor.iterate()
  855. # Since we're going to be poking the reactor's guts, let's make sure what
  856. # we're doing is vaguely reasonable:
  857. assert hasattr(reactor, "_startedBefore")
  858. assert hasattr(reactor, "_started")
  859. assert hasattr(reactor, "_justStopped")
  860. assert hasattr(reactor, "running")
  861. reactor._startedBefore = False
  862. reactor._started = False
  863. reactor._justStopped = False
  864. reactor.running = False
  865. # Start running has consumed the startup events, so we need
  866. # to restore them.
  867. reactor.addSystemEventTrigger("during", "startup", reactor._reallyStartRunning)