Browse Source

ref(hc): Type-checking on sentry.services.hybrid_cloud (#53822)

Ryan Skonnord 1 year ago
parent
commit
39af05a1c3

+ 0 - 10
pyproject.toml

@@ -732,13 +732,6 @@ module = [
     "sentry.sentry_metrics.configuration",
     "sentry.sentry_metrics.configuration",
     "sentry.sentry_metrics.consumers.indexer.slicing_router",
     "sentry.sentry_metrics.consumers.indexer.slicing_router",
     "sentry.sentry_metrics.indexer.postgres.postgres_v2",
     "sentry.sentry_metrics.indexer.postgres.postgres_v2",
-    "sentry.services.hybrid_cloud.actor",
-    "sentry.services.hybrid_cloud.auth.impl",
-    "sentry.services.hybrid_cloud.integration.impl",
-    "sentry.services.hybrid_cloud.log.impl",
-    "sentry.services.hybrid_cloud.notifications.impl",
-    "sentry.services.hybrid_cloud.organizationmember_mapping.impl",
-    "sentry.services.hybrid_cloud.user.impl",
     "sentry.services.smtp",
     "sentry.services.smtp",
     "sentry.shared_integrations.client.base",
     "sentry.shared_integrations.client.base",
     "sentry.shared_integrations.client.internal",
     "sentry.shared_integrations.client.internal",
@@ -778,7 +771,6 @@ module = [
     "sentry.tasks.codeowners.update_code_owners_schema",
     "sentry.tasks.codeowners.update_code_owners_schema",
     "sentry.tasks.commit_context",
     "sentry.tasks.commit_context",
     "sentry.tasks.commits",
     "sentry.tasks.commits",
-    "sentry.tasks.deletion.hybrid_cloud",
     "sentry.tasks.deliver_from_outbox",
     "sentry.tasks.deliver_from_outbox",
     "sentry.tasks.derive_code_mappings",
     "sentry.tasks.derive_code_mappings",
     "sentry.tasks.digests",
     "sentry.tasks.digests",
@@ -947,8 +939,6 @@ module = [
     "tests.sentry.digests.test_notifications",
     "tests.sentry.digests.test_notifications",
     "tests.sentry.eventstore.test_base",
     "tests.sentry.eventstore.test_base",
     "tests.sentry.grouping.test_result",
     "tests.sentry.grouping.test_result",
-    "tests.sentry.hybrid_cloud.test_integration",
-    "tests.sentry.hybrid_cloud.test_organization",
     "tests.sentry.identity.test_oauth2",
     "tests.sentry.identity.test_oauth2",
     "tests.sentry.incidents.action_handlers",
     "tests.sentry.incidents.action_handlers",
     "tests.sentry.incidents.action_handlers.test_sentry_app",
     "tests.sentry.incidents.action_handlers.test_sentry_app",

+ 11 - 2
src/sentry/api/paginator.py

@@ -2,6 +2,7 @@ import bisect
 import functools
 import functools
 import math
 import math
 from datetime import datetime, timezone
 from datetime import datetime, timezone
+from typing import Any
 from urllib.parse import quote
 from urllib.parse import quote
 
 
 from django.core.exceptions import EmptyResultSet, ObjectDoesNotExist
 from django.core.exceptions import EmptyResultSet, ObjectDoesNotExist
@@ -9,6 +10,7 @@ from django.db import connections
 from django.db.models.functions import Lower
 from django.db.models.functions import Lower
 
 
 from sentry.utils.cursors import Cursor, CursorResult, build_cursor
 from sentry.utils.cursors import Cursor, CursorResult, build_cursor
+from sentry.utils.pagination_factory import PaginatorLike
 
 
 quote_name = connections["default"].ops.quote_name
 quote_name = connections["default"].ops.quote_name
 
 
@@ -219,7 +221,7 @@ class DateTimePaginator(BasePaginator):
 # TODO(dcramer): previous cursors are too complex at the moment for many things
 # TODO(dcramer): previous cursors are too complex at the moment for many things
 # and are only useful for polling situations. The OffsetPaginator ignores them
 # and are only useful for polling situations. The OffsetPaginator ignores them
 # entirely and uses standard paging
 # entirely and uses standard paging
-class OffsetPaginator:
+class OffsetPaginator(PaginatorLike):
     def __init__(
     def __init__(
         self, queryset, order_by=None, max_limit=MAX_LIMIT, max_offset=None, on_results=None
         self, queryset, order_by=None, max_limit=MAX_LIMIT, max_offset=None, on_results=None
     ):
     ):
@@ -233,7 +235,14 @@ class OffsetPaginator:
         self.max_offset = max_offset
         self.max_offset = max_offset
         self.on_results = on_results
         self.on_results = on_results
 
 
-    def get_result(self, limit=100, cursor=None):
+    def get_result(
+        self,
+        limit: int = 100,
+        cursor: Any = None,
+        count_hits: Any = False,
+        known_hits: Any = None,
+        max_hits: Any = None,
+    ):
         # offset is page #
         # offset is page #
         # value is page limit
         # value is page limit
         if cursor is None:
         if cursor is None:

+ 3 - 0
src/sentry/integrations/msteams/utils.py

@@ -104,6 +104,9 @@ def send_incident_alert_notification(
 ) -> None:
 ) -> None:
     from .card_builder import build_incident_attachment
     from .card_builder import build_incident_attachment
 
 
+    if action.target_identifier is None:
+        raise ValueError("Can't send without `target_identifier`")
+
     attachment = build_incident_attachment(incident, new_status, metric_value)
     attachment = build_incident_attachment(incident, new_status, metric_value)
     integration_service.send_msteams_incident_alert_notification(
     integration_service.send_msteams_incident_alert_notification(
         integration_id=action.integration_id,
         integration_id=action.integration_id,

+ 2 - 2
src/sentry/models/project.py

@@ -4,7 +4,7 @@ import logging
 import warnings
 import warnings
 from collections import defaultdict
 from collections import defaultdict
 from itertools import chain
 from itertools import chain
-from typing import TYPE_CHECKING, Collection, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Collection, Iterable, Mapping
 from uuid import uuid1
 from uuid import uuid1
 
 
 import sentry_sdk
 import sentry_sdk
@@ -78,7 +78,7 @@ class ProjectManager(BaseManager):
             teams__organizationmember__user_id__in=user_ids,
             teams__organizationmember__user_id__in=user_ids,
         )
         )
 
 
-    def get_for_team_ids(self, team_ids: Sequence[int]) -> QuerySet:
+    def get_for_team_ids(self, team_ids: Collection[int]) -> QuerySet:
         """Returns the QuerySet of all projects that a set of Teams have access to."""
         """Returns the QuerySet of all projects that a set of Teams have access to."""
         return self.filter(status=ObjectStatus.ACTIVE, teams__in=team_ids)
         return self.filter(status=ObjectStatus.ACTIVE, teams__in=team_ids)
 
 

+ 1 - 1
src/sentry/notifications/manager.py

@@ -476,7 +476,7 @@ class NotificationsManager(BaseManager["NotificationSetting"]):  # noqa: F821
         ).delete()
         ).delete()
 
 
     def disable_settings_for_users(
     def disable_settings_for_users(
-        self, provider: ExternalProviders, users: Sequence[User]
+        self, provider: ExternalProviders, users: Iterable[User]
     ) -> None:
     ) -> None:
         """
         """
         Given a list of users, overwrite all of their parent-independent
         Given a list of users, overwrite all of their parent-independent

+ 1 - 1
src/sentry/services/hybrid_cloud/actor.py

@@ -21,7 +21,7 @@ class ActorType(str, Enum):
     TEAM = "Team"
     TEAM = "Team"
 
 
 
 
-ActorTarget = Union["RpcActor", "User", "RpcUser", "Team", "RpcTeam"]
+ActorTarget = Union["Actor", "RpcActor", "User", "RpcUser", "Team", "RpcTeam"]
 
 
 
 
 class RpcActor(RpcModel):
 class RpcActor(RpcModel):

+ 3 - 3
src/sentry/services/hybrid_cloud/app/impl.py

@@ -138,7 +138,7 @@ class DatabaseBackedAppService(AppService):
             SentryAppInstallation, SentryAppInstallationFilterArgs, RpcSentryAppInstallation, None
             SentryAppInstallation, SentryAppInstallationFilterArgs, RpcSentryAppInstallation, None
         ]
         ]
     ):
     ):
