test_service_account.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # Copyright 2016 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. import json
  16. import os
  17. import mock
  18. from google.auth import _helpers
  19. from google.auth import crypt
  20. from google.auth import jwt
  21. from google.auth import transport
  22. from google.oauth2 import service_account
  23. import yatest.common
  24. DATA_DIR = os.path.join(yatest.common.test_source_path(), "data")
  25. with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
  26. PRIVATE_KEY_BYTES = fh.read()
  27. with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
  28. PUBLIC_CERT_BYTES = fh.read()
  29. with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh:
  30. OTHER_CERT_BYTES = fh.read()
  31. SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json")
  32. with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
  33. SERVICE_ACCOUNT_INFO = json.load(fh)
  34. SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1")
  35. class TestCredentials(object):
  36. SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
  37. TOKEN_URI = "https://example.com/oauth2/token"
  38. @classmethod
  39. def make_credentials(cls):
  40. return service_account.Credentials(
  41. SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI
  42. )
  43. def test_from_service_account_info(self):
  44. credentials = service_account.Credentials.from_service_account_info(
  45. SERVICE_ACCOUNT_INFO
  46. )
  47. assert credentials._signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"]
  48. assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"]
  49. assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"]
  50. def test_from_service_account_info_args(self):
  51. info = SERVICE_ACCOUNT_INFO.copy()
  52. scopes = ["email", "profile"]
  53. subject = "subject"
  54. additional_claims = {"meta": "data"}
  55. credentials = service_account.Credentials.from_service_account_info(
  56. info, scopes=scopes, subject=subject, additional_claims=additional_claims
  57. )
  58. assert credentials.service_account_email == info["client_email"]
  59. assert credentials.project_id == info["project_id"]
  60. assert credentials._signer.key_id == info["private_key_id"]
  61. assert credentials._token_uri == info["token_uri"]
  62. assert credentials._scopes == scopes
  63. assert credentials._subject == subject
  64. assert credentials._additional_claims == additional_claims
  65. def test_from_service_account_file(self):
  66. info = SERVICE_ACCOUNT_INFO.copy()
  67. credentials = service_account.Credentials.from_service_account_file(
  68. SERVICE_ACCOUNT_JSON_FILE
  69. )
  70. assert credentials.service_account_email == info["client_email"]
  71. assert credentials.project_id == info["project_id"]
  72. assert credentials._signer.key_id == info["private_key_id"]
  73. assert credentials._token_uri == info["token_uri"]
  74. def test_from_service_account_file_args(self):
  75. info = SERVICE_ACCOUNT_INFO.copy()
  76. scopes = ["email", "profile"]
  77. subject = "subject"
  78. additional_claims = {"meta": "data"}
  79. credentials = service_account.Credentials.from_service_account_file(
  80. SERVICE_ACCOUNT_JSON_FILE,
  81. subject=subject,
  82. scopes=scopes,
  83. additional_claims=additional_claims,
  84. )
  85. assert credentials.service_account_email == info["client_email"]
  86. assert credentials.project_id == info["project_id"]
  87. assert credentials._signer.key_id == info["private_key_id"]
  88. assert credentials._token_uri == info["token_uri"]
  89. assert credentials._scopes == scopes
  90. assert credentials._subject == subject
  91. assert credentials._additional_claims == additional_claims
  92. def test_default_state(self):
  93. credentials = self.make_credentials()
  94. assert not credentials.valid
  95. # Expiration hasn't been set yet
  96. assert not credentials.expired
  97. # Scopes haven't been specified yet
  98. assert credentials.requires_scopes
  99. def test_sign_bytes(self):
  100. credentials = self.make_credentials()
  101. to_sign = b"123"
  102. signature = credentials.sign_bytes(to_sign)
  103. assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
  104. def test_signer(self):
  105. credentials = self.make_credentials()
  106. assert isinstance(credentials.signer, crypt.Signer)
  107. def test_signer_email(self):
  108. credentials = self.make_credentials()
  109. assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
  110. def test_create_scoped(self):
  111. credentials = self.make_credentials()
  112. scopes = ["email", "profile"]
  113. credentials = credentials.with_scopes(scopes)
  114. assert credentials._scopes == scopes
  115. def test_with_claims(self):
  116. credentials = self.make_credentials()
  117. new_credentials = credentials.with_claims({"meep": "moop"})
  118. assert new_credentials._additional_claims == {"meep": "moop"}
  119. def test_with_quota_project(self):
  120. credentials = self.make_credentials()
  121. new_credentials = credentials.with_quota_project("new-project-456")
  122. assert new_credentials.quota_project_id == "new-project-456"
  123. hdrs = {}
  124. new_credentials.apply(hdrs, token="tok")
  125. assert "x-goog-user-project" in hdrs
  126. def test__make_authorization_grant_assertion(self):
  127. credentials = self.make_credentials()
  128. token = credentials._make_authorization_grant_assertion()
  129. payload = jwt.decode(token, PUBLIC_CERT_BYTES)
  130. assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
  131. assert payload["aud"] == service_account._GOOGLE_OAUTH2_TOKEN_ENDPOINT
  132. def test__make_authorization_grant_assertion_scoped(self):
  133. credentials = self.make_credentials()
  134. scopes = ["email", "profile"]
  135. credentials = credentials.with_scopes(scopes)
  136. token = credentials._make_authorization_grant_assertion()
  137. payload = jwt.decode(token, PUBLIC_CERT_BYTES)
  138. assert payload["scope"] == "email profile"
  139. def test__make_authorization_grant_assertion_subject(self):
  140. credentials = self.make_credentials()
  141. subject = "user@example.com"
  142. credentials = credentials.with_subject(subject)
  143. token = credentials._make_authorization_grant_assertion()
  144. payload = jwt.decode(token, PUBLIC_CERT_BYTES)
  145. assert payload["sub"] == subject
  146. def test_apply_with_quota_project_id(self):
  147. credentials = service_account.Credentials(
  148. SIGNER,
  149. self.SERVICE_ACCOUNT_EMAIL,
  150. self.TOKEN_URI,
  151. quota_project_id="quota-project-123",
  152. )
  153. headers = {}
  154. credentials.apply(headers, token="token")
  155. assert headers["x-goog-user-project"] == "quota-project-123"
  156. assert "token" in headers["authorization"]
  157. def test_apply_with_no_quota_project_id(self):
  158. credentials = service_account.Credentials(
  159. SIGNER, self.SERVICE_ACCOUNT_EMAIL, self.TOKEN_URI
  160. )
  161. headers = {}
  162. credentials.apply(headers, token="token")
  163. assert "x-goog-user-project" not in headers
  164. assert "token" in headers["authorization"]
  165. @mock.patch("google.auth.jwt.Credentials", instance=True, autospec=True)
  166. def test__create_self_signed_jwt(self, jwt):
  167. credentials = service_account.Credentials(
  168. SIGNER, self.SERVICE_ACCOUNT_EMAIL, self.TOKEN_URI
  169. )
  170. audience = "https://pubsub.googleapis.com"
  171. credentials._create_self_signed_jwt(audience)
  172. jwt.from_signing_credentials.assert_called_once_with(credentials, audience)
  173. @mock.patch("google.auth.jwt.Credentials", instance=True, autospec=True)
  174. def test__create_self_signed_jwt_with_user_scopes(self, jwt):
  175. credentials = service_account.Credentials(
  176. SIGNER, self.SERVICE_ACCOUNT_EMAIL, self.TOKEN_URI, scopes=["foo"]
  177. )
  178. audience = "https://pubsub.googleapis.com"
  179. credentials._create_self_signed_jwt(audience)
  180. # JWT should not be created if there are user-defined scopes
  181. jwt.from_signing_credentials.assert_not_called()
  182. @mock.patch("google.oauth2._client.jwt_grant", autospec=True)
  183. def test_refresh_success(self, jwt_grant):
  184. credentials = self.make_credentials()
  185. token = "token"
  186. jwt_grant.return_value = (
  187. token,
  188. _helpers.utcnow() + datetime.timedelta(seconds=500),
  189. {},
  190. )
  191. request = mock.create_autospec(transport.Request, instance=True)
  192. # Refresh credentials
  193. credentials.refresh(request)
  194. # Check jwt grant call.
  195. assert jwt_grant.called
  196. called_request, token_uri, assertion = jwt_grant.call_args[0]
  197. assert called_request == request
  198. assert token_uri == credentials._token_uri
  199. assert jwt.decode(assertion, PUBLIC_CERT_BYTES)
  200. # No further assertion done on the token, as there are separate tests
  201. # for checking the authorization grant assertion.
  202. # Check that the credentials have the token.
  203. assert credentials.token == token
  204. # Check that the credentials are valid (have a token and are not
  205. # expired)
  206. assert credentials.valid
  207. @mock.patch("google.oauth2._client.jwt_grant", autospec=True)
  208. def test_before_request_refreshes(self, jwt_grant):
  209. credentials = self.make_credentials()
  210. token = "token"
  211. jwt_grant.return_value = (
  212. token,
  213. _helpers.utcnow() + datetime.timedelta(seconds=500),
  214. None,
  215. )
  216. request = mock.create_autospec(transport.Request, instance=True)
  217. # Credentials should start as invalid
  218. assert not credentials.valid
  219. # before_request should cause a refresh
  220. credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
  221. # The refresh endpoint should've been called.
  222. assert jwt_grant.called
  223. # Credentials should now be valid.
  224. assert credentials.valid
  225. @mock.patch("google.auth.jwt.Credentials._make_jwt")
  226. def test_refresh_with_jwt_credentials(self, make_jwt):
  227. credentials = self.make_credentials()
  228. credentials._create_self_signed_jwt("https://pubsub.googleapis.com")
  229. request = mock.create_autospec(transport.Request, instance=True)
  230. token = "token"
  231. expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
  232. make_jwt.return_value = (token, expiry)
  233. # Credentials should start as invalid
  234. assert not credentials.valid
  235. # before_request should cause a refresh
  236. credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
  237. # Credentials should now be valid.
  238. assert credentials.valid
  239. # Assert make_jwt was called
  240. assert make_jwt.called_once()
  241. assert credentials.token == token
  242. assert credentials.expiry == expiry
  243. class TestIDTokenCredentials(object):
  244. SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
  245. TOKEN_URI = "https://example.com/oauth2/token"
  246. TARGET_AUDIENCE = "https://example.com"
  247. @classmethod
  248. def make_credentials(cls):
  249. return service_account.IDTokenCredentials(
  250. SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI, cls.TARGET_AUDIENCE
  251. )
  252. def test_from_service_account_info(self):
  253. credentials = service_account.IDTokenCredentials.from_service_account_info(
  254. SERVICE_ACCOUNT_INFO, target_audience=self.TARGET_AUDIENCE
  255. )
  256. assert credentials._signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"]
  257. assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"]
  258. assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"]
  259. assert credentials._target_audience == self.TARGET_AUDIENCE
  260. def test_from_service_account_file(self):
  261. info = SERVICE_ACCOUNT_INFO.copy()
  262. credentials = service_account.IDTokenCredentials.from_service_account_file(
  263. SERVICE_ACCOUNT_JSON_FILE, target_audience=self.TARGET_AUDIENCE
  264. )
  265. assert credentials.service_account_email == info["client_email"]
  266. assert credentials._signer.key_id == info["private_key_id"]
  267. assert credentials._token_uri == info["token_uri"]
  268. assert credentials._target_audience == self.TARGET_AUDIENCE
  269. def test_default_state(self):
  270. credentials = self.make_credentials()
  271. assert not credentials.valid
  272. # Expiration hasn't been set yet
  273. assert not credentials.expired
  274. def test_sign_bytes(self):
  275. credentials = self.make_credentials()
  276. to_sign = b"123"
  277. signature = credentials.sign_bytes(to_sign)
  278. assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
  279. def test_signer(self):
  280. credentials = self.make_credentials()
  281. assert isinstance(credentials.signer, crypt.Signer)
  282. def test_signer_email(self):
  283. credentials = self.make_credentials()
  284. assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
  285. def test_with_target_audience(self):
  286. credentials = self.make_credentials()
  287. new_credentials = credentials.with_target_audience("https://new.example.com")
  288. assert new_credentials._target_audience == "https://new.example.com"
  289. def test_with_quota_project(self):
  290. credentials = self.make_credentials()
  291. new_credentials = credentials.with_quota_project("project-foo")
  292. assert new_credentials._quota_project_id == "project-foo"
  293. def test__make_authorization_grant_assertion(self):
  294. credentials = self.make_credentials()
  295. token = credentials._make_authorization_grant_assertion()
  296. payload = jwt.decode(token, PUBLIC_CERT_BYTES)
  297. assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
  298. assert payload["aud"] == service_account._GOOGLE_OAUTH2_TOKEN_ENDPOINT
  299. assert payload["target_audience"] == self.TARGET_AUDIENCE
  300. @mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True)
  301. def test_refresh_success(self, id_token_jwt_grant):
  302. credentials = self.make_credentials()
  303. token = "token"
  304. id_token_jwt_grant.return_value = (
  305. token,
  306. _helpers.utcnow() + datetime.timedelta(seconds=500),
  307. {},
  308. )
  309. request = mock.create_autospec(transport.Request, instance=True)
  310. # Refresh credentials
  311. credentials.refresh(request)
  312. # Check jwt grant call.
  313. assert id_token_jwt_grant.called
  314. called_request, token_uri, assertion = id_token_jwt_grant.call_args[0]
  315. assert called_request == request
  316. assert token_uri == credentials._token_uri
  317. assert jwt.decode(assertion, PUBLIC_CERT_BYTES)
  318. # No further assertion done on the token, as there are separate tests
  319. # for checking the authorization grant assertion.
  320. # Check that the credentials have the token.
  321. assert credentials.token == token
  322. # Check that the credentials are valid (have a token and are not
  323. # expired)
  324. assert credentials.valid
  325. @mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True)
  326. def test_before_request_refreshes(self, id_token_jwt_grant):
  327. credentials = self.make_credentials()
  328. token = "token"
  329. id_token_jwt_grant.return_value = (
  330. token,
  331. _helpers.utcnow() + datetime.timedelta(seconds=500),
  332. None,
  333. )
  334. request = mock.create_autospec(transport.Request, instance=True)
  335. # Credentials should start as invalid
  336. assert not credentials.valid
  337. # before_request should cause a refresh
  338. credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
  339. # The refresh endpoint should've been called.
  340. assert id_token_jwt_grant.called
  341. # Credentials should now be valid.
  342. assert credentials.valid