Browse Source

ref: make ParamsType a TypedDict (#65787)

it isn't actually a well defined dict -- but this helps with unpacking
it inside the events code

there's definitely some problematic dict stuffing happening here -- it's
wild that this sometimes has `project_id` (list of int???), `projects`
(list of Project), and project_objects (list of Project) -- but
hopefully that can be cleaned up later with an actual type

<!-- Describe your PR here. -->
anthony sottile 1 year ago
parent
commit
c10f7b8211

+ 4 - 4
src/sentry/api/bases/organization_events.py

@@ -28,7 +28,7 @@ from sentry.models.project import Project
 from sentry.models.team import Team
 from sentry.search.events.constants import DURATION_UNITS, SIZE_UNITS
 from sentry.search.events.fields import get_function_alias
-from sentry.search.events.types import SnubaParams
+from sentry.search.events.types import ParamsType, SnubaParams
 from sentry.snuba import discover
 from sentry.snuba.metrics.extraction import MetricSpecType
 from sentry.snuba.utils import DATASET_LABELS, DATASET_OPTIONS, get_dataset
@@ -125,7 +125,7 @@ class OrganizationEventsEndpointBase(OrganizationEndpoint):
 
     def get_snuba_params(
         self, request: Request, organization: Organization, check_global_views: bool = True
-    ) -> dict[str, Any]:
+    ) -> ParamsType:
         with sentry_sdk.start_span(op="discover.endpoint", description="filter_params"):
             if (
                 len(self.get_field_list(organization, request))
@@ -136,7 +136,7 @@ class OrganizationEventsEndpointBase(OrganizationEndpoint):
                     detail=f"You can view up to {MAX_FIELDS} fields at a time. Please delete some and try again."
                 )
 
-            params: dict[str, Any] = self.get_filter_params(request, organization)
+            params: ParamsType = self.get_filter_params(request, organization)
             params = self.quantize_date_params(request, params)
             params["user_id"] = request.user.id if request.user else None
             params["team_id"] = self.get_team_ids(request, organization)
@@ -350,7 +350,7 @@ class OrganizationEventsV2EndpointBase(OrganizationEventsEndpointBase):
         ],
         top_events: int = 0,
         query_column: str = "count()",
-        params: dict[str, Any] | None = None,
+        params: ParamsType | None = None,
         query: str | None = None,
         allow_partial_buckets: bool = False,
         zerofill_results: bool = True,

+ 1 - 1
src/sentry/api/endpoints/organization_events_spans_performance.py

@@ -87,7 +87,7 @@ SPAN_PERFORMANCE_COLUMNS: dict[str, SpanPerformanceColumn] = {
 class OrganizationEventsSpansEndpointBase(OrganizationEventsV2EndpointBase):
     def get_snuba_params(
         self, request: Request, organization: Organization, check_global_views: bool = True
-    ) -> dict[str, Any]:
+    ) -> ParamsType:
         params = super().get_snuba_params(request, organization, check_global_views)
 
         if len(params.get("project_id", [])) != 1:

+ 8 - 5
src/sentry/api/endpoints/organization_spans_aggregation.py

@@ -369,15 +369,18 @@ class OrganizationSpansAggregationEndpoint(OrganizationEventsEndpointBase):
 
             return Response(data=aggregated_tree)
 
-        conditions = [["transaction", "=", transaction]]
+        conditions: list[list[object]] = [["transaction", "=", transaction]]
         if http_method is not None:
             conditions.append(["http.method", "=", http_method])
 
         environments = params.get("environment", None)