-        def base_query(self, ids_only: bool = False) -> QuerySet:
+        def base_query(self, ids_only: bool = False) -> QuerySet[SentryAppInstallation]:
             if ids_only:
             if ids_only:
                 return SentryAppInstallation.objects
                 return SentryAppInstallation.objects
             return SentryAppInstallation.objects.select_related("sentry_app")
             return SentryAppInstallation.objects.select_related("sentry_app")
@@ -154,8 +154,8 @@ class DatabaseBackedAppService(AppService):
             raise NotImplementedError("Serialization not supported for AppService")
             raise NotImplementedError("Serialization not supported for AppService")
 
 
         def apply_filters(
         def apply_filters(
-            self, query: QuerySet, filters: SentryAppInstallationFilterArgs
-        ) -> QuerySet:
+            self, query: QuerySet[SentryAppInstallation], filters: SentryAppInstallationFilterArgs
+        ) -> QuerySet[SentryAppInstallation]:
             # filters["status"] = SentryAppInstallationStatus.INSTALLED
             # filters["status"] = SentryAppInstallationStatus.INSTALLED
             if "installation_ids" in filters:
             if "installation_ids" in filters:
                 query = query.filter(id__in=filters["installation_ids"])
                 query = query.filter(id__in=filters["installation_ids"])

