123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # Copyright 2015 gRPC authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import collections
- import logging
- import threading
- from typing import Callable, Optional, Type
- import grpc
- from grpc import _common
- from grpc._cython import cygrpc
- from grpc._typing import MetadataType
- _LOGGER = logging.getLogger(__name__)
- class _AuthMetadataContext(
- collections.namedtuple('AuthMetadataContext', (
- 'service_url',
- 'method_name',
- )), grpc.AuthMetadataContext):
- pass
- class _CallbackState(object):
- def __init__(self):
- self.lock = threading.Lock()
- self.called = False
- self.exception = None
- class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
- _state: _CallbackState
- _callback: Callable
- def __init__(self, state: _CallbackState, callback: Callable):
- self._state = state
- self._callback = callback
- def __call__(self, metadata: MetadataType,
- error: Optional[Type[BaseException]]):
- with self._state.lock:
- if self._state.exception is None:
- if self._state.called:
- raise RuntimeError(
- 'AuthMetadataPluginCallback invoked more than once!')
- else:
- self._state.called = True
- else:
- raise RuntimeError(
- 'AuthMetadataPluginCallback raised exception "{}"!'.format(
- self._state.exception))
- if error is None:
- self._callback(metadata, cygrpc.StatusCode.ok, None)
- else:
- self._callback(None, cygrpc.StatusCode.internal,
- _common.encode(str(error)))
- class _Plugin(object):
- _metadata_plugin: grpc.AuthMetadataPlugin
- def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin):
- self._metadata_plugin = metadata_plugin
- self._stored_ctx = None
- try:
- import contextvars # pylint: disable=wrong-import-position
- # The plugin may be invoked on a thread created by Core, which will not
- # have the context propagated. This context is stored and installed in
- # the thread invoking the plugin.
- self._stored_ctx = contextvars.copy_context()
- except ImportError:
- # Support versions predating contextvars.
- pass
- def __call__(self, service_url: str, method_name: str, callback: Callable):
- context = _AuthMetadataContext(_common.decode(service_url),
- _common.decode(method_name))
- callback_state = _CallbackState()
- try:
- self._metadata_plugin(
- context, _AuthMetadataPluginCallback(callback_state, callback))
- except Exception as exception: # pylint: disable=broad-except
- _LOGGER.exception(
- 'AuthMetadataPluginCallback "%s" raised exception!',
- self._metadata_plugin)
- with callback_state.lock:
- callback_state.exception = exception
- if callback_state.called:
- return
- callback(None, cygrpc.StatusCode.internal,
- _common.encode(str(exception)))
- def metadata_plugin_call_credentials(
- metadata_plugin: grpc.AuthMetadataPlugin,
- name: Optional[str]) -> grpc.CallCredentials:
- if name is None:
- try:
- effective_name = metadata_plugin.__name__
- except AttributeError:
- effective_name = metadata_plugin.__class__.__name__
- else:
- effective_name = name
- return grpc.CallCredentials(
- cygrpc.MetadataPluginCallCredentials(_Plugin(metadata_plugin),
- _common.encode(effective_name)))
|