_channel.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. # Copyright 2019 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Invocation-side implementation of gRPC Asyncio Python."""
  15. import asyncio
  16. import sys
  17. from typing import Any, Iterable, List, Optional, Sequence
  18. import grpc
  19. from grpc import _common
  20. from grpc import _compression
  21. from grpc import _grpcio_metadata
  22. from grpc._cython import cygrpc
  23. from . import _base_call
  24. from . import _base_channel
  25. from ._call import StreamStreamCall
  26. from ._call import StreamUnaryCall
  27. from ._call import UnaryStreamCall
  28. from ._call import UnaryUnaryCall
  29. from ._interceptor import ClientInterceptor
  30. from ._interceptor import InterceptedStreamStreamCall
  31. from ._interceptor import InterceptedStreamUnaryCall
  32. from ._interceptor import InterceptedUnaryStreamCall
  33. from ._interceptor import InterceptedUnaryUnaryCall
  34. from ._interceptor import StreamStreamClientInterceptor
  35. from ._interceptor import StreamUnaryClientInterceptor
  36. from ._interceptor import UnaryStreamClientInterceptor
  37. from ._interceptor import UnaryUnaryClientInterceptor
  38. from ._metadata import Metadata
  39. from ._typing import ChannelArgumentType
  40. from ._typing import DeserializingFunction
  41. from ._typing import RequestIterableType
  42. from ._typing import SerializingFunction
  43. from ._utils import _timeout_to_deadline
  44. _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__)
  45. if sys.version_info[1] < 7:
  46. def _all_tasks() -> Iterable[asyncio.Task]:
  47. return asyncio.Task.all_tasks()
  48. else:
  49. def _all_tasks() -> Iterable[asyncio.Task]:
  50. return asyncio.all_tasks()
  51. def _augment_channel_arguments(base_options: ChannelArgumentType,
  52. compression: Optional[grpc.Compression]):
  53. compression_channel_argument = _compression.create_channel_option(
  54. compression)
  55. user_agent_channel_argument = ((
  56. cygrpc.ChannelArgKey.primary_user_agent_string,
  57. _USER_AGENT,
  58. ),)
  59. return tuple(base_options
  60. ) + compression_channel_argument + user_agent_channel_argument
  61. class _BaseMultiCallable:
  62. """Base class of all multi callable objects.
  63. Handles the initialization logic and stores common attributes.
  64. """
  65. _loop: asyncio.AbstractEventLoop
  66. _channel: cygrpc.AioChannel
  67. _method: bytes
  68. _request_serializer: SerializingFunction
  69. _response_deserializer: DeserializingFunction
  70. _interceptors: Optional[Sequence[ClientInterceptor]]
  71. _references: List[Any]
  72. _loop: asyncio.AbstractEventLoop
  73. # pylint: disable=too-many-arguments
  74. def __init__(
  75. self,
  76. channel: cygrpc.AioChannel,
  77. method: bytes,
  78. request_serializer: SerializingFunction,
  79. response_deserializer: DeserializingFunction,
  80. interceptors: Optional[Sequence[ClientInterceptor]],
  81. references: List[Any],
  82. loop: asyncio.AbstractEventLoop,
  83. ) -> None:
  84. self._loop = loop
  85. self._channel = channel
  86. self._method = method
  87. self._request_serializer = request_serializer
  88. self._response_deserializer = response_deserializer
  89. self._interceptors = interceptors
  90. self._references = references
  91. @staticmethod
  92. def _init_metadata(
  93. metadata: Optional[Metadata] = None,
  94. compression: Optional[grpc.Compression] = None) -> Metadata:
  95. """Based on the provided values for <metadata> or <compression> initialise the final
  96. metadata, as it should be used for the current call.
  97. """
  98. metadata = metadata or Metadata()
  99. if compression:
  100. metadata = Metadata(
  101. *_compression.augment_metadata(metadata, compression))
  102. return metadata
  103. class UnaryUnaryMultiCallable(_BaseMultiCallable,
  104. _base_channel.UnaryUnaryMultiCallable):
  105. def __call__(
  106. self,
  107. request: Any,
  108. *,
  109. timeout: Optional[float] = None,
  110. metadata: Optional[Metadata] = None,
  111. credentials: Optional[grpc.CallCredentials] = None,
  112. wait_for_ready: Optional[bool] = None,
  113. compression: Optional[grpc.Compression] = None
  114. ) -> _base_call.UnaryUnaryCall:
  115. metadata = self._init_metadata(metadata, compression)
  116. if not self._interceptors:
  117. call = UnaryUnaryCall(request, _timeout_to_deadline(timeout),
  118. metadata, credentials, wait_for_ready,
  119. self._channel, self._method,
  120. self._request_serializer,
  121. self._response_deserializer, self._loop)
  122. else:
  123. call = InterceptedUnaryUnaryCall(
  124. self._interceptors, request, timeout, metadata, credentials,
  125. wait_for_ready, self._channel, self._method,
  126. self._request_serializer, self._response_deserializer,
  127. self._loop)
  128. return call
  129. class UnaryStreamMultiCallable(_BaseMultiCallable,
  130. _base_channel.UnaryStreamMultiCallable):
  131. def __call__(
  132. self,
  133. request: Any,
  134. *,
  135. timeout: Optional[float] = None,
  136. metadata: Optional[Metadata] = None,
  137. credentials: Optional[grpc.CallCredentials] = None,
  138. wait_for_ready: Optional[bool] = None,
  139. compression: Optional[grpc.Compression] = None
  140. ) -> _base_call.UnaryStreamCall:
  141. metadata = self._init_metadata(metadata, compression)
  142. deadline = _timeout_to_deadline(timeout)
  143. if not self._interceptors:
  144. call = UnaryStreamCall(request, deadline, metadata, credentials,
  145. wait_for_ready, self._channel, self._method,
  146. self._request_serializer,
  147. self._response_deserializer, self._loop)
  148. else:
  149. call = InterceptedUnaryStreamCall(
  150. self._interceptors, request, deadline, metadata, credentials,
  151. wait_for_ready, self._channel, self._method,
  152. self._request_serializer, self._response_deserializer,
  153. self._loop)
  154. return call
  155. class StreamUnaryMultiCallable(_BaseMultiCallable,
  156. _base_channel.StreamUnaryMultiCallable):
  157. def __call__(
  158. self,
  159. request_iterator: Optional[RequestIterableType] = None,
  160. timeout: Optional[float] = None,
  161. metadata: Optional[Metadata] = None,
  162. credentials: Optional[grpc.CallCredentials] = None,
  163. wait_for_ready: Optional[bool] = None,
  164. compression: Optional[grpc.Compression] = None
  165. ) -> _base_call.StreamUnaryCall:
  166. metadata = self._init_metadata(metadata, compression)
  167. deadline = _timeout_to_deadline(timeout)
  168. if not self._interceptors:
  169. call = StreamUnaryCall(request_iterator, deadline, metadata,
  170. credentials, wait_for_ready, self._channel,
  171. self._method, self._request_serializer,
  172. self._response_deserializer, self._loop)
  173. else:
  174. call = InterceptedStreamUnaryCall(
  175. self._interceptors, request_iterator, deadline, metadata,
  176. credentials, wait_for_ready, self._channel, self._method,
  177. self._request_serializer, self._response_deserializer,
  178. self._loop)
  179. return call
  180. class StreamStreamMultiCallable(_BaseMultiCallable,
  181. _base_channel.StreamStreamMultiCallable):
  182. def __call__(
  183. self,
  184. request_iterator: Optional[RequestIterableType] = None,
  185. timeout: Optional[float] = None,
  186. metadata: Optional[Metadata] = None,
  187. credentials: Optional[grpc.CallCredentials] = None,
  188. wait_for_ready: Optional[bool] = None,
  189. compression: Optional[grpc.Compression] = None
  190. ) -> _base_call.StreamStreamCall:
  191. metadata = self._init_metadata(metadata, compression)
  192. deadline = _timeout_to_deadline(timeout)
  193. if not self._interceptors:
  194. call = StreamStreamCall(request_iterator, deadline, metadata,
  195. credentials, wait_for_ready, self._channel,
  196. self._method, self._request_serializer,
  197. self._response_deserializer, self._loop)
  198. else:
  199. call = InterceptedStreamStreamCall(
  200. self._interceptors, request_iterator, deadline, metadata,
  201. credentials, wait_for_ready, self._channel, self._method,
  202. self._request_serializer, self._response_deserializer,
  203. self._loop)
  204. return call
  205. class Channel(_base_channel.Channel):
  206. _loop: asyncio.AbstractEventLoop
  207. _channel: cygrpc.AioChannel
  208. _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
  209. _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
  210. _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
  211. _stream_stream_interceptors: List[StreamStreamClientInterceptor]
  212. def __init__(self, target: str, options: ChannelArgumentType,
  213. credentials: Optional[grpc.ChannelCredentials],
  214. compression: Optional[grpc.Compression],
  215. interceptors: Optional[Sequence[ClientInterceptor]]):
  216. """Constructor.
  217. Args:
  218. target: The target to which to connect.
  219. options: Configuration options for the channel.
  220. credentials: A cygrpc.ChannelCredentials or None.
  221. compression: An optional value indicating the compression method to be
  222. used over the lifetime of the channel.
  223. interceptors: An optional list of interceptors that would be used for
  224. intercepting any RPC executed with that channel.
  225. """
  226. self._unary_unary_interceptors = []
  227. self._unary_stream_interceptors = []
  228. self._stream_unary_interceptors = []
  229. self._stream_stream_interceptors = []
  230. if interceptors is not None:
  231. for interceptor in interceptors:
  232. if isinstance(interceptor, UnaryUnaryClientInterceptor):
  233. self._unary_unary_interceptors.append(interceptor)
  234. elif isinstance(interceptor, UnaryStreamClientInterceptor):
  235. self._unary_stream_interceptors.append(interceptor)
  236. elif isinstance(interceptor, StreamUnaryClientInterceptor):
  237. self._stream_unary_interceptors.append(interceptor)
  238. elif isinstance(interceptor, StreamStreamClientInterceptor):
  239. self._stream_stream_interceptors.append(interceptor)
  240. else:
  241. raise ValueError(
  242. "Interceptor {} must be ".format(interceptor) +
  243. "{} or ".format(UnaryUnaryClientInterceptor.__name__) +
  244. "{} or ".format(UnaryStreamClientInterceptor.__name__) +
  245. "{} or ".format(StreamUnaryClientInterceptor.__name__) +
  246. "{}. ".format(StreamStreamClientInterceptor.__name__))
  247. self._loop = cygrpc.get_working_loop()
  248. self._channel = cygrpc.AioChannel(
  249. _common.encode(target),
  250. _augment_channel_arguments(options, compression), credentials,
  251. self._loop)
  252. async def __aenter__(self):
  253. return self
  254. async def __aexit__(self, exc_type, exc_val, exc_tb):
  255. await self._close(None)
  256. async def _close(self, grace): # pylint: disable=too-many-branches
  257. if self._channel.closed():
  258. return
  259. # No new calls will be accepted by the Cython channel.
  260. self._channel.closing()
  261. # Iterate through running tasks
  262. tasks = _all_tasks()
  263. calls = []
  264. call_tasks = []
  265. for task in tasks:
  266. try:
  267. stack = task.get_stack(limit=1)
  268. except AttributeError as attribute_error:
  269. # NOTE(lidiz) tl;dr: If the Task is created with a CPython
  270. # object, it will trigger AttributeError.
  271. #
  272. # In the global finalizer, the event loop schedules
  273. # a CPython PyAsyncGenAThrow object.
  274. # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
  275. #
  276. # However, the PyAsyncGenAThrow object is written in C and
  277. # failed to include the normal Python frame objects. Hence,
  278. # this exception is a false negative, and it is safe to ignore
  279. # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
  280. # but not available until 3.9 or 3.8.3. So, we have to keep it
  281. # for a while.
  282. # TODO(lidiz) drop this hack after 3.8 deprecation
  283. if 'frame' in str(attribute_error):
  284. continue
  285. else:
  286. raise
  287. # If the Task is created by a C-extension, the stack will be empty.
  288. if not stack:
  289. continue
  290. # Locate ones created by `aio.Call`.
  291. frame = stack[0]
  292. candidate = frame.f_locals.get('self')
  293. if candidate:
  294. if isinstance(candidate, _base_call.Call):
  295. if hasattr(candidate, '_channel'):
  296. # For intercepted Call object
  297. if candidate._channel is not self._channel:
  298. continue
  299. elif hasattr(candidate, '_cython_call'):
  300. # For normal Call object
  301. if candidate._cython_call._channel is not self._channel:
  302. continue
  303. else:
  304. # Unidentified Call object
  305. raise cygrpc.InternalError(
  306. f'Unrecognized call object: {candidate}')
  307. calls.append(candidate)
  308. call_tasks.append(task)
  309. # If needed, try to wait for them to finish.
  310. # Call objects are not always awaitables.
  311. if grace and call_tasks:
  312. await asyncio.wait(call_tasks, timeout=grace)
  313. # Time to cancel existing calls.
  314. for call in calls:
  315. call.cancel()
  316. # Destroy the channel
  317. self._channel.close()
  318. async def close(self, grace: Optional[float] = None):
  319. await self._close(grace)
  320. def __del__(self):
  321. if hasattr(self, '_channel'):
  322. if not self._channel.closed():
  323. self._channel.close()
  324. def get_state(self,
  325. try_to_connect: bool = False) -> grpc.ChannelConnectivity:
  326. result = self._channel.check_connectivity_state(try_to_connect)
  327. return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
  328. async def wait_for_state_change(
  329. self,
  330. last_observed_state: grpc.ChannelConnectivity,
  331. ) -> None:
  332. assert await self._channel.watch_connectivity_state(
  333. last_observed_state.value[0], None)
  334. async def channel_ready(self) -> None:
  335. state = self.get_state(try_to_connect=True)
  336. while state != grpc.ChannelConnectivity.READY:
  337. await self.wait_for_state_change(state)
  338. state = self.get_state(try_to_connect=True)
  339. def unary_unary(
  340. self,
  341. method: str,
  342. request_serializer: Optional[SerializingFunction] = None,
  343. response_deserializer: Optional[DeserializingFunction] = None
  344. ) -> UnaryUnaryMultiCallable:
  345. return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
  346. request_serializer,
  347. response_deserializer,
  348. self._unary_unary_interceptors, [self],
  349. self._loop)
  350. def unary_stream(
  351. self,
  352. method: str,
  353. request_serializer: Optional[SerializingFunction] = None,
  354. response_deserializer: Optional[DeserializingFunction] = None
  355. ) -> UnaryStreamMultiCallable:
  356. return UnaryStreamMultiCallable(self._channel, _common.encode(method),
  357. request_serializer,
  358. response_deserializer,
  359. self._unary_stream_interceptors, [self],
  360. self._loop)
  361. def stream_unary(
  362. self,
  363. method: str,
  364. request_serializer: Optional[SerializingFunction] = None,
  365. response_deserializer: Optional[DeserializingFunction] = None
  366. ) -> StreamUnaryMultiCallable:
  367. return StreamUnaryMultiCallable(self._channel, _common.encode(method),
  368. request_serializer,
  369. response_deserializer,
  370. self._stream_unary_interceptors, [self],
  371. self._loop)
  372. def stream_stream(
  373. self,
  374. method: str,
  375. request_serializer: Optional[SerializingFunction] = None,
  376. response_deserializer: Optional[DeserializingFunction] = None
  377. ) -> StreamStreamMultiCallable:
  378. return StreamStreamMultiCallable(self._channel, _common.encode(method),
  379. request_serializer,
  380. response_deserializer,
  381. self._stream_stream_interceptors,
  382. [self], self._loop)
  383. def insecure_channel(
  384. target: str,
  385. options: Optional[ChannelArgumentType] = None,
  386. compression: Optional[grpc.Compression] = None,
  387. interceptors: Optional[Sequence[ClientInterceptor]] = None):
  388. """Creates an insecure asynchronous Channel to a server.
  389. Args:
  390. target: The server address
  391. options: An optional list of key-value pairs (:term:`channel_arguments`
  392. in gRPC Core runtime) to configure the channel.
  393. compression: An optional value indicating the compression method to be
  394. used over the lifetime of the channel.
  395. interceptors: An optional sequence of interceptors that will be executed for
  396. any call executed with this channel.
  397. Returns:
  398. A Channel.
  399. """
  400. return Channel(target, () if options is None else options, None,
  401. compression, interceptors)
  402. def secure_channel(target: str,
  403. credentials: grpc.ChannelCredentials,
  404. options: Optional[ChannelArgumentType] = None,
  405. compression: Optional[grpc.Compression] = None,
  406. interceptors: Optional[Sequence[ClientInterceptor]] = None):
  407. """Creates a secure asynchronous Channel to a server.
  408. Args:
  409. target: The server address.
  410. credentials: A ChannelCredentials instance.
  411. options: An optional list of key-value pairs (:term:`channel_arguments`
  412. in gRPC Core runtime) to configure the channel.
  413. compression: An optional value indicating the compression method to be
  414. used over the lifetime of the channel.
  415. interceptors: An optional sequence of interceptors that will be executed for
  416. any call executed with this channel.
  417. Returns:
  418. An aio.Channel.
  419. """
  420. return Channel(target, () if options is None else options,
  421. credentials._credentials, compression, interceptors)