udp.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. UDP support for IOCP reactor
  5. """
  6. from __future__ import annotations
  7. import errno
  8. import socket
  9. import struct
  10. import warnings
  11. from typing import TYPE_CHECKING, Optional
  12. from zope.interface import implementer
  13. from twisted.internet import address, defer, error, interfaces
  14. from twisted.internet._multicast import MulticastMixin
  15. from twisted.internet.abstract import isIPAddress, isIPv6Address
  16. from twisted.internet.iocpreactor import abstract, iocpsupport as _iocp
  17. from twisted.internet.iocpreactor.const import (
  18. ERROR_CONNECTION_REFUSED,
  19. ERROR_IO_PENDING,
  20. ERROR_PORT_UNREACHABLE,
  21. )
  22. from twisted.internet.iocpreactor.interfaces import IReadWriteHandle
  23. from twisted.internet.protocol import AbstractDatagramProtocol
  24. from twisted.python import log
  25. if TYPE_CHECKING:
  26. from twisted.internet.iocpreactor.reactor import IOCPReactor
  27. @implementer(
  28. IReadWriteHandle,
  29. interfaces.IListeningPort,
  30. interfaces.IUDPTransport,
  31. interfaces.ISystemHandle,
  32. )
  33. class Port(abstract.FileHandle):
  34. """
  35. UDP port, listening for packets.
  36. @ivar addressFamily: L{socket.AF_INET} or L{socket.AF_INET6}, depending on
  37. whether this port is listening on an IPv4 address or an IPv6 address.
  38. """
  39. reactor: IOCPReactor
  40. addressFamily = socket.AF_INET
  41. socketType = socket.SOCK_DGRAM
  42. dynamicReadBuffers = False
  43. # Actual port number being listened on, only set to a non-None
  44. # value when we are actually listening.
  45. _realPortNumber: Optional[int] = None
  46. def __init__(
  47. self,
  48. port: int,
  49. proto: AbstractDatagramProtocol,
  50. interface: str = "",
  51. maxPacketSize: int = 8192,
  52. reactor: IOCPReactor | None = None,
  53. ) -> None:
  54. """
  55. Initialize with a numeric port to listen on.
  56. """
  57. self.port = port
  58. self.protocol = proto
  59. self.readBufferSize = maxPacketSize
  60. self.interface = interface
  61. self.setLogStr()
  62. self._connectedAddr = None
  63. self._setAddressFamily()
  64. abstract.FileHandle.__init__(self, reactor)
  65. skt = socket.socket(self.addressFamily, self.socketType)
  66. addrLen = _iocp.maxAddrLen(skt.fileno())
  67. self.addressBuffer = bytearray(addrLen)
  68. # WSARecvFrom takes an int
  69. self.addressLengthBuffer = bytearray(struct.calcsize("i"))
  70. def _setAddressFamily(self):
  71. """
  72. Resolve address family for the socket.
  73. """
  74. if isIPv6Address(self.interface):
  75. self.addressFamily = socket.AF_INET6
  76. elif isIPAddress(self.interface):
  77. self.addressFamily = socket.AF_INET
  78. elif self.interface:
  79. raise error.InvalidAddressError(
  80. self.interface, "not an IPv4 or IPv6 address"
  81. )
  82. def __repr__(self) -> str:
  83. if self._realPortNumber is not None:
  84. return f"<{self.protocol.__class__} on {self._realPortNumber}>"
  85. else:
  86. return f"<{self.protocol.__class__} not connected>"
  87. def getHandle(self):
  88. """
  89. Return a socket object.
  90. """
  91. return self.socket
  92. def startListening(self):
  93. """
  94. Create and bind my socket, and begin listening on it.
  95. This is called on unserialization, and must be called after creating a
  96. server to begin listening on the specified port.
  97. """
  98. self._bindSocket()
  99. self._connectToProtocol()
  100. def createSocket(self) -> socket.socket:
  101. return self.reactor.createSocket(self.addressFamily, self.socketType)
  102. def _bindSocket(self):
  103. try:
  104. skt = self.createSocket()
  105. skt.bind((self.interface, self.port))
  106. except OSError as le:
  107. raise error.CannotListenError(self.interface, self.port, le)
  108. # Make sure that if we listened on port 0, we update that to
  109. # reflect what the OS actually assigned us.
  110. self._realPortNumber = skt.getsockname()[1]
  111. log.msg(
  112. "%s starting on %s"
  113. % (self._getLogPrefix(self.protocol), self._realPortNumber)
  114. )
  115. self.connected = True
  116. self.socket = skt
  117. self.getFileHandle = self.socket.fileno
  118. def _connectToProtocol(self):
  119. self.protocol.makeConnection(self)
  120. self.startReading()
  121. self.reactor.addActiveHandle(self)
  122. def cbRead(self, rc, data, evt):
  123. if self.reading:
  124. self.handleRead(rc, data, evt)
  125. self.doRead()
  126. def handleRead(self, rc, data, evt):
  127. if rc in (
  128. errno.WSAECONNREFUSED,
  129. errno.WSAECONNRESET,
  130. ERROR_CONNECTION_REFUSED,
  131. ERROR_PORT_UNREACHABLE,
  132. ):
  133. if self._connectedAddr:
  134. self.protocol.connectionRefused()
  135. elif rc:
  136. log.msg(
  137. "error in recvfrom -- %s (%s)"
  138. % (errno.errorcode.get(rc, "unknown error"), rc)
  139. )
  140. else:
  141. try:
  142. self.protocol.datagramReceived(
  143. bytes(evt.buff[:data]), _iocp.makesockaddr(evt.addr_buff)
  144. )
  145. except BaseException:
  146. log.err()
  147. def doRead(self):
  148. evt = _iocp.Event(self.cbRead, self)
  149. evt.buff = buff = self._readBuffers[0]
  150. evt.addr_buff = addr_buff = self.addressBuffer
  151. evt.addr_len_buff = addr_len_buff = self.addressLengthBuffer
  152. rc, data = _iocp.recvfrom(
  153. self.getFileHandle(), buff, addr_buff, addr_len_buff, evt
  154. )
  155. if rc and rc != ERROR_IO_PENDING:
  156. # If the error was not 0 or IO_PENDING then that means recvfrom() hit a
  157. # failure condition. In this situation recvfrom() gives us our response
  158. # right away and we don't need to wait for Windows to call the callback
  159. # on our event. In fact, windows will not call it for us so we must call it
  160. # ourselves manually
  161. self.reactor.callLater(0, self.cbRead, rc, data, evt)
  162. def write(self, datagram, addr=None):
  163. """
  164. Write a datagram.
  165. @param addr: should be a tuple (ip, port), can be None in connected
  166. mode.
  167. """
  168. if self._connectedAddr:
  169. assert addr in (None, self._connectedAddr)
  170. try:
  171. return self.socket.send(datagram)
  172. except OSError as se:
  173. no = se.args[0]
  174. if no == errno.WSAEINTR:
  175. return self.write(datagram)
  176. elif no == errno.WSAEMSGSIZE:
  177. raise error.MessageLengthError("message too long")
  178. elif no in (
  179. errno.WSAECONNREFUSED,
  180. errno.WSAECONNRESET,
  181. ERROR_CONNECTION_REFUSED,
  182. ERROR_PORT_UNREACHABLE,
  183. ):
  184. self.protocol.connectionRefused()
  185. else:
  186. raise
  187. else:
  188. assert addr != None
  189. if (
  190. not isIPAddress(addr[0])
  191. and not isIPv6Address(addr[0])
  192. and addr[0] != "<broadcast>"
  193. ):
  194. raise error.InvalidAddressError(
  195. addr[0], "write() only accepts IP addresses, not hostnames"
  196. )
  197. if isIPAddress(addr[0]) and self.addressFamily == socket.AF_INET6:
  198. raise error.InvalidAddressError(
  199. addr[0], "IPv6 port write() called with IPv4 address"
  200. )
  201. if isIPv6Address(addr[0]) and self.addressFamily == socket.AF_INET:
  202. raise error.InvalidAddressError(
  203. addr[0], "IPv4 port write() called with IPv6 address"
  204. )
  205. try:
  206. return self.socket.sendto(datagram, addr)
  207. except OSError as se:
  208. no = se.args[0]
  209. if no == errno.WSAEINTR:
  210. return self.write(datagram, addr)
  211. elif no == errno.WSAEMSGSIZE:
  212. raise error.MessageLengthError("message too long")
  213. elif no in (
  214. errno.WSAECONNREFUSED,
  215. errno.WSAECONNRESET,
  216. ERROR_CONNECTION_REFUSED,
  217. ERROR_PORT_UNREACHABLE,
  218. ):
  219. # in non-connected UDP ECONNREFUSED is platform dependent,
  220. # I think and the info is not necessarily useful.
  221. # Nevertheless maybe we should call connectionRefused? XXX
  222. return
  223. else:
  224. raise
  225. def writeSequence(self, seq, addr):
  226. self.write(b"".join(seq), addr)
  227. def connect(self, host, port):
  228. """
  229. 'Connect' to remote server.
  230. """
  231. if self._connectedAddr:
  232. raise RuntimeError(
  233. "already connected, reconnecting is not currently supported "
  234. "(talk to itamar if you want this)"
  235. )
  236. if not isIPAddress(host) and not isIPv6Address(host):
  237. raise error.InvalidAddressError(host, "not an IPv4 or IPv6 address.")
  238. self._connectedAddr = (host, port)
  239. self.socket.connect((host, port))
  240. def _loseConnection(self):
  241. self.stopReading()
  242. self.reactor.removeActiveHandle(self)
  243. if self.connected: # actually means if we are *listening*
  244. self.reactor.callLater(0, self.connectionLost)
  245. def stopListening(self):
  246. if self.connected:
  247. result = self.d = defer.Deferred()
  248. else:
  249. result = None
  250. self._loseConnection()
  251. return result
  252. def loseConnection(self):
  253. warnings.warn(
  254. "Please use stopListening() to disconnect port",
  255. DeprecationWarning,
  256. stacklevel=2,
  257. )
  258. self.stopListening()
  259. def connectionLost(self, reason=None):
  260. """
  261. Cleans up my socket.
  262. """
  263. log.msg("(UDP Port %s Closed)" % self._realPortNumber)
  264. self._realPortNumber = None
  265. abstract.FileHandle.connectionLost(self, reason)
  266. self.protocol.doStop()
  267. self.socket.close()
  268. del self.socket
  269. del self.getFileHandle
  270. if hasattr(self, "d"):
  271. self.d.callback(None)
  272. del self.d
  273. def setLogStr(self):
  274. """
  275. Initialize the C{logstr} attribute to be used by C{logPrefix}.
  276. """
  277. logPrefix = self._getLogPrefix(self.protocol)
  278. self.logstr = "%s (UDP)" % logPrefix
  279. def logPrefix(self):
  280. """
  281. Returns the name of my class, to prefix log entries with.
  282. """
  283. return self.logstr
  284. def getHost(self):
  285. """
  286. Return the local address of the UDP connection
  287. @returns: the local address of the UDP connection
  288. @rtype: L{IPv4Address} or L{IPv6Address}
  289. """
  290. addr = self.socket.getsockname()
  291. if self.addressFamily == socket.AF_INET:
  292. return address.IPv4Address("UDP", *addr)
  293. elif self.addressFamily == socket.AF_INET6:
  294. return address.IPv6Address("UDP", *(addr[:2]))
  295. def setBroadcastAllowed(self, enabled):
  296. """
  297. Set whether this port may broadcast. This is disabled by default.
  298. @param enabled: Whether the port may broadcast.
  299. @type enabled: L{bool}
  300. """
  301. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, enabled)
  302. def getBroadcastAllowed(self):
  303. """
  304. Checks if broadcast is currently allowed on this port.
  305. @return: Whether this port may broadcast.
  306. @rtype: L{bool}
  307. """
  308. return bool(self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST))
  309. @implementer(interfaces.IMulticastTransport)
  310. class MulticastPort(MulticastMixin, Port):
  311. """
  312. UDP Port that supports multicasting.
  313. """
  314. def __init__(
  315. self,
  316. port: int,
  317. proto: AbstractDatagramProtocol,
  318. interface: str = "",
  319. maxPacketSize: int = 8192,
  320. reactor: IOCPReactor | None = None,
  321. listenMultiple: bool = False,
  322. ) -> None:
  323. Port.__init__(self, port, proto, interface, maxPacketSize, reactor)
  324. self.listenMultiple = listenMultiple
  325. def createSocket(self) -> socket.socket:
  326. skt = Port.createSocket(self)
  327. if self.listenMultiple:
  328. skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  329. if hasattr(socket, "SO_REUSEPORT"):
  330. skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
  331. return skt