# Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ UDP support for IOCP reactor """ from __future__ import annotations import errno import socket import struct import warnings from typing import TYPE_CHECKING, Optional from zope.interface import implementer from twisted.internet import address, defer, error, interfaces from twisted.internet._multicast import MulticastMixin from twisted.internet.abstract import isIPAddress, isIPv6Address from twisted.internet.iocpreactor import abstract, iocpsupport as _iocp from twisted.internet.iocpreactor.const import ( ERROR_CONNECTION_REFUSED, ERROR_IO_PENDING, ERROR_PORT_UNREACHABLE, ) from twisted.internet.iocpreactor.interfaces import IReadWriteHandle from twisted.internet.protocol import AbstractDatagramProtocol from twisted.python import log if TYPE_CHECKING: from twisted.internet.iocpreactor.reactor import IOCPReactor @implementer( IReadWriteHandle, interfaces.IListeningPort, interfaces.IUDPTransport, interfaces.ISystemHandle, ) class Port(abstract.FileHandle): """ UDP port, listening for packets. @ivar addressFamily: L{socket.AF_INET} or L{socket.AF_INET6}, depending on whether this port is listening on an IPv4 address or an IPv6 address. """ reactor: IOCPReactor addressFamily = socket.AF_INET socketType = socket.SOCK_DGRAM dynamicReadBuffers = False # Actual port number being listened on, only set to a non-None # value when we are actually listening. _realPortNumber: Optional[int] = None def __init__( self, port: int, proto: AbstractDatagramProtocol, interface: str = "", maxPacketSize: int = 8192, reactor: IOCPReactor | None = None, ) -> None: """ Initialize with a numeric port to listen on. """ self.port = port self.protocol = proto self.readBufferSize = maxPacketSize self.interface = interface self.setLogStr() self._connectedAddr = None self._setAddressFamily() abstract.FileHandle.__init__(self, reactor) skt = socket.socket(self.addressFamily, self.socketType) addrLen = _iocp.maxAddrLen(skt.fileno()) self.addressBuffer = bytearray(addrLen) # WSARecvFrom takes an int self.addressLengthBuffer = bytearray(struct.calcsize("i")) def _setAddressFamily(self): """ Resolve address family for the socket. """ if isIPv6Address(self.interface): self.addressFamily = socket.AF_INET6 elif isIPAddress(self.interface): self.addressFamily = socket.AF_INET elif self.interface: raise error.InvalidAddressError( self.interface, "not an IPv4 or IPv6 address" ) def __repr__(self) -> str: if self._realPortNumber is not None: return f"<{self.protocol.__class__} on {self._realPortNumber}>" else: return f"<{self.protocol.__class__} not connected>" def getHandle(self): """ Return a socket object. """ return self.socket def startListening(self): """ Create and bind my socket, and begin listening on it. This is called on unserialization, and must be called after creating a server to begin listening on the specified port. """ self._bindSocket() self._connectToProtocol() def createSocket(self) -> socket.socket: return self.reactor.createSocket(self.addressFamily, self.socketType) def _bindSocket(self): try: skt = self.createSocket() skt.bind((self.interface, self.port)) except OSError as le: raise error.CannotListenError(self.interface, self.port, le) # Make sure that if we listened on port 0, we update that to # reflect what the OS actually assigned us. self._realPortNumber = skt.getsockname()[1] log.msg( "%s starting on %s" % (self._getLogPrefix(self.protocol), self._realPortNumber) ) self.connected = True self.socket = skt self.getFileHandle = self.socket.fileno def _connectToProtocol(self): self.protocol.makeConnection(self) self.startReading() self.reactor.addActiveHandle(self) def cbRead(self, rc, data, evt): if self.reading: self.handleRead(rc, data, evt) self.doRead() def handleRead(self, rc, data, evt): if rc in ( errno.WSAECONNREFUSED, errno.WSAECONNRESET, ERROR_CONNECTION_REFUSED, ERROR_PORT_UNREACHABLE, ): if self._connectedAddr: self.protocol.connectionRefused() elif rc: log.msg( "error in recvfrom -- %s (%s)" % (errno.errorcode.get(rc, "unknown error"), rc) ) else: try: self.protocol.datagramReceived( bytes(evt.buff[:data]), _iocp.makesockaddr(evt.addr_buff) ) except BaseException: log.err() def doRead(self): evt = _iocp.Event(self.cbRead, self) evt.buff = buff = self._readBuffers[0] evt.addr_buff = addr_buff = self.addressBuffer evt.addr_len_buff = addr_len_buff = self.addressLengthBuffer rc, data = _iocp.recvfrom( self.getFileHandle(), buff, addr_buff, addr_len_buff, evt ) if rc and rc != ERROR_IO_PENDING: # If the error was not 0 or IO_PENDING then that means recvfrom() hit a # failure condition. In this situation recvfrom() gives us our response # right away and we don't need to wait for Windows to call the callback # on our event. In fact, windows will not call it for us so we must call it # ourselves manually self.reactor.callLater(0, self.cbRead, rc, data, evt) def write(self, datagram, addr=None): """ Write a datagram. @param addr: should be a tuple (ip, port), can be None in connected mode. """ if self._connectedAddr: assert addr in (None, self._connectedAddr) try: return self.socket.send(datagram) except OSError as se: no = se.args[0] if no == errno.WSAEINTR: return self.write(datagram) elif no == errno.WSAEMSGSIZE: raise error.MessageLengthError("message too long") elif no in ( errno.WSAECONNREFUSED, errno.WSAECONNRESET, ERROR_CONNECTION_REFUSED, ERROR_PORT_UNREACHABLE, ): self.protocol.connectionRefused() else: raise else: assert addr != None if ( not isIPAddress(addr[0]) and not isIPv6Address(addr[0]) and addr[0] != "" ): raise error.InvalidAddressError( addr[0], "write() only accepts IP addresses, not hostnames" ) if isIPAddress(addr[0]) and self.addressFamily == socket.AF_INET6: raise error.InvalidAddressError( addr[0], "IPv6 port write() called with IPv4 address" ) if isIPv6Address(addr[0]) and self.addressFamily == socket.AF_INET: raise error.InvalidAddressError( addr[0], "IPv4 port write() called with IPv6 address" ) try: return self.socket.sendto(datagram, addr) except OSError as se: no = se.args[0] if no == errno.WSAEINTR: return self.write(datagram, addr) elif no == errno.WSAEMSGSIZE: raise error.MessageLengthError("message too long") elif no in ( errno.WSAECONNREFUSED, errno.WSAECONNRESET, ERROR_CONNECTION_REFUSED, ERROR_PORT_UNREACHABLE, ): # in non-connected UDP ECONNREFUSED is platform dependent, # I think and the info is not necessarily useful. # Nevertheless maybe we should call connectionRefused? XXX return else: raise def writeSequence(self, seq, addr): self.write(b"".join(seq), addr) def connect(self, host, port): """ 'Connect' to remote server. """ if self._connectedAddr: raise RuntimeError( "already connected, reconnecting is not currently supported " "(talk to itamar if you want this)" ) if not isIPAddress(host) and not isIPv6Address(host): raise error.InvalidAddressError(host, "not an IPv4 or IPv6 address.") self._connectedAddr = (host, port) self.socket.connect((host, port)) def _loseConnection(self): self.stopReading() self.reactor.removeActiveHandle(self) if self.connected: # actually means if we are *listening* self.reactor.callLater(0, self.connectionLost) def stopListening(self): if self.connected: result = self.d = defer.Deferred() else: result = None self._loseConnection() return result def loseConnection(self): warnings.warn( "Please use stopListening() to disconnect port", DeprecationWarning, stacklevel=2, ) self.stopListening() def connectionLost(self, reason=None): """ Cleans up my socket. """ log.msg("(UDP Port %s Closed)" % self._realPortNumber) self._realPortNumber = None abstract.FileHandle.connectionLost(self, reason) self.protocol.doStop() self.socket.close() del self.socket del self.getFileHandle if hasattr(self, "d"): self.d.callback(None) del self.d def setLogStr(self): """ Initialize the C{logstr} attribute to be used by C{logPrefix}. """ logPrefix = self._getLogPrefix(self.protocol) self.logstr = "%s (UDP)" % logPrefix def logPrefix(self): """ Returns the name of my class, to prefix log entries with. """ return self.logstr def getHost(self): """ Return the local address of the UDP connection @returns: the local address of the UDP connection @rtype: L{IPv4Address} or L{IPv6Address} """ addr = self.socket.getsockname() if self.addressFamily == socket.AF_INET: return address.IPv4Address("UDP", *addr) elif self.addressFamily == socket.AF_INET6: return address.IPv6Address("UDP", *(addr[:2])) def setBroadcastAllowed(self, enabled): """ Set whether this port may broadcast. This is disabled by default. @param enabled: Whether the port may broadcast. @type enabled: L{bool} """ self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, enabled) def getBroadcastAllowed(self): """ Checks if broadcast is currently allowed on this port. @return: Whether this port may broadcast. @rtype: L{bool} """ return bool(self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST)) @implementer(interfaces.IMulticastTransport) class MulticastPort(MulticastMixin, Port): """ UDP Port that supports multicasting. """ def __init__( self, port: int, proto: AbstractDatagramProtocol, interface: str = "", maxPacketSize: int = 8192, reactor: IOCPReactor | None = None, listenMultiple: bool = False, ) -> None: Port.__init__(self, port, proto, interface, maxPacketSize, reactor) self.listenMultiple = listenMultiple def createSocket(self) -> socket.socket: skt = Port.createSocket(self) if self.listenMultiple: skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, "SO_REUSEPORT"): skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) return skt