# Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ TCP support for IOCP reactor """ from __future__ import annotations import errno import socket import struct from typing import TYPE_CHECKING, Optional, Union from zope.interface import classImplements, implementer from twisted.internet import address, defer, error, interfaces, main from twisted.internet.abstract import _LogOwner, isIPv6Address from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.interfaces import IProtocol from twisted.internet.iocpreactor import abstract, iocpsupport as _iocp from twisted.internet.iocpreactor.const import ( ERROR_CONNECTION_REFUSED, ERROR_IO_PENDING, ERROR_NETWORK_UNREACHABLE, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, ) from twisted.internet.iocpreactor.interfaces import IReadWriteHandle from twisted.internet.protocol import Protocol from twisted.internet.tcp import ( Connector as TCPConnector, _AbortingMixin, _BaseBaseClient, _BaseTCPClient, _getsockname, _resolveIPv6, _SocketCloser, ) from twisted.python import failure, log, reflect try: from twisted.internet._newtls import startTLS as __startTLS except ImportError: _startTLS = None else: _startTLS = __startTLS if TYPE_CHECKING: # Circular import only to describe a type. from twisted.internet.iocpreactor.reactor import IOCPReactor # ConnectEx returns these. XXX: find out what it does for timeout connectExErrors = { ERROR_CONNECTION_REFUSED: errno.WSAECONNREFUSED, # type: ignore[attr-defined] ERROR_NETWORK_UNREACHABLE: errno.WSAENETUNREACH, # type: ignore[attr-defined] } @implementer(IReadWriteHandle, interfaces.ITCPTransport, interfaces.ISystemHandle) class Connection(abstract.FileHandle, _SocketCloser, _AbortingMixin): """ @ivar TLS: C{False} to indicate the connection is in normal TCP mode, C{True} to indicate that TLS has been started and that operations must be routed through the L{TLSMemoryBIOProtocol} instance. """ TLS = False def __init__(self, sock, proto, reactor=None): abstract.FileHandle.__init__(self, reactor) self.socket = sock self.getFileHandle = sock.fileno self.protocol = proto def getHandle(self): return self.socket def dataReceived(self, rbuffer): """ @param rbuffer: Data received. @type rbuffer: L{bytes} or L{bytearray} """ if isinstance(rbuffer, bytes): pass elif isinstance(rbuffer, bytearray): # XXX: some day, we'll have protocols that can handle raw buffers rbuffer = bytes(rbuffer) else: raise TypeError("data must be bytes or bytearray, not " + type(rbuffer)) self.protocol.dataReceived(rbuffer) def readFromHandle(self, bufflist, evt): return _iocp.recv(self.getFileHandle(), bufflist, evt) def writeToHandle(self, buff, evt): """ Send C{buff} to current file handle using C{_iocp.send}. The buffer sent is limited to a size of C{self.SEND_LIMIT}. """ writeView = memoryview(buff) return _iocp.send( self.getFileHandle(), writeView[0 : self.SEND_LIMIT].tobytes(), evt ) def _closeWriteConnection(self): try: self.socket.shutdown(1) except OSError: pass p = interfaces.IHalfCloseableProtocol(self.protocol, None) if p: try: p.writeConnectionLost() except BaseException: f = failure.Failure() log.err() self.connectionLost(f) def readConnectionLost(self, reason): p = interfaces.IHalfCloseableProtocol(self.protocol, None) if p: try: p.readConnectionLost() except BaseException: log.err() self.connectionLost(failure.Failure()) else: self.connectionLost(reason) def connectionLost(self, reason): if self.disconnected: return abstract.FileHandle.connectionLost(self, reason) isClean = reason is None or not reason.check(error.ConnectionAborted) self._closeSocket(isClean) protocol = self.protocol del self.protocol del self.socket del self.getFileHandle protocol.connectionLost(reason) def logPrefix(self): """ Return the prefix to log with when I own the logging thread. """ return self.logstr def getTcpNoDelay(self): return bool(self.socket.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)) def setTcpNoDelay(self, enabled): self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled) def getTcpKeepAlive(self): return bool(self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)) def setTcpKeepAlive(self, enabled): self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, enabled) if _startTLS is not None: def startTLS(self, contextFactory, normal=True): """ @see: L{ITLSTransport.startTLS} """ _startTLS(self, contextFactory, normal, abstract.FileHandle) def write(self, data): """ Write some data, either directly to the underlying handle or, if TLS has been started, to the L{TLSMemoryBIOProtocol} for it to encrypt and send. @see: L{twisted.internet.interfaces.ITransport.write} """ if self.disconnected: return if self.TLS: self.protocol.write(data) else: abstract.FileHandle.write(self, data) def writeSequence(self, iovec): """ Write some data, either directly to the underlying handle or, if TLS has been started, to the L{TLSMemoryBIOProtocol} for it to encrypt and send. @see: L{twisted.internet.interfaces.ITransport.writeSequence} """ if self.disconnected: return if self.TLS: self.protocol.writeSequence(iovec) else: abstract.FileHandle.writeSequence(self, iovec) def loseConnection(self, reason=None): """ Close the underlying handle or, if TLS has been started, first shut it down. @see: L{twisted.internet.interfaces.ITransport.loseConnection} """ if self.TLS: if self.connected and not self.disconnecting: self.protocol.loseConnection() else: abstract.FileHandle.loseConnection(self, reason) def registerProducer(self, producer, streaming): """ Register a producer. If TLS is enabled, the TLS connection handles this. """ if self.TLS: # Registering a producer before we're connected shouldn't be a # problem. If we end up with a write(), that's already handled in # the write() code above, and there are no other potential # side-effects. self.protocol.registerProducer(producer, streaming) else: abstract.FileHandle.registerProducer(self, producer, streaming) def unregisterProducer(self): """ Unregister a producer. If TLS is enabled, the TLS connection handles this. """ if self.TLS: self.protocol.unregisterProducer() else: abstract.FileHandle.unregisterProducer(self) def getHost(self): # ITCPTransport.getHost pass def getPeer(self): # ITCPTransport.getPeer pass if _startTLS is not None: classImplements(Connection, interfaces.ITLSTransport) class Client(_BaseBaseClient, _BaseTCPClient, Connection): """ @ivar _tlsClientDefault: Always C{True}, indicating that this is a client connection, and by default when TLS is negotiated this class will act as a TLS client. """ addressFamily = socket.AF_INET socketType = socket.SOCK_STREAM _tlsClientDefault = True _commonConnection = Connection def __init__(self, host, port, bindAddress, connector, reactor): # ConnectEx documentation says socket _has_ to be bound if bindAddress is None: bindAddress = ("", 0) self.reactor = reactor # createInternetSocket needs this _BaseTCPClient.__init__(self, host, port, bindAddress, connector, reactor) def createInternetSocket(self): """ Create a socket registered with the IOCP reactor. @see: L{_BaseTCPClient} """ return self.reactor.createSocket(self.addressFamily, self.socketType) def _collectSocketDetails(self): """ Clean up potentially circular references to the socket and to its C{getFileHandle} method. @see: L{_BaseBaseClient} """ del self.socket, self.getFileHandle def _stopReadingAndWriting(self): """ Remove the active handle from the reactor. @see: L{_BaseBaseClient} """ self.reactor.removeActiveHandle(self) def cbConnect(self, rc, data, evt): if rc: rc = connectExErrors.get(rc, rc) self.failIfNotConnected( error.getConnectError((rc, errno.errorcode.get(rc, "Unknown error"))) ) else: self.socket.setsockopt( socket.SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, struct.pack("P", self.socket.fileno()), ) self.protocol = self.connector.buildProtocol(self.getPeer()) self.connected = True logPrefix = self._getLogPrefix(self.protocol) self.logstr = logPrefix + ",client" if self.protocol is None: # Factory.buildProtocol is allowed to return None. In that # case, make up a protocol to satisfy the rest of the # implementation; connectionLost is going to be called on # something, for example. This is easier than adding special # case support for a None protocol throughout the rest of the # transport implementation. self.protocol = Protocol() # But dispose of the connection quickly. self.loseConnection() else: self.protocol.makeConnection(self) self.startReading() def doConnect(self): if not hasattr(self, "connector"): # this happens if we connector.stopConnecting in # factory.startedConnecting return assert _iocp.have_connectex self.reactor.addActiveHandle(self) evt = _iocp.Event(self.cbConnect, self) rc = _iocp.connect(self.socket.fileno(), self.realAddress, evt) if rc and rc != ERROR_IO_PENDING: self.cbConnect(rc, 0, evt) class Server(Connection): """ Serverside socket-stream connection class. I am a serverside network connection transport; a socket which came from an accept() on a server. @ivar _tlsClientDefault: Always C{False}, indicating that this is a server connection, and by default when TLS is negotiated this class will act as a TLS server. """ _tlsClientDefault = False def __init__( self, sock: socket.socket, protocol: IProtocol, clientAddr: Union[IPv4Address, IPv6Address], serverAddr: Union[IPv4Address, IPv6Address], sessionno: int, reactor: IOCPReactor, ): """ Server(sock, protocol, client, server, sessionno) Initialize me with a socket, a protocol, a descriptor for my peer (a tuple of host, port describing the other end of the connection), an instance of Port, and a session number. """ Connection.__init__(self, sock, protocol, reactor) self.serverAddr = serverAddr self.clientAddr = clientAddr self.sessionno = sessionno logPrefix = self._getLogPrefix(self.protocol) self.logstr = f"{logPrefix},{sessionno},{self.clientAddr.host}" self.repstr: str = "<{} #{} on {}>".format( self.protocol.__class__.__name__, self.sessionno, self.serverAddr.port, ) self.connected = True self.startReading() def __repr__(self) -> str: """ A string representation of this connection. """ return self.repstr def getHost(self): """ Returns an IPv4Address. This indicates the server's address. """ return self.serverAddr def getPeer(self): """ Returns an IPv4Address. This indicates the client's address. """ return self.clientAddr class Connector(TCPConnector): def _makeTransport(self): return Client(self.host, self.port, self.bindAddress, self, self.reactor) @implementer(interfaces.IListeningPort) class Port(_SocketCloser, _LogOwner): connected = False disconnected = False disconnecting = False addressFamily = socket.AF_INET socketType = socket.SOCK_STREAM _addressType = address.IPv4Address sessionno = 0 # Actual port number being listened on, only set to a non-None # value when we are actually listening. _realPortNumber: Optional[int] = None # A string describing the connections which will be created by this port. # Normally this is C{"TCP"}, since this is a TCP port, but when the TLS # implementation re-uses this class it overrides the value with C{"TLS"}. # Only used for logging. _type = "TCP" def __init__(self, port, factory, backlog=50, interface="", reactor=None): self.port = port self.factory = factory self.backlog = backlog self.interface = interface self.reactor = reactor if isIPv6Address(interface): self.addressFamily = socket.AF_INET6 self._addressType = address.IPv6Address def __repr__(self) -> str: if self._realPortNumber is not None: return "<{} of {} on {}>".format( self.__class__, self.factory.__class__, self._realPortNumber, ) else: return "<{} of {} (not listening)>".format( self.__class__, self.factory.__class__, ) def startListening(self): try: skt = self.reactor.createSocket(self.addressFamily, self.socketType) # TODO: resolve self.interface if necessary if self.addressFamily == socket.AF_INET6: addr = _resolveIPv6(self.interface, self.port) else: addr = (self.interface, self.port) skt.bind(addr) except OSError as le: raise error.CannotListenError(self.interface, self.port, le) self.addrLen = _iocp.maxAddrLen(skt.fileno()) # Make sure that if we listened on port 0, we update that to # reflect what the OS actually assigned us. self._realPortNumber = skt.getsockname()[1] log.msg( "%s starting on %s" % (self._getLogPrefix(self.factory), self._realPortNumber) ) self.factory.doStart() skt.listen(self.backlog) self.connected = True self.disconnected = False self.reactor.addActiveHandle(self) self.socket = skt self.getFileHandle = self.socket.fileno self.doAccept() def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)): """ Stop accepting connections on this port. This will shut down my socket and call self.connectionLost(). It returns a deferred which will fire successfully when the port is actually closed. """ self.disconnecting = True if self.connected: self.deferred = defer.Deferred() self.reactor.callLater(0, self.connectionLost, connDone) return self.deferred stopListening = loseConnection def _logConnectionLostMsg(self): """ Log message for closing port """ log.msg(f"({self._type} Port {self._realPortNumber} Closed)") def connectionLost(self, reason): """ Cleans up the socket. """ self._logConnectionLostMsg() self._realPortNumber = None d = None if hasattr(self, "deferred"): d = self.deferred del self.deferred self.disconnected = True self.reactor.removeActiveHandle(self) self.connected = False self._closeSocket(True) del self.socket del self.getFileHandle try: self.factory.doStop() except BaseException: self.disconnecting = False if d is not None: d.errback(failure.Failure()) else: raise else: self.disconnecting = False if d is not None: d.callback(None) def logPrefix(self): """ Returns the name of my class, to prefix log entries with. """ return reflect.qual(self.factory.__class__) def getHost(self): """ Returns an IPv4Address or IPv6Address. This indicates the server's address. """ return self._addressType("TCP", *_getsockname(self.socket)) def cbAccept(self, rc, data, evt): self.handleAccept(rc, evt) if not (self.disconnecting or self.disconnected): self.doAccept() def handleAccept(self, rc, evt): if self.disconnecting or self.disconnected: return False # possible errors: # (WSAEMFILE, WSAENOBUFS, WSAENFILE, WSAENOMEM, WSAECONNABORTED) if rc: log.msg( "Could not accept new connection -- %s (%s)" % (errno.errorcode.get(rc, "unknown error"), rc) ) return False else: # Inherit the properties from the listening port socket as # documented in the `Remarks` section of AcceptEx. # https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex # In this way we can call getsockname and getpeername on the # accepted socket. evt.newskt.setsockopt( socket.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, struct.pack("P", self.socket.fileno()), ) family, lAddr, rAddr = _iocp.get_accept_addrs(evt.newskt.fileno(), evt.buff) assert family == self.addressFamily # Build an IPv6 address that includes the scopeID, if necessary if "%" in lAddr[0]: scope = int(lAddr[0].split("%")[1]) lAddr = (lAddr[0], lAddr[1], 0, scope) if "%" in rAddr[0]: scope = int(rAddr[0].split("%")[1]) rAddr = (rAddr[0], rAddr[1], 0, scope) protocol = self.factory.buildProtocol(self._addressType("TCP", *rAddr)) if protocol is None: evt.newskt.close() else: s = self.sessionno self.sessionno = s + 1 transport = Server( evt.newskt, protocol, self._addressType("TCP", *rAddr), self._addressType("TCP", *lAddr), s, self.reactor, ) protocol.makeConnection(transport) return True def doAccept(self): evt = _iocp.Event(self.cbAccept, self) # see AcceptEx documentation evt.buff = buff = bytearray(2 * (self.addrLen + 16)) evt.newskt = newskt = self.reactor.createSocket( self.addressFamily, self.socketType ) rc = _iocp.accept(self.socket.fileno(), newskt.fileno(), buff, evt) if rc and rc != ERROR_IO_PENDING: self.handleAccept(rc, evt)