test_credentials.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # Copyright 2016 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. import mock
  16. import pytest # type: ignore
  17. from google.auth import _helpers
  18. from google.auth import credentials
  19. class CredentialsImpl(credentials.Credentials):
  20. def refresh(self, request):
  21. self.token = request
  22. self.expiry = (
  23. datetime.datetime.utcnow()
  24. + _helpers.REFRESH_THRESHOLD
  25. + datetime.timedelta(seconds=5)
  26. )
  27. def with_quota_project(self, quota_project_id):
  28. raise NotImplementedError()
  29. class CredentialsImplWithMetrics(credentials.Credentials):
  30. def refresh(self, request):
  31. self.token = request
  32. def _metric_header_for_usage(self):
  33. return "foo"
  34. def test_credentials_constructor():
  35. credentials = CredentialsImpl()
  36. assert not credentials.token
  37. assert not credentials.expiry
  38. assert not credentials.expired
  39. assert not credentials.valid
  40. assert credentials.universe_domain == "googleapis.com"
  41. assert not credentials._use_non_blocking_refresh
  42. def test_with_non_blocking_refresh():
  43. c = CredentialsImpl()
  44. c.with_non_blocking_refresh()
  45. assert c._use_non_blocking_refresh
  46. def test_expired_and_valid():
  47. credentials = CredentialsImpl()
  48. credentials.token = "token"
  49. assert credentials.valid
  50. assert not credentials.expired
  51. # Set the expiration to one second more than now plus the clock skew
  52. # accomodation. These credentials should be valid.
  53. credentials.expiry = (
  54. _helpers.utcnow() + _helpers.REFRESH_THRESHOLD + datetime.timedelta(seconds=1)
  55. )
  56. assert credentials.valid
  57. assert not credentials.expired
  58. # Set the credentials expiration to now. Because of the clock skew
  59. # accomodation, these credentials should report as expired.
  60. credentials.expiry = _helpers.utcnow()
  61. assert not credentials.valid
  62. assert credentials.expired
  63. def test_before_request():
  64. credentials = CredentialsImpl()
  65. request = "token"
  66. headers = {}
  67. # First call should call refresh, setting the token.
  68. credentials.before_request(request, "http://example.com", "GET", headers)
  69. assert credentials.valid
  70. assert credentials.token == "token"
  71. assert headers["authorization"] == "Bearer token"
  72. assert "x-allowed-locations" not in headers
  73. request = "token2"
  74. headers = {}
  75. # Second call shouldn't call refresh.
  76. credentials.before_request(request, "http://example.com", "GET", headers)
  77. assert credentials.valid
  78. assert credentials.token == "token"
  79. assert headers["authorization"] == "Bearer token"
  80. assert "x-allowed-locations" not in headers
  81. def test_before_request_with_trust_boundary():
  82. DUMMY_BOUNDARY = "0xA30"
  83. credentials = CredentialsImpl()
  84. credentials._trust_boundary = {"locations": [], "encoded_locations": DUMMY_BOUNDARY}
  85. request = "token"
  86. headers = {}
  87. # First call should call refresh, setting the token.
  88. credentials.before_request(request, "http://example.com", "GET", headers)
  89. assert credentials.valid
  90. assert credentials.token == "token"
  91. assert headers["authorization"] == "Bearer token"
  92. assert headers["x-allowed-locations"] == DUMMY_BOUNDARY
  93. request = "token2"
  94. headers = {}
  95. # Second call shouldn't call refresh.
  96. credentials.before_request(request, "http://example.com", "GET", headers)
  97. assert credentials.valid
  98. assert credentials.token == "token"
  99. assert headers["authorization"] == "Bearer token"
  100. assert headers["x-allowed-locations"] == DUMMY_BOUNDARY
  101. def test_before_request_metrics():
  102. credentials = CredentialsImplWithMetrics()
  103. request = "token"
  104. headers = {}
  105. credentials.before_request(request, "http://example.com", "GET", headers)
  106. assert headers["x-goog-api-client"] == "foo"
  107. def test_anonymous_credentials_ctor():
  108. anon = credentials.AnonymousCredentials()
  109. assert anon.token is None
  110. assert anon.expiry is None
  111. assert not anon.expired
  112. assert anon.valid
  113. def test_anonymous_credentials_refresh():
  114. anon = credentials.AnonymousCredentials()
  115. request = object()
  116. with pytest.raises(ValueError):
  117. anon.refresh(request)
  118. def test_anonymous_credentials_apply_default():
  119. anon = credentials.AnonymousCredentials()
  120. headers = {}
  121. anon.apply(headers)
  122. assert headers == {}
  123. with pytest.raises(ValueError):
  124. anon.apply(headers, token="TOKEN")
  125. def test_anonymous_credentials_before_request():
  126. anon = credentials.AnonymousCredentials()
  127. request = object()
  128. method = "GET"
  129. url = "https://example.com/api/endpoint"
  130. headers = {}
  131. anon.before_request(request, method, url, headers)
  132. assert headers == {}
  133. class ReadOnlyScopedCredentialsImpl(credentials.ReadOnlyScoped, CredentialsImpl):
  134. @property
  135. def requires_scopes(self):
  136. return super(ReadOnlyScopedCredentialsImpl, self).requires_scopes
  137. def test_readonly_scoped_credentials_constructor():
  138. credentials = ReadOnlyScopedCredentialsImpl()
  139. assert credentials._scopes is None
  140. def test_readonly_scoped_credentials_scopes():
  141. credentials = ReadOnlyScopedCredentialsImpl()
  142. credentials._scopes = ["one", "two"]
  143. assert credentials.scopes == ["one", "two"]
  144. assert credentials.has_scopes(["one"])
  145. assert credentials.has_scopes(["two"])
  146. assert credentials.has_scopes(["one", "two"])
  147. assert not credentials.has_scopes(["three"])
  148. def test_readonly_scoped_credentials_requires_scopes():
  149. credentials = ReadOnlyScopedCredentialsImpl()
  150. assert not credentials.requires_scopes
  151. class RequiresScopedCredentialsImpl(credentials.Scoped, CredentialsImpl):
  152. def __init__(self, scopes=None, default_scopes=None):
  153. super(RequiresScopedCredentialsImpl, self).__init__()
  154. self._scopes = scopes
  155. self._default_scopes = default_scopes
  156. @property
  157. def requires_scopes(self):
  158. return not self.scopes
  159. def with_scopes(self, scopes, default_scopes=None):
  160. return RequiresScopedCredentialsImpl(
  161. scopes=scopes, default_scopes=default_scopes
  162. )
  163. def test_create_scoped_if_required_scoped():
  164. unscoped_credentials = RequiresScopedCredentialsImpl()
  165. scoped_credentials = credentials.with_scopes_if_required(
  166. unscoped_credentials, ["one", "two"]
  167. )
  168. assert scoped_credentials is not unscoped_credentials
  169. assert not scoped_credentials.requires_scopes
  170. assert scoped_credentials.has_scopes(["one", "two"])
  171. def test_create_scoped_if_required_not_scopes():
  172. unscoped_credentials = CredentialsImpl()
  173. scoped_credentials = credentials.with_scopes_if_required(
  174. unscoped_credentials, ["one", "two"]
  175. )
  176. assert scoped_credentials is unscoped_credentials
  177. def test_nonblocking_refresh_fresh_credentials():
  178. c = CredentialsImpl()
  179. c._refresh_worker = mock.MagicMock()
  180. request = "token"
  181. c.refresh(request)
  182. assert c.token_state == credentials.TokenState.FRESH
  183. c.with_non_blocking_refresh()
  184. c.before_request(request, "http://example.com", "GET", {})
  185. def test_nonblocking_refresh_invalid_credentials():
  186. c = CredentialsImpl()
  187. c.with_non_blocking_refresh()
  188. request = "token"
  189. headers = {}
  190. assert c.token_state == credentials.TokenState.INVALID
  191. c.before_request(request, "http://example.com", "GET", headers)
  192. assert c.token_state == credentials.TokenState.FRESH
  193. assert c.valid
  194. assert c.token == "token"
  195. assert headers["authorization"] == "Bearer token"
  196. assert "x-identity-trust-boundary" not in headers
  197. def test_nonblocking_refresh_stale_credentials():
  198. c = CredentialsImpl()
  199. c.with_non_blocking_refresh()
  200. request = "token"
  201. headers = {}
  202. # Invalid credentials MUST require a blocking refresh.
  203. c.before_request(request, "http://example.com", "GET", headers)
  204. assert c.token_state == credentials.TokenState.FRESH
  205. assert not c._refresh_worker._worker
  206. c.expiry = (
  207. datetime.datetime.utcnow()
  208. + _helpers.REFRESH_THRESHOLD
  209. - datetime.timedelta(seconds=1)
  210. )
  211. # STALE credentials SHOULD spawn a non-blocking worker
  212. assert c.token_state == credentials.TokenState.STALE
  213. c.before_request(request, "http://example.com", "GET", headers)
  214. assert c._refresh_worker._worker is not None
  215. assert c.token_state == credentials.TokenState.FRESH
  216. assert c.valid
  217. assert c.token == "token"
  218. assert headers["authorization"] == "Bearer token"
  219. assert "x-identity-trust-boundary" not in headers
  220. def test_nonblocking_refresh_failed_credentials():
  221. c = CredentialsImpl()
  222. c.with_non_blocking_refresh()
  223. request = "token"
  224. headers = {}
  225. # Invalid credentials MUST require a blocking refresh.
  226. c.before_request(request, "http://example.com", "GET", headers)
  227. assert c.token_state == credentials.TokenState.FRESH
  228. assert not c._refresh_worker._worker
  229. c.expiry = (
  230. datetime.datetime.utcnow()
  231. + _helpers.REFRESH_THRESHOLD
  232. - datetime.timedelta(seconds=1)
  233. )
  234. # STALE credentials SHOULD spawn a non-blocking worker
  235. assert c.token_state == credentials.TokenState.STALE
  236. c._refresh_worker._worker = mock.MagicMock()
  237. c._refresh_worker._worker._error_info = "Some Error"
  238. c.before_request(request, "http://example.com", "GET", headers)
  239. assert c._refresh_worker._worker is not None
  240. assert c.token_state == credentials.TokenState.FRESH
  241. assert c.valid
  242. assert c.token == "token"
  243. assert headers["authorization"] == "Bearer token"
  244. assert "x-identity-trust-boundary" not in headers
  245. def test_token_state_no_expiry():
  246. c = CredentialsImpl()
  247. request = "token"
  248. c.refresh(request)
  249. c.expiry = None
  250. assert c.token_state == credentials.TokenState.FRESH
  251. c.before_request(request, "http://example.com", "GET", {})