# -*- test-case-name: twisted.internet.test.test_resolver -*- # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ IPv6-aware hostname resolution. @see: L{IHostnameResolver} """ from socket import ( AF_INET, AF_INET6, AF_UNSPEC, SOCK_DGRAM, SOCK_STREAM, AddressFamily, SocketKind, gaierror, getaddrinfo, ) from typing import ( TYPE_CHECKING, Callable, List, NoReturn, Optional, Sequence, Tuple, Type, Union, ) from zope.interface import implementer from twisted.internet._idna import _idnaBytes from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IAddress, IHostnameResolver, IHostResolution, IReactorThreads, IResolutionReceiver, IResolverSimple, ) from twisted.internet.threads import deferToThreadPool from twisted.logger import Logger from twisted.python.compat import nativeString if TYPE_CHECKING: from twisted.python.threadpool import ThreadPool @implementer(IHostResolution) class HostResolution: """ The in-progress resolution of a given hostname. """ def __init__(self, name: str): """ Create a L{HostResolution} with the given name. """ self.name = name def cancel(self) -> NoReturn: # IHostResolution.cancel raise NotImplementedError() _any = frozenset([IPv4Address, IPv6Address]) _typesToAF = { frozenset([IPv4Address]): AF_INET, frozenset([IPv6Address]): AF_INET6, _any: AF_UNSPEC, } _afToType = { AF_INET: IPv4Address, AF_INET6: IPv6Address, } _transportToSocket = { "TCP": SOCK_STREAM, "UDP": SOCK_DGRAM, } _socktypeToType = { SOCK_STREAM: "TCP", SOCK_DGRAM: "UDP", } _GETADDRINFO_RESULT = List[ Tuple[ AddressFamily, SocketKind, int, str, Union[Tuple[str, int], Tuple[str, int, int, int]], ] ] @implementer(IHostnameResolver) class GAIResolver: """ L{IHostnameResolver} implementation that resolves hostnames by calling L{getaddrinfo} in a thread. """ def __init__( self, reactor: IReactorThreads, getThreadPool: Optional[Callable[[], "ThreadPool"]] = None, getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo, ): """ Create a L{GAIResolver}. @param reactor: the reactor to schedule result-delivery on @type reactor: L{IReactorThreads} @param getThreadPool: a function to retrieve the thread pool to use for scheduling name resolutions. If not supplied, the use the given C{reactor}'s thread pool. @type getThreadPool: 0-argument callable returning a L{twisted.python.threadpool.ThreadPool} @param getaddrinfo: a reference to the L{getaddrinfo} to use - mainly parameterized for testing. @type getaddrinfo: callable with the same signature as L{getaddrinfo} """ self._reactor = reactor self._getThreadPool = ( reactor.getThreadPool if getThreadPool is None else getThreadPool ) self._getaddrinfo = getaddrinfo def resolveHostName( self, resolutionReceiver: IResolutionReceiver, hostName: str, portNumber: int = 0, addressTypes: Optional[Sequence[Type[IAddress]]] = None, transportSemantics: str = "TCP", ) -> IHostResolution: """ See L{IHostnameResolver.resolveHostName} @param resolutionReceiver: see interface @param hostName: see interface @param portNumber: see interface @param addressTypes: see interface @param transportSemantics: see interface @return: see interface """ pool = self._getThreadPool() addressFamily = _typesToAF[ _any if addressTypes is None else frozenset(addressTypes) ] socketType = _transportToSocket[transportSemantics] def get() -> _GETADDRINFO_RESULT: try: return self._getaddrinfo( hostName, portNumber, addressFamily, socketType ) except gaierror: return [] d = deferToThreadPool(self._reactor, pool, get) resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) @d.addCallback def deliverResults(result: _GETADDRINFO_RESULT) -> None: for family, socktype, proto, cannoname, sockaddr in result: addrType = _afToType[family] resolutionReceiver.addressResolved( addrType(_socktypeToType.get(socktype, "TCP"), *sockaddr) ) resolutionReceiver.resolutionComplete() return resolution @implementer(IHostnameResolver) class SimpleResolverComplexifier: """ A converter from L{IResolverSimple} to L{IHostnameResolver}. """ _log = Logger() def __init__(self, simpleResolver: IResolverSimple): """ Construct a L{SimpleResolverComplexifier} with an L{IResolverSimple}. """ self._simpleResolver = simpleResolver def resolveHostName( self, resolutionReceiver: IResolutionReceiver, hostName: str, portNumber: int = 0, addressTypes: Optional[Sequence[Type[IAddress]]] = None, transportSemantics: str = "TCP", ) -> IHostResolution: """ See L{IHostnameResolver.resolveHostName} @param resolutionReceiver: see interface @param hostName: see interface @param portNumber: see interface @param addressTypes: see interface @param transportSemantics: see interface @return: see interface """ # If it's str, we need to make sure that it's just ASCII. try: hostName_bytes = hostName.encode("ascii") except UnicodeEncodeError: # If it's not just ASCII, IDNA it. We don't want to give a Unicode # string with non-ASCII in it to Python 3, as if anyone passes that # to a Python 3 stdlib function, it will probably use the wrong # IDNA version and break absolutely everything hostName_bytes = _idnaBytes(hostName) # Make sure it's passed down as a native str, to maintain the interface hostName = nativeString(hostName_bytes) resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) ( self._simpleResolver.getHostByName(hostName) .addCallback( lambda address: resolutionReceiver.addressResolved( IPv4Address("TCP", address, portNumber) ) ) .addErrback( lambda error: None if error.check(DNSLookupError) else self._log.failure( "while looking up {name} with {resolver}", error, name=hostName, resolver=self._simpleResolver, ) ) .addCallback(lambda nothing: resolutionReceiver.resolutionComplete()) ) return resolution @implementer(IResolutionReceiver) class FirstOneWins: """ An L{IResolutionReceiver} which fires a L{Deferred} with its first result. """ def __init__(self, deferred: "Deferred[str]"): """ @param deferred: The L{Deferred} to fire when the first resolution result arrives. """ self._deferred = deferred self._resolved = False def resolutionBegan(self, resolution: IHostResolution) -> None: """ See L{IResolutionReceiver.resolutionBegan} @param resolution: See L{IResolutionReceiver.resolutionBegan} """ self._resolution = resolution def addressResolved(self, address: IAddress) -> None: """ See L{IResolutionReceiver.addressResolved} @param address: See L{IResolutionReceiver.addressResolved} """ if self._resolved: return self._resolved = True # This is used by ComplexResolverSimplifier which specifies only results # of IPv4Address. assert isinstance(address, IPv4Address) self._deferred.callback(address.host) def resolutionComplete(self) -> None: """ See L{IResolutionReceiver.resolutionComplete} """ if self._resolved: return self._deferred.errback(DNSLookupError(self._resolution.name)) @implementer(IResolverSimple) class ComplexResolverSimplifier: """ A converter from L{IHostnameResolver} to L{IResolverSimple} """ def __init__(self, nameResolver: IHostnameResolver): """ Create a L{ComplexResolverSimplifier} with an L{IHostnameResolver}. @param nameResolver: The L{IHostnameResolver} to use. """ self._nameResolver = nameResolver def getHostByName(self, name: str, timeouts: Sequence[int] = ()) -> "Deferred[str]": """ See L{IResolverSimple.getHostByName} @param name: see L{IResolverSimple.getHostByName} @param timeouts: see L{IResolverSimple.getHostByName} @return: see L{IResolverSimple.getHostByName} """ result: "Deferred[str]" = Deferred() self._nameResolver.resolveHostName(FirstOneWins(result), name, 0, [IPv4Address]) return result