test_sessions.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # Copyright 2024 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from typing import AsyncGenerator
  16. from aioresponses import aioresponses # type: ignore
  17. from mock import Mock, patch
  18. import pytest # type: ignore
  19. from google.auth.aio.credentials import AnonymousCredentials
  20. from google.auth.aio.transport import (
  21. _DEFAULT_TIMEOUT_SECONDS,
  22. DEFAULT_MAX_RETRY_ATTEMPTS,
  23. DEFAULT_RETRYABLE_STATUS_CODES,
  24. Request,
  25. Response,
  26. sessions,
  27. )
  28. from google.auth.exceptions import InvalidType, TimeoutError, TransportError
  29. @pytest.fixture
  30. async def simple_async_task():
  31. return True
  32. class MockRequest(Request):
  33. def __init__(self, response=None, side_effect=None):
  34. self._closed = False
  35. self._response = response
  36. self._side_effect = side_effect
  37. self.call_count = 0
  38. async def __call__(
  39. self,
  40. url,
  41. method="GET",
  42. body=None,
  43. headers=None,
  44. timeout=_DEFAULT_TIMEOUT_SECONDS,
  45. **kwargs,
  46. ):
  47. self.call_count += 1
  48. if self._side_effect:
  49. raise self._side_effect
  50. return self._response
  51. async def close(self):
  52. self._closed = True
  53. return None
  54. class MockResponse(Response):
  55. def __init__(self, status_code, headers=None, content=None):
  56. self._status_code = status_code
  57. self._headers = headers
  58. self._content = content
  59. self._close = False
  60. @property
  61. def status_code(self):
  62. return self._status_code
  63. @property
  64. def headers(self):
  65. return self._headers
  66. async def read(self) -> bytes:
  67. content = await self.content(1024)
  68. return b"".join([chunk async for chunk in content])
  69. async def content(self, chunk_size=None) -> AsyncGenerator:
  70. return self._content
  71. async def close(self) -> None:
  72. self._close = True
  73. class TestTimeoutGuard(object):
  74. default_timeout = 1
  75. def make_timeout_guard(self, timeout):
  76. return sessions.timeout_guard(timeout)
  77. @pytest.mark.asyncio
  78. async def test_timeout_with_simple_async_task_within_bounds(
  79. self, simple_async_task
  80. ):
  81. task = False
  82. with patch("time.monotonic", side_effect=[0, 0.25, 0.75]):
  83. with patch("asyncio.wait_for", lambda coro, _: coro):
  84. async with self.make_timeout_guard(
  85. timeout=self.default_timeout
  86. ) as with_timeout:
  87. task = await with_timeout(simple_async_task)
  88. # Task succeeds.
  89. assert task is True
  90. @pytest.mark.asyncio
  91. async def test_timeout_with_simple_async_task_out_of_bounds(
  92. self, simple_async_task
  93. ):
  94. task = False
  95. with patch("time.monotonic", side_effect=[0, 1, 1]):
  96. with pytest.raises(TimeoutError) as exc:
  97. async with self.make_timeout_guard(
  98. timeout=self.default_timeout
  99. ) as with_timeout:
  100. task = await with_timeout(simple_async_task)
  101. # Task does not succeed and the context manager times out i.e. no remaining time left.
  102. assert task is False
  103. assert exc.match(
  104. f"Context manager exceeded the configured timeout of {self.default_timeout}s."
  105. )
  106. @pytest.mark.asyncio
  107. async def test_timeout_with_async_task_timing_out_before_context(
  108. self, simple_async_task
  109. ):
  110. task = False
  111. with pytest.raises(TimeoutError) as exc:
  112. async with self.make_timeout_guard(
  113. timeout=self.default_timeout
  114. ) as with_timeout:
  115. with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError):
  116. task = await with_timeout(simple_async_task)
  117. # Task does not complete i.e. the operation times out.
  118. assert task is False
  119. assert exc.match(
  120. f"The operation {simple_async_task} exceeded the configured timeout of {self.default_timeout}s."
  121. )
  122. class TestAsyncAuthorizedSession(object):
  123. TEST_URL = "http://example.com/"
  124. credentials = AnonymousCredentials()
  125. @pytest.fixture
  126. async def mocked_content(self):
  127. content = [b"Cavefish ", b"have ", b"no ", b"sight."]
  128. for chunk in content:
  129. yield chunk
  130. @pytest.mark.asyncio
  131. async def test_constructor_with_default_auth_request(self):
  132. with patch("google.auth.aio.transport.sessions.AIOHTTP_INSTALLED", True):
  133. authed_session = sessions.AsyncAuthorizedSession(self.credentials)
  134. assert authed_session._credentials == self.credentials
  135. await authed_session.close()
  136. @pytest.mark.asyncio
  137. async def test_constructor_with_provided_auth_request(self):
  138. auth_request = MockRequest()
  139. authed_session = sessions.AsyncAuthorizedSession(
  140. self.credentials, auth_request=auth_request
  141. )
  142. assert authed_session._auth_request is auth_request
  143. await authed_session.close()
  144. @pytest.mark.asyncio
  145. async def test_constructor_raises_no_auth_request_error(self):
  146. with patch("google.auth.aio.transport.sessions.AIOHTTP_INSTALLED", False):
  147. with pytest.raises(TransportError) as exc:
  148. sessions.AsyncAuthorizedSession(self.credentials)
  149. exc.match(
  150. "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value."
  151. )
  152. @pytest.mark.asyncio
  153. async def test_constructor_raises_incorrect_credentials_error(self):
  154. credentials = Mock()
  155. with pytest.raises(InvalidType) as exc:
  156. sessions.AsyncAuthorizedSession(credentials)
  157. exc.match(
  158. f"The configured credentials of type {type(credentials)} are invalid and must be of type `google.auth.aio.credentials.Credentials`"
  159. )
  160. @pytest.mark.asyncio
  161. async def test_request_default_auth_request_success(self):
  162. with aioresponses() as m:
  163. mocked_chunks = [b"Cavefish ", b"have ", b"no ", b"sight."]
  164. mocked_response = b"".join(mocked_chunks)
  165. m.get(self.TEST_URL, status=200, body=mocked_response)
  166. authed_session = sessions.AsyncAuthorizedSession(self.credentials)
  167. response = await authed_session.request("GET", self.TEST_URL)
  168. assert response.status_code == 200
  169. assert response.headers == {"Content-Type": "application/json"}
  170. assert await response.read() == b"Cavefish have no sight."
  171. await response.close()
  172. await authed_session.close()
  173. @pytest.mark.asyncio
  174. async def test_request_provided_auth_request_success(self, mocked_content):
  175. mocked_response = MockResponse(
  176. status_code=200,
  177. headers={"Content-Type": "application/json"},
  178. content=mocked_content,
  179. )
  180. auth_request = MockRequest(mocked_response)
  181. authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request)
  182. response = await authed_session.request("GET", self.TEST_URL)
  183. assert response.status_code == 200
  184. assert response.headers == {"Content-Type": "application/json"}
  185. assert await response.read() == b"Cavefish have no sight."
  186. await response.close()
  187. assert response._close
  188. await authed_session.close()
  189. @pytest.mark.asyncio
  190. async def test_request_raises_timeout_error(self):
  191. auth_request = MockRequest(side_effect=asyncio.TimeoutError)
  192. authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request)
  193. with pytest.raises(TimeoutError):
  194. await authed_session.request("GET", self.TEST_URL)
  195. @pytest.mark.asyncio
  196. async def test_request_raises_transport_error(self):
  197. auth_request = MockRequest(side_effect=TransportError)
  198. authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request)
  199. with pytest.raises(TransportError):
  200. await authed_session.request("GET", self.TEST_URL)
  201. @pytest.mark.asyncio
  202. async def test_request_max_allowed_time_exceeded_error(self):
  203. auth_request = MockRequest(side_effect=TransportError)
  204. authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request)
  205. with patch("time.monotonic", side_effect=[0, 1, 1]):
  206. with pytest.raises(TimeoutError):
  207. await authed_session.request("GET", self.TEST_URL, max_allowed_time=1)
  208. @pytest.mark.parametrize("retry_status", DEFAULT_RETRYABLE_STATUS_CODES)
  209. @pytest.mark.asyncio
  210. async def test_request_max_retries(self, retry_status):
  211. mocked_response = MockResponse(status_code=retry_status)
  212. auth_request = MockRequest(mocked_response)
  213. with patch("asyncio.sleep", return_value=None):
  214. authed_session = sessions.AsyncAuthorizedSession(
  215. self.credentials, auth_request
  216. )
  217. await authed_session.request("GET", self.TEST_URL)
  218. assert auth_request.call_count == DEFAULT_MAX_RETRY_ATTEMPTS
  219. @pytest.mark.asyncio
  220. async def test_http_get_method_success(self):
  221. expected_payload = b"content is retrieved."
  222. authed_session = sessions.AsyncAuthorizedSession(self.credentials)
  223. with aioresponses() as m:
  224. m.get(self.TEST_URL, status=200, body=expected_payload)
  225. response = await authed_session.get(self.TEST_URL)
  226. assert await response.read() == expected_payload
  227. response = await authed_session.close()
  228. @pytest.mark.asyncio
  229. async def test_http_post_method_success(self):
  230. expected_payload = b"content is posted."
  231. authed_session = sessions.AsyncAuthorizedSession(self.credentials)
  232. with aioresponses() as m:
  233. m.post(self.TEST_URL, status=200, body=expected_payload)
  234. response = await authed_session.post(self.TEST_URL)
  235. assert await response.read() == expected_payload
  236. response = await authed_session.close()
  237. @pytest.mark.asyncio
  238. async def test_http_put_method_success(self):
  239. expected_payload = b"content is retrieved."
  240. authed_session = sessions.AsyncAuthorizedSession(self.credentials)
  241. with aioresponses() as m:
  242. m.put(self.TEST_URL, status=200, body=expected_payload)
  243. response = await authed_session.put(self.TEST_URL)
  244. assert await response.read() == expected_payload
  245. response = await authed_session.close()
  246. @pytest.mark.asyncio
  247. async def test_http_patch_method_success(self):
  248. expected_payload = b"content is retrieved."
  249. authed_session = sessions.AsyncAuthorizedSession(self.credentials)
  250. with aioresponses() as m:
  251. m.patch(self.TEST_URL, status=200, body=expected_payload)
  252. response = await authed_session.patch(self.TEST_URL)
  253. assert await response.read() == expected_payload
  254. response = await authed_session.close()
  255. @pytest.mark.asyncio
  256. async def test_http_delete_method_success(self):
  257. expected_payload = b"content is deleted."
  258. authed_session = sessions.AsyncAuthorizedSession(self.credentials)
  259. with aioresponses() as m:
  260. m.delete(self.TEST_URL, status=200, body=expected_payload)
  261. response = await authed_session.delete(self.TEST_URL)
  262. assert await response.read() == expected_payload
  263. response = await authed_session.close()