_auth.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright 2016 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """GRPCAuthMetadataPlugins for standard authentication."""
  15. import inspect
  16. from typing import Any, Optional
  17. import grpc
  18. def _sign_request(callback: grpc.AuthMetadataPluginCallback,
  19. token: Optional[str], error: Optional[Exception]):
  20. metadata = (('authorization', 'Bearer {}'.format(token)),)
  21. callback(metadata, error)
  22. class GoogleCallCredentials(grpc.AuthMetadataPlugin):
  23. """Metadata wrapper for GoogleCredentials from the oauth2client library."""
  24. _is_jwt: bool
  25. _credentials: Any
  26. # TODO(xuanwn): Give credentials an actual type.
  27. def __init__(self, credentials: Any):
  28. self._credentials = credentials
  29. # Hack to determine if these are JWT creds and we need to pass
  30. # additional_claims when getting a token
  31. self._is_jwt = 'additional_claims' in inspect.getfullargspec(
  32. credentials.get_access_token).args
  33. def __call__(self, context: grpc.AuthMetadataContext,
  34. callback: grpc.AuthMetadataPluginCallback):
  35. try:
  36. if self._is_jwt:
  37. access_token = self._credentials.get_access_token(
  38. additional_claims={
  39. 'aud':
  40. context.
  41. service_url # pytype: disable=attribute-error
  42. }).access_token
  43. else:
  44. access_token = self._credentials.get_access_token().access_token
  45. except Exception as exception: # pylint: disable=broad-except
  46. _sign_request(callback, None, exception)
  47. else:
  48. _sign_request(callback, access_token, None)
  49. class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
  50. """Metadata wrapper for raw access token credentials."""
  51. _access_token: str
  52. def __init__(self, access_token: str):
  53. self._access_token = access_token
  54. def __call__(self, context: grpc.AuthMetadataContext,
  55. callback: grpc.AuthMetadataPluginCallback):
  56. _sign_request(callback, self._access_token, None)