test_credentials.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # Copyright 2016 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. import mock
  16. import pytest # type: ignore
  17. from google.auth import _helpers
  18. from google.auth import credentials
  19. class CredentialsImpl(credentials.Credentials):
  20. def refresh(self, request):
  21. self.token = request
  22. self.expiry = (
  23. datetime.datetime.utcnow()
  24. + _helpers.REFRESH_THRESHOLD
  25. + datetime.timedelta(seconds=5)
  26. )
  27. def with_quota_project(self, quota_project_id):
  28. raise NotImplementedError()
  29. class CredentialsImplWithMetrics(credentials.Credentials):
  30. def refresh(self, request):
  31. self.token = request
  32. def _metric_header_for_usage(self):
  33. return "foo"
  34. def test_credentials_constructor():
  35. credentials = CredentialsImpl()
  36. assert not credentials.token
  37. assert not credentials.expiry
  38. assert not credentials.expired
  39. assert not credentials.valid
  40. assert credentials.universe_domain == "googleapis.com"
  41. assert not credentials._use_non_blocking_refresh
  42. def test_credentials_get_cred_info():
  43. credentials = CredentialsImpl()
  44. assert not credentials.get_cred_info()
  45. def test_with_non_blocking_refresh():
  46. c = CredentialsImpl()
  47. c.with_non_blocking_refresh()
  48. assert c._use_non_blocking_refresh
  49. def test_expired_and_valid():
  50. credentials = CredentialsImpl()
  51. credentials.token = "token"
  52. assert credentials.valid
  53. assert not credentials.expired
  54. # Set the expiration to one second more than now plus the clock skew
  55. # accomodation. These credentials should be valid.
  56. credentials.expiry = (
  57. _helpers.utcnow() + _helpers.REFRESH_THRESHOLD + datetime.timedelta(seconds=1)
  58. )
  59. assert credentials.valid
  60. assert not credentials.expired
  61. # Set the credentials expiration to now. Because of the clock skew
  62. # accomodation, these credentials should report as expired.
  63. credentials.expiry = _helpers.utcnow()
  64. assert not credentials.valid
  65. assert credentials.expired
  66. def test_before_request():
  67. credentials = CredentialsImpl()
  68. request = "token"
  69. headers = {}
  70. # First call should call refresh, setting the token.
  71. credentials.before_request(request, "http://example.com", "GET", headers)
  72. assert credentials.valid
  73. assert credentials.token == "token"
  74. assert headers["authorization"] == "Bearer token"
  75. assert "x-allowed-locations" not in headers
  76. request = "token2"
  77. headers = {}
  78. # Second call shouldn't call refresh.
  79. credentials.before_request(request, "http://example.com", "GET", headers)
  80. assert credentials.valid
  81. assert credentials.token == "token"
  82. assert headers["authorization"] == "Bearer token"
  83. assert "x-allowed-locations" not in headers
  84. def test_before_request_with_trust_boundary():
  85. DUMMY_BOUNDARY = "0xA30"
  86. credentials = CredentialsImpl()
  87. credentials._trust_boundary = {"locations": [], "encoded_locations": DUMMY_BOUNDARY}
  88. request = "token"
  89. headers = {}
  90. # First call should call refresh, setting the token.
  91. credentials.before_request(request, "http://example.com", "GET", headers)
  92. assert credentials.valid
  93. assert credentials.token == "token"
  94. assert headers["authorization"] == "Bearer token"
  95. assert headers["x-allowed-locations"] == DUMMY_BOUNDARY
  96. request = "token2"
  97. headers = {}
  98. # Second call shouldn't call refresh.
  99. credentials.before_request(request, "http://example.com", "GET", headers)
  100. assert credentials.valid
  101. assert credentials.token == "token"
  102. assert headers["authorization"] == "Bearer token"
  103. assert headers["x-allowed-locations"] == DUMMY_BOUNDARY
  104. def test_before_request_metrics():
  105. credentials = CredentialsImplWithMetrics()
  106. request = "token"
  107. headers = {}
  108. credentials.before_request(request, "http://example.com", "GET", headers)
  109. assert headers["x-goog-api-client"] == "foo"
  110. def test_anonymous_credentials_ctor():
  111. anon = credentials.AnonymousCredentials()
  112. assert anon.token is None
  113. assert anon.expiry is None
  114. assert not anon.expired
  115. assert anon.valid
  116. def test_anonymous_credentials_refresh():
  117. anon = credentials.AnonymousCredentials()
  118. request = object()
  119. with pytest.raises(ValueError):
  120. anon.refresh(request)
  121. def test_anonymous_credentials_apply_default():
  122. anon = credentials.AnonymousCredentials()
  123. headers = {}
  124. anon.apply(headers)
  125. assert headers == {}
  126. with pytest.raises(ValueError):
  127. anon.apply(headers, token="TOKEN")
  128. def test_anonymous_credentials_before_request():
  129. anon = credentials.AnonymousCredentials()
  130. request = object()
  131. method = "GET"
  132. url = "https://example.com/api/endpoint"
  133. headers = {}
  134. anon.before_request(request, method, url, headers)
  135. assert headers == {}
  136. class ReadOnlyScopedCredentialsImpl(credentials.ReadOnlyScoped, CredentialsImpl):
  137. @property
  138. def requires_scopes(self):
  139. return super(ReadOnlyScopedCredentialsImpl, self).requires_scopes
  140. def test_readonly_scoped_credentials_constructor():
  141. credentials = ReadOnlyScopedCredentialsImpl()
  142. assert credentials._scopes is None
  143. def test_readonly_scoped_credentials_scopes():
  144. credentials = ReadOnlyScopedCredentialsImpl()
  145. credentials._scopes = ["one", "two"]
  146. assert credentials.scopes == ["one", "two"]
  147. assert credentials.has_scopes(["one"])
  148. assert credentials.has_scopes(["two"])
  149. assert credentials.has_scopes(["one", "two"])
  150. assert not credentials.has_scopes(["three"])
  151. def test_readonly_scoped_credentials_requires_scopes():
  152. credentials = ReadOnlyScopedCredentialsImpl()
  153. assert not credentials.requires_scopes
  154. class RequiresScopedCredentialsImpl(credentials.Scoped, CredentialsImpl):
  155. def __init__(self, scopes=None, default_scopes=None):
  156. super(RequiresScopedCredentialsImpl, self).__init__()
  157. self._scopes = scopes
  158. self._default_scopes = default_scopes
  159. @property
  160. def requires_scopes(self):
  161. return not self.scopes
  162. def with_scopes(self, scopes, default_scopes=None):
  163. return RequiresScopedCredentialsImpl(
  164. scopes=scopes, default_scopes=default_scopes
  165. )
  166. def test_create_scoped_if_required_scoped():
  167. unscoped_credentials = RequiresScopedCredentialsImpl()
  168. scoped_credentials = credentials.with_scopes_if_required(
  169. unscoped_credentials, ["one", "two"]
  170. )
  171. assert scoped_credentials is not unscoped_credentials
  172. assert not scoped_credentials.requires_scopes
  173. assert scoped_credentials.has_scopes(["one", "two"])
  174. def test_create_scoped_if_required_not_scopes():
  175. unscoped_credentials = CredentialsImpl()
  176. scoped_credentials = credentials.with_scopes_if_required(
  177. unscoped_credentials, ["one", "two"]
  178. )
  179. assert scoped_credentials is unscoped_credentials
  180. def test_nonblocking_refresh_fresh_credentials():
  181. c = CredentialsImpl()
  182. c._refresh_worker = mock.MagicMock()
  183. request = "token"
  184. c.refresh(request)
  185. assert c.token_state == credentials.TokenState.FRESH
  186. c.with_non_blocking_refresh()
  187. c.before_request(request, "http://example.com", "GET", {})
  188. def test_nonblocking_refresh_invalid_credentials():
  189. c = CredentialsImpl()
  190. c.with_non_blocking_refresh()
  191. request = "token"
  192. headers = {}
  193. assert c.token_state == credentials.TokenState.INVALID
  194. c.before_request(request, "http://example.com", "GET", headers)
  195. assert c.token_state == credentials.TokenState.FRESH
  196. assert c.valid
  197. assert c.token == "token"
  198. assert headers["authorization"] == "Bearer token"
  199. assert "x-identity-trust-boundary" not in headers
  200. def test_nonblocking_refresh_stale_credentials():
  201. c = CredentialsImpl()
  202. c.with_non_blocking_refresh()
  203. request = "token"
  204. headers = {}
  205. # Invalid credentials MUST require a blocking refresh.
  206. c.before_request(request, "http://example.com", "GET", headers)
  207. assert c.token_state == credentials.TokenState.FRESH
  208. assert not c._refresh_worker._worker
  209. c.expiry = (
  210. datetime.datetime.utcnow()
  211. + _helpers.REFRESH_THRESHOLD
  212. - datetime.timedelta(seconds=1)
  213. )
  214. # STALE credentials SHOULD spawn a non-blocking worker
  215. assert c.token_state == credentials.TokenState.STALE
  216. c.before_request(request, "http://example.com", "GET", headers)
  217. assert c._refresh_worker._worker is not None
  218. assert c.token_state == credentials.TokenState.FRESH
  219. assert c.valid
  220. assert c.token == "token"
  221. assert headers["authorization"] == "Bearer token"
  222. assert "x-identity-trust-boundary" not in headers
  223. def test_nonblocking_refresh_failed_credentials():
  224. c = CredentialsImpl()
  225. c.with_non_blocking_refresh()
  226. request = "token"
  227. headers = {}
  228. # Invalid credentials MUST require a blocking refresh.
  229. c.before_request(request, "http://example.com", "GET", headers)
  230. assert c.token_state == credentials.TokenState.FRESH
  231. assert not c._refresh_worker._worker
  232. c.expiry = (
  233. datetime.datetime.utcnow()
  234. + _helpers.REFRESH_THRESHOLD
  235. - datetime.timedelta(seconds=1)
  236. )
  237. # STALE credentials SHOULD spawn a non-blocking worker
  238. assert c.token_state == credentials.TokenState.STALE
  239. c._refresh_worker._worker = mock.MagicMock()
  240. c._refresh_worker._worker._error_info = "Some Error"
  241. c.before_request(request, "http://example.com", "GET", headers)
  242. assert c._refresh_worker._worker is not None
  243. assert c.token_state == credentials.TokenState.FRESH
  244. assert c.valid
  245. assert c.token == "token"
  246. assert headers["authorization"] == "Bearer token"
  247. assert "x-identity-trust-boundary" not in headers
  248. def test_token_state_no_expiry():
  249. c = CredentialsImpl()
  250. request = "token"
  251. c.refresh(request)
  252. c.expiry = None
  253. assert c.token_state == credentials.TokenState.FRESH
  254. c.before_request(request, "http://example.com", "GET", {})