123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492 |
- # Copyright 2019 gRPC authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Invocation-side implementation of gRPC Asyncio Python."""
- import asyncio
- import sys
- from typing import Any, Iterable, List, Optional, Sequence
- import grpc
- from grpc import _common
- from grpc import _compression
- from grpc import _grpcio_metadata
- from grpc._cython import cygrpc
- from . import _base_call
- from . import _base_channel
- from ._call import StreamStreamCall
- from ._call import StreamUnaryCall
- from ._call import UnaryStreamCall
- from ._call import UnaryUnaryCall
- from ._interceptor import ClientInterceptor
- from ._interceptor import InterceptedStreamStreamCall
- from ._interceptor import InterceptedStreamUnaryCall
- from ._interceptor import InterceptedUnaryStreamCall
- from ._interceptor import InterceptedUnaryUnaryCall
- from ._interceptor import StreamStreamClientInterceptor
- from ._interceptor import StreamUnaryClientInterceptor
- from ._interceptor import UnaryStreamClientInterceptor
- from ._interceptor import UnaryUnaryClientInterceptor
- from ._metadata import Metadata
- from ._typing import ChannelArgumentType
- from ._typing import DeserializingFunction
- from ._typing import RequestIterableType
- from ._typing import SerializingFunction
- from ._utils import _timeout_to_deadline
- _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
- if sys.version_info[1] < 7:
- def _all_tasks() -> Iterable[asyncio.Task]:
- return asyncio.Task.all_tasks()
- else:
- def _all_tasks() -> Iterable[asyncio.Task]:
- return asyncio.all_tasks()
- def _augment_channel_arguments(base_options: ChannelArgumentType,
- compression: Optional[grpc.Compression]):
- compression_channel_argument = _compression.create_channel_option(
- compression)
- user_agent_channel_argument = ((
- cygrpc.ChannelArgKey.primary_user_agent_string,
- _USER_AGENT,
- ),)
- return tuple(base_options
- ) + compression_channel_argument + user_agent_channel_argument
- class _BaseMultiCallable:
- """Base class of all multi callable objects.
- Handles the initialization logic and stores common attributes.
- """
- _loop: asyncio.AbstractEventLoop
- _channel: cygrpc.AioChannel
- _method: bytes
- _request_serializer: SerializingFunction
- _response_deserializer: DeserializingFunction
- _interceptors: Optional[Sequence[ClientInterceptor]]
- _references: List[Any]
- _loop: asyncio.AbstractEventLoop
- # pylint: disable=too-many-arguments
- def __init__(
- self,
- channel: cygrpc.AioChannel,
- method: bytes,
- request_serializer: SerializingFunction,
- response_deserializer: DeserializingFunction,
- interceptors: Optional[Sequence[ClientInterceptor]],
- references: List[Any],
- loop: asyncio.AbstractEventLoop,
- ) -> None:
- self._loop = loop
- self._channel = channel
- self._method = method
- self._request_serializer = request_serializer
- self._response_deserializer = response_deserializer
- self._interceptors = interceptors
- self._references = references
- @staticmethod
- def _init_metadata(
- metadata: Optional[Metadata] = None,
- compression: Optional[grpc.Compression] = None) -> Metadata:
- """Based on the provided values for <metadata> or <compression> initialise the final
- metadata, as it should be used for the current call.
- """
- metadata = metadata or Metadata()
- if compression:
- metadata = Metadata(
- *_compression.augment_metadata(metadata, compression))
- return metadata
- class UnaryUnaryMultiCallable(_BaseMultiCallable,
- _base_channel.UnaryUnaryMultiCallable):
- def __call__(
- self,
- request: Any,
- *,
- timeout: Optional[float] = None,
- metadata: Optional[Metadata] = None,
- credentials: Optional[grpc.CallCredentials] = None,
- wait_for_ready: Optional[bool] = None,
- compression: Optional[grpc.Compression] = None
- ) -> _base_call.UnaryUnaryCall:
- metadata = self._init_metadata(metadata, compression)
- if not self._interceptors:
- call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
- metadata, credentials, wait_for_ready,
- self._channel, self._method,
- self._request_serializer,
- self._response_deserializer, self._loop)
- else:
- call = InterceptedUnaryUnaryCall(
- self._interceptors, request, timeout, metadata, credentials,
- wait_for_ready, self._channel, self._method,
- self._request_serializer, self._response_deserializer,
- self._loop)
- return call
- class UnaryStreamMultiCallable(_BaseMultiCallable,
- _base_channel.UnaryStreamMultiCallable):
- def __call__(
- self,
- request: Any,
- *,
- timeout: Optional[float] = None,
- metadata: Optional[Metadata] = None,
- credentials: Optional[grpc.CallCredentials] = None,
- wait_for_ready: Optional[bool] = None,
- compression: Optional[grpc.Compression] = None
- ) -> _base_call.UnaryStreamCall:
- metadata = self._init_metadata(metadata, compression)
- deadline = _timeout_to_deadline(timeout)
- if not self._interceptors:
- call = UnaryStreamCall(request, deadline, metadata, credentials,
- wait_for_ready, self._channel, self._method,
- self._request_serializer,
- self._response_deserializer, self._loop)
- else:
- call = InterceptedUnaryStreamCall(
- self._interceptors, request, deadline, metadata, credentials,
- wait_for_ready, self._channel, self._method,
- self._request_serializer, self._response_deserializer,
- self._loop)
- return call
- class StreamUnaryMultiCallable(_BaseMultiCallable,
- _base_channel.StreamUnaryMultiCallable):
- def __call__(
- self,
- request_iterator: Optional[RequestIterableType] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Metadata] = None,
- credentials: Optional[grpc.CallCredentials] = None,
- wait_for_ready: Optional[bool] = None,
- compression: Optional[grpc.Compression] = None
- ) -> _base_call.StreamUnaryCall:
- metadata = self._init_metadata(metadata, compression)
- deadline = _timeout_to_deadline(timeout)
- if not self._interceptors:
- call = StreamUnaryCall(request_iterator, deadline, metadata,
- credentials, wait_for_ready, self._channel,
- self._method, self._request_serializer,
- self._response_deserializer, self._loop)
- else:
- call = InterceptedStreamUnaryCall(
- self._interceptors, request_iterator, deadline, metadata,
- credentials, wait_for_ready, self._channel, self._method,
- self._request_serializer, self._response_deserializer,
- self._loop)
- return call
- class StreamStreamMultiCallable(_BaseMultiCallable,
- _base_channel.StreamStreamMultiCallable):
- def __call__(
- self,
- request_iterator: Optional[RequestIterableType] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Metadata] = None,
- credentials: Optional[grpc.CallCredentials] = None,
- wait_for_ready: Optional[bool] = None,
- compression: Optional[grpc.Compression] = None
- ) -> _base_call.StreamStreamCall:
- metadata = self._init_metadata(metadata, compression)
- deadline = _timeout_to_deadline(timeout)
- if not self._interceptors:
- call = StreamStreamCall(request_iterator, deadline, metadata,
- credentials, wait_for_ready, self._channel,
- self._method, self._request_serializer,
- self._response_deserializer, self._loop)
- else:
- call = InterceptedStreamStreamCall(
- self._interceptors, request_iterator, deadline, metadata,
- credentials, wait_for_ready, self._channel, self._method,
- self._request_serializer, self._response_deserializer,
- self._loop)
- return call
- class Channel(_base_channel.Channel):
- _loop: asyncio.AbstractEventLoop
- _channel: cygrpc.AioChannel
- _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
- _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
- _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
- _stream_stream_interceptors: List[StreamStreamClientInterceptor]
- def __init__(self, target: str, options: ChannelArgumentType,
- credentials: Optional[grpc.ChannelCredentials],
- compression: Optional[grpc.Compression],
- interceptors: Optional[Sequence[ClientInterceptor]]):
- """Constructor.
- Args:
- target: The target to which to connect.
- options: Configuration options for the channel.
- credentials: A cygrpc.ChannelCredentials or None.
- compression: An optional value indicating the compression method to be
- used over the lifetime of the channel.
- interceptors: An optional list of interceptors that would be used for
- intercepting any RPC executed with that channel.
- """
- self._unary_unary_interceptors = []
- self._unary_stream_interceptors = []
- self._stream_unary_interceptors = []
- self._stream_stream_interceptors = []
- if interceptors is not None:
- for interceptor in interceptors:
- if isinstance(interceptor, UnaryUnaryClientInterceptor):
- self._unary_unary_interceptors.append(interceptor)
- elif isinstance(interceptor, UnaryStreamClientInterceptor):
- self._unary_stream_interceptors.append(interceptor)
- elif isinstance(interceptor, StreamUnaryClientInterceptor):
- self._stream_unary_interceptors.append(interceptor)
- elif isinstance(interceptor, StreamStreamClientInterceptor):
- self._stream_stream_interceptors.append(interceptor)
- else:
- raise ValueError(
- "Interceptor {} must be ".format(interceptor) +
- "{} or ".format(UnaryUnaryClientInterceptor.__name__) +
- "{} or ".format(UnaryStreamClientInterceptor.__name__) +
- "{} or ".format(StreamUnaryClientInterceptor.__name__) +
- "{}. ".format(StreamStreamClientInterceptor.__name__))
- self._loop = cygrpc.get_working_loop()
- self._channel = cygrpc.AioChannel(
- _common.encode(target),
- _augment_channel_arguments(options, compression), credentials,
- self._loop)
- async def __aenter__(self):
- return self
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._close(None)
- async def _close(self, grace): # pylint: disable=too-many-branches
- if self._channel.closed():
- return
- # No new calls will be accepted by the Cython channel.
- self._channel.closing()
- # Iterate through running tasks
- tasks = _all_tasks()
- calls = []
- call_tasks = []
- for task in tasks:
- try:
- stack = task.get_stack(limit=1)
- except AttributeError as attribute_error:
- # NOTE(lidiz) tl;dr: If the Task is created with a CPython
- # object, it will trigger AttributeError.
- #
- # In the global finalizer, the event loop schedules
- # a CPython PyAsyncGenAThrow object.
- # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
- #
- # However, the PyAsyncGenAThrow object is written in C and
- # failed to include the normal Python frame objects. Hence,
- # this exception is a false negative, and it is safe to ignore
- # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
- # but not available until 3.9 or 3.8.3. So, we have to keep it
- # for a while.
- # TODO(lidiz) drop this hack after 3.8 deprecation
- if 'frame' in str(attribute_error):
- continue
- else:
- raise
- # If the Task is created by a C-extension, the stack will be empty.
- if not stack:
- continue
- # Locate ones created by `aio.Call`.
- frame = stack[0]
- candidate = frame.f_locals.get('self')
- if candidate:
- if isinstance(candidate, _base_call.Call):
- if hasattr(candidate, '_channel'):
- # For intercepted Call object
- if candidate._channel is not self._channel:
- continue
- elif hasattr(candidate, '_cython_call'):
- # For normal Call object
- if candidate._cython_call._channel is not self._channel:
- continue
- else:
- # Unidentified Call object
- raise cygrpc.InternalError(
- f'Unrecognized call object: {candidate}')
- calls.append(candidate)
- call_tasks.append(task)
- # If needed, try to wait for them to finish.
- # Call objects are not always awaitables.
- if grace and call_tasks:
- await asyncio.wait(call_tasks, timeout=grace)
- # Time to cancel existing calls.
- for call in calls:
- call.cancel()
- # Destroy the channel
- self._channel.close()
- async def close(self, grace: Optional[float] = None):
- await self._close(grace)
- def __del__(self):
- if hasattr(self, '_channel'):
- if not self._channel.closed():
- self._channel.close()
- def get_state(self,
- try_to_connect: bool = False) -> grpc.ChannelConnectivity:
- result = self._channel.check_connectivity_state(try_to_connect)
- return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
- async def wait_for_state_change(
- self,
- last_observed_state: grpc.ChannelConnectivity,
- ) -> None:
- assert await self._channel.watch_connectivity_state(
- last_observed_state.value[0], None)
- async def channel_ready(self) -> None:
- state = self.get_state(try_to_connect=True)
- while state != grpc.ChannelConnectivity.READY:
- await self.wait_for_state_change(state)
- state = self.get_state(try_to_connect=True)
- def unary_unary(
- self,
- method: str,
- request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None
- ) -> UnaryUnaryMultiCallable:
- return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
- request_serializer,
- response_deserializer,
- self._unary_unary_interceptors, [self],
- self._loop)
- def unary_stream(
- self,
- method: str,
- request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None
- ) -> UnaryStreamMultiCallable:
- return UnaryStreamMultiCallable(self._channel, _common.encode(method),
- request_serializer,
- response_deserializer,
- self._unary_stream_interceptors, [self],
- self._loop)
- def stream_unary(
- self,
- method: str,
- request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None
- ) -> StreamUnaryMultiCallable:
- return StreamUnaryMultiCallable(self._channel, _common.encode(method),
- request_serializer,
- response_deserializer,
- self._stream_unary_interceptors, [self],
- self._loop)
- def stream_stream(
- self,
- method: str,
- request_serializer: Optional[SerializingFunction] = None,
- response_deserializer: Optional[DeserializingFunction] = None
- ) -> StreamStreamMultiCallable:
- return StreamStreamMultiCallable(self._channel, _common.encode(method),
- request_serializer,
- response_deserializer,
- self._stream_stream_interceptors,
- [self], self._loop)
- def insecure_channel(
- target: str,
- options: Optional[ChannelArgumentType] = None,
- compression: Optional[grpc.Compression] = None,
- interceptors: Optional[Sequence[ClientInterceptor]] = None):
- """Creates an insecure asynchronous Channel to a server.
- Args:
- target: The server address
- options: An optional list of key-value pairs (:term:`channel_arguments`
- in gRPC Core runtime) to configure the channel.
- compression: An optional value indicating the compression method to be
- used over the lifetime of the channel.
- interceptors: An optional sequence of interceptors that will be executed for
- any call executed with this channel.
- Returns:
- A Channel.
- """
- return Channel(target, () if options is None else options, None,
- compression, interceptors)
- def secure_channel(target: str,
- credentials: grpc.ChannelCredentials,
- options: Optional[ChannelArgumentType] = None,
- compression: Optional[grpc.Compression] = None,
- interceptors: Optional[Sequence[ClientInterceptor]] = None):
- """Creates a secure asynchronous Channel to a server.
- Args:
- target: The server address.
- credentials: A ChannelCredentials instance.
- options: An optional list of key-value pairs (:term:`channel_arguments`
- in gRPC Core runtime) to configure the channel.
- compression: An optional value indicating the compression method to be
- used over the lifetime of the channel.
- interceptors: An optional sequence of interceptors that will be executed for
- any call executed with this channel.
- Returns:
- An aio.Channel.
- """
- return Channel(target, () if options is None else options,
- credentials._credentials, compression, interceptors)
|