+ 26 - 34
src/sentry/services/hybrid_cloud/auth/impl.py

@@ -40,7 +40,6 @@ from sentry.services.hybrid_cloud.organization import (
     RpcOrganizationMemberSummary,
     RpcOrganizationMemberSummary,
     organization_service,
     organization_service,
 )
 )
-from sentry.services.hybrid_cloud.user import RpcUser
 from sentry.services.hybrid_cloud.user.service import user_service
 from sentry.services.hybrid_cloud.user.service import user_service
 from sentry.silo import unguarded_write
 from sentry.silo import unguarded_write
 from sentry.silo.base import SiloMode
 from sentry.silo.base import SiloMode
@@ -122,15 +121,17 @@ class DatabaseBackedAuthService(AuthService):
     def get_org_auth_config(
     def get_org_auth_config(
         self, *, organization_ids: List[int]
         self, *, organization_ids: List[int]
     ) -> List[RpcOrganizationAuthConfig]:
     ) -> List[RpcOrganizationAuthConfig]:
-        aps: Mapping[int, AuthProvider] = {
+        aps = {
             ap.organization_id: ap
             ap.organization_id: ap
             for ap in AuthProvider.objects.filter(organization_id__in=organization_ids)
             for ap in AuthProvider.objects.filter(organization_id__in=organization_ids)
         }
         }
