_wrapper.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # -*- test-case-name: twisted.protocols.haproxy.test.test_wrapper -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Protocol wrapper that provides HAProxy PROXY protocol support.
  6. """
  7. from typing import Optional, Union
  8. from twisted.internet import interfaces
  9. from twisted.internet.endpoints import _WrapperServerEndpoint
  10. from twisted.protocols import policies
  11. from . import _info
  12. from ._exceptions import InvalidProxyHeader
  13. from ._v1parser import V1Parser
  14. from ._v2parser import V2Parser
  15. class HAProxyProtocolWrapper(policies.ProtocolWrapper):
  16. """
  17. A Protocol wrapper that provides HAProxy support.
  18. This protocol reads the PROXY stream header, v1 or v2, parses the provided
  19. connection data, and modifies the behavior of getPeer and getHost to return
  20. the data provided by the PROXY header.
  21. """
  22. def __init__(
  23. self, factory: policies.WrappingFactory, wrappedProtocol: interfaces.IProtocol
  24. ):
  25. super().__init__(factory, wrappedProtocol)
  26. self._proxyInfo: Optional[_info.ProxyInfo] = None
  27. self._parser: Union[V2Parser, V1Parser, None] = None
  28. def dataReceived(self, data: bytes) -> None:
  29. if self._proxyInfo is not None:
  30. return self.wrappedProtocol.dataReceived(data)
  31. parser = self._parser
  32. if parser is None:
  33. if (
  34. len(data) >= 16
  35. and data[:12] == V2Parser.PREFIX
  36. and ord(data[12:13]) & 0b11110000 == 0x20
  37. ):
  38. self._parser = parser = V2Parser()
  39. elif len(data) >= 8 and data[:5] == V1Parser.PROXYSTR:
  40. self._parser = parser = V1Parser()
  41. else:
  42. self.loseConnection()
  43. return None
  44. try:
  45. self._proxyInfo, remaining = parser.feed(data)
  46. if remaining:
  47. self.wrappedProtocol.dataReceived(remaining)
  48. except InvalidProxyHeader:
  49. self.loseConnection()
  50. def getPeer(self) -> interfaces.IAddress:
  51. if self._proxyInfo and self._proxyInfo.source:
  52. return self._proxyInfo.source
  53. assert self.transport
  54. return self.transport.getPeer()
  55. def getHost(self) -> interfaces.IAddress:
  56. if self._proxyInfo and self._proxyInfo.destination:
  57. return self._proxyInfo.destination
  58. assert self.transport
  59. return self.transport.getHost()
  60. class HAProxyWrappingFactory(policies.WrappingFactory):
  61. """
  62. A Factory wrapper that adds PROXY protocol support to connections.
  63. """
  64. protocol = HAProxyProtocolWrapper
  65. def logPrefix(self) -> str:
  66. """
  67. Annotate the wrapped factory's log prefix with some text indicating
  68. the PROXY protocol is in use.
  69. @rtype: C{str}
  70. """
  71. if interfaces.ILoggingContext.providedBy(self.wrappedFactory):
  72. logPrefix = self.wrappedFactory.logPrefix()
  73. else:
  74. logPrefix = self.wrappedFactory.__class__.__name__
  75. return f"{logPrefix} (PROXY)"
  76. def proxyEndpoint(
  77. wrappedEndpoint: interfaces.IStreamServerEndpoint,
  78. ) -> _WrapperServerEndpoint:
  79. """
  80. Wrap an endpoint with PROXY protocol support, so that the transport's
  81. C{getHost} and C{getPeer} methods reflect the attributes of the proxied
  82. connection rather than the underlying connection.
  83. @param wrappedEndpoint: The underlying listening endpoint.
  84. @type wrappedEndpoint: L{IStreamServerEndpoint}
  85. @return: a new listening endpoint that speaks the PROXY protocol.
  86. @rtype: L{IStreamServerEndpoint}
  87. """
  88. return _WrapperServerEndpoint(wrappedEndpoint, HAProxyWrappingFactory)