123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- # -*- test-case-name: twisted.protocols.haproxy.test.test_wrapper -*-
- # Copyright (c) Twisted Matrix Laboratories.
- # See LICENSE for details.
- """
- Protocol wrapper that provides HAProxy PROXY protocol support.
- """
- from typing import Optional, Union
- from twisted.internet import interfaces
- from twisted.internet.endpoints import _WrapperServerEndpoint
- from twisted.protocols import policies
- from . import _info
- from ._exceptions import InvalidProxyHeader
- from ._v1parser import V1Parser
- from ._v2parser import V2Parser
- class HAProxyProtocolWrapper(policies.ProtocolWrapper):
- """
- A Protocol wrapper that provides HAProxy support.
- This protocol reads the PROXY stream header, v1 or v2, parses the provided
- connection data, and modifies the behavior of getPeer and getHost to return
- the data provided by the PROXY header.
- """
- def __init__(
- self, factory: policies.WrappingFactory, wrappedProtocol: interfaces.IProtocol
- ):
- super().__init__(factory, wrappedProtocol)
- self._proxyInfo: Optional[_info.ProxyInfo] = None
- self._parser: Union[V2Parser, V1Parser, None] = None
- def dataReceived(self, data: bytes) -> None:
- if self._proxyInfo is not None:
- return self.wrappedProtocol.dataReceived(data)
- parser = self._parser
- if parser is None:
- if (
- len(data) >= 16
- and data[:12] == V2Parser.PREFIX
- and ord(data[12:13]) & 0b11110000 == 0x20
- ):
- self._parser = parser = V2Parser()
- elif len(data) >= 8 and data[:5] == V1Parser.PROXYSTR:
- self._parser = parser = V1Parser()
- else:
- self.loseConnection()
- return None
- try:
- self._proxyInfo, remaining = parser.feed(data)
- if remaining:
- self.wrappedProtocol.dataReceived(remaining)
- except InvalidProxyHeader:
- self.loseConnection()
- def getPeer(self) -> interfaces.IAddress:
- if self._proxyInfo and self._proxyInfo.source:
- return self._proxyInfo.source
- assert self.transport
- return self.transport.getPeer()
- def getHost(self) -> interfaces.IAddress:
- if self._proxyInfo and self._proxyInfo.destination:
- return self._proxyInfo.destination
- assert self.transport
- return self.transport.getHost()
- class HAProxyWrappingFactory(policies.WrappingFactory):
- """
- A Factory wrapper that adds PROXY protocol support to connections.
- """
- protocol = HAProxyProtocolWrapper
- def logPrefix(self) -> str:
- """
- Annotate the wrapped factory's log prefix with some text indicating
- the PROXY protocol is in use.
- @rtype: C{str}
- """
- if interfaces.ILoggingContext.providedBy(self.wrappedFactory):
- logPrefix = self.wrappedFactory.logPrefix()
- else:
- logPrefix = self.wrappedFactory.__class__.__name__
- return f"{logPrefix} (PROXY)"
- def proxyEndpoint(
- wrappedEndpoint: interfaces.IStreamServerEndpoint,
- ) -> _WrapperServerEndpoint:
- """
- Wrap an endpoint with PROXY protocol support, so that the transport's
- C{getHost} and C{getPeer} methods reflect the attributes of the proxied
- connection rather than the underlying connection.
- @param wrappedEndpoint: The underlying listening endpoint.
- @type wrappedEndpoint: L{IStreamServerEndpoint}
- @return: a new listening endpoint that speaks the PROXY protocol.
- @rtype: L{IStreamServerEndpoint}
- """
- return _WrapperServerEndpoint(wrappedEndpoint, HAProxyWrappingFactory)
|