Browse Source

fix(hybrid-cloud): Add instance methods to rpc for UserSocialAuth (#53888)

This PR adds the `revoke_token` and `refresh_token` instance methods to
the RPC and changes the erroring instance to perform an RPC call
instead. Since these credentials are project-scoped (and therefore
region scoped) we don't have to do any proxying to prevent race
conditions the way we have to do with the rest of the integrations.

**todo**
- [x] Add tests that would fail without this check
Leander Rodrigues 1 year ago
parent
commit
bbfe5e1c90

+ 20 - 0
src/sentry/services/hybrid_cloud/usersocialauth/impl.py

@@ -28,6 +28,26 @@ class DatabaseBackedUserSocialAuthService(UserSocialAuthService):
             return None
         return auths[0]
 
+    def revoke_token(
+        self, *, filter: UserSocialAuthFilterArgs, drop_token: bool = True
+    ) -> List[RpcUserSocialAuth]:
+        """
+        Calls UserSocialAuth.revoke_token() on all matching results, returning the modified RpcUserSocialAuths.
+        """
+        db_auths = self._FQ._query_many(filter=filter)
+        for db_auth in db_auths:
+            db_auth.revoke_token(drop_token=drop_token)
+        return self.get_many(filter=filter)
+
+    def refresh_token(self, *, filter: UserSocialAuthFilterArgs) -> List[RpcUserSocialAuth]:
+        """
+        Calls UserSocialAuth.refresh_token() on all matching results, returning the modified RpcUserSocialAuths.
+        """
+        db_auths = self._FQ._query_many(filter=filter)
+        for db_auth in db_auths:
+            db_auth.refresh_token()
+        return self.get_many(filter=filter)
+
     def link_auth(self, *, usa: RpcUserSocialAuth, organization: RpcOrganization) -> bool:
         try:
             integration, _created = Integration.objects.get_or_create(

+ 18 - 0
src/sentry/services/hybrid_cloud/usersocialauth/service.py

@@ -40,6 +40,24 @@ class UserSocialAuthService(RpcService):
         """
         pass
 
+    @rpc_method
+    @abstractmethod
+    def revoke_token(
+        self, *, filter: UserSocialAuthFilterArgs, drop_token: bool = True
+    ) -> List[RpcUserSocialAuth]:
+        """
+        Calls UserSocialAuth.revoke_token() on all matching results, returning the modified RpcUserSocialAuths.
+        """
+        pass
+
+    @rpc_method
+    @abstractmethod
+    def refresh_token(self, *, filter: UserSocialAuthFilterArgs) -> List[RpcUserSocialAuth]:
+        """
+        Calls UserSocialAuth.refresh_token() on all matching results, returning the modified RpcUserSocialAuths.
+        """
+        pass
+
     @rpc_method
     @abstractmethod
     def link_auth(self, *, usa: RpcUserSocialAuth, organization: RpcOrganization) -> bool:

+ 1 - 1
src/sentry/testutils/factories.py

@@ -753,7 +753,7 @@ class Factories:
         extra_data: Mapping[str, Any] | None = None,
     ):
         if not provider:
-            provider = "example"
+            provider = "asana"
         if not uid:
             uid = "abc-123"
         usa = UserSocialAuth(user=user, provider=provider, uid=uid, extra_data=extra_data)

+ 2 - 1
src/sentry_plugins/client.py

@@ -1,3 +1,4 @@
+from sentry.services.hybrid_cloud.usersocialauth.service import usersocialauth_service
 from sentry.shared_integrations.client import BaseApiClient, BaseInternalApiClient
 from sentry.shared_integrations.exceptions import ApiUnauthorized
 
@@ -56,7 +57,7 @@ class AuthApiClient(ApiClient):
         self.logger.info(
             "token.refresh", extra={"auth_id": self.auth.id, "provider": self.auth.provider}
         )
-        self.auth.refresh_token()
+        usersocialauth_service.refresh_token(filter={"id": self.auth.id})
         kwargs = self.bind_auth(**kwargs)
         return ApiClient._request(self, method, path, **kwargs)
 

+ 28 - 7
tests/sentry_plugins/test_client.py

@@ -1,8 +1,9 @@
-from unittest.mock import Mock
+from unittest.mock import patch
 
 import pytest
 import responses
 
+from sentry.services.hybrid_cloud.usersocialauth.serial import serialize_usersocialauth
 from sentry.shared_integrations.exceptions import (
     ApiError,
     ApiHostError,
@@ -74,10 +75,10 @@ class AuthApiClientTest(TestCase):
     def test_with_authorization(self):
         responses.add(responses.GET, "http://example.com", json={})
 
-        auth = Mock()
-        auth.tokens = {"access_token": "access-token"}
+        auth = self.create_usersocialauth(extra_data={"access_token": "access-token"})
+        rpc_auth = serialize_usersocialauth(auth=auth)
 
-        resp = AuthApiClient(auth=auth).get("http://example.com")
+        resp = AuthApiClient(auth=rpc_auth).get("http://example.com")
         assert isinstance(resp, BaseApiResponse)
         assert resp.status_code == 200
 
@@ -88,16 +89,36 @@ class AuthApiClientTest(TestCase):
     def test_with_authorization_and_no_auth(self):
         responses.add(responses.GET, "http://example.com", json={})
 
-        auth = Mock()
-        auth.tokens = {"access_token": "access-token"}
+        auth = self.create_usersocialauth(extra_data={"access_token": "access-token"})
+        rpc_auth = serialize_usersocialauth(auth=auth)
+        resp = AuthApiClient(auth=rpc_auth).get("http://example.com", auth=None)
 
-        resp = AuthApiClient(auth=auth).get("http://example.com", auth=None)
         assert isinstance(resp, BaseApiResponse)
         assert resp.status_code == 200
 
         request = responses.calls[-1].request
         assert not request.headers.get("Authorization")
 
+    @responses.activate
+    def test_with_authorized_token_refresh(self):
+        # First attempt
+        responses.add(responses.GET, "http://example.com", json={}, status=401)
+        # After refresh
+        responses.add(responses.GET, "http://example.com", json={}, status=200)
+
+        auth = self.create_usersocialauth(extra_data={"access_token": "access-token"})
+        rpc_auth = serialize_usersocialauth(auth=auth)
+
+        with patch("social_auth.models.UserSocialAuth.refresh_token") as mock_refresh_token:
+            resp = AuthApiClient(auth=rpc_auth).get("http://example.com")
+            assert mock_refresh_token.called
+
+        assert isinstance(resp, BaseApiResponse)
+        assert resp.status_code == 200
+
+        request = responses.calls[-1].request
+        assert request.headers.get("Authorization") == "Bearer access-token"
+
     @responses.activate
     def test_invalid_host(self):
         with pytest.raises(ApiHostError):