|
@@ -1,12 +1,14 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
import ssl
|
|
|
import sys
|
|
|
-from types import TracebackType
|
|
|
-from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type
|
|
|
+import types
|
|
|
+import typing
|
|
|
|
|
|
from .._backends.auto import AutoBackend
|
|
|
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
|
|
|
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
|
|
|
-from .._models import Origin, Request, Response
|
|
|
+from .._models import Origin, Proxy, Request, Response
|
|
|
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
|
|
|
from .connection import AsyncHTTPConnection
|
|
|
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
|
|
@@ -15,12 +17,10 @@ from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
|
|
|
class AsyncPoolRequest:
|
|
|
def __init__(self, request: Request) -> None:
|
|
|
self.request = request
|
|
|
- self.connection: Optional[AsyncConnectionInterface] = None
|
|
|
+ self.connection: AsyncConnectionInterface | None = None
|
|
|
self._connection_acquired = AsyncEvent()
|
|
|
|
|
|
- def assign_to_connection(
|
|
|
- self, connection: Optional[AsyncConnectionInterface]
|
|
|
- ) -> None:
|
|
|
+ def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None:
|
|
|
self.connection = connection
|
|
|
self._connection_acquired.set()
|
|
|
|
|
@@ -29,7 +29,7 @@ class AsyncPoolRequest:
|
|
|
self._connection_acquired = AsyncEvent()
|
|
|
|
|
|
async def wait_for_connection(
|
|
|
- self, timeout: Optional[float] = None
|
|
|
+ self, timeout: float | None = None
|
|
|
) -> AsyncConnectionInterface:
|
|
|
if self.connection is None:
|
|
|
await self._connection_acquired.wait(timeout=timeout)
|
|
@@ -47,17 +47,18 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- ssl_context: Optional[ssl.SSLContext] = None,
|
|
|
- max_connections: Optional[int] = 10,
|
|
|
- max_keepalive_connections: Optional[int] = None,
|
|
|
- keepalive_expiry: Optional[float] = None,
|
|
|
+ ssl_context: ssl.SSLContext | None = None,
|
|
|
+ proxy: Proxy | None = None,
|
|
|
+ max_connections: int | None = 10,
|
|
|
+ max_keepalive_connections: int | None = None,
|
|
|
+ keepalive_expiry: float | None = None,
|
|
|
http1: bool = True,
|
|
|
http2: bool = False,
|
|
|
retries: int = 0,
|
|
|
- local_address: Optional[str] = None,
|
|
|
- uds: Optional[str] = None,
|
|
|
- network_backend: Optional[AsyncNetworkBackend] = None,
|
|
|
- socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
|
|
+ local_address: str | None = None,
|
|
|
+ uds: str | None = None,
|
|
|
+ network_backend: AsyncNetworkBackend | None = None,
|
|
|
+ socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
|
|
) -> None:
|
|
|
"""
|
|
|
A connection pool for making HTTP requests.
|
|
@@ -89,7 +90,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
in the TCP socket when the connection was established.
|
|
|
"""
|
|
|
self._ssl_context = ssl_context
|
|
|
-
|
|
|
+ self._proxy = proxy
|
|
|
self._max_connections = (
|
|
|
sys.maxsize if max_connections is None else max_connections
|
|
|
)
|
|
@@ -116,8 +117,8 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
|
|
|
# The mutable state on a connection pool is the queue of incoming requests,
|
|
|
# and the set of connections that are servicing those requests.
|
|
|
- self._connections: List[AsyncConnectionInterface] = []
|
|
|
- self._requests: List[AsyncPoolRequest] = []
|
|
|
+ self._connections: list[AsyncConnectionInterface] = []
|
|
|
+ self._requests: list[AsyncPoolRequest] = []
|
|
|
|
|
|
# We only mutate the state of the connection pool within an 'optional_thread_lock'
|
|
|
# context. This holds a threading lock unless we're running in async mode,
|
|
@@ -125,6 +126,45 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
self._optional_thread_lock = AsyncThreadLock()
|
|
|
|
|
|
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
|
|
+ if self._proxy is not None:
|
|
|
+ if self._proxy.url.scheme in (b"socks5", b"socks5h"):
|
|
|
+ from .socks_proxy import AsyncSocks5Connection
|
|
|
+
|
|
|
+ return AsyncSocks5Connection(
|
|
|
+ proxy_origin=self._proxy.url.origin,
|
|
|
+ proxy_auth=self._proxy.auth,
|
|
|
+ remote_origin=origin,
|
|
|
+ ssl_context=self._ssl_context,
|
|
|
+ keepalive_expiry=self._keepalive_expiry,
|
|
|
+ http1=self._http1,
|
|
|
+ http2=self._http2,
|
|
|
+ network_backend=self._network_backend,
|
|
|
+ )
|
|
|
+ elif origin.scheme == b"http":
|
|
|
+ from .http_proxy import AsyncForwardHTTPConnection
|
|
|
+
|
|
|
+ return AsyncForwardHTTPConnection(
|
|
|
+ proxy_origin=self._proxy.url.origin,
|
|
|
+ proxy_headers=self._proxy.headers,
|
|
|
+ proxy_ssl_context=self._proxy.ssl_context,
|
|
|
+ remote_origin=origin,
|
|
|
+ keepalive_expiry=self._keepalive_expiry,
|
|
|
+ network_backend=self._network_backend,
|
|
|
+ )
|
|
|
+ from .http_proxy import AsyncTunnelHTTPConnection
|
|
|
+
|
|
|
+ return AsyncTunnelHTTPConnection(
|
|
|
+ proxy_origin=self._proxy.url.origin,
|
|
|
+ proxy_headers=self._proxy.headers,
|
|
|
+ proxy_ssl_context=self._proxy.ssl_context,
|
|
|
+ remote_origin=origin,
|
|
|
+ ssl_context=self._ssl_context,
|
|
|
+ keepalive_expiry=self._keepalive_expiry,
|
|
|
+ http1=self._http1,
|
|
|
+ http2=self._http2,
|
|
|
+ network_backend=self._network_backend,
|
|
|
+ )
|
|
|
+
|
|
|
return AsyncHTTPConnection(
|
|
|
origin=origin,
|
|
|
ssl_context=self._ssl_context,
|
|
@@ -139,7 +179,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
)
|
|
|
|
|
|
@property
|
|
|
- def connections(self) -> List[AsyncConnectionInterface]:
|
|
|
+ def connections(self) -> list[AsyncConnectionInterface]:
|
|
|
"""
|
|
|
Return a list of the connections currently in the pool.
|
|
|
|
|
@@ -217,7 +257,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
|
|
|
# Return the response. Note that in this case we still have to manage
|
|
|
# the point at which the response is closed.
|
|
|
- assert isinstance(response.stream, AsyncIterable)
|
|
|
+ assert isinstance(response.stream, typing.AsyncIterable)
|
|
|
return Response(
|
|
|
status=response.status,
|
|
|
headers=response.headers,
|
|
@@ -227,7 +267,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
extensions=response.extensions,
|
|
|
)
|
|
|
|
|
|
- def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]:
|
|
|
+ def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]:
|
|
|
"""
|
|
|
Manage the state of the connection pool, assigning incoming
|
|
|
requests to connections as available.
|
|
@@ -298,7 +338,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
|
|
|
return closing_connections
|
|
|
|
|
|
- async def _close_connections(self, closing: List[AsyncConnectionInterface]) -> None:
|
|
|
+ async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None:
|
|
|
# Close connections which have been removed from the pool.
|
|
|
with AsyncShieldCancellation():
|
|
|
for connection in closing:
|
|
@@ -312,14 +352,14 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
self._connections = []
|
|
|
await self._close_connections(closing_connections)
|
|
|
|
|
|
- async def __aenter__(self) -> "AsyncConnectionPool":
|
|
|
+ async def __aenter__(self) -> AsyncConnectionPool:
|
|
|
return self
|
|
|
|
|
|
async def __aexit__(
|
|
|
self,
|
|
|
- exc_type: Optional[Type[BaseException]] = None,
|
|
|
- exc_value: Optional[BaseException] = None,
|
|
|
- traceback: Optional[TracebackType] = None,
|
|
|
+ exc_type: type[BaseException] | None = None,
|
|
|
+ exc_value: BaseException | None = None,
|
|
|
+ traceback: types.TracebackType | None = None,
|
|
|
) -> None:
|
|
|
await self.aclose()
|
|
|
|
|
@@ -349,7 +389,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
|
|
|
class PoolByteStream:
|
|
|
def __init__(
|
|
|
self,
|
|
|
- stream: AsyncIterable[bytes],
|
|
|
+ stream: typing.AsyncIterable[bytes],
|
|
|
pool_request: AsyncPoolRequest,
|
|
|
pool: AsyncConnectionPool,
|
|
|
) -> None:
|
|
@@ -358,7 +398,7 @@ class PoolByteStream:
|
|
|
self._pool = pool
|
|
|
self._closed = False
|
|
|
|
|
|
- async def __aiter__(self) -> AsyncIterator[bytes]:
|
|
|
+ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
|
|
try:
|
|
|
async for part in self._stream:
|
|
|
yield part
|