_interceptor.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. # Copyright 2017 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Implementation of gRPC Python interceptors."""
  15. import collections
  16. import sys
  17. import grpc
  18. class _ServicePipeline(object):
  19. def __init__(self, interceptors):
  20. self.interceptors = tuple(interceptors)
  21. def _continuation(self, thunk, index):
  22. return lambda context: self._intercept_at(thunk, index, context)
  23. def _intercept_at(self, thunk, index, context):
  24. if index < len(self.interceptors):
  25. interceptor = self.interceptors[index]
  26. thunk = self._continuation(thunk, index + 1)
  27. return interceptor.intercept_service(thunk, context)
  28. else:
  29. return thunk(context)
  30. def execute(self, thunk, context):
  31. return self._intercept_at(thunk, 0, context)
  32. def service_pipeline(interceptors):
  33. return _ServicePipeline(interceptors) if interceptors else None
  34. class _ClientCallDetails(
  35. collections.namedtuple('_ClientCallDetails',
  36. ('method', 'timeout', 'metadata', 'credentials',
  37. 'wait_for_ready', 'compression')),
  38. grpc.ClientCallDetails):
  39. pass
  40. def _unwrap_client_call_details(call_details, default_details):
  41. try:
  42. method = call_details.method
  43. except AttributeError:
  44. method = default_details.method
  45. try:
  46. timeout = call_details.timeout
  47. except AttributeError:
  48. timeout = default_details.timeout
  49. try:
  50. metadata = call_details.metadata
  51. except AttributeError:
  52. metadata = default_details.metadata
  53. try:
  54. credentials = call_details.credentials
  55. except AttributeError:
  56. credentials = default_details.credentials
  57. try:
  58. wait_for_ready = call_details.wait_for_ready
  59. except AttributeError:
  60. wait_for_ready = default_details.wait_for_ready
  61. try:
  62. compression = call_details.compression
  63. except AttributeError:
  64. compression = default_details.compression
  65. return method, timeout, metadata, credentials, wait_for_ready, compression
  66. class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
  67. def __init__(self, exception, traceback):
  68. super(_FailureOutcome, self).__init__()
  69. self._exception = exception
  70. self._traceback = traceback
  71. def initial_metadata(self):
  72. return None
  73. def trailing_metadata(self):
  74. return None
  75. def code(self):
  76. return grpc.StatusCode.INTERNAL
  77. def details(self):
  78. return 'Exception raised while intercepting the RPC'
  79. def cancel(self):
  80. return False
  81. def cancelled(self):
  82. return False
  83. def is_active(self):
  84. return False
  85. def time_remaining(self):
  86. return None
  87. def running(self):
  88. return False
  89. def done(self):
  90. return True
  91. def result(self, ignored_timeout=None):
  92. raise self._exception
  93. def exception(self, ignored_timeout=None):
  94. return self._exception
  95. def traceback(self, ignored_timeout=None):
  96. return self._traceback
  97. def add_callback(self, unused_callback):
  98. return False
  99. def add_done_callback(self, fn):
  100. fn(self)
  101. def __iter__(self):
  102. return self
  103. def __next__(self):
  104. raise self._exception
  105. def next(self):
  106. return self.__next__()
  107. class _UnaryOutcome(grpc.Call, grpc.Future):
  108. def __init__(self, response, call):
  109. self._response = response
  110. self._call = call
  111. def initial_metadata(self):
  112. return self._call.initial_metadata()
  113. def trailing_metadata(self):
  114. return self._call.trailing_metadata()
  115. def code(self):
  116. return self._call.code()
  117. def details(self):
  118. return self._call.details()
  119. def is_active(self):
  120. return self._call.is_active()
  121. def time_remaining(self):
  122. return self._call.time_remaining()
  123. def cancel(self):
  124. return self._call.cancel()
  125. def add_callback(self, callback):
  126. return self._call.add_callback(callback)
  127. def cancelled(self):
  128. return False
  129. def running(self):
  130. return False
  131. def done(self):
  132. return True
  133. def result(self, ignored_timeout=None):
  134. return self._response
  135. def exception(self, ignored_timeout=None):
  136. return None
  137. def traceback(self, ignored_timeout=None):
  138. return None
  139. def add_done_callback(self, fn):
  140. fn(self)
  141. class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
  142. def __init__(self, thunk, method, interceptor):
  143. self._thunk = thunk
  144. self._method = method
  145. self._interceptor = interceptor
  146. def __call__(self,
  147. request,
  148. timeout=None,
  149. metadata=None,
  150. credentials=None,
  151. wait_for_ready=None,
  152. compression=None):
  153. response, ignored_call = self._with_call(request,
  154. timeout=timeout,
  155. metadata=metadata,
  156. credentials=credentials,
  157. wait_for_ready=wait_for_ready,
  158. compression=compression)
  159. return response
  160. def _with_call(self,
  161. request,
  162. timeout=None,
  163. metadata=None,
  164. credentials=None,
  165. wait_for_ready=None,
  166. compression=None):
  167. client_call_details = _ClientCallDetails(self._method, timeout,
  168. metadata, credentials,
  169. wait_for_ready, compression)
  170. def continuation(new_details, request):
  171. (new_method, new_timeout, new_metadata, new_credentials,
  172. new_wait_for_ready,
  173. new_compression) = (_unwrap_client_call_details(
  174. new_details, client_call_details))
  175. try:
  176. response, call = self._thunk(new_method).with_call(
  177. request,
  178. timeout=new_timeout,
  179. metadata=new_metadata,
  180. credentials=new_credentials,
  181. wait_for_ready=new_wait_for_ready,
  182. compression=new_compression)
  183. return _UnaryOutcome(response, call)
  184. except grpc.RpcError as rpc_error:
  185. return rpc_error
  186. except Exception as exception: # pylint:disable=broad-except
  187. return _FailureOutcome(exception, sys.exc_info()[2])
  188. call = self._interceptor.intercept_unary_unary(continuation,
  189. client_call_details,
  190. request)
  191. return call.result(), call
  192. def with_call(self,
  193. request,
  194. timeout=None,
  195. metadata=None,
  196. credentials=None,
  197. wait_for_ready=None,
  198. compression=None):
  199. return self._with_call(request,
  200. timeout=timeout,
  201. metadata=metadata,
  202. credentials=credentials,
  203. wait_for_ready=wait_for_ready,
  204. compression=compression)
  205. def future(self,
  206. request,
  207. timeout=None,
  208. metadata=None,
  209. credentials=None,
  210. wait_for_ready=None,
  211. compression=None):
  212. client_call_details = _ClientCallDetails(self._method, timeout,
  213. metadata, credentials,
  214. wait_for_ready, compression)
  215. def continuation(new_details, request):
  216. (new_method, new_timeout, new_metadata, new_credentials,
  217. new_wait_for_ready,
  218. new_compression) = (_unwrap_client_call_details(
  219. new_details, client_call_details))
  220. return self._thunk(new_method).future(
  221. request,
  222. timeout=new_timeout,
  223. metadata=new_metadata,
  224. credentials=new_credentials,
  225. wait_for_ready=new_wait_for_ready,
  226. compression=new_compression)
  227. try:
  228. return self._interceptor.intercept_unary_unary(
  229. continuation, client_call_details, request)
  230. except Exception as exception: # pylint:disable=broad-except
  231. return _FailureOutcome(exception, sys.exc_info()[2])
  232. class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
  233. def __init__(self, thunk, method, interceptor):
  234. self._thunk = thunk
  235. self._method = method
  236. self._interceptor = interceptor
  237. def __call__(self,
  238. request,
  239. timeout=None,
  240. metadata=None,
  241. credentials=None,
  242. wait_for_ready=None,
  243. compression=None):
  244. client_call_details = _ClientCallDetails(self._method, timeout,
  245. metadata, credentials,
  246. wait_for_ready, compression)
  247. def continuation(new_details, request):
  248. (new_method, new_timeout, new_metadata, new_credentials,
  249. new_wait_for_ready,
  250. new_compression) = (_unwrap_client_call_details(
  251. new_details, client_call_details))
  252. return self._thunk(new_method)(request,
  253. timeout=new_timeout,
  254. metadata=new_metadata,
  255. credentials=new_credentials,
  256. wait_for_ready=new_wait_for_ready,
  257. compression=new_compression)
  258. try:
  259. return self._interceptor.intercept_unary_stream(
  260. continuation, client_call_details, request)
  261. except Exception as exception: # pylint:disable=broad-except
  262. return _FailureOutcome(exception, sys.exc_info()[2])
  263. class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
  264. def __init__(self, thunk, method, interceptor):
  265. self._thunk = thunk
  266. self._method = method
  267. self._interceptor = interceptor
  268. def __call__(self,
  269. request_iterator,
  270. timeout=None,
  271. metadata=None,
  272. credentials=None,
  273. wait_for_ready=None,
  274. compression=None):
  275. response, ignored_call = self._with_call(request_iterator,
  276. timeout=timeout,
  277. metadata=metadata,
  278. credentials=credentials,
  279. wait_for_ready=wait_for_ready,
  280. compression=compression)
  281. return response
  282. def _with_call(self,
  283. request_iterator,
  284. timeout=None,
  285. metadata=None,
  286. credentials=None,
  287. wait_for_ready=None,
  288. compression=None):
  289. client_call_details = _ClientCallDetails(self._method, timeout,
  290. metadata, credentials,
  291. wait_for_ready, compression)
  292. def continuation(new_details, request_iterator):
  293. (new_method, new_timeout, new_metadata, new_credentials,
  294. new_wait_for_ready,
  295. new_compression) = (_unwrap_client_call_details(
  296. new_details, client_call_details))
  297. try:
  298. response, call = self._thunk(new_method).with_call(
  299. request_iterator,
  300. timeout=new_timeout,
  301. metadata=new_metadata,
  302. credentials=new_credentials,
  303. wait_for_ready=new_wait_for_ready,
  304. compression=new_compression)
  305. return _UnaryOutcome(response, call)
  306. except grpc.RpcError as rpc_error:
  307. return rpc_error
  308. except Exception as exception: # pylint:disable=broad-except
  309. return _FailureOutcome(exception, sys.exc_info()[2])
  310. call = self._interceptor.intercept_stream_unary(continuation,
  311. client_call_details,
  312. request_iterator)
  313. return call.result(), call
  314. def with_call(self,
  315. request_iterator,
  316. timeout=None,
  317. metadata=None,
  318. credentials=None,
  319. wait_for_ready=None,
  320. compression=None):
  321. return self._with_call(request_iterator,
  322. timeout=timeout,
  323. metadata=metadata,
  324. credentials=credentials,
  325. wait_for_ready=wait_for_ready,
  326. compression=compression)
  327. def future(self,
  328. request_iterator,
  329. timeout=None,
  330. metadata=None,
  331. credentials=None,
  332. wait_for_ready=None,
  333. compression=None):
  334. client_call_details = _ClientCallDetails(self._method, timeout,
  335. metadata, credentials,
  336. wait_for_ready, compression)
  337. def continuation(new_details, request_iterator):
  338. (new_method, new_timeout, new_metadata, new_credentials,
  339. new_wait_for_ready,
  340. new_compression) = (_unwrap_client_call_details(
  341. new_details, client_call_details))
  342. return self._thunk(new_method).future(
  343. request_iterator,
  344. timeout=new_timeout,
  345. metadata=new_metadata,
  346. credentials=new_credentials,
  347. wait_for_ready=new_wait_for_ready,
  348. compression=new_compression)
  349. try:
  350. return self._interceptor.intercept_stream_unary(
  351. continuation, client_call_details, request_iterator)
  352. except Exception as exception: # pylint:disable=broad-except
  353. return _FailureOutcome(exception, sys.exc_info()[2])
  354. class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
  355. def __init__(self, thunk, method, interceptor):
  356. self._thunk = thunk
  357. self._method = method
  358. self._interceptor = interceptor
  359. def __call__(self,
  360. request_iterator,
  361. timeout=None,
  362. metadata=None,
  363. credentials=None,
  364. wait_for_ready=None,
  365. compression=None):
  366. client_call_details = _ClientCallDetails(self._method, timeout,
  367. metadata, credentials,
  368. wait_for_ready, compression)
  369. def continuation(new_details, request_iterator):
  370. (new_method, new_timeout, new_metadata, new_credentials,
  371. new_wait_for_ready,
  372. new_compression) = (_unwrap_client_call_details(
  373. new_details, client_call_details))
  374. return self._thunk(new_method)(request_iterator,
  375. timeout=new_timeout,
  376. metadata=new_metadata,
  377. credentials=new_credentials,
  378. wait_for_ready=new_wait_for_ready,
  379. compression=new_compression)
  380. try:
  381. return self._interceptor.intercept_stream_stream(
  382. continuation, client_call_details, request_iterator)
  383. except Exception as exception: # pylint:disable=broad-except
  384. return _FailureOutcome(exception, sys.exc_info()[2])
  385. class _Channel(grpc.Channel):
  386. def __init__(self, channel, interceptor):
  387. self._channel = channel
  388. self._interceptor = interceptor
  389. def subscribe(self, callback, try_to_connect=False):
  390. self._channel.subscribe(callback, try_to_connect=try_to_connect)
  391. def unsubscribe(self, callback):
  392. self._channel.unsubscribe(callback)
  393. def unary_unary(self,
  394. method,
  395. request_serializer=None,
  396. response_deserializer=None):
  397. thunk = lambda m: self._channel.unary_unary(m, request_serializer,
  398. response_deserializer)
  399. if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
  400. return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
  401. else:
  402. return thunk(method)
  403. def unary_stream(self,
  404. method,
  405. request_serializer=None,
  406. response_deserializer=None):
  407. thunk = lambda m: self._channel.unary_stream(m, request_serializer,
  408. response_deserializer)
  409. if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
  410. return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
  411. else:
  412. return thunk(method)
  413. def stream_unary(self,
  414. method,
  415. request_serializer=None,
  416. response_deserializer=None):
  417. thunk = lambda m: self._channel.stream_unary(m, request_serializer,
  418. response_deserializer)
  419. if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
  420. return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
  421. else:
  422. return thunk(method)
  423. def stream_stream(self,
  424. method,
  425. request_serializer=None,
  426. response_deserializer=None):
  427. thunk = lambda m: self._channel.stream_stream(m, request_serializer,
  428. response_deserializer)
  429. if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
  430. return _StreamStreamMultiCallable(thunk, method, self._interceptor)
  431. else:
  432. return thunk(method)
  433. def _close(self):
  434. self._channel.close()
  435. def __enter__(self):
  436. return self
  437. def __exit__(self, exc_type, exc_val, exc_tb):
  438. self._close()
  439. return False
  440. def close(self):
  441. self._channel.close()
  442. def intercept_channel(channel, *interceptors):
  443. for interceptor in reversed(list(interceptors)):
  444. if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
  445. not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
  446. not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
  447. not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
  448. raise TypeError('interceptor must be '
  449. 'grpc.UnaryUnaryClientInterceptor or '
  450. 'grpc.UnaryStreamClientInterceptor or '
  451. 'grpc.StreamUnaryClientInterceptor or '
  452. 'grpc.StreamStreamClientInterceptor or ')
  453. channel = _Channel(channel, interceptor)
  454. return channel