123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623 |
- # Copyright (c) Twisted Matrix Laboratories.
- # See LICENSE for details.
- """
- TCP support for IOCP reactor
- """
- from __future__ import annotations
- import errno
- import socket
- import struct
- from typing import TYPE_CHECKING, Optional, Union
- from zope.interface import classImplements, implementer
- from twisted.internet import address, defer, error, interfaces, main
- from twisted.internet.abstract import _LogOwner, isIPv6Address
- from twisted.internet.address import IPv4Address, IPv6Address
- from twisted.internet.interfaces import IProtocol
- from twisted.internet.iocpreactor import abstract, iocpsupport as _iocp
- from twisted.internet.iocpreactor.const import (
- ERROR_CONNECTION_REFUSED,
- ERROR_IO_PENDING,
- ERROR_NETWORK_UNREACHABLE,
- SO_UPDATE_ACCEPT_CONTEXT,
- SO_UPDATE_CONNECT_CONTEXT,
- )
- from twisted.internet.iocpreactor.interfaces import IReadWriteHandle
- from twisted.internet.protocol import Protocol
- from twisted.internet.tcp import (
- Connector as TCPConnector,
- _AbortingMixin,
- _BaseBaseClient,
- _BaseTCPClient,
- _getsockname,
- _resolveIPv6,
- _SocketCloser,
- )
- from twisted.python import failure, log, reflect
- try:
- from twisted.internet._newtls import startTLS as __startTLS
- except ImportError:
- _startTLS = None
- else:
- _startTLS = __startTLS
- if TYPE_CHECKING:
- # Circular import only to describe a type.
- from twisted.internet.iocpreactor.reactor import IOCPReactor
- # ConnectEx returns these. XXX: find out what it does for timeout
- connectExErrors = {
- ERROR_CONNECTION_REFUSED: errno.WSAECONNREFUSED, # type: ignore[attr-defined]
- ERROR_NETWORK_UNREACHABLE: errno.WSAENETUNREACH, # type: ignore[attr-defined]
- }
- @implementer(IReadWriteHandle, interfaces.ITCPTransport, interfaces.ISystemHandle)
- class Connection(abstract.FileHandle, _SocketCloser, _AbortingMixin):
- """
- @ivar TLS: C{False} to indicate the connection is in normal TCP mode,
- C{True} to indicate that TLS has been started and that operations must
- be routed through the L{TLSMemoryBIOProtocol} instance.
- """
- TLS = False
- def __init__(self, sock, proto, reactor=None):
- abstract.FileHandle.__init__(self, reactor)
- self.socket = sock
- self.getFileHandle = sock.fileno
- self.protocol = proto
- def getHandle(self):
- return self.socket
- def dataReceived(self, rbuffer):
- """
- @param rbuffer: Data received.
- @type rbuffer: L{bytes} or L{bytearray}
- """
- if isinstance(rbuffer, bytes):
- pass
- elif isinstance(rbuffer, bytearray):
- # XXX: some day, we'll have protocols that can handle raw buffers
- rbuffer = bytes(rbuffer)
- else:
- raise TypeError("data must be bytes or bytearray, not " + type(rbuffer))
- self.protocol.dataReceived(rbuffer)
- def readFromHandle(self, bufflist, evt):
- return _iocp.recv(self.getFileHandle(), bufflist, evt)
- def writeToHandle(self, buff, evt):
- """
- Send C{buff} to current file handle using C{_iocp.send}. The buffer
- sent is limited to a size of C{self.SEND_LIMIT}.
- """
- writeView = memoryview(buff)
- return _iocp.send(
- self.getFileHandle(), writeView[0 : self.SEND_LIMIT].tobytes(), evt
- )
- def _closeWriteConnection(self):
- try:
- self.socket.shutdown(1)
- except OSError:
- pass
- p = interfaces.IHalfCloseableProtocol(self.protocol, None)
- if p:
- try:
- p.writeConnectionLost()
- except BaseException:
- f = failure.Failure()
- log.err()
- self.connectionLost(f)
- def readConnectionLost(self, reason):
- p = interfaces.IHalfCloseableProtocol(self.protocol, None)
- if p:
- try:
- p.readConnectionLost()
- except BaseException:
- log.err()
- self.connectionLost(failure.Failure())
- else:
- self.connectionLost(reason)
- def connectionLost(self, reason):
- if self.disconnected:
- return
- abstract.FileHandle.connectionLost(self, reason)
- isClean = reason is None or not reason.check(error.ConnectionAborted)
- self._closeSocket(isClean)
- protocol = self.protocol
- del self.protocol
- del self.socket
- del self.getFileHandle
- protocol.connectionLost(reason)
- def logPrefix(self):
- """
- Return the prefix to log with when I own the logging thread.
- """
- return self.logstr
- def getTcpNoDelay(self):
- return bool(self.socket.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY))
- def setTcpNoDelay(self, enabled):
- self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
- def getTcpKeepAlive(self):
- return bool(self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE))
- def setTcpKeepAlive(self, enabled):
- self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, enabled)
- if _startTLS is not None:
- def startTLS(self, contextFactory, normal=True):
- """
- @see: L{ITLSTransport.startTLS}
- """
- _startTLS(self, contextFactory, normal, abstract.FileHandle)
- def write(self, data):
- """
- Write some data, either directly to the underlying handle or, if TLS
- has been started, to the L{TLSMemoryBIOProtocol} for it to encrypt and
- send.
- @see: L{twisted.internet.interfaces.ITransport.write}
- """
- if self.disconnected:
- return
- if self.TLS:
- self.protocol.write(data)
- else:
- abstract.FileHandle.write(self, data)
- def writeSequence(self, iovec):
- """
- Write some data, either directly to the underlying handle or, if TLS
- has been started, to the L{TLSMemoryBIOProtocol} for it to encrypt and
- send.
- @see: L{twisted.internet.interfaces.ITransport.writeSequence}
- """
- if self.disconnected:
- return
- if self.TLS:
- self.protocol.writeSequence(iovec)
- else:
- abstract.FileHandle.writeSequence(self, iovec)
- def loseConnection(self, reason=None):
- """
- Close the underlying handle or, if TLS has been started, first shut it
- down.
- @see: L{twisted.internet.interfaces.ITransport.loseConnection}
- """
- if self.TLS:
- if self.connected and not self.disconnecting:
- self.protocol.loseConnection()
- else:
- abstract.FileHandle.loseConnection(self, reason)
- def registerProducer(self, producer, streaming):
- """
- Register a producer.
- If TLS is enabled, the TLS connection handles this.
- """
- if self.TLS:
- # Registering a producer before we're connected shouldn't be a
- # problem. If we end up with a write(), that's already handled in
- # the write() code above, and there are no other potential
- # side-effects.
- self.protocol.registerProducer(producer, streaming)
- else:
- abstract.FileHandle.registerProducer(self, producer, streaming)
- def unregisterProducer(self):
- """
- Unregister a producer.
- If TLS is enabled, the TLS connection handles this.
- """
- if self.TLS:
- self.protocol.unregisterProducer()
- else:
- abstract.FileHandle.unregisterProducer(self)
- def getHost(self):
- # ITCPTransport.getHost
- pass
- def getPeer(self):
- # ITCPTransport.getPeer
- pass
- if _startTLS is not None:
- classImplements(Connection, interfaces.ITLSTransport)
- class Client(_BaseBaseClient, _BaseTCPClient, Connection):
- """
- @ivar _tlsClientDefault: Always C{True}, indicating that this is a client
- connection, and by default when TLS is negotiated this class will act as
- a TLS client.
- """
- addressFamily = socket.AF_INET
- socketType = socket.SOCK_STREAM
- _tlsClientDefault = True
- _commonConnection = Connection
- def __init__(self, host, port, bindAddress, connector, reactor):
- # ConnectEx documentation says socket _has_ to be bound
- if bindAddress is None:
- bindAddress = ("", 0)
- self.reactor = reactor # createInternetSocket needs this
- _BaseTCPClient.__init__(self, host, port, bindAddress, connector, reactor)
- def createInternetSocket(self):
- """
- Create a socket registered with the IOCP reactor.
- @see: L{_BaseTCPClient}
- """
- return self.reactor.createSocket(self.addressFamily, self.socketType)
- def _collectSocketDetails(self):
- """
- Clean up potentially circular references to the socket and to its
- C{getFileHandle} method.
- @see: L{_BaseBaseClient}
- """
- del self.socket, self.getFileHandle
- def _stopReadingAndWriting(self):
- """
- Remove the active handle from the reactor.
- @see: L{_BaseBaseClient}
- """
- self.reactor.removeActiveHandle(self)
- def cbConnect(self, rc, data, evt):
- if rc:
- rc = connectExErrors.get(rc, rc)
- self.failIfNotConnected(
- error.getConnectError((rc, errno.errorcode.get(rc, "Unknown error")))
- )
- else:
- self.socket.setsockopt(
- socket.SOL_SOCKET,
- SO_UPDATE_CONNECT_CONTEXT,
- struct.pack("P", self.socket.fileno()),
- )
- self.protocol = self.connector.buildProtocol(self.getPeer())
- self.connected = True
- logPrefix = self._getLogPrefix(self.protocol)
- self.logstr = logPrefix + ",client"
- if self.protocol is None:
- # Factory.buildProtocol is allowed to return None. In that
- # case, make up a protocol to satisfy the rest of the
- # implementation; connectionLost is going to be called on
- # something, for example. This is easier than adding special
- # case support for a None protocol throughout the rest of the
- # transport implementation.
- self.protocol = Protocol()
- # But dispose of the connection quickly.
- self.loseConnection()
- else:
- self.protocol.makeConnection(self)
- self.startReading()
- def doConnect(self):
- if not hasattr(self, "connector"):
- # this happens if we connector.stopConnecting in
- # factory.startedConnecting
- return
- assert _iocp.have_connectex
- self.reactor.addActiveHandle(self)
- evt = _iocp.Event(self.cbConnect, self)
- rc = _iocp.connect(self.socket.fileno(), self.realAddress, evt)
- if rc and rc != ERROR_IO_PENDING:
- self.cbConnect(rc, 0, evt)
- class Server(Connection):
- """
- Serverside socket-stream connection class.
- I am a serverside network connection transport; a socket which came from an
- accept() on a server.
- @ivar _tlsClientDefault: Always C{False}, indicating that this is a server
- connection, and by default when TLS is negotiated this class will act as
- a TLS server.
- """
- _tlsClientDefault = False
- def __init__(
- self,
- sock: socket.socket,
- protocol: IProtocol,
- clientAddr: Union[IPv4Address, IPv6Address],
- serverAddr: Union[IPv4Address, IPv6Address],
- sessionno: int,
- reactor: IOCPReactor,
- ):
- """
- Server(sock, protocol, client, server, sessionno)
- Initialize me with a socket, a protocol, a descriptor for my peer (a
- tuple of host, port describing the other end of the connection), an
- instance of Port, and a session number.
- """
- Connection.__init__(self, sock, protocol, reactor)
- self.serverAddr = serverAddr
- self.clientAddr = clientAddr
- self.sessionno = sessionno
- logPrefix = self._getLogPrefix(self.protocol)
- self.logstr = f"{logPrefix},{sessionno},{self.clientAddr.host}"
- self.repstr: str = "<{} #{} on {}>".format(
- self.protocol.__class__.__name__,
- self.sessionno,
- self.serverAddr.port,
- )
- self.connected = True
- self.startReading()
- def __repr__(self) -> str:
- """
- A string representation of this connection.
- """
- return self.repstr
- def getHost(self):
- """
- Returns an IPv4Address.
- This indicates the server's address.
- """
- return self.serverAddr
- def getPeer(self):
- """
- Returns an IPv4Address.
- This indicates the client's address.
- """
- return self.clientAddr
- class Connector(TCPConnector):
- def _makeTransport(self):
- return Client(self.host, self.port, self.bindAddress, self, self.reactor)
- @implementer(interfaces.IListeningPort)
- class Port(_SocketCloser, _LogOwner):
- connected = False
- disconnected = False
- disconnecting = False
- addressFamily = socket.AF_INET
- socketType = socket.SOCK_STREAM
- _addressType = address.IPv4Address
- sessionno = 0
- # Actual port number being listened on, only set to a non-None
- # value when we are actually listening.
- _realPortNumber: Optional[int] = None
- # A string describing the connections which will be created by this port.
- # Normally this is C{"TCP"}, since this is a TCP port, but when the TLS
- # implementation re-uses this class it overrides the value with C{"TLS"}.
- # Only used for logging.
- _type = "TCP"
- def __init__(self, port, factory, backlog=50, interface="", reactor=None):
- self.port = port
- self.factory = factory
- self.backlog = backlog
- self.interface = interface
- self.reactor = reactor
- if isIPv6Address(interface):
- self.addressFamily = socket.AF_INET6
- self._addressType = address.IPv6Address
- def __repr__(self) -> str:
- if self._realPortNumber is not None:
- return "<{} of {} on {}>".format(
- self.__class__,
- self.factory.__class__,
- self._realPortNumber,
- )
- else:
- return "<{} of {} (not listening)>".format(
- self.__class__,
- self.factory.__class__,
- )
- def startListening(self):
- try:
- skt = self.reactor.createSocket(self.addressFamily, self.socketType)
- # TODO: resolve self.interface if necessary
- if self.addressFamily == socket.AF_INET6:
- addr = _resolveIPv6(self.interface, self.port)
- else:
- addr = (self.interface, self.port)
- skt.bind(addr)
- except OSError as le:
- raise error.CannotListenError(self.interface, self.port, le)
- self.addrLen = _iocp.maxAddrLen(skt.fileno())
- # Make sure that if we listened on port 0, we update that to
- # reflect what the OS actually assigned us.
- self._realPortNumber = skt.getsockname()[1]
- log.msg(
- "%s starting on %s"
- % (self._getLogPrefix(self.factory), self._realPortNumber)
- )
- self.factory.doStart()
- skt.listen(self.backlog)
- self.connected = True
- self.disconnected = False
- self.reactor.addActiveHandle(self)
- self.socket = skt
- self.getFileHandle = self.socket.fileno
- self.doAccept()
- def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)):
- """
- Stop accepting connections on this port.
- This will shut down my socket and call self.connectionLost().
- It returns a deferred which will fire successfully when the
- port is actually closed.
- """
- self.disconnecting = True
- if self.connected:
- self.deferred = defer.Deferred()
- self.reactor.callLater(0, self.connectionLost, connDone)
- return self.deferred
- stopListening = loseConnection
- def _logConnectionLostMsg(self):
- """
- Log message for closing port
- """
- log.msg(f"({self._type} Port {self._realPortNumber} Closed)")
- def connectionLost(self, reason):
- """
- Cleans up the socket.
- """
- self._logConnectionLostMsg()
- self._realPortNumber = None
- d = None
- if hasattr(self, "deferred"):
- d = self.deferred
- del self.deferred
- self.disconnected = True
- self.reactor.removeActiveHandle(self)
- self.connected = False
- self._closeSocket(True)
- del self.socket
- del self.getFileHandle
- try:
- self.factory.doStop()
- except BaseException:
- self.disconnecting = False
- if d is not None:
- d.errback(failure.Failure())
- else:
- raise
- else:
- self.disconnecting = False
- if d is not None:
- d.callback(None)
- def logPrefix(self):
- """
- Returns the name of my class, to prefix log entries with.
- """
- return reflect.qual(self.factory.__class__)
- def getHost(self):
- """
- Returns an IPv4Address or IPv6Address.
- This indicates the server's address.
- """
- return self._addressType("TCP", *_getsockname(self.socket))
- def cbAccept(self, rc, data, evt):
- self.handleAccept(rc, evt)
- if not (self.disconnecting or self.disconnected):
- self.doAccept()
- def handleAccept(self, rc, evt):
- if self.disconnecting or self.disconnected:
- return False
- # possible errors:
- # (WSAEMFILE, WSAENOBUFS, WSAENFILE, WSAENOMEM, WSAECONNABORTED)
- if rc:
- log.msg(
- "Could not accept new connection -- %s (%s)"
- % (errno.errorcode.get(rc, "unknown error"), rc)
- )
- return False
- else:
- # Inherit the properties from the listening port socket as
- # documented in the `Remarks` section of AcceptEx.
- # https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
- # In this way we can call getsockname and getpeername on the
- # accepted socket.
- evt.newskt.setsockopt(
- socket.SOL_SOCKET,
- SO_UPDATE_ACCEPT_CONTEXT,
- struct.pack("P", self.socket.fileno()),
- )
- family, lAddr, rAddr = _iocp.get_accept_addrs(evt.newskt.fileno(), evt.buff)
- assert family == self.addressFamily
- # Build an IPv6 address that includes the scopeID, if necessary
- if "%" in lAddr[0]:
- scope = int(lAddr[0].split("%")[1])
- lAddr = (lAddr[0], lAddr[1], 0, scope)
- if "%" in rAddr[0]:
- scope = int(rAddr[0].split("%")[1])
- rAddr = (rAddr[0], rAddr[1], 0, scope)
- protocol = self.factory.buildProtocol(self._addressType("TCP", *rAddr))
- if protocol is None:
- evt.newskt.close()
- else:
- s = self.sessionno
- self.sessionno = s + 1
- transport = Server(
- evt.newskt,
- protocol,
- self._addressType("TCP", *rAddr),
- self._addressType("TCP", *lAddr),
- s,
- self.reactor,
- )
- protocol.makeConnection(transport)
- return True
- def doAccept(self):
- evt = _iocp.Event(self.cbAccept, self)
- # see AcceptEx documentation
- evt.buff = buff = bytearray(2 * (self.addrLen + 16))
- evt.newskt = newskt = self.reactor.createSocket(
- self.addressFamily, self.socketType
- )
- rc = _iocp.accept(self.socket.fileno(), newskt.fileno(), buff, evt)
- if rc and rc != ERROR_IO_PENDING:
- self.handleAccept(rc, evt)
|