Browse Source

ref(pipeline): Abstract Properties (#33348)

Marcos Gaeta 2 years ago
parent
commit
5af3960dcd

+ 11 - 6
src/sentry/auth/provider.py

@@ -1,5 +1,7 @@
+import abc
 import logging
 from collections import namedtuple
+from typing import Any
 
 from django.utils.encoding import force_text, python_2_unicode_compatible
 
@@ -22,21 +24,24 @@ class MigratingIdentityId(namedtuple("MigratingIdentityId", ["id", "legacy_id"])
         return force_text(self.id)
 
 
-class Provider(PipelineProvider):
+class Provider(PipelineProvider, abc.ABC):
     """
     A provider indicates how authenticate should happen for a given service,
     including its configuration and basic identity management.
     """
 
-    name = None
-
     # All auth providers by default require the sso-basic feature
     required_feature = "organizations:sso-basic"
 
-    def __init__(self, key, **config):
-        self.key = key
+    def __init__(self, key: str, **config: Any) -> None:
+        super().__init__()
+        self._key = key
         self.config = config
-        self.logger = logging.getLogger(f"sentry.auth.{key}")
+        self.logger = logging.getLogger(f"sentry.auth.{self.key}")
+
+    @property
+    def key(self) -> str:
+        return self._key
 
     def get_configure_view(self):
         """

+ 19 - 12
src/sentry/auth/providers/oauth2.py

@@ -1,5 +1,7 @@
+import abc
 import logging
 from time import time
+from typing import Any, Mapping
 from urllib.parse import parse_qsl, urlencode
 from uuid import uuid4
 
@@ -118,7 +120,7 @@ class OAuth2Callback(AuthView):
         return helper.next_step()
 
 
-class OAuth2Provider(Provider):
+class OAuth2Provider(Provider, abc.ABC):
     client_id = None
     client_secret = None
 
@@ -134,8 +136,9 @@ class OAuth2Provider(Provider):
             OAuth2Callback(client_id=self.get_client_id(), client_secret=self.get_client_secret()),
         ]
 
-    def get_refresh_token_url(self):
-        raise NotImplementedError
+    @abc.abstractmethod
+    def get_refresh_token_url(self) -> str:
+        pass
 
     def get_refresh_token_params(self, refresh_token):
         return {
@@ -153,15 +156,19 @@ class OAuth2Provider(Provider):
             data["refresh_token"] = payload["refresh_token"]
         return data
 
-    def build_identity(self, state):
-        # data = state['data']
-        # return {
-        #     'id': '',
-        #     'email': '',
-        #     'name': '',
-        #     'data': self.get_oauth_data(data),
-        # }
-        raise NotImplementedError
+    @abc.abstractmethod
+    def build_identity(self, state: Mapping[str, Any]) -> Mapping[str, Any]:
+        """
+        Example implementation:
+        data = state['data']
+        return {
+            'id': '',
+            'email': '',
+            'name': '',
+            'data': self.get_oauth_data(data),
+        }
+        """
+        pass
 
     def update_identity(self, new_data, current_data):
         # we want to maintain things like refresh_token that might not

+ 5 - 3
src/sentry/auth/providers/saml2/provider.py

@@ -1,3 +1,4 @@
+import abc
 from urllib.parse import urlparse
 
 from django.contrib import messages
@@ -87,7 +88,7 @@ class SAML2LoginView(AuthView):
 # (sentry) (the typical case) and the Identity Provider. In the second case,
 # the auth assertion is directly posted to the ACS URL. Because the user will
 # not have initiated their SSO flow we must provide a endpoint similar to
-# auth_provider_login, but with support for initing the auth flow.
+# auth_provider_login, but with support for initializing the auth flow.
 class SAML2AcceptACSView(BaseView):
     @method_decorator(csrf_exempt)
     def dispatch(self, request: Request, organization_slug):
@@ -203,7 +204,7 @@ class Attributes:
     LAST_NAME = "last_name"
 
 
-class SAML2Provider(Provider):
+class SAML2Provider(Provider, abc.ABC):
     """
     Base SAML2 Authentication provider. SAML style authentication plugins
     should implement this.
@@ -256,6 +257,7 @@ class SAML2Provider(Provider):
     def get_setup_pipeline(self):
         return self.get_saml_setup_pipeline() + self.get_auth_pipeline()
 
+    @abc.abstractmethod
     def get_saml_setup_pipeline(self):
         """
         Return a list of AuthViews to setup the SAML provider.
@@ -263,7 +265,7 @@ class SAML2Provider(Provider):
         The setup AuthView(s) must bind the `idp` parameter into the helper
         state.
         """
-        raise NotImplementedError
+        pass
 
     def attribute_mapping(self):
         """

+ 3 - 7
src/sentry/identity/base.py

@@ -1,20 +1,16 @@
+import abc
 import logging
 
 from sentry.pipeline import PipelineProvider
 
 
-class Provider(PipelineProvider):
+class Provider(PipelineProvider, abc.ABC):
     """
     A provider indicates how identity authenticate should happen for a given service.
     """
 
-    # The unique identifier of the provider
-    key = None
-
-    # A human readable name for this provider
-    name = None
-
     def __init__(self, **config):
+        super().__init__()
         self.config = config
         self.logger = logging.getLogger(f"sentry.identity.{self.key}")
 

+ 3 - 0
src/sentry/identity/providers/dummy.py

@@ -28,3 +28,6 @@ class DummyProvider(Provider):
 
     def build_identity(self, state):
         return {"id": state["email"], "email": state["email"], "name": "Dummy"}
+
+    def refresh_identity(self, auth_identity, *args, **kwargs):
+        pass

+ 2 - 9
src/sentry/integrations/base.py

@@ -1,3 +1,4 @@
+import abc
 import logging
 import sys
 from collections import namedtuple
@@ -109,7 +110,7 @@ class IntegrationFeatures(Enum):
     DEPLOYMENT = "deployment"
 
 
-class IntegrationProvider(PipelineProvider):  # type: ignore
+class IntegrationProvider(PipelineProvider, abc.ABC):  # type: ignore
     """
     An integration provider describes a third party that can be registered within Sentry.
 
@@ -123,11 +124,6 @@ class IntegrationProvider(PipelineProvider):  # type: ignore
     it provides (such as extensions provided).
     """
 
-    # a unique identifier (e.g. 'slack').
-    # Used to lookup sibling classes and the ``key`` used when creating
-    # Integration objects.
-    key: Optional[str] = None
-
     # a unique identifier to use when creating the ``Integration`` object.
     # Only needed when you want to create the above object with something other
     # than ``key``. See: VstsExtensionIntegrationProvider.
@@ -137,9 +133,6 @@ class IntegrationProvider(PipelineProvider):  # type: ignore
     # Integrations page.
     visible = True
 
-    # a human readable name (e.g. 'Slack')
-    name: Optional[str] = None
-
     # an IntegrationMetadata object, used to provide extra details in the
     # configuration interface of the integration.
     metadata: Optional[IntegrationMetadata] = None

+ 28 - 3
src/sentry/pipeline/provider.py

@@ -1,4 +1,13 @@
-class PipelineProvider:
+from __future__ import annotations
+
+import abc
+from typing import TYPE_CHECKING, Sequence
+
+if TYPE_CHECKING:
+    from sentry.pipeline.views.base import PipelineView
+
+
+class PipelineProvider(abc.ABC):
     """
     A class implementing the PipelineProvider interface provides the pipeline
     views that the Pipeline will traverse through.
@@ -7,13 +16,29 @@ class PipelineProvider:
     def __init__(self):
         self.config = {}
 
-    def get_pipeline_views(self):
+    @property
+    @abc.abstractmethod
+    def key(self) -> str:
+        """
+        A unique identifier (e.g. 'slack'). Used to lookup sibling classes and
+        the `key` used when creating Integration objects.
+        """
+        pass
+
+    @property
+    @abc.abstractmethod
+    def name(self) -> str:
+        """A human readable name (e.g. 'Slack')."""
+        pass
+
+    @abc.abstractmethod
+    def get_pipeline_views(self) -> Sequence[PipelineView]:
         """
         Returns a list of instantiated views which implement the PipelineView
         interface. Each view will be dispatched in order.
         >>> return [OAuthInitView(), OAuthCallbackView()]
         """
-        raise NotImplementedError
+        pass
 
     def update_config(self, config):
         """

+ 3 - 3
tests/sentry/auth/providers/google/test_provider.py

@@ -8,9 +8,9 @@ from sentry.testutils import TestCase
 
 class GoogleOAuth2ProviderTest(TestCase):
     def setUp(self):
-        self.org = self.create_organization(owner=self.user)
-        self.user = self.create_user("foo@example.com")
-        self.auth_provider = AuthProvider.objects.create(provider="google", organization=self.org)
+        self.auth_provider = AuthProvider.objects.create(
+            provider="google", organization=self.organization
+        )
         super().setUp()
 
     def test_refresh_identity_without_refresh_token(self):

+ 20 - 12
tests/sentry/auth/providers/test_oauth2.py

@@ -1,3 +1,5 @@
+from typing import Any, Mapping
+
 import pytest
 from exam import fixture
 
@@ -7,25 +9,31 @@ from sentry.models import AuthIdentity, AuthProvider
 from sentry.testutils import TestCase
 
 
-class OAuth2ProviderTest(TestCase):
-    def setUp(self):
-        self.org = self.create_organization(owner=self.user)
-        self.user = self.create_user("foo@example.com")
-        super().setUp()
+class DummyOAuth2Provider(OAuth2Provider):
+    name = "dummy"
 
-    @fixture
-    def auth_provider(self):
-        return AuthProvider.objects.create(provider="oauth2", organization=self.org)
+    def get_refresh_token_url(self) -> str:
+        pass
+
+    def build_identity(self, state: Mapping[str, Any]) -> Mapping[str, Any]:
+        pass
 
+    def build_config(self, state):
+        pass
+
+
+class OAuth2ProviderTest(TestCase):
     @fixture
-    def provider(self):
-        return OAuth2Provider(key=self.auth_provider.provider)
+    def auth_provider(self):
+        return AuthProvider.objects.create(provider="oauth2", organization=self.organization)
 
     def test_refresh_identity_without_refresh_token(self):
         auth_identity = AuthIdentity.objects.create(
-            auth_provider=self.auth_provider, user=self.user, data={"access_token": "access_token"}
+            auth_provider=self.auth_provider,
+            user=self.user,
+            data={"access_token": "access_token"},
         )
 
-        provider = OAuth2Provider(key=self.auth_provider.provider)
+        provider = DummyOAuth2Provider(key=self.auth_provider.provider)
         with pytest.raises(IdentityNotValid):
             provider.refresh_identity(auth_identity)

+ 12 - 4
tests/sentry/auth/providers/test_saml2.py

@@ -17,11 +17,19 @@ dummy_provider_config = {
 }
 
 
+class DummySAML2Provider(SAML2Provider):
+    name = "dummy"
+
+    def get_saml_setup_pipeline(self):
+        pass
+
+
 class SAML2ProviderTest(TestCase):
     def setUp(self):
-        self.org = self.create_organization()
-        self.auth_provider = AuthProvider.objects.create(provider="saml2", organization=self.org)
-        self.provider = SAML2Provider(key=self.auth_provider.provider)
+        self.auth_provider = AuthProvider.objects.create(
+            provider="saml2", organization=self.organization
+        )
+        self.provider = DummySAML2Provider(key=self.auth_provider.provider)
         super().setUp()
 
     def test_build_config_adds_attributes(self):
@@ -29,7 +37,7 @@ class SAML2ProviderTest(TestCase):
 
         assert "attribute_mapping" in config
 
-    def test_buld_config_with_provider_attributes(self):
+    def test_build_config_with_provider_attributes(self):
         with mock.patch.object(self.provider, "attribute_mapping") as attribute_mapping:
             config = self.provider.build_config({})
 

Some files were not shown because too many files changed in this diff