srvconnect.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # -*- test-case-name: twisted.names.test.test_srvconnect -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. from __future__ import absolute_import, division
  5. import random
  6. from zope.interface import implementer
  7. from twisted.internet import error, interfaces
  8. from twisted.names import client, dns
  9. from twisted.names.error import DNSNameError
  10. from twisted.python.compat import nativeString
  11. class _SRVConnector_ClientFactoryWrapper:
  12. def __init__(self, connector, wrappedFactory):
  13. self.__connector = connector
  14. self.__wrappedFactory = wrappedFactory
  15. def startedConnecting(self, connector):
  16. self.__wrappedFactory.startedConnecting(self.__connector)
  17. def clientConnectionFailed(self, connector, reason):
  18. self.__connector.connectionFailed(reason)
  19. def clientConnectionLost(self, connector, reason):
  20. self.__connector.connectionLost(reason)
  21. def __getattr__(self, key):
  22. return getattr(self.__wrappedFactory, key)
  23. @implementer(interfaces.IConnector)
  24. class SRVConnector:
  25. """
  26. A connector that looks up DNS SRV records.
  27. RFC 2782 details how SRV records should be interpreted and selected
  28. for subsequent connection attempts. The algorithm for using the records'
  29. priority and weight is implemented in L{pickServer}.
  30. @ivar servers: List of candidate server records for future connection
  31. attempts.
  32. @type servers: L{list} of L{dns.Record_SRV}
  33. @ivar orderedServers: List of server records that have already been tried
  34. in this round of connection attempts.
  35. @type orderedServers: L{list} of L{dns.Record_SRV}
  36. """
  37. stopAfterDNS = 0
  38. def __init__(self, reactor, service, domain, factory,
  39. protocol='tcp', connectFuncName='connectTCP',
  40. connectFuncArgs=(),
  41. connectFuncKwArgs={},
  42. defaultPort=None,
  43. ):
  44. """
  45. @param domain: The domain to connect to. If passed as a text
  46. string, it will be encoded using C{idna} encoding.
  47. @type domain: L{bytes} or L{str}
  48. @param defaultPort: Optional default port number to be used when SRV
  49. lookup fails and the service name is unknown. This should be the
  50. port number associated with the service name as defined by the IANA
  51. registry.
  52. @type defaultPort: L{int}
  53. """
  54. self.reactor = reactor
  55. self.service = service
  56. self.domain = None if domain is None else dns.domainString(domain)
  57. self.factory = factory
  58. self.protocol = protocol
  59. self.connectFuncName = connectFuncName
  60. self.connectFuncArgs = connectFuncArgs
  61. self.connectFuncKwArgs = connectFuncKwArgs
  62. self._defaultPort = defaultPort
  63. self.connector = None
  64. self.servers = None
  65. # list of servers already used in this round:
  66. self.orderedServers = None
  67. def connect(self):
  68. """Start connection to remote server."""
  69. self.factory.doStart()
  70. self.factory.startedConnecting(self)
  71. if not self.servers:
  72. if self.domain is None:
  73. self.connectionFailed(
  74. error.DNSLookupError("Domain is not defined."),
  75. )
  76. return
  77. d = client.lookupService('_%s._%s.%s' % (
  78. nativeString(self.service),
  79. nativeString(self.protocol),
  80. nativeString(self.domain)),
  81. )
  82. d.addCallbacks(self._cbGotServers, self._ebGotServers)
  83. d.addCallback(lambda x, self=self: self._reallyConnect())
  84. if self._defaultPort:
  85. d.addErrback(self._ebServiceUnknown)
  86. d.addErrback(self.connectionFailed)
  87. elif self.connector is None:
  88. self._reallyConnect()
  89. else:
  90. self.connector.connect()
  91. def _ebGotServers(self, failure):
  92. failure.trap(DNSNameError)
  93. # Some DNS servers reply with NXDOMAIN when in fact there are
  94. # just no SRV records for that domain. Act as if we just got an
  95. # empty response and use fallback.
  96. self.servers = []
  97. self.orderedServers = []
  98. def _cbGotServers(self, result):
  99. answers, auth, add = result
  100. if len(answers) == 1 and answers[0].type == dns.SRV \
  101. and answers[0].payload \
  102. and answers[0].payload.target == dns.Name(b'.'):
  103. # decidedly not available
  104. raise error.DNSLookupError("Service %s not available for domain %s."
  105. % (repr(self.service), repr(self.domain)))
  106. self.servers = []
  107. self.orderedServers = []
  108. for a in answers:
  109. if a.type != dns.SRV or not a.payload:
  110. continue
  111. self.orderedServers.append(a.payload)
  112. def _ebServiceUnknown(self, failure):
  113. """
  114. Connect to the default port when the service name is unknown.
  115. If no SRV records were found, the service name will be passed as the
  116. port. If resolving the name fails with
  117. L{error.ServiceNameUnknownError}, a final attempt is done using the
  118. default port.
  119. """
  120. failure.trap(error.ServiceNameUnknownError)
  121. self.servers = [dns.Record_SRV(0, 0, self._defaultPort, self.domain)]
  122. self.orderedServers = []
  123. self.connect()
  124. def pickServer(self):
  125. """
  126. Pick the next server.
  127. This selects the next server from the list of SRV records according
  128. to their priority and weight values, as set out by the default
  129. algorithm specified in RFC 2782.
  130. At the beginning of a round, L{servers} is populated with
  131. L{orderedServers}, and the latter is made empty. L{servers}
  132. is the list of candidates, and L{orderedServers} is the list of servers
  133. that have already been tried.
  134. First, all records are ordered by priority and weight in ascending
  135. order. Then for each priority level, a running sum is calculated
  136. over the sorted list of records for that priority. Then a random value
  137. between 0 and the final sum is compared to each record in order. The
  138. first record that is greater than or equal to that random value is
  139. chosen and removed from the list of candidates for this round.
  140. @return: A tuple of target hostname and port from the chosen DNS SRV
  141. record.
  142. @rtype: L{tuple} of native L{str} and L{int}
  143. """
  144. assert self.servers is not None
  145. assert self.orderedServers is not None
  146. if not self.servers and not self.orderedServers:
  147. # no SRV record, fall back..
  148. return nativeString(self.domain), self.service
  149. if not self.servers and self.orderedServers:
  150. # start new round
  151. self.servers = self.orderedServers
  152. self.orderedServers = []
  153. assert self.servers
  154. self.servers.sort(key=lambda record: (record.priority, record.weight))
  155. minPriority = self.servers[0].priority
  156. index = 0
  157. weightSum = 0
  158. weightIndex = []
  159. for x in self.servers:
  160. if x.priority == minPriority:
  161. weightSum += x.weight
  162. weightIndex.append((index, weightSum))
  163. index += 1
  164. rand = random.randint(0, weightSum)
  165. for index, weight in weightIndex:
  166. if weight >= rand:
  167. chosen = self.servers[index]
  168. del self.servers[index]
  169. self.orderedServers.append(chosen)
  170. return str(chosen.target), chosen.port
  171. raise RuntimeError(
  172. 'Impossible %s pickServer result.' % (self.__class__.__name__,))
  173. def _reallyConnect(self):
  174. if self.stopAfterDNS:
  175. self.stopAfterDNS = 0
  176. return
  177. self.host, self.port = self.pickServer()
  178. assert self.host is not None, 'Must have a host to connect to.'
  179. assert self.port is not None, 'Must have a port to connect to.'
  180. connectFunc = getattr(self.reactor, self.connectFuncName)
  181. self.connector = connectFunc(
  182. self.host, self.port,
  183. _SRVConnector_ClientFactoryWrapper(self, self.factory),
  184. *self.connectFuncArgs, **self.connectFuncKwArgs)
  185. def stopConnecting(self):
  186. """Stop attempting to connect."""
  187. if self.connector:
  188. self.connector.stopConnecting()
  189. else:
  190. self.stopAfterDNS = 1
  191. def disconnect(self):
  192. """Disconnect whatever our are state is."""
  193. if self.connector is not None:
  194. self.connector.disconnect()
  195. else:
  196. self.stopConnecting()
  197. def getDestination(self):
  198. assert self.connector
  199. return self.connector.getDestination()
  200. def connectionFailed(self, reason):
  201. self.factory.clientConnectionFailed(self, reason)
  202. self.factory.doStop()
  203. def connectionLost(self, reason):
  204. self.factory.clientConnectionLost(self, reason)
  205. self.factory.doStop()