test_gdch_credentials.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright 2022 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import datetime
  16. import json
  17. import os
  18. import mock
  19. import pytest # type: ignore
  20. import requests
  21. from google.auth import exceptions
  22. from google.auth import jwt
  23. import google.auth.transport.requests
  24. from google.oauth2 import gdch_credentials
  25. from google.oauth2.gdch_credentials import ServiceAccountCredentials
  26. import yatest.common as yc
  27. class TestServiceAccountCredentials(object):
  28. AUDIENCE = "https://service-identity.<Domain>/authenticate"
  29. PROJECT = "project_foo"
  30. PRIVATE_KEY_ID = "key_foo"
  31. NAME = "service_identity_name"
  32. CA_CERT_PATH = "/path/to/ca/cert"
  33. TOKEN_URI = "https://service-identity.<Domain>/authenticate"
  34. JSON_PATH = os.path.join(
  35. os.path.dirname(yc.source_path(__file__)), "..", "data", "gdch_service_account.json"
  36. )
  37. with open(JSON_PATH, "rb") as fh:
  38. INFO = json.load(fh)
  39. def test_with_gdch_audience(self):
  40. mock_signer = mock.Mock()
  41. creds = ServiceAccountCredentials._from_signer_and_info(mock_signer, self.INFO)
  42. assert creds._signer == mock_signer
  43. assert creds._service_identity_name == self.NAME
  44. assert creds._audience is None
  45. assert creds._token_uri == self.TOKEN_URI
  46. assert creds._ca_cert_path == self.CA_CERT_PATH
  47. new_creds = creds.with_gdch_audience(self.AUDIENCE)
  48. assert new_creds._signer == mock_signer
  49. assert new_creds._service_identity_name == self.NAME
  50. assert new_creds._audience == self.AUDIENCE
  51. assert new_creds._token_uri == self.TOKEN_URI
  52. assert new_creds._ca_cert_path == self.CA_CERT_PATH
  53. def test__create_jwt(self):
  54. creds = ServiceAccountCredentials.from_service_account_file(self.JSON_PATH)
  55. with mock.patch("google.auth._helpers.utcnow") as utcnow:
  56. utcnow.return_value = datetime.datetime.now()
  57. jwt_token = creds._create_jwt()
  58. header, payload, _, _ = jwt._unverified_decode(jwt_token)
  59. expected_iss_sub_value = (
  60. "system:serviceaccount:project_foo:service_identity_name"
  61. )
  62. assert isinstance(jwt_token, str)
  63. assert header["alg"] == "ES256"
  64. assert header["kid"] == self.PRIVATE_KEY_ID
  65. assert payload["iss"] == expected_iss_sub_value
  66. assert payload["sub"] == expected_iss_sub_value
  67. assert payload["aud"] == self.AUDIENCE
  68. assert payload["exp"] == (payload["iat"] + 3600)
  69. @mock.patch(
  70. "google.oauth2.gdch_credentials.ServiceAccountCredentials._create_jwt",
  71. autospec=True,
  72. )
  73. @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True)
  74. def test_refresh(self, token_endpoint_request, create_jwt):
  75. creds = ServiceAccountCredentials.from_service_account_info(self.INFO)
  76. creds = creds.with_gdch_audience(self.AUDIENCE)
  77. req = google.auth.transport.requests.Request()
  78. mock_jwt_token = "jwt token"
  79. create_jwt.return_value = mock_jwt_token
  80. sts_token = "STS token"
  81. token_endpoint_request.return_value = {
  82. "access_token": sts_token,
  83. "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
  84. "token_type": "Bearer",
  85. "expires_in": 3600,
  86. }
  87. creds.refresh(req)
  88. token_endpoint_request.assert_called_with(
  89. req,
  90. self.TOKEN_URI,
  91. {
  92. "grant_type": gdch_credentials.TOKEN_EXCHANGE_TYPE,
  93. "audience": self.AUDIENCE,
  94. "requested_token_type": gdch_credentials.ACCESS_TOKEN_TOKEN_TYPE,
  95. "subject_token": mock_jwt_token,
  96. "subject_token_type": gdch_credentials.SERVICE_ACCOUNT_TOKEN_TYPE,
  97. },
  98. access_token=None,
  99. use_json=True,
  100. verify=self.CA_CERT_PATH,
  101. )
  102. assert creds.token == sts_token
  103. def test_refresh_wrong_requests_object(self):
  104. creds = ServiceAccountCredentials.from_service_account_info(self.INFO)
  105. creds = creds.with_gdch_audience(self.AUDIENCE)
  106. req = requests.Request()
  107. with pytest.raises(exceptions.RefreshError) as excinfo:
  108. creds.refresh(req)
  109. assert excinfo.match(
  110. "request must be a google.auth.transport.requests.Request object"
  111. )
  112. def test__from_signer_and_info_wrong_format_version(self):
  113. with pytest.raises(ValueError) as excinfo:
  114. ServiceAccountCredentials._from_signer_and_info(
  115. mock.Mock(), {"format_version": "2"}
  116. )
  117. assert excinfo.match("Only format version 1 is supported")
  118. def test_from_service_account_info_miss_field(self):
  119. for field in [
  120. "format_version",
  121. "private_key_id",
  122. "private_key",
  123. "name",
  124. "project",
  125. "token_uri",
  126. ]:
  127. info_with_missing_field = copy.deepcopy(self.INFO)
  128. del info_with_missing_field[field]
  129. with pytest.raises(ValueError) as excinfo:
  130. ServiceAccountCredentials.from_service_account_info(
  131. info_with_missing_field
  132. )
  133. assert excinfo.match("missing fields")
  134. @mock.patch("google.auth._service_account_info.from_filename")
  135. def test_from_service_account_file(self, from_filename):
  136. mock_signer = mock.Mock()
  137. from_filename.return_value = (self.INFO, mock_signer)
  138. creds = ServiceAccountCredentials.from_service_account_file(self.JSON_PATH)
  139. from_filename.assert_called_with(
  140. self.JSON_PATH,
  141. require=[
  142. "format_version",
  143. "private_key_id",
  144. "private_key",
  145. "name",
  146. "project",
  147. "token_uri",
  148. ],
  149. use_rsa_signer=False,
  150. )
  151. assert creds._signer == mock_signer
  152. assert creds._service_identity_name == self.NAME
  153. assert creds._audience is None
  154. assert creds._token_uri == self.TOKEN_URI
  155. assert creds._ca_cert_path == self.CA_CERT_PATH