Browse Source

ref(types): Add types to the pipeline module (#32863)

Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
Marcos Gaeta 2 years ago
parent
commit
4f4c335b2c

+ 1 - 0
.github/CODEOWNERS

@@ -218,6 +218,7 @@ build-utils/        @getsentry/owners-js-build
 /src/sentry/mail/                             @getsentry/ecosystem
 /src/sentry/mediators/                        @getsentry/ecosystem
 /src/sentry/notifications/                    @getsentry/ecosystem @getsentry/growth
+/src/sentry/pipeline/                         @getsentry/ecosystem
 /src/sentry/plugins/                          @getsentry/ecosystem
 /src/sentry/ratelimits/                       @getsentry/ecosystem
 /src/sentry/shared_integrations/              @getsentry/ecosystem

+ 1 - 0
mypy.ini

@@ -55,6 +55,7 @@ files = src/sentry/analytics/,
         src/sentry/models/options/,
         src/sentry/models/rulefirehistory.py,
         src/sentry/notifications/,
+        src/sentry/pipeline/,
         src/sentry/processing/realtime_metrics/,
         src/sentry/profiles/,
         src/sentry/ratelimits/,

+ 5 - 3
src/sentry/api/serializers/models/user_identity_config.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from dataclasses import dataclass
 from datetime import datetime
 from enum import Enum
@@ -41,7 +43,7 @@ class UserIdentityProvider:
     name: str
 
     @classmethod
-    def adapt(cls, provider: Provider) -> "UserIdentityProvider":
+    def adapt(cls, provider: Provider) -> UserIdentityProvider:
         return cls(provider.key, provider.name)
 
 
@@ -59,7 +61,7 @@ class UserIdentityConfig:
     date_synced: Optional[datetime] = None
 
     @classmethod
-    def wrap(cls, identity: IdentityType, status: Status) -> "UserIdentityConfig":
+    def wrap(cls, identity: IdentityType, status: Status) -> UserIdentityConfig:
         def base(**kwargs):
             return cls(
                 category=_IDENTITY_CATEGORY_KEYS[type(identity)],
@@ -108,7 +110,7 @@ class UserIdentityConfig:
 
 @register(UserIdentityConfig)
 class UserIdentityConfigSerializer(Serializer):
-    def serialize(self, obj: UserIdentityConfig, attrs, user):
+    def serialize(self, obj: UserIdentityConfig, attrs, user, **kwargs):
         return {
             "category": obj.category,
             "id": str(obj.id),

+ 1 - 1
src/sentry/integrations/base.py

@@ -110,7 +110,7 @@ class IntegrationFeatures(Enum):
     DEPLOYMENT = "deployment"
 
 
-class IntegrationProvider(PipelineProvider, abc.ABC):  # type: ignore
+class IntegrationProvider(PipelineProvider, abc.ABC):
     """
     An integration provider describes a third party that can be registered within Sentry.
 

+ 3 - 2
src/sentry/integrations/github/integration.py

@@ -249,7 +249,7 @@ class GitHubIntegrationProvider(IntegrationProvider):  # type: ignore
         )
 
 
-class GitHubInstallationRedirect(PipelineView):  # type: ignore
+class GitHubInstallationRedirect(PipelineView):
     def get_app_url(self) -> str:
         name = options.get("github-app.name")
         return f"https://github.com/apps/{slugify(name)}"
@@ -261,7 +261,8 @@ class GitHubInstallationRedirect(PipelineView):  # type: ignore
         if "installation_id" in request.GET:
             organization = self.get_active_organization(request)
 
-            # We want to wait until the scheduled deletions finish or else the post install to migrate repos do not work.
+            # We want to wait until the scheduled deletions finish or else the
+            # post install to migrate repos do not work.
             integration_pending_deletion_exists = OrganizationIntegration.objects.filter(
                 integration__provider=GitHubIntegrationProvider.key,
                 organization=organization,

+ 11 - 11
src/sentry/integrations/vsts/integration.py

@@ -15,7 +15,7 @@ from sentry import features, http
 from sentry.auth.exceptions import IdentityNotValid
 from sentry.constants import ObjectStatus
 from sentry.identity.pipeline import IdentityProviderPipeline
-from sentry.identity.vsts import get_user_info
+from sentry.identity.vsts.provider import get_user_info
 from sentry.integrations import (
     FeatureDescription,
     IntegrationFeatures,
@@ -484,27 +484,27 @@ class VstsIntegrationProvider(IntegrationProvider):  # type: ignore
         )
 
 
-class AccountConfigView(PipelineView):  # type: ignore
+class AccountConfigView(PipelineView):
     def dispatch(self, request: Request, pipeline: Pipeline) -> Response:
-        if "account" in request.POST:
-            account_id = request.POST.get("account")
-            accounts = pipeline.fetch_state(key="accounts")
-            account = self.get_account_from_id(account_id, accounts)
+        account_id = request.POST.get("account")
+        if account_id is not None:
+            state_accounts: Sequence[Mapping[str, Any]] | None = pipeline.fetch_state(
+                key="accounts"
+            )
+            account = self.get_account_from_id(account_id, state_accounts or [])
             if account is not None:
-                state = pipeline.fetch_state(key="identity")
-                access_token = state["data"]["access_token"]
                 pipeline.bind_state("account", account)
                 return pipeline.next_step()
 
-        state = pipeline.fetch_state(key="identity")
-        access_token = state["data"]["access_token"]
+        state: Mapping[str, Any] | None = pipeline.fetch_state(key="identity")
+        access_token = (state or {}).get("data", {}).get("access_token")
         user = get_user_info(access_token)
 
         accounts = self.get_accounts(access_token, user["uuid"])
         logger.info(
             "vsts.get_accounts",
             extra={
-                "organization_id": pipeline.organization.id,
+                "organization_id": pipeline.organization.id if pipeline.organization else None,
                 "user_id": request.user.id,
                 "accounts": accounts,
             },

+ 2 - 2
src/sentry/pipeline/__init__.py

@@ -1,5 +1,5 @@
-from .base import Pipeline
-from .provider import PipelineProvider
+from .provider import PipelineProvider  # isort:skip
+from .base import Pipeline  # isort:skip
 from .store import PipelineSessionStore
 from .types import PipelineAnalyticsEntry
 from .views.base import PipelineView

+ 48 - 35
src/sentry/pipeline/base.py

@@ -1,22 +1,27 @@
+from __future__ import annotations
+
+import abc
 import logging
 from types import LambdaType
-from typing import Any, Dict, Optional, Sequence
+from typing import Any, Mapping, Sequence, Type
 
 from django.http.response import HttpResponseBase
 from django.views import View
 from rest_framework.request import Request
 
 from sentry import analytics
+from sentry.db.models import Model
 from sentry.models import Organization
 from sentry.utils.hashlib import md5_text
 from sentry.web.helpers import render_to_response
 
+from . import PipelineProvider
 from .constants import INTEGRATION_EXPIRATION_TTL
 from .store import PipelineSessionStore
 from .types import PipelineAnalyticsEntry, PipelineRequestState
 
 
-class Pipeline:
+class Pipeline(abc.ABC):
     """
     Pipeline provides a mechanism to guide the user through a request
     'pipeline', where each view may be completed by calling the ``next_step``
@@ -42,13 +47,13 @@ class Pipeline:
     using the ``update_config`` method.
     """
 
-    pipeline_name = None
-    provider_manager = None
-    provider_model_cls = None
+    pipeline_name: str
+    provider_manager: Any
+    provider_model_cls: Type[Model]
     session_store_cls = PipelineSessionStore
 
     @classmethod
-    def get_for_request(cls, request):
+    def get_for_request(cls, request: Request) -> Pipeline | None:
         req_state = cls.unpack_state(request)
         if not req_state:
             return None
@@ -63,7 +68,7 @@ class Pipeline:
         )
 
     @classmethod
-    def unpack_state(cls, request) -> Optional[PipelineRequestState]:
+    def unpack_state(cls, request: Request) -> PipelineRequestState | None:
         state = cls.session_store_cls(request, cls.pipeline_name, ttl=INTEGRATION_EXPIRATION_TTL)
         if not state.is_valid():
             return None
@@ -80,12 +85,18 @@ class Pipeline:
 
         return PipelineRequestState(state, provider_model, organization, provider_key)
 
-    def get_provider(self, provider_key: str):
-        return self.provider_manager.get(provider_key)
+    def get_provider(self, provider_key: str) -> PipelineProvider:
+        provider: PipelineProvider = self.provider_manager.get(provider_key)
+        return provider
 
     def __init__(
-        self, request: Request, provider_key, organization=None, provider_model=None, config=None
-    ):
+        self,
+        request: Request,
+        provider_key: str,
+        organization: Organization | None = None,
+        provider_model: Model | None = None,
+        config: Mapping[str, Any] | None = None,
+    ) -> None:
         self.request = request
         self.organization = organization
         self.state = self.session_store_cls(
@@ -114,17 +125,20 @@ class Pipeline:
         providers should inherit, or customize the provider method called to
         retrieve the views.
         """
-        return self.provider.get_pipeline_views()
+        views: Sequence[View] = self.provider.get_pipeline_views()
+        return views
 
     def is_valid(self) -> bool:
-        return self.state.is_valid() and self.state.signature == self.signature
+        _is_valid: bool = self.state.is_valid() and self.state.signature == self.signature
+        return _is_valid
 
     def initialize(self) -> None:
         self.state.regenerate(self.get_initial_state())
 
-    def get_initial_state(self) -> Dict[str, Any]:
+    def get_initial_state(self) -> Mapping[str, Any]:
+        user: Any = self.request.user
         return {
-            "uid": self.request.user.id if self.request.user.is_authenticated else None,
+            "uid": user.id if user.is_authenticated else None,
             "provider_model_id": self.provider_model.id if self.provider_model else None,
             "provider_key": self.provider.key,
             "org_id": self.organization.id if self.organization else None,
@@ -134,10 +148,10 @@ class Pipeline:
             "data": {},
         }
 
-    def clear_session(self):
+    def clear_session(self) -> None:
         self.state.clear()
 
-    def current_step(self):
+    def current_step(self) -> HttpResponseBase:
         """
         Render the current step.
         """
@@ -154,8 +168,9 @@ class Pipeline:
 
         return self.dispatch_to(step)
 
-    def dispatch_to(self, step: View):
-        """Dispatch to a view expected by this pipeline.
+    def dispatch_to(self, step: View) -> HttpResponseBase:
+        """
+        Dispatch to a view expected by this pipeline.
 
         A subclass may override this if its views take other parameters.
         """
@@ -177,22 +192,21 @@ class Pipeline:
             request=self.request,
         )
 
-    def render_warning(self, message):
-        """For situations when we want to display an error without triggering an issue"""
+    def render_warning(self, message: str) -> HttpResponseBase:
+        """For situations when we want to display an error without triggering an issue."""
         context = {"error": message}
         return render_to_response("sentry/pipeline-provider-error.html", context, self.request)
 
-    def next_step(self, step_size=1):
-        """
-        Render the next step.
-        """
+    def next_step(self, step_size: int = 1) -> HttpResponseBase:
+        """Render the next step."""
         self.state.step_index += step_size
 
         analytics_entry = self.get_analytics_entry()
         if analytics_entry and self.organization:
+            user: Any = self.request.user
             analytics.record(
                 analytics_entry.event_type,
-                user_id=self.request.user.id,
+                user_id=user.id,
                 organization_id=self.organization.id,
                 integration=self.provider.key,
                 step_index=self.state.step_index,
@@ -201,27 +215,26 @@ class Pipeline:
 
         return self.current_step()
 
-    def get_analytics_entry(self) -> Optional[PipelineAnalyticsEntry]:
+    def get_analytics_entry(self) -> PipelineAnalyticsEntry | None:
         """Return analytics attributes for this pipeline."""
         return None
 
-    def finish_pipeline(self):
-        """
-        Called when the pipeline completes the final step.
-        """
-        raise NotImplementedError
+    @abc.abstractmethod
+    def finish_pipeline(self) -> HttpResponseBase:
+        """Called when the pipeline completes the final step."""
+        pass
 
-    def bind_state(self, key, value):
+    def bind_state(self, key: str, value: Any) -> None:
         data = self.state.data or {}
         data[key] = value
 
         self.state.data = data
 
-    def fetch_state(self, key=None):
+    def fetch_state(self, key: str | None = None) -> Any | None:
         data = self.state.data
         if not data:
             return None
         return data if key is None else data.get(key)
 
-    def get_logger(self):
+    def get_logger(self) -> logging.Logger:
         return logging.getLogger(f"sentry.integration.{self.provider.key}")

+ 6 - 5
src/sentry/pipeline/provider.py

@@ -1,9 +1,10 @@
 from __future__ import annotations
 
 import abc
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Mapping, Sequence
 
 if TYPE_CHECKING:
+    from sentry.pipeline import Pipeline
     from sentry.pipeline.views.base import PipelineView
 
 
@@ -13,8 +14,8 @@ class PipelineProvider(abc.ABC):
     views that the Pipeline will traverse through.
     """
 
-    def __init__(self):
-        self.config = {}
+    def __init__(self) -> None:
+        self.config: dict[str, Any] = {}
 
     @property
     @abc.abstractmethod
@@ -40,7 +41,7 @@ class PipelineProvider(abc.ABC):
         """
         pass
 
-    def update_config(self, config):
+    def update_config(self, config: Mapping[str, Any]) -> None:
         """
         Use update_config to allow additional provider configuration be assigned to
         the provider instance. This is useful for example when nesting
@@ -48,7 +49,7 @@ class PipelineProvider(abc.ABC):
         """
         self.config.update(config)
 
-    def set_pipeline(self, pipeline):
+    def set_pipeline(self, pipeline: Pipeline) -> None:
         """
         Used by the pipeline to give the provider access to the executing pipeline.
         """

+ 1 - 1
src/sentry/pipeline/store.py

@@ -1,7 +1,7 @@
 from sentry.utils.session_store import RedisSessionStore, redis_property
 
 
-class PipelineSessionStore(RedisSessionStore):
+class PipelineSessionStore(RedisSessionStore):  # type: ignore
     uid = redis_property("uid")
     provider_model_id = redis_property("provider_model_id")
     provider_key = redis_property("provider_key")

Some files were not shown because too many files changed in this diff