123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311 |
- # Copyright 2024 Google LLC
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import asyncio
- from typing import AsyncGenerator
- from aioresponses import aioresponses # type: ignore
- from mock import Mock, patch
- import pytest # type: ignore
- from google.auth.aio.credentials import AnonymousCredentials
- from google.auth.aio.transport import (
- _DEFAULT_TIMEOUT_SECONDS,
- DEFAULT_MAX_RETRY_ATTEMPTS,
- DEFAULT_RETRYABLE_STATUS_CODES,
- Request,
- Response,
- sessions,
- )
- from google.auth.exceptions import InvalidType, TimeoutError, TransportError
- @pytest.fixture
- async def simple_async_task():
- return True
- class MockRequest(Request):
- def __init__(self, response=None, side_effect=None):
- self._closed = False
- self._response = response
- self._side_effect = side_effect
- self.call_count = 0
- async def __call__(
- self,
- url,
- method="GET",
- body=None,
- headers=None,
- timeout=_DEFAULT_TIMEOUT_SECONDS,
- **kwargs,
- ):
- self.call_count += 1
- if self._side_effect:
- raise self._side_effect
- return self._response
- async def close(self):
- self._closed = True
- return None
- class MockResponse(Response):
- def __init__(self, status_code, headers=None, content=None):
- self._status_code = status_code
- self._headers = headers
- self._content = content
- self._close = False
- @property
- def status_code(self):
- return self._status_code
- @property
- def headers(self):
- return self._headers
- async def read(self) -> bytes:
- content = await self.content(1024)
- return b"".join([chunk async for chunk in content])
- async def content(self, chunk_size=None) -> AsyncGenerator:
- return self._content
- async def close(self) -> None:
- self._close = True
- class TestTimeoutGuard(object):
- default_timeout = 1
- def make_timeout_guard(self, timeout):
- return sessions.timeout_guard(timeout)
- @pytest.mark.asyncio
- async def test_timeout_with_simple_async_task_within_bounds(
- self, simple_async_task
- ):
- task = False
- with patch("time.monotonic", side_effect=[0, 0.25, 0.75]):
- with patch("asyncio.wait_for", lambda coro, _: coro):
- async with self.make_timeout_guard(
- timeout=self.default_timeout
- ) as with_timeout:
- task = await with_timeout(simple_async_task)
- # Task succeeds.
- assert task is True
- @pytest.mark.asyncio
- async def test_timeout_with_simple_async_task_out_of_bounds(
- self, simple_async_task
- ):
- task = False
- with patch("time.monotonic", side_effect=[0, 1, 1]):
- with pytest.raises(TimeoutError) as exc:
- async with self.make_timeout_guard(
- timeout=self.default_timeout
- ) as with_timeout:
- task = await with_timeout(simple_async_task)
- # Task does not succeed and the context manager times out i.e. no remaining time left.
- assert task is False
- assert exc.match(
- f"Context manager exceeded the configured timeout of {self.default_timeout}s."
- )
- @pytest.mark.asyncio
- async def test_timeout_with_async_task_timing_out_before_context(
- self, simple_async_task
- ):
- task = False
- with pytest.raises(TimeoutError) as exc:
- async with self.make_timeout_guard(
- timeout=self.default_timeout
- ) as with_timeout:
- with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError):
- task = await with_timeout(simple_async_task)
- # Task does not complete i.e. the operation times out.
- assert task is False
- assert exc.match(
- f"The operation {simple_async_task} exceeded the configured timeout of {self.default_timeout}s."
- )
- class TestAsyncAuthorizedSession(object):
- TEST_URL = "http://example.com/"
- credentials = AnonymousCredentials()
- @pytest.fixture
- async def mocked_content(self):
- content = [b"Cavefish ", b"have ", b"no ", b"sight."]
- for chunk in content:
- yield chunk
- @pytest.mark.asyncio
- async def test_constructor_with_default_auth_request(self):
- with patch("google.auth.aio.transport.sessions.AIOHTTP_INSTALLED", True):
- authed_session = sessions.AsyncAuthorizedSession(self.credentials)
- assert authed_session._credentials == self.credentials
- await authed_session.close()
- @pytest.mark.asyncio
- async def test_constructor_with_provided_auth_request(self):
- auth_request = MockRequest()
- authed_session = sessions.AsyncAuthorizedSession(
- self.credentials, auth_request=auth_request
- )
- assert authed_session._auth_request is auth_request
- await authed_session.close()
- @pytest.mark.asyncio
- async def test_constructor_raises_no_auth_request_error(self):
- with patch("google.auth.aio.transport.sessions.AIOHTTP_INSTALLED", False):
- with pytest.raises(TransportError) as exc:
- sessions.AsyncAuthorizedSession(self.credentials)
- exc.match(
- "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value."
- )
- @pytest.mark.asyncio
- async def test_constructor_raises_incorrect_credentials_error(self):
- credentials = Mock()
- with pytest.raises(InvalidType) as exc:
- sessions.AsyncAuthorizedSession(credentials)
- exc.match(
- f"The configured credentials of type {type(credentials)} are invalid and must be of type `google.auth.aio.credentials.Credentials`"
- )
- @pytest.mark.asyncio
- async def test_request_default_auth_request_success(self):
- with aioresponses() as m:
- mocked_chunks = [b"Cavefish ", b"have ", b"no ", b"sight."]
- mocked_response = b"".join(mocked_chunks)
- m.get(self.TEST_URL, status=200, body=mocked_response)
- authed_session = sessions.AsyncAuthorizedSession(self.credentials)
- response = await authed_session.request("GET", self.TEST_URL)
- assert response.status_code == 200
- assert response.headers == {"Content-Type": "application/json"}
- assert await response.read() == b"Cavefish have no sight."
- await response.close()
- await authed_session.close()
- @pytest.mark.asyncio
- async def test_request_provided_auth_request_success(self, mocked_content):
- mocked_response = MockResponse(
- status_code=200,
- headers={"Content-Type": "application/json"},
- content=mocked_content,
- )
- auth_request = MockRequest(mocked_response)
- authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request)
- response = await authed_session.request("GET", self.TEST_URL)
- assert response.status_code == 200
- assert response.headers == {"Content-Type": "application/json"}
- assert await response.read() == b"Cavefish have no sight."
- await response.close()
- assert response._close
- await authed_session.close()
- @pytest.mark.asyncio
- async def test_request_raises_timeout_error(self):
- auth_request = MockRequest(side_effect=asyncio.TimeoutError)
- authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request)
- with pytest.raises(TimeoutError):
- await authed_session.request("GET", self.TEST_URL)
- @pytest.mark.asyncio
- async def test_request_raises_transport_error(self):
- auth_request = MockRequest(side_effect=TransportError)
- authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request)
- with pytest.raises(TransportError):
- await authed_session.request("GET", self.TEST_URL)
- @pytest.mark.asyncio
- async def test_request_max_allowed_time_exceeded_error(self):
- auth_request = MockRequest(side_effect=TransportError)
- authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request)
- with patch("time.monotonic", side_effect=[0, 1, 1]):
- with pytest.raises(TimeoutError):
- await authed_session.request("GET", self.TEST_URL, max_allowed_time=1)
- @pytest.mark.parametrize("retry_status", DEFAULT_RETRYABLE_STATUS_CODES)
- @pytest.mark.asyncio
- async def test_request_max_retries(self, retry_status):
- mocked_response = MockResponse(status_code=retry_status)
- auth_request = MockRequest(mocked_response)
- with patch("asyncio.sleep", return_value=None):
- authed_session = sessions.AsyncAuthorizedSession(
- self.credentials, auth_request
- )
- await authed_session.request("GET", self.TEST_URL)
- assert auth_request.call_count == DEFAULT_MAX_RETRY_ATTEMPTS
- @pytest.mark.asyncio
- async def test_http_get_method_success(self):
- expected_payload = b"content is retrieved."
- authed_session = sessions.AsyncAuthorizedSession(self.credentials)
- with aioresponses() as m:
- m.get(self.TEST_URL, status=200, body=expected_payload)
- response = await authed_session.get(self.TEST_URL)
- assert await response.read() == expected_payload
- response = await authed_session.close()
- @pytest.mark.asyncio
- async def test_http_post_method_success(self):
- expected_payload = b"content is posted."
- authed_session = sessions.AsyncAuthorizedSession(self.credentials)
- with aioresponses() as m:
- m.post(self.TEST_URL, status=200, body=expected_payload)
- response = await authed_session.post(self.TEST_URL)
- assert await response.read() == expected_payload
- response = await authed_session.close()
- @pytest.mark.asyncio
- async def test_http_put_method_success(self):
- expected_payload = b"content is retrieved."
- authed_session = sessions.AsyncAuthorizedSession(self.credentials)
- with aioresponses() as m:
- m.put(self.TEST_URL, status=200, body=expected_payload)
- response = await authed_session.put(self.TEST_URL)
- assert await response.read() == expected_payload
- response = await authed_session.close()
- @pytest.mark.asyncio
- async def test_http_patch_method_success(self):
- expected_payload = b"content is retrieved."
- authed_session = sessions.AsyncAuthorizedSession(self.credentials)
- with aioresponses() as m:
- m.patch(self.TEST_URL, status=200, body=expected_payload)
- response = await authed_session.patch(self.TEST_URL)
- assert await response.read() == expected_payload
- response = await authed_session.close()
- @pytest.mark.asyncio
- async def test_http_delete_method_success(self):
- expected_payload = b"content is deleted."
- authed_session = sessions.AsyncAuthorizedSession(self.credentials)
- with aioresponses() as m:
- m.delete(self.TEST_URL, status=200, body=expected_payload)
- response = await authed_session.delete(self.TEST_URL)
- assert await response.read() == expected_payload
- response = await authed_session.close()
|