_interceptor.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. # Copyright 2017 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Implementation of gRPC Python interceptors."""
  15. import collections
  16. import sys
  17. import types
  18. from typing import Any, Callable, Optional, Sequence, Tuple, Union
  19. import grpc
  20. from ._typing import DeserializingFunction
  21. from ._typing import DoneCallbackType
  22. from ._typing import MetadataType
  23. from ._typing import RequestIterableType
  24. from ._typing import SerializingFunction
  25. class _ServicePipeline(object):
  26. interceptors: Tuple[grpc.ServerInterceptor]
  27. def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]):
  28. self.interceptors = tuple(interceptors)
  29. def _continuation(self, thunk: Callable, index: int) -> Callable:
  30. return lambda context: self._intercept_at(thunk, index, context)
  31. def _intercept_at(
  32. self, thunk: Callable, index: int,
  33. context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
  34. if index < len(self.interceptors):
  35. interceptor = self.interceptors[index]
  36. thunk = self._continuation(thunk, index + 1)
  37. return interceptor.intercept_service(thunk, context)
  38. else:
  39. return thunk(context)
  40. def execute(self, thunk: Callable,
  41. context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
  42. return self._intercept_at(thunk, 0, context)
  43. def service_pipeline(
  44. interceptors: Optional[Sequence[grpc.ServerInterceptor]]
  45. ) -> Optional[_ServicePipeline]:
  46. return _ServicePipeline(interceptors) if interceptors else None
  47. class _ClientCallDetails(
  48. collections.namedtuple('_ClientCallDetails',
  49. ('method', 'timeout', 'metadata', 'credentials',
  50. 'wait_for_ready', 'compression')),
  51. grpc.ClientCallDetails):
  52. pass
  53. def _unwrap_client_call_details(
  54. call_details: grpc.ClientCallDetails,
  55. default_details: grpc.ClientCallDetails
  56. ) -> Tuple[str, float, MetadataType, grpc.CallCredentials, bool,
  57. grpc.Compression]:
  58. try:
  59. method = call_details.method # pytype: disable=attribute-error
  60. except AttributeError:
  61. method = default_details.method # pytype: disable=attribute-error
  62. try:
  63. timeout = call_details.timeout # pytype: disable=attribute-error
  64. except AttributeError:
  65. timeout = default_details.timeout # pytype: disable=attribute-error
  66. try:
  67. metadata = call_details.metadata # pytype: disable=attribute-error
  68. except AttributeError:
  69. metadata = default_details.metadata # pytype: disable=attribute-error
  70. try:
  71. credentials = call_details.credentials # pytype: disable=attribute-error
  72. except AttributeError:
  73. credentials = default_details.credentials # pytype: disable=attribute-error
  74. try:
  75. wait_for_ready = call_details.wait_for_ready # pytype: disable=attribute-error
  76. except AttributeError:
  77. wait_for_ready = default_details.wait_for_ready # pytype: disable=attribute-error
  78. try:
  79. compression = call_details.compression # pytype: disable=attribute-error
  80. except AttributeError:
  81. compression = default_details.compression # pytype: disable=attribute-error
  82. return method, timeout, metadata, credentials, wait_for_ready, compression
  83. class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
  84. _exception: Exception
  85. _traceback: types.TracebackType
  86. def __init__(self, exception: Exception, traceback: types.TracebackType):
  87. super(_FailureOutcome, self).__init__()
  88. self._exception = exception
  89. self._traceback = traceback
  90. def initial_metadata(self) -> Optional[MetadataType]:
  91. return None
  92. def trailing_metadata(self) -> Optional[MetadataType]:
  93. return None
  94. def code(self) -> Optional[grpc.StatusCode]:
  95. return grpc.StatusCode.INTERNAL
  96. def details(self) -> Optional[str]:
  97. return 'Exception raised while intercepting the RPC'
  98. def cancel(self) -> bool:
  99. return False
  100. def cancelled(self) -> bool:
  101. return False
  102. def is_active(self) -> bool:
  103. return False
  104. def time_remaining(self) -> Optional[float]:
  105. return None
  106. def running(self) -> bool:
  107. return False
  108. def done(self) -> bool:
  109. return True
  110. def result(self, ignored_timeout: Optional[float] = None):
  111. raise self._exception
  112. def exception(
  113. self,
  114. ignored_timeout: Optional[float] = None) -> Optional[Exception]:
  115. return self._exception
  116. def traceback(
  117. self,
  118. ignored_timeout: Optional[float] = None
  119. ) -> Optional[types.TracebackType]:
  120. return self._traceback
  121. def add_callback(self, unused_callback) -> bool:
  122. return False
  123. def add_done_callback(self, fn: DoneCallbackType) -> None:
  124. fn(self)
  125. def __iter__(self):
  126. return self
  127. def __next__(self):
  128. raise self._exception
  129. def next(self):
  130. return self.__next__()
  131. class _UnaryOutcome(grpc.Call, grpc.Future):
  132. _response: Any
  133. _call: grpc.Call
  134. def __init__(self, response: Any, call: grpc.Call):
  135. self._response = response
  136. self._call = call
  137. def initial_metadata(self) -> Optional[MetadataType]:
  138. return self._call.initial_metadata()
  139. def trailing_metadata(self) -> Optional[MetadataType]:
  140. return self._call.trailing_metadata()
  141. def code(self) -> Optional[grpc.StatusCode]:
  142. return self._call.code()
  143. def details(self) -> Optional[str]:
  144. return self._call.details()
  145. def is_active(self) -> bool:
  146. return self._call.is_active()
  147. def time_remaining(self) -> Optional[float]:
  148. return self._call.time_remaining()
  149. def cancel(self) -> bool:
  150. return self._call.cancel()
  151. def add_callback(self, callback) -> bool:
  152. return self._call.add_callback(callback)
  153. def cancelled(self) -> bool:
  154. return False
  155. def running(self) -> bool:
  156. return False
  157. def done(self) -> bool:
  158. return True
  159. def result(self, ignored_timeout: Optional[float] = None):
  160. return self._response
  161. def exception(self, ignored_timeout: Optional[float] = None):
  162. return None
  163. def traceback(self, ignored_timeout: Optional[float] = None):
  164. return None
  165. def add_done_callback(self, fn: DoneCallbackType) -> None:
  166. fn(self)
  167. class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
  168. _thunk: Callable
  169. _method: str
  170. _interceptor: grpc.UnaryUnaryClientInterceptor
  171. def __init__(self, thunk: Callable, method: str,
  172. interceptor: grpc.UnaryUnaryClientInterceptor):
  173. self._thunk = thunk
  174. self._method = method
  175. self._interceptor = interceptor
  176. def __call__(self,
  177. request: Any,
  178. timeout: Optional[float] = None,
  179. metadata: Optional[MetadataType] = None,
  180. credentials: Optional[grpc.CallCredentials] = None,
  181. wait_for_ready: Optional[bool] = None,
  182. compression: Optional[grpc.Compression] = None) -> Any:
  183. response, ignored_call = self._with_call(request,
  184. timeout=timeout,
  185. metadata=metadata,
  186. credentials=credentials,
  187. wait_for_ready=wait_for_ready,
  188. compression=compression)
  189. return response
  190. def _with_call(
  191. self,
  192. request: Any,
  193. timeout: Optional[float] = None,
  194. metadata: Optional[MetadataType] = None,
  195. credentials: Optional[grpc.CallCredentials] = None,
  196. wait_for_ready: Optional[bool] = None,
  197. compression: Optional[grpc.Compression] = None
  198. ) -> Tuple[Any, grpc.Call]:
  199. client_call_details = _ClientCallDetails(self._method, timeout,
  200. metadata, credentials,
  201. wait_for_ready, compression)
  202. def continuation(new_details, request):
  203. (new_method, new_timeout, new_metadata, new_credentials,
  204. new_wait_for_ready,
  205. new_compression) = (_unwrap_client_call_details(
  206. new_details, client_call_details))
  207. try:
  208. response, call = self._thunk(new_method).with_call(
  209. request,
  210. timeout=new_timeout,
  211. metadata=new_metadata,
  212. credentials=new_credentials,
  213. wait_for_ready=new_wait_for_ready,
  214. compression=new_compression)
  215. return _UnaryOutcome(response, call)
  216. except grpc.RpcError as rpc_error:
  217. return rpc_error
  218. except Exception as exception: # pylint:disable=broad-except
  219. return _FailureOutcome(exception, sys.exc_info()[2])
  220. call = self._interceptor.intercept_unary_unary(continuation,
  221. client_call_details,
  222. request)
  223. return call.result(), call
  224. def with_call(
  225. self,
  226. request: Any,
  227. timeout: Optional[float] = None,
  228. metadata: Optional[MetadataType] = None,
  229. credentials: Optional[grpc.CallCredentials] = None,
  230. wait_for_ready: Optional[bool] = None,
  231. compression: Optional[grpc.Compression] = None
  232. ) -> Tuple[Any, grpc.Call]:
  233. return self._with_call(request,
  234. timeout=timeout,
  235. metadata=metadata,
  236. credentials=credentials,
  237. wait_for_ready=wait_for_ready,
  238. compression=compression)
  239. def future(self,
  240. request: Any,
  241. timeout: Optional[float] = None,
  242. metadata: Optional[MetadataType] = None,
  243. credentials: Optional[grpc.CallCredentials] = None,
  244. wait_for_ready: Optional[bool] = None,
  245. compression: Optional[grpc.Compression] = None) -> Any:
  246. client_call_details = _ClientCallDetails(self._method, timeout,
  247. metadata, credentials,
  248. wait_for_ready, compression)
  249. def continuation(new_details, request):
  250. (new_method, new_timeout, new_metadata, new_credentials,
  251. new_wait_for_ready,
  252. new_compression) = (_unwrap_client_call_details(
  253. new_details, client_call_details))
  254. return self._thunk(new_method).future(
  255. request,
  256. timeout=new_timeout,
  257. metadata=new_metadata,
  258. credentials=new_credentials,
  259. wait_for_ready=new_wait_for_ready,
  260. compression=new_compression)
  261. try:
  262. return self._interceptor.intercept_unary_unary(
  263. continuation, client_call_details, request)
  264. except Exception as exception: # pylint:disable=broad-except
  265. return _FailureOutcome(exception, sys.exc_info()[2])
  266. class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
  267. _thunk: Callable
  268. _method: str
  269. _interceptor: grpc.UnaryStreamClientInterceptor
  270. def __init__(self, thunk: Callable, method: str,
  271. interceptor: grpc.UnaryStreamClientInterceptor):
  272. self._thunk = thunk
  273. self._method = method
  274. self._interceptor = interceptor
  275. def __call__(self,
  276. request: Any,
  277. timeout: Optional[float] = None,
  278. metadata: Optional[MetadataType] = None,
  279. credentials: Optional[grpc.CallCredentials] = None,
  280. wait_for_ready: Optional[bool] = None,
  281. compression: Optional[grpc.Compression] = None):
  282. client_call_details = _ClientCallDetails(self._method, timeout,
  283. metadata, credentials,
  284. wait_for_ready, compression)
  285. def continuation(new_details, request):
  286. (new_method, new_timeout, new_metadata, new_credentials,
  287. new_wait_for_ready,
  288. new_compression) = (_unwrap_client_call_details(
  289. new_details, client_call_details))
  290. return self._thunk(new_method)(request,
  291. timeout=new_timeout,
  292. metadata=new_metadata,
  293. credentials=new_credentials,
  294. wait_for_ready=new_wait_for_ready,
  295. compression=new_compression)
  296. try:
  297. return self._interceptor.intercept_unary_stream(
  298. continuation, client_call_details, request)
  299. except Exception as exception: # pylint:disable=broad-except
  300. return _FailureOutcome(exception, sys.exc_info()[2])
  301. class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
  302. _thunk: Callable
  303. _method: str
  304. _interceptor: grpc.StreamUnaryClientInterceptor
  305. def __init__(self, thunk: Callable, method: str,
  306. interceptor: grpc.StreamUnaryClientInterceptor):
  307. self._thunk = thunk
  308. self._method = method
  309. self._interceptor = interceptor
  310. def __call__(self,
  311. request_iterator: RequestIterableType,
  312. timeout: Optional[float] = None,
  313. metadata: Optional[MetadataType] = None,
  314. credentials: Optional[grpc.CallCredentials] = None,
  315. wait_for_ready: Optional[bool] = None,
  316. compression: Optional[grpc.Compression] = None) -> Any:
  317. response, ignored_call = self._with_call(request_iterator,
  318. timeout=timeout,
  319. metadata=metadata,
  320. credentials=credentials,
  321. wait_for_ready=wait_for_ready,
  322. compression=compression)
  323. return response
  324. def _with_call(
  325. self,
  326. request_iterator: RequestIterableType,
  327. timeout: Optional[float] = None,
  328. metadata: Optional[MetadataType] = None,
  329. credentials: Optional[grpc.CallCredentials] = None,
  330. wait_for_ready: Optional[bool] = None,
  331. compression: Optional[grpc.Compression] = None
  332. ) -> Tuple[Any, grpc.Call]:
  333. client_call_details = _ClientCallDetails(self._method, timeout,
  334. metadata, credentials,
  335. wait_for_ready, compression)
  336. def continuation(new_details, request_iterator):
  337. (new_method, new_timeout, new_metadata, new_credentials,
  338. new_wait_for_ready,
  339. new_compression) = (_unwrap_client_call_details(
  340. new_details, client_call_details))
  341. try:
  342. response, call = self._thunk(new_method).with_call(
  343. request_iterator,
  344. timeout=new_timeout,
  345. metadata=new_metadata,
  346. credentials=new_credentials,
  347. wait_for_ready=new_wait_for_ready,
  348. compression=new_compression)
  349. return _UnaryOutcome(response, call)
  350. except grpc.RpcError as rpc_error:
  351. return rpc_error
  352. except Exception as exception: # pylint:disable=broad-except
  353. return _FailureOutcome(exception, sys.exc_info()[2])
  354. call = self._interceptor.intercept_stream_unary(continuation,
  355. client_call_details,
  356. request_iterator)
  357. return call.result(), call
  358. def with_call(
  359. self,
  360. request_iterator: RequestIterableType,
  361. timeout: Optional[float] = None,
  362. metadata: Optional[MetadataType] = None,
  363. credentials: Optional[grpc.CallCredentials] = None,
  364. wait_for_ready: Optional[bool] = None,
  365. compression: Optional[grpc.Compression] = None
  366. ) -> Tuple[Any, grpc.Call]:
  367. return self._with_call(request_iterator,
  368. timeout=timeout,
  369. metadata=metadata,
  370. credentials=credentials,
  371. wait_for_ready=wait_for_ready,
  372. compression=compression)
  373. def future(self,
  374. request_iterator: RequestIterableType,
  375. timeout: Optional[float] = None,
  376. metadata: Optional[MetadataType] = None,
  377. credentials: Optional[grpc.CallCredentials] = None,
  378. wait_for_ready: Optional[bool] = None,
  379. compression: Optional[grpc.Compression] = None) -> Any:
  380. client_call_details = _ClientCallDetails(self._method, timeout,
  381. metadata, credentials,
  382. wait_for_ready, compression)
  383. def continuation(new_details, request_iterator):
  384. (new_method, new_timeout, new_metadata, new_credentials,
  385. new_wait_for_ready,
  386. new_compression) = (_unwrap_client_call_details(
  387. new_details, client_call_details))
  388. return self._thunk(new_method).future(
  389. request_iterator,
  390. timeout=new_timeout,
  391. metadata=new_metadata,
  392. credentials=new_credentials,
  393. wait_for_ready=new_wait_for_ready,
  394. compression=new_compression)
  395. try:
  396. return self._interceptor.intercept_stream_unary(
  397. continuation, client_call_details, request_iterator)
  398. except Exception as exception: # pylint:disable=broad-except
  399. return _FailureOutcome(exception, sys.exc_info()[2])
  400. class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
  401. _thunk: Callable
  402. _method: str
  403. _interceptor: grpc.StreamStreamClientInterceptor
  404. def __init__(self, thunk: Callable, method: str,
  405. interceptor: grpc.StreamStreamClientInterceptor):
  406. self._thunk = thunk
  407. self._method = method
  408. self._interceptor = interceptor
  409. def __call__(self,
  410. request_iterator: RequestIterableType,
  411. timeout: Optional[float] = None,
  412. metadata: Optional[MetadataType] = None,
  413. credentials: Optional[grpc.CallCredentials] = None,
  414. wait_for_ready: Optional[bool] = None,
  415. compression: Optional[grpc.Compression] = None):
  416. client_call_details = _ClientCallDetails(self._method, timeout,
  417. metadata, credentials,
  418. wait_for_ready, compression)
  419. def continuation(new_details, request_iterator):
  420. (new_method, new_timeout, new_metadata, new_credentials,
  421. new_wait_for_ready,
  422. new_compression) = (_unwrap_client_call_details(
  423. new_details, client_call_details))
  424. return self._thunk(new_method)(request_iterator,
  425. timeout=new_timeout,
  426. metadata=new_metadata,
  427. credentials=new_credentials,
  428. wait_for_ready=new_wait_for_ready,
  429. compression=new_compression)
  430. try:
  431. return self._interceptor.intercept_stream_stream(
  432. continuation, client_call_details, request_iterator)
  433. except Exception as exception: # pylint:disable=broad-except
  434. return _FailureOutcome(exception, sys.exc_info()[2])
  435. class _Channel(grpc.Channel):
  436. _channel: grpc.Channel
  437. _interceptor: Union[grpc.UnaryUnaryClientInterceptor,
  438. grpc.UnaryStreamClientInterceptor,
  439. grpc.StreamStreamClientInterceptor,
  440. grpc.StreamUnaryClientInterceptor]
  441. def __init__(self, channel: grpc.Channel,
  442. interceptor: Union[grpc.UnaryUnaryClientInterceptor,
  443. grpc.UnaryStreamClientInterceptor,
  444. grpc.StreamStreamClientInterceptor,
  445. grpc.StreamUnaryClientInterceptor]):
  446. self._channel = channel
  447. self._interceptor = interceptor
  448. def subscribe(self,
  449. callback: Callable,
  450. try_to_connect: Optional[bool] = False):
  451. self._channel.subscribe(callback, try_to_connect=try_to_connect)
  452. def unsubscribe(self, callback: Callable):
  453. self._channel.unsubscribe(callback)
  454. def unary_unary(
  455. self,
  456. method: str,
  457. request_serializer: Optional[SerializingFunction] = None,
  458. response_deserializer: Optional[DeserializingFunction] = None
  459. ) -> grpc.UnaryUnaryMultiCallable:
  460. thunk = lambda m: self._channel.unary_unary(m, request_serializer,
  461. response_deserializer)
  462. if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
  463. return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
  464. else:
  465. return thunk(method)
  466. def unary_stream(
  467. self,
  468. method: str,
  469. request_serializer: Optional[SerializingFunction] = None,
  470. response_deserializer: Optional[DeserializingFunction] = None
  471. ) -> grpc.UnaryStreamMultiCallable:
  472. thunk = lambda m: self._channel.unary_stream(m, request_serializer,
  473. response_deserializer)
  474. if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
  475. return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
  476. else:
  477. return thunk(method)
  478. def stream_unary(
  479. self,
  480. method: str,
  481. request_serializer: Optional[SerializingFunction] = None,
  482. response_deserializer: Optional[DeserializingFunction] = None
  483. ) -> grpc.StreamUnaryMultiCallable:
  484. thunk = lambda m: self._channel.stream_unary(m, request_serializer,
  485. response_deserializer)
  486. if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
  487. return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
  488. else:
  489. return thunk(method)
  490. def stream_stream(
  491. self,
  492. method: str,
  493. request_serializer: Optional[SerializingFunction] = None,
  494. response_deserializer: Optional[DeserializingFunction] = None
  495. ) -> grpc.StreamStreamMultiCallable:
  496. thunk = lambda m: self._channel.stream_stream(m, request_serializer,
  497. response_deserializer)
  498. if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
  499. return _StreamStreamMultiCallable(thunk, method, self._interceptor)
  500. else:
  501. return thunk(method)
  502. def _close(self):
  503. self._channel.close()
  504. def __enter__(self):
  505. return self
  506. def __exit__(self, exc_type, exc_val, exc_tb):
  507. self._close()
  508. return False
  509. def close(self):
  510. self._channel.close()
  511. def intercept_channel(
  512. channel: grpc.Channel,
  513. *interceptors: Optional[Sequence[Union[grpc.UnaryUnaryClientInterceptor,
  514. grpc.UnaryStreamClientInterceptor,
  515. grpc.StreamStreamClientInterceptor,
  516. grpc.StreamUnaryClientInterceptor]]]
  517. ) -> grpc.Channel:
  518. for interceptor in reversed(list(interceptors)):
  519. if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
  520. not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
  521. not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
  522. not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
  523. raise TypeError('interceptor must be '
  524. 'grpc.UnaryUnaryClientInterceptor or '
  525. 'grpc.UnaryStreamClientInterceptor or '
  526. 'grpc.StreamUnaryClientInterceptor or '
  527. 'grpc.StreamStreamClientInterceptor or ')
  528. channel = _Channel(channel, interceptor)
  529. return channel