_resolver.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # -*- test-case-name: twisted.internet.test.test_resolver -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. IPv6-aware hostname resolution.
  6. @see: L{IHostnameResolver}
  7. """
  8. from __future__ import division, absolute_import
  9. __metaclass__ = type
  10. from socket import (getaddrinfo, AF_INET, AF_INET6, AF_UNSPEC, SOCK_STREAM,
  11. SOCK_DGRAM, gaierror)
  12. from zope.interface import implementer
  13. from twisted.internet.interfaces import (IHostnameResolver, IHostResolution,
  14. IResolverSimple, IResolutionReceiver)
  15. from twisted.internet.error import DNSLookupError
  16. from twisted.internet.defer import Deferred
  17. from twisted.internet.threads import deferToThreadPool
  18. from twisted.internet.address import IPv4Address, IPv6Address
  19. from twisted.python.compat import nativeString
  20. from twisted.internet._idna import _idnaBytes
  21. from twisted.logger import Logger
  22. @implementer(IHostResolution)
  23. class HostResolution(object):
  24. """
  25. The in-progress resolution of a given hostname.
  26. """
  27. def __init__(self, name):
  28. """
  29. Create a L{HostResolution} with the given name.
  30. """
  31. self.name = name
  32. _any = frozenset([IPv4Address, IPv6Address])
  33. _typesToAF = {
  34. frozenset([IPv4Address]): AF_INET,
  35. frozenset([IPv6Address]): AF_INET6,
  36. _any: AF_UNSPEC,
  37. }
  38. _afToType = {
  39. AF_INET: IPv4Address,
  40. AF_INET6: IPv6Address,
  41. }
  42. _transportToSocket = {
  43. 'TCP': SOCK_STREAM,
  44. 'UDP': SOCK_DGRAM,
  45. }
  46. _socktypeToType = {
  47. SOCK_STREAM: 'TCP',
  48. SOCK_DGRAM: 'UDP',
  49. }
  50. @implementer(IHostnameResolver)
  51. class GAIResolver(object):
  52. """
  53. L{IHostnameResolver} implementation that resolves hostnames by calling
  54. L{getaddrinfo} in a thread.
  55. """
  56. def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
  57. """
  58. Create a L{GAIResolver}.
  59. @param reactor: the reactor to schedule result-delivery on
  60. @type reactor: L{IReactorThreads}
  61. @param getThreadPool: a function to retrieve the thread pool to use for
  62. scheduling name resolutions. If not supplied, the use the given
  63. C{reactor}'s thread pool.
  64. @type getThreadPool: 0-argument callable returning a
  65. L{twisted.python.threadpool.ThreadPool}
  66. @param getaddrinfo: a reference to the L{getaddrinfo} to use - mainly
  67. parameterized for testing.
  68. @type getaddrinfo: callable with the same signature as L{getaddrinfo}
  69. """
  70. self._reactor = reactor
  71. self._getThreadPool = (reactor.getThreadPool if getThreadPool is None
  72. else getThreadPool)
  73. self._getaddrinfo = getaddrinfo
  74. def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
  75. addressTypes=None, transportSemantics='TCP'):
  76. """
  77. See L{IHostnameResolver.resolveHostName}
  78. @param resolutionReceiver: see interface
  79. @param hostName: see interface
  80. @param portNumber: see interface
  81. @param addressTypes: see interface
  82. @param transportSemantics: see interface
  83. @return: see interface
  84. """
  85. pool = self._getThreadPool()
  86. addressFamily = _typesToAF[_any if addressTypes is None
  87. else frozenset(addressTypes)]
  88. socketType = _transportToSocket[transportSemantics]
  89. def get():
  90. try:
  91. return self._getaddrinfo(hostName, portNumber, addressFamily,
  92. socketType)
  93. except gaierror:
  94. return []
  95. d = deferToThreadPool(self._reactor, pool, get)
  96. resolution = HostResolution(hostName)
  97. resolutionReceiver.resolutionBegan(resolution)
  98. @d.addCallback
  99. def deliverResults(result):
  100. for family, socktype, proto, cannoname, sockaddr in result:
  101. addrType = _afToType[family]
  102. resolutionReceiver.addressResolved(
  103. addrType(_socktypeToType.get(socktype, 'TCP'), *sockaddr)
  104. )
  105. resolutionReceiver.resolutionComplete()
  106. return resolution
  107. @implementer(IHostnameResolver)
  108. class SimpleResolverComplexifier(object):
  109. """
  110. A converter from L{IResolverSimple} to L{IHostnameResolver}.
  111. """
  112. _log = Logger()
  113. def __init__(self, simpleResolver):
  114. """
  115. Construct a L{SimpleResolverComplexifier} with an L{IResolverSimple}.
  116. """
  117. self._simpleResolver = simpleResolver
  118. def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
  119. addressTypes=None, transportSemantics='TCP'):
  120. """
  121. See L{IHostnameResolver.resolveHostName}
  122. @param resolutionReceiver: see interface
  123. @param hostName: see interface
  124. @param portNumber: see interface
  125. @param addressTypes: see interface
  126. @param transportSemantics: see interface
  127. @return: see interface
  128. """
  129. # If it's str, we need to make sure that it's just ASCII.
  130. try:
  131. hostName = hostName.encode('ascii')
  132. except UnicodeEncodeError:
  133. # If it's not just ASCII, IDNA it. We don't want to give a Unicode
  134. # string with non-ASCII in it to Python 3, as if anyone passes that
  135. # to a Python 3 stdlib function, it will probably use the wrong
  136. # IDNA version and break absolutely everything
  137. hostName = _idnaBytes(hostName)
  138. # Make sure it's passed down as a native str, to maintain the interface
  139. hostName = nativeString(hostName)
  140. resolution = HostResolution(hostName)
  141. resolutionReceiver.resolutionBegan(resolution)
  142. onAddress = self._simpleResolver.getHostByName(hostName)
  143. def addressReceived(address):
  144. resolutionReceiver.addressResolved(IPv4Address('TCP', address,
  145. portNumber))
  146. def errorReceived(error):
  147. if not error.check(DNSLookupError):
  148. self._log.failure("while looking up {name} with {resolver}",
  149. error, name=hostName,
  150. resolver=self._simpleResolver)
  151. onAddress.addCallbacks(addressReceived, errorReceived)
  152. def finish(result):
  153. resolutionReceiver.resolutionComplete()
  154. onAddress.addCallback(finish)
  155. return resolution
  156. @implementer(IResolutionReceiver)
  157. class FirstOneWins(object):
  158. """
  159. An L{IResolutionReceiver} which fires a L{Deferred} with its first result.
  160. """
  161. def __init__(self, deferred):
  162. """
  163. @param deferred: The L{Deferred} to fire when the first resolution
  164. result arrives.
  165. """
  166. self._deferred = deferred
  167. self._resolved = False
  168. def resolutionBegan(self, resolution):
  169. """
  170. See L{IResolutionReceiver.resolutionBegan}
  171. @param resolution: See L{IResolutionReceiver.resolutionBegan}
  172. """
  173. self._resolution = resolution
  174. def addressResolved(self, address):
  175. """
  176. See L{IResolutionReceiver.addressResolved}
  177. @param address: See L{IResolutionReceiver.addressResolved}
  178. """
  179. if self._resolved:
  180. return
  181. self._resolved = True
  182. self._deferred.callback(address.host)
  183. def resolutionComplete(self):
  184. """
  185. See L{IResolutionReceiver.resolutionComplete}
  186. """
  187. if self._resolved:
  188. return
  189. self._deferred.errback(DNSLookupError(self._resolution.name))
  190. @implementer(IResolverSimple)
  191. class ComplexResolverSimplifier(object):
  192. """
  193. A converter from L{IHostnameResolver} to L{IResolverSimple}
  194. """
  195. def __init__(self, nameResolver):
  196. """
  197. Create a L{ComplexResolverSimplifier} with an L{IHostnameResolver}.
  198. @param nameResolver: The L{IHostnameResolver} to use.
  199. """
  200. self._nameResolver = nameResolver
  201. def getHostByName(self, name, timeouts=()):
  202. """
  203. See L{IResolverSimple.getHostByName}
  204. @param name: see L{IResolverSimple.getHostByName}
  205. @param timeouts: see L{IResolverSimple.getHostByName}
  206. @return: see L{IResolverSimple.getHostByName}
  207. """
  208. result = Deferred()
  209. self._nameResolver.resolveHostName(FirstOneWins(result), name, 0,
  210. [IPv4Address])
  211. return result