Просмотр исходного кода

chore(vsts): Merge OAuth2RefreshMixin into VstsApiClient (#52574)

Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
Alberto Leal 1 год назад
Родитель
Сommit
7e1ac08b58

+ 0 - 1
pyproject.toml

@@ -521,7 +521,6 @@ module = [
     "sentry.integrations.bitbucket_server.integration",
     "sentry.integrations.bitbucket_server.repository",
     "sentry.integrations.bitbucket_server.webhook",
-    "sentry.integrations.client",
     "sentry.integrations.custom_scm.integration",
     "sentry.integrations.custom_scm.repository",
     "sentry.integrations.example.integration",

+ 0 - 19
src/sentry/integrations/client.py

@@ -1,29 +1,10 @@
-from time import time
-
-from sentry.exceptions import InvalidIdentity
-from sentry.services.hybrid_cloud.util import control_silo_function
 from sentry.shared_integrations.client import BaseApiClient
 
 
 class ApiClient(BaseApiClient):
     integration_type = "integration"
-
     metrics_prefix = "integrations"
-
     log_path = "sentry.integrations.client"
 
     # Used in metrics and logging.
     integration_name = "undefined"
-
-
-class OAuth2RefreshMixin:
-    @control_silo_function
-    def check_auth(self, *args, **kwargs):
-        """
-        Checks if auth is expired and if so refreshes it
-        """
-        time_expires = self.identity.data.get("expires")
-        if time_expires is None:
-            raise InvalidIdentity("OAuth2ApiClient requires identity with specified expired time")
-        if int(time_expires) <= int(time()):
-            self.identity.get_provider().refresh_identity(self.identity, *args, **kwargs)

+ 18 - 3
src/sentry/integrations/vsts/client.py

@@ -1,11 +1,13 @@
 from __future__ import annotations
 
+from time import time
 from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
 
 from requests import PreparedRequest
 from rest_framework.response import Response
 
-from sentry.integrations.client import ApiClient, OAuth2RefreshMixin
+from sentry.exceptions import InvalidIdentity
+from sentry.integrations.client import ApiClient
 from sentry.models import Identity
 from sentry.services.hybrid_cloud.util import control_silo_function
 from sentry.shared_integrations.client.base import BaseApiResponseX
@@ -146,7 +148,7 @@ class VstsSetupApiClient(ApiClient, VstsApiMixin):
         return self._request(method, path, headers=headers, data=data, params=params)
 
 
-class VstsApiClient(IntegrationProxyClient, OAuth2RefreshMixin, VstsApiMixin):
+class VstsApiClient(IntegrationProxyClient, VstsApiMixin):
     integration_name = "vsts"
     _identity: Identity | None = None
 
@@ -181,12 +183,25 @@ class VstsApiClient(IntegrationProxyClient, OAuth2RefreshMixin, VstsApiMixin):
 
         return self._request(method, *args, headers=headers, **kwargs)
 
+    @control_silo_function
+    def _refresh_auth_if_expired(self):
+        """
+        Checks if auth is expired and if so refreshes it
+        """
+        time_expires = self.identity.data.get("expires")
+        if time_expires is None:
+            raise InvalidIdentity("VstsApiClient requires identity with specified expired time")
+        if int(time_expires) <= int(time()):
+            self.identity.get_provider().refresh_identity(
+                self.identity, redirect_url=self.oauth_redirect_url
+            )
+
     @control_silo_function
     def authorize_request(
         self,
         prepared_request: PreparedRequest,
     ) -> PreparedRequest:
-        self.check_auth(redirect_url=self.oauth_redirect_url)
+        self._refresh_auth_if_expired()
         access_token = self.identity.data["access_token"]
         headers = prepare_auth_header(
             access_token=access_token,

+ 1 - 72
tests/sentry/integrations/test_client.py

@@ -1,5 +1,4 @@
 import errno
-from time import time
 from unittest import mock
 
 import pytest
@@ -12,10 +11,8 @@ from sentry_sdk.tracing import Transaction
 from urllib3.exceptions import InvalidChunkLength
 from urllib3.response import HTTPResponse
 
-from sentry.identity import register
 from sentry.identity.oauth2 import OAuth2Provider
-from sentry.integrations.client import ApiClient, OAuth2RefreshMixin
-from sentry.models import Identity, IdentityProvider
+from sentry.integrations.client import ApiClient
 from sentry.shared_integrations.exceptions import (
     ApiConnectionResetError,
     ApiHostError,
@@ -24,7 +21,6 @@ from sentry.shared_integrations.exceptions import (
 from sentry.shared_integrations.exceptions.base import ApiError
 from sentry.shared_integrations.response.base import BaseApiResponse
 from sentry.testutils import TestCase
-from sentry.testutils.silo import control_silo_test
 
 
 class ApiClientTest(TestCase):
@@ -319,70 +315,3 @@ class OAuthProvider(OAuth2Provider):
 
     def get_refresh_token_url(self):
         return "https://example.com"
-
-
-class OAuth2ApiClient(ApiClient, OAuth2RefreshMixin):
-    def __init__(self, identity, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.identity = identity
-
-
-@control_silo_test(stable=True)
-class OAuth2ApiClientTest(TestCase):
-    def setUp(self):
-        self.user = self.create_user()
-        self.organization = self.create_organization()
-        self.access_token = "1234567890"
-        self.identity_provider_model = IdentityProvider.objects.create(type="oauth")
-        register(OAuthProvider)
-
-    @responses.activate
-    def test_check_auth(self):
-        new_auth = {
-            "access_token": "1234567890",
-            "refresh_token": "0987654321",
-            "expires_in": 45678988239,
-        }
-        responses.add(responses.POST, "https://example.com", json=new_auth)
-        identity = Identity.objects.create(
-            idp=self.identity_provider_model,
-            user=self.user,
-            external_id="oauth_base",
-            data={
-                "access_token": "access_token",
-                "refresh_token": "refresh_token",
-                "expires": int(time()) - 3600,
-            },
-        )
-
-        client = OAuth2ApiClient(identity)
-        client.check_auth()
-
-        assert client.identity.data["access_token"] == new_auth["access_token"]
-        assert client.identity.data["refresh_token"] == new_auth["refresh_token"]
-        assert client.identity.data["expires"] > int(time())
-
-    @responses.activate
-    def test_check_auth_no_refresh(self):
-        new_auth = {
-            "access_token": "1234567890",
-            "refresh_token": "0987654321",
-            "expires_in": 45678988239,
-        }
-        old_auth = {
-            "access_token": "access_token",
-            "refresh_token": "refresh_token",
-            "expires": int(time()) + 3600,
-        }
-        responses.add(responses.POST, "https://example.com", json=new_auth)
-        identity = Identity.objects.create(
-            idp=self.identity_provider_model,
-            user=self.user,
-            external_id="oauth_base",
-            data=old_auth,
-        )
-
-        client = OAuth2ApiClient(identity)
-        client.check_auth()
-
-        assert client.identity.data == old_auth

+ 35 - 0
tests/sentry/integrations/vsts/test_client.py

@@ -64,6 +64,41 @@ class VstsApiClientTest(VstsIntegrationTestCase):
         ]
         assert identity.data["access_token"] == "new-access-token"
         assert identity.data["refresh_token"] == "new-refresh-token"
+        assert identity.data["expires"] > int(time())
+
+    @responses.activate
+    def test_does_not_refresh_valid_tokens(self):
+        self.assert_installation()
+        responses.reset()
+        integration = Integration.objects.get(provider="vsts")
+
+        # Make the Identity have a non-expired token
+        idp = IdentityProvider.objects.get(external_id=self.vsts_account_id)
+        identity = Identity.objects.get(idp_id=idp.id)
+        expires = int(time()) + int(123456789)
+        identity.data["expires"] = expires
+        access_token = identity.data["access_token"]
+        refresh_token = identity.data["refresh_token"]
+        identity.save()
+
+        # New values VSTS will return on refresh
+        self.access_token = "new-access-token"
+        self.refresh_token = "new-refresh-token"
+        self._stub_vsts()
+
+        # Make a request
+
+        integration.get_installation(
+            integration.organizationintegration_set.first().organization_id
+        ).get_client(base_url=self.vsts_base_url).get_projects()
+        assert len(responses.calls) == 1
+        assert (
+            responses.calls[0].request.url
+            == "https://myvstsaccount.visualstudio.com/_apis/projects?stateFilter=WellFormed&%24skip=0&%24top=100"
+        )
+        assert identity.data["access_token"] == access_token != self.access_token
+        assert identity.data["refresh_token"] == refresh_token != self.refresh_token
+        assert identity.data["expires"] == expires
 
     def test_project_pagination(self):
         def request_callback(request):