-        if environments and len(environments) > 1:
-            conditions.append(["environment", "IN", environments])
-        elif environments and len(environments) == 1:
-            conditions.append(["environment", "=", environments[0]])
+        if environments:
+            if isinstance(environments, str):
+                conditions.append(["environment", "=", environments])
+            elif len(environments) == 1:
+                conditions.append(["environment", "=", environments[0]])
+            elif len(environments) > 1:
+                conditions.append(["environment", "IN", environments])
 
         events = eventstore.backend.get_events(
             filter=eventstore.Filter(

+ 4 - 4
src/sentry/api/helpers/events.py

@@ -1,7 +1,6 @@
 from __future__ import annotations
 
-from collections.abc import Mapping
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
 
 from rest_framework.request import Request
 from rest_framework.response import Response
@@ -11,6 +10,7 @@ from sentry.api.serializers import serialize
 from sentry.eventstore.models import Event
 from sentry.issues.grouptype import GroupCategory
 from sentry.search.events.builder import QueryBuilder
+from sentry.search.events.types import ParamsType
 from sentry.snuba.dataset import Dataset
 from sentry.utils.validators import normalize_event_id
 
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
 def get_direct_hit_response(
     request: Request,
     query: str | None,
-    snuba_params: Mapping[str, Any],
+    snuba_params: ParamsType,
     referrer: str,
     group: Group,
 ) -> Response | None:
@@ -52,7 +52,7 @@ def get_direct_hit_response(
 
 
 def get_query_builder_for_group(
-    query: str, snuba_params: Mapping[str, Any], group: Group, limit: int, offset: int
+    query: str, snuba_params: ParamsType, group: Group, limit: int, offset: int
 ) -> QueryBuilder:
     dataset = Dataset.IssuePlatform
     if group.issue_category == GroupCategory.ERROR:

+ 3 - 3
src/sentry/api/serializers/rest_framework/dashboard.py

@@ -24,7 +24,7 @@ from sentry.models.dashboard_widget import (
 from sentry.relay.config.metric_extraction import get_current_widget_specs, widget_exceeds_max_specs
 from sentry.search.events.builder import UnresolvedQuery
 from sentry.search.events.fields import is_function
-from sentry.search.events.types import QueryBuilderConfig
+from sentry.search.events.types import ParamsType, QueryBuilderConfig
 from sentry.snuba.dataset import Dataset
 from sentry.tasks.on_demand_metrics import (
     _get_widget_on_demand_specs,
@@ -195,12 +195,12 @@ class DashboardWidgetQuerySerializer(CamelSnakeSerializer[Dashboard]):
             # Subtract one because the equation is injected to fields
             orderby = f"{orderby_prefix}equation[{len(equations) - 1}]"
 
-        params = {
+        params: ParamsType = {
             "start": datetime.now() - timedelta(days=1),
             "end": datetime.now(),
             "project_id": [p.id for p in self.context["projects"]],
             "organization_id": self.context["organization"].id,
-            "environment": self.context.get("environment"),
+            "environment": self.context.get("environment", []),
         }
 
         try:

+ 12 - 4
src/sentry/api/utils.py

@@ -90,7 +90,7 @@ def default_start_end_dates(
 
 @overload
 def get_date_range_from_params(
-    params: dict[str, Any],
+    params: Mapping[str, Any],
     optional: Literal[False] = ...,
     default_stats_period: datetime.timedelta = ...,
 ) -> tuple[datetime.datetime, datetime.datetime]:
@@ -99,7 +99,7 @@ def get_date_range_from_params(
 
 @overload
 def get_date_range_from_params(
-    params: dict[str, Any],
+    params: Mapping[str, Any],
     optional: bool = ...,
     default_stats_period: datetime.timedelta = ...,
 ) -> tuple[None, None] | tuple[datetime.datetime, datetime.datetime]:
@@ -107,7 +107,7 @@ def get_date_range_from_params(
 
 
 def get_date_range_from_params(
-    params: dict[str, Any],
+    params: Mapping[str, Any],
     optional: bool = False,
     default_stats_period: datetime.timedelta = MAX_STATS_PERIOD,
 ) -> tuple[None, None] | tuple[datetime.datetime, datetime.datetime]:
@@ -133,7 +133,15 @@ def get_date_range_from_params(
     :return: A length 2 tuple containing start/end or raises an `InvalidParams`
     exception
     """
-    mutable_params = params.copy()
+    mutable_params = {
+        k: params[k]
+        for k in (
+            *("timeframe", "timeframeStart", "timeframeEnd"),
+            *("statsPeriod", "statsPeriodStart", "statsPeriodEnd"),
+            *("start", "end"),
+        )
+        if k in params
+    }
     timeframe = mutable_params.get("timeframe")
     timeframe_start = mutable_params.get("timeframeStart")
     timeframe_end = mutable_params.get("timeframeEnd")

+ 2 - 1
src/sentry/data_export/processors/discover.py

@@ -7,6 +7,7 @@ from sentry.models.environment import Environment
 from sentry.models.group import Group
 from sentry.models.project import Project
 from sentry.search.events.fields import get_function_alias
+from sentry.search.events.types import ParamsType
 from sentry.snuba import discover
 from sentry.snuba.utils import get_dataset
 
@@ -24,7 +25,7 @@ class DiscoverProcessor:
         self.projects = self.get_projects(organization_id, discover_query)
         self.environments = self.get_environments(organization_id, discover_query)
         self.start, self.end = get_date_range_from_params(discover_query)
-        self.params = {
+        self.params: ParamsType = {
             "organization_id": organization_id,
             "project_id": [project.id for project in self.projects],
             "start": self.start,

+ 2 - 2
src/sentry/relay/config/metric_extraction.py

@@ -30,7 +30,7 @@ from sentry.models.transaction_threshold import (
 )
 from sentry.search.events import fields
 from sentry.search.events.builder import QueryBuilder
-from sentry.search.events.types import QueryBuilderConfig
+from sentry.search.events.types import ParamsType, QueryBuilderConfig
 from sentry.snuba.dataset import Dataset
 from sentry.snuba.metrics.extraction import (
     MetricSpec,
@@ -590,7 +590,7 @@ def _is_widget_query_low_cardinality(widget_query: DashboardWidgetQuery, project
 
     New queries will be checked upon creation and not allowed at that time.
     """
-    params: dict[str, Any] = {
+    params: ParamsType = {
         "statsPeriod": "30m",
         "project_objects": [project],
         "organization_id": project.organization_id,  # Organization id has to be specified to not violate allocation policy.

+ 1 - 1
src/sentry/search/events/builder/discover.py

@@ -149,7 +149,7 @@ class BaseQueryBuilder:
                         organization_id=organization.id, name__in=params["environment"]
                     )
                 )
-                if "" in cast(list[str], params["environment"]):
+                if "" in params["environment"]:
                     environments.append(None)
             elif isinstance(params["environment"], str):
                 environments = list(

+ 15 - 1
src/sentry/search/events/types.py

@@ -18,8 +18,22 @@ from sentry.models.team import Team
 from sentry.services.hybrid_cloud.user import RpcUser
 
 WhereType = Union[Condition, BooleanCondition]
+
+
 # Replaced by SnubaParams
-ParamsType = Mapping[str, Union[Sequence[int], int, str, datetime]]
+class ParamsType(TypedDict, total=False):
+    project_id: list[int]
+    projects: list[Project]
+    project_objects: list[Project]
+    start: datetime
+    end: datetime
+    environment: str | list[str]
+    organization_id: int
+    use_case_id: str
+    environment_objects: list[Environment]
+    statsPeriod: str
+
+
 SelectType = Union[AliasedExpression, Column, Function, CurriedFunction]
 
 NormalizedArg = Optional[Union[str, float]]

Some files were not shown because too many files changed in this diff