-        qs: Mapping[int, int] = {
+        qs = {
             row["organization_id"]: row["id__count"]
             row["organization_id"]: row["id__count"]
-            for row in ApiKey.objects.filter(organization_id__in=organization_ids)
-            .values("organization_id")
-            .annotate(Count("id"))
+            for row in (
+                ApiKey.objects.filter(organization_id__in=organization_ids)
+                .values("organization_id")
+                .annotate(Count("id"))
+            )
         }
         }
         return [
         return [
             RpcOrganizationAuthConfig(
             RpcOrganizationAuthConfig(
@@ -141,29 +142,21 @@ class DatabaseBackedAuthService(AuthService):
             for oid in organization_ids
             for oid in organization_ids
         ]
         ]
 
 
-    def _load_auth_user(self, user: User) -> RpcUser | None:
-        rpc_user: RpcUser | None = None
-        if user is not None:
-            return user_service.get_user(user_id=user.id)
-        return rpc_user
-
     def authenticate_with(
     def authenticate_with(
         self, *, request: AuthenticationRequest, authenticator_types: List[RpcAuthenticatorType]
         self, *, request: AuthenticationRequest, authenticator_types: List[RpcAuthenticatorType]
     ) -> AuthenticationContext:
     ) -> AuthenticationContext:
         fake_request = FakeAuthenticationRequest(request)
         fake_request = FakeAuthenticationRequest(request)
-        user: User | None = None
-        token: Any = None
 
 
         for authenticator_type in authenticator_types:
         for authenticator_type in authenticator_types:
-            t = authenticator_type.as_authenticator().authenticate(fake_request)
+            t = authenticator_type.as_authenticator().authenticate(fake_request)  # type: ignore[arg-type]
             if t is not None:
             if t is not None:
                 user, token = t
                 user, token = t
-                break
+                return AuthenticationContext(
+                    auth=AuthenticatedToken.from_token(token),
+                    user=user_service.get_user(user_id=user.id),
+                )
 
 
-        return AuthenticationContext(
-            auth=AuthenticatedToken.from_token(token) if token else None,
-            user=self._load_auth_user(user),
-        )
+        return AuthenticationContext(auth=None, user=None)
 
 
     def token_has_org_access(self, *, token: AuthenticatedToken, organization_id: int) -> bool:
     def token_has_org_access(self, *, token: AuthenticatedToken, organization_id: int) -> bool:
         return SentryAppInstallationToken.objects.has_organization_access(token, organization_id)
         return SentryAppInstallationToken.objects.has_organization_access(token, organization_id)
@@ -171,17 +164,17 @@ class DatabaseBackedAuthService(AuthService):
     def authenticate(self, *, request: AuthenticationRequest) -> MiddlewareAuthenticationResponse:
     def authenticate(self, *, request: AuthenticationRequest) -> MiddlewareAuthenticationResponse:
         fake_request = FakeAuthenticationRequest(request)
         fake_request = FakeAuthenticationRequest(request)
         handler = RequestAuthenticationMiddleware(placeholder_get_response)
         handler = RequestAuthenticationMiddleware(placeholder_get_response)
-        expired_user: User | None = None
+        expired_user = None
         try:
         try:
             # Hahaha.  Yes.  You're reading this right.  I'm calling, the middleware, from the service method, that is
             # Hahaha.  Yes.  You're reading this right.  I'm calling, the middleware, from the service method, that is
             # called, from slightly different, middleware.
             # called, from slightly different, middleware.
-            handler.process_request(fake_request)
+            handler.process_request(fake_request)  # type: ignore[arg-type]
         except AuthUserPasswordExpired as e:
         except AuthUserPasswordExpired as e:
             expired_user = e.user
             expired_user = e.user
         except Exception as e:
         except Exception as e:
             raise Exception("Unexpected error processing handler") from e
             raise Exception("Unexpected error processing handler") from e
 
 
-        auth: AuthenticatedToken | None = None
+        auth = None
         if fake_request.auth is not None:
         if fake_request.auth is not None:
             auth = AuthenticatedToken.from_token(fake_request.auth)
             auth = AuthenticatedToken.from_token(fake_request.auth)
 
 
@@ -192,11 +185,11 @@ class DatabaseBackedAuthService(AuthService):
         )
         )
 
 
         if expired_user is not None:
         if expired_user is not None:
-            result.user = self._load_auth_user(expired_user)
+            result.user = user_service.get_user(user_id=expired_user.id)
             result.expired = True
             result.expired = True
         elif fake_request.user is not None and not fake_request.user.is_anonymous:
         elif fake_request.user is not None and not fake_request.user.is_anonymous:
             with transaction.atomic(using=router.db_for_read(User)):
             with transaction.atomic(using=router.db_for_read(User)):
-                result.user = self._load_auth_user(fake_request.user)
+                result.user = user_service.get_user(user_id=fake_request.user.id)
                 transaction.set_rollback(True, using=router.db_for_read(User))
                 transaction.set_rollback(True, using=router.db_for_read(User))
             if SiloMode.single_process_silo_mode():
             if SiloMode.single_process_silo_mode():
                 connections.close_all()
                 connections.close_all()
@@ -214,16 +207,15 @@ class DatabaseBackedAuthService(AuthService):
         sso_state = _query_sso_state(
         sso_state = _query_sso_state(
             organization_id=organization_id, is_super_user=is_superuser, member=org_member
             organization_id=organization_id, is_super_user=is_superuser, member=org_member
         )
         )
-        permissions: List[str] = list()
-        # "permissions" is a bit of a misnomer -- these are all admin level permissions, and the intent is that if you
-        # have them, you can only use them when you are acting, as a superuser.  This is intentional.
+
         if is_superuser:
         if is_superuser:
-            permissions.extend(get_permissions_for_user(user_id))
+            # "permissions" is a bit of a misnomer -- these are all admin level permissions, and the intent is that if you
+            # have them, you can only use them when you are acting, as a superuser.  This is intentional.
+            permissions = list(get_permissions_for_user(user_id))
+        else:
+            permissions = []
 
 
-        return RpcAuthState(
-            sso_state=sso_state,
-            permissions=permissions,
-        )
+        return RpcAuthState(sso_state=sso_state, permissions=permissions)
 
 
     def get_org_ids_with_scim(
     def get_org_ids_with_scim(
         self,
         self,
@@ -262,7 +254,7 @@ class DatabaseBackedAuthService(AuthService):
     def disable_provider(self, *, provider_id: int) -> None:
     def disable_provider(self, *, provider_id: int) -> None:
         with outbox_context(transaction.atomic(router.db_for_write(AuthProvider))):
         with outbox_context(transaction.atomic(router.db_for_write(AuthProvider))):
             try:
             try:
-                auth_provider: AuthProvider = AuthProvider.objects.get(id=provider_id)
+                auth_provider = AuthProvider.objects.get(id=provider_id)
             except AuthProvider.DoesNotExist:
             except AuthProvider.DoesNotExist:
                 return
                 return
 
 

+ 15 - 9
src/sentry/services/hybrid_cloud/filter_query.py

@@ -1,9 +1,12 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import abc
 import abc
-from typing import TYPE_CHECKING, Any, Callable, Generic, List, Optional, TypeVar
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Callable, Generic, List, Optional, TypeVar, Union
 
 
-from django.db.models import QuerySet
+from django.db.models import Model, QuerySet
+
+from sentry.services.hybrid_cloud import RpcModel
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from sentry.api.serializers import Serializer
     from sentry.api.serializers import Serializer
@@ -12,9 +15,9 @@ if TYPE_CHECKING:
 
 
 
 
 FILTER_ARGS = TypeVar("FILTER_ARGS")  # A typedict
 FILTER_ARGS = TypeVar("FILTER_ARGS")  # A typedict
-RPC_RESPONSE = TypeVar("RPC_RESPONSE")
-SERIALIZER_ENUM = TypeVar("SERIALIZER_ENUM")
-BASE_MODEL = TypeVar("BASE_MODEL")
+RPC_RESPONSE = TypeVar("RPC_RESPONSE", bound=RpcModel)
+SERIALIZER_ENUM = TypeVar("SERIALIZER_ENUM", bound=Union[Enum, None])
+BASE_MODEL = TypeVar("BASE_MODEL", bound=Model)
 
 
 # In the future, this ought to be a pass through type that does not get double serializer, and which cannot be
 # In the future, this ought to be a pass through type that does not get double serializer, and which cannot be
 # inspected by code.
 # inspected by code.
@@ -45,7 +48,7 @@ class FilterQueryDatabaseImpl(
     # Required Overrides
     # Required Overrides
 
 
     @abc.abstractmethod
     @abc.abstractmethod
-    def base_query(self, ids_only: bool = False) -> QuerySet:
+    def base_query(self, ids_only: bool = False) -> QuerySet[BASE_MODEL]:
         # This should return a QuerySet for the model in question along with any other required data
         # This should return a QuerySet for the model in question along with any other required data
         # that is not a filter
         # that is not a filter
         pass
         pass
@@ -63,7 +66,9 @@ class FilterQueryDatabaseImpl(
         pass
         pass
 
 
     @abc.abstractmethod
     @abc.abstractmethod
-    def apply_filters(self, query: QuerySet, filters: FILTER_ARGS) -> QuerySet:
+    def apply_filters(
+        self, query: QuerySet[BASE_MODEL], filters: FILTER_ARGS
+    ) -> QuerySet[BASE_MODEL]:
         pass
         pass
 
 
     @abc.abstractmethod
     @abc.abstractmethod
@@ -75,7 +80,7 @@ class FilterQueryDatabaseImpl(
     def _filter_has_any_key_validator(self, *keys: str) -> Callable[[FILTER_ARGS], Optional[str]]:
     def _filter_has_any_key_validator(self, *keys: str) -> Callable[[FILTER_ARGS], Optional[str]]:
         def validator(d: FILTER_ARGS) -> Optional[str]:
         def validator(d: FILTER_ARGS) -> Optional[str]:
             for k in keys:
             for k in keys:
-                if k in d:  # type: ignore # We assume FILTER_ARGS is a dict
+                if k in d:  # type: ignore[operator]  # We assume FILTER_ARGS is a dict
                     return None
                     return None
 
 
             return f"Filter must contain at least one of: {keys}"
             return f"Filter must contain at least one of: {keys}"
@@ -108,8 +113,9 @@ class FilterQueryDatabaseImpl(
         if as_user is None and auth_context:
         if as_user is None and auth_context:
             as_user = auth_context.user
             as_user = auth_context.user
 
 
+        result = self._query_many(filter=filter)
         return serialize(
         return serialize(
-            self._query_many(filter=filter),
+            list(result),
             user=as_user,
             user=as_user,
             serializer=self.serialize_api(serializer),
             serializer=self.serialize_api(serializer),
         )
         )

+ 4 - 2
src/sentry/services/hybrid_cloud/identity/impl.py

@@ -92,7 +92,9 @@ class DatabaseBackedIdentityService(IdentityService):
     class _IdentityFilterQuery(
     class _IdentityFilterQuery(
         FilterQueryDatabaseImpl[Identity, IdentityFilterArgs, RpcIdentity, None]
         FilterQueryDatabaseImpl[Identity, IdentityFilterArgs, RpcIdentity, None]
     ):
     ):
-        def apply_filters(self, query: QuerySet, filters: IdentityFilterArgs) -> QuerySet:
+        def apply_filters(
+            self, query: QuerySet[Identity], filters: IdentityFilterArgs
+        ) -> QuerySet[Identity]:
             if "id" in filters:
             if "id" in filters:
                 query = query.filter(id=filters["id"])
                 query = query.filter(id=filters["id"])
             if "user_id" in filters:
             if "user_id" in filters:
@@ -107,7 +109,7 @@ class DatabaseBackedIdentityService(IdentityService):
                 query = query.filter(idp__type=filters["provider_type"])
                 query = query.filter(idp__type=filters["provider_type"])
             return query
             return query
 
 
-        def base_query(self, ids_only: bool = False) -> QuerySet:
+        def base_query(self, ids_only: bool = False) -> QuerySet[Identity]:
             return Identity.objects
             return Identity.objects
 
 
         def filter_arg_validator(self) -> Callable[[IdentityFilterArgs], Optional[str]]:
         def filter_arg_validator(self) -> Callable[[IdentityFilterArgs], Optional[str]]:

Some files were not shown because too many files changed in this diff