loopback.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # -*- test-case-name: twisted.test.test_loopback -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Testing support for protocols -- loopback between client and server.
  6. """
  7. from __future__ import division, absolute_import
  8. # system imports
  9. import tempfile
  10. from zope.interface import implementer
  11. # Twisted Imports
  12. from twisted.protocols import policies
  13. from twisted.internet import interfaces, protocol, main, defer
  14. from twisted.internet.task import deferLater
  15. from twisted.python import failure
  16. from twisted.internet.interfaces import IAddress
  17. class _LoopbackQueue(object):
  18. """
  19. Trivial wrapper around a list to give it an interface like a queue, which
  20. the addition of also sending notifications by way of a Deferred whenever
  21. the list has something added to it.
  22. """
  23. _notificationDeferred = None
  24. disconnect = False
  25. def __init__(self):
  26. self._queue = []
  27. def put(self, v):
  28. self._queue.append(v)
  29. if self._notificationDeferred is not None:
  30. d, self._notificationDeferred = self._notificationDeferred, None
  31. d.callback(None)
  32. def __nonzero__(self):
  33. return bool(self._queue)
  34. __bool__ = __nonzero__
  35. def get(self):
  36. return self._queue.pop(0)
  37. @implementer(IAddress)
  38. class _LoopbackAddress(object):
  39. pass
  40. @implementer(interfaces.ITransport, interfaces.IConsumer)
  41. class _LoopbackTransport(object):
  42. disconnecting = False
  43. producer = None
  44. # ITransport
  45. def __init__(self, q):
  46. self.q = q
  47. def write(self, data):
  48. if not isinstance(data, bytes):
  49. raise TypeError("Can only write bytes to ITransport")
  50. self.q.put(data)
  51. def writeSequence(self, iovec):
  52. self.q.put(b''.join(iovec))
  53. def loseConnection(self):
  54. self.q.disconnect = True
  55. self.q.put(None)
  56. def abortConnection(self):
  57. """
  58. Abort the connection. Same as L{loseConnection}.
  59. """
  60. self.loseConnection()
  61. def getPeer(self):
  62. return _LoopbackAddress()
  63. def getHost(self):
  64. return _LoopbackAddress()
  65. # IConsumer
  66. def registerProducer(self, producer, streaming):
  67. assert self.producer is None
  68. self.producer = producer
  69. self.streamingProducer = streaming
  70. self._pollProducer()
  71. def unregisterProducer(self):
  72. assert self.producer is not None
  73. self.producer = None
  74. def _pollProducer(self):
  75. if self.producer is not None and not self.streamingProducer:
  76. self.producer.resumeProducing()
  77. def identityPumpPolicy(queue, target):
  78. """
  79. L{identityPumpPolicy} is a policy which delivers each chunk of data written
  80. to the given queue as-is to the target.
  81. This isn't a particularly realistic policy.
  82. @see: L{loopbackAsync}
  83. """
  84. while queue:
  85. bytes = queue.get()
  86. if bytes is None:
  87. break
  88. target.dataReceived(bytes)
  89. def collapsingPumpPolicy(queue, target):
  90. """
  91. L{collapsingPumpPolicy} is a policy which collapses all outstanding chunks
  92. into a single string and delivers it to the target.
  93. @see: L{loopbackAsync}
  94. """
  95. bytes = []
  96. while queue:
  97. chunk = queue.get()
  98. if chunk is None:
  99. break
  100. bytes.append(chunk)
  101. if bytes:
  102. target.dataReceived(b''.join(bytes))
  103. def loopbackAsync(server, client, pumpPolicy=identityPumpPolicy):
  104. """
  105. Establish a connection between C{server} and C{client} then transfer data
  106. between them until the connection is closed. This is often useful for
  107. testing a protocol.
  108. @param server: The protocol instance representing the server-side of this
  109. connection.
  110. @param client: The protocol instance representing the client-side of this
  111. connection.
  112. @param pumpPolicy: When either C{server} or C{client} writes to its
  113. transport, the string passed in is added to a queue of data for the
  114. other protocol. Eventually, C{pumpPolicy} will be called with one such
  115. queue and the corresponding protocol object. The pump policy callable
  116. is responsible for emptying the queue and passing the strings it
  117. contains to the given protocol's C{dataReceived} method. The signature
  118. of C{pumpPolicy} is C{(queue, protocol)}. C{queue} is an object with a
  119. C{get} method which will return the next string written to the
  120. transport, or L{None} if the transport has been disconnected, and which
  121. evaluates to C{True} if and only if there are more items to be
  122. retrieved via C{get}.
  123. @return: A L{Deferred} which fires when the connection has been closed and
  124. both sides have received notification of this.
  125. """
  126. serverToClient = _LoopbackQueue()
  127. clientToServer = _LoopbackQueue()
  128. server.makeConnection(_LoopbackTransport(serverToClient))
  129. client.makeConnection(_LoopbackTransport(clientToServer))
  130. return _loopbackAsyncBody(
  131. server, serverToClient, client, clientToServer, pumpPolicy)
  132. def _loopbackAsyncBody(server, serverToClient, client, clientToServer,
  133. pumpPolicy):
  134. """
  135. Transfer bytes from the output queue of each protocol to the input of the other.
  136. @param server: The protocol instance representing the server-side of this
  137. connection.
  138. @param serverToClient: The L{_LoopbackQueue} holding the server's output.
  139. @param client: The protocol instance representing the client-side of this
  140. connection.
  141. @param clientToServer: The L{_LoopbackQueue} holding the client's output.
  142. @param pumpPolicy: See L{loopbackAsync}.
  143. @return: A L{Deferred} which fires when the connection has been closed and
  144. both sides have received notification of this.
  145. """
  146. def pump(source, q, target):
  147. sent = False
  148. if q:
  149. pumpPolicy(q, target)
  150. sent = True
  151. if sent and not q:
  152. # A write buffer has now been emptied. Give any producer on that
  153. # side an opportunity to produce more data.
  154. source.transport._pollProducer()
  155. return sent
  156. while 1:
  157. disconnect = clientSent = serverSent = False
  158. # Deliver the data which has been written.
  159. serverSent = pump(server, serverToClient, client)
  160. clientSent = pump(client, clientToServer, server)
  161. if not clientSent and not serverSent:
  162. # Neither side wrote any data. Wait for some new data to be added
  163. # before trying to do anything further.
  164. d = defer.Deferred()
  165. clientToServer._notificationDeferred = d
  166. serverToClient._notificationDeferred = d
  167. d.addCallback(
  168. _loopbackAsyncContinue,
  169. server, serverToClient, client, clientToServer, pumpPolicy)
  170. return d
  171. if serverToClient.disconnect:
  172. # The server wants to drop the connection. Flush any remaining
  173. # data it has.
  174. disconnect = True
  175. pump(server, serverToClient, client)
  176. elif clientToServer.disconnect:
  177. # The client wants to drop the connection. Flush any remaining
  178. # data it has.
  179. disconnect = True
  180. pump(client, clientToServer, server)
  181. if disconnect:
  182. # Someone wanted to disconnect, so okay, the connection is gone.
  183. server.connectionLost(failure.Failure(main.CONNECTION_DONE))
  184. client.connectionLost(failure.Failure(main.CONNECTION_DONE))
  185. return defer.succeed(None)
  186. def _loopbackAsyncContinue(ignored, server, serverToClient, client,
  187. clientToServer, pumpPolicy):
  188. # Clear the Deferred from each message queue, since it has already fired
  189. # and cannot be used again.
  190. clientToServer._notificationDeferred = None
  191. serverToClient._notificationDeferred = None
  192. # Schedule some more byte-pushing to happen. This isn't done
  193. # synchronously because no actual transport can re-enter dataReceived as
  194. # a result of calling write, and doing this synchronously could result
  195. # in that.
  196. from twisted.internet import reactor
  197. return deferLater(
  198. reactor, 0,
  199. _loopbackAsyncBody,
  200. server, serverToClient, client, clientToServer, pumpPolicy)
  201. @implementer(interfaces.ITransport, interfaces.IConsumer)
  202. class LoopbackRelay:
  203. buffer = b''
  204. shouldLose = 0
  205. disconnecting = 0
  206. producer = None
  207. def __init__(self, target, logFile=None):
  208. self.target = target
  209. self.logFile = logFile
  210. def write(self, data):
  211. self.buffer = self.buffer + data
  212. if self.logFile:
  213. self.logFile.write("loopback writing %s\n" % repr(data))
  214. def writeSequence(self, iovec):
  215. self.write(b"".join(iovec))
  216. def clearBuffer(self):
  217. if self.shouldLose == -1:
  218. return
  219. if self.producer:
  220. self.producer.resumeProducing()
  221. if self.buffer:
  222. if self.logFile:
  223. self.logFile.write("loopback receiving %s\n" % repr(self.buffer))
  224. buffer = self.buffer
  225. self.buffer = b''
  226. self.target.dataReceived(buffer)
  227. if self.shouldLose == 1:
  228. self.shouldLose = -1
  229. self.target.connectionLost(failure.Failure(main.CONNECTION_DONE))
  230. def loseConnection(self):
  231. if self.shouldLose != -1:
  232. self.shouldLose = 1
  233. def getHost(self):
  234. return 'loopback'
  235. def getPeer(self):
  236. return 'loopback'
  237. def registerProducer(self, producer, streaming):
  238. self.producer = producer
  239. def unregisterProducer(self):
  240. self.producer = None
  241. def logPrefix(self):
  242. return 'Loopback(%r)' % (self.target.__class__.__name__,)
  243. class LoopbackClientFactory(protocol.ClientFactory):
  244. def __init__(self, protocol):
  245. self.disconnected = 0
  246. self.deferred = defer.Deferred()
  247. self.protocol = protocol
  248. def buildProtocol(self, addr):
  249. return self.protocol
  250. def clientConnectionLost(self, connector, reason):
  251. self.disconnected = 1
  252. self.deferred.callback(None)
  253. class _FireOnClose(policies.ProtocolWrapper):
  254. def __init__(self, protocol, factory):
  255. policies.ProtocolWrapper.__init__(self, protocol, factory)
  256. self.deferred = defer.Deferred()
  257. def connectionLost(self, reason):
  258. policies.ProtocolWrapper.connectionLost(self, reason)
  259. self.deferred.callback(None)
  260. def loopbackTCP(server, client, port=0, noisy=True):
  261. """Run session between server and client protocol instances over TCP."""
  262. from twisted.internet import reactor
  263. f = policies.WrappingFactory(protocol.Factory())
  264. serverWrapper = _FireOnClose(f, server)
  265. f.noisy = noisy
  266. f.buildProtocol = lambda addr: serverWrapper
  267. serverPort = reactor.listenTCP(port, f, interface='127.0.0.1')
  268. clientF = LoopbackClientFactory(client)
  269. clientF.noisy = noisy
  270. reactor.connectTCP('127.0.0.1', serverPort.getHost().port, clientF)
  271. d = clientF.deferred
  272. d.addCallback(lambda x: serverWrapper.deferred)
  273. d.addCallback(lambda x: serverPort.stopListening())
  274. return d
  275. def loopbackUNIX(server, client, noisy=True):
  276. """Run session between server and client protocol instances over UNIX socket."""
  277. path = tempfile.mktemp()
  278. from twisted.internet import reactor
  279. f = policies.WrappingFactory(protocol.Factory())
  280. serverWrapper = _FireOnClose(f, server)
  281. f.noisy = noisy
  282. f.buildProtocol = lambda addr: serverWrapper
  283. serverPort = reactor.listenUNIX(path, f)
  284. clientF = LoopbackClientFactory(client)
  285. clientF.noisy = noisy
  286. reactor.connectUNIX(path, clientF)
  287. d = clientF.deferred
  288. d.addCallback(lambda x: serverWrapper.deferred)
  289. d.addCallback(lambda x: serverPort.stopListening())
  290. return d