Browse Source

typing(discover): Adding some of the files for QueryBuilder to mypy (#29851)

- This adds a few files to mypy, didn't include search/events/filter or
  fields yet
William Mak 3 years ago
parent
commit
f5ecc4ae0a

+ 3 - 0
mypy.ini

@@ -56,6 +56,9 @@ files = src/sentry/api/bases/external_actor.py,
         src/sentry/snuba/outcomes.py,
         src/sentry/snuba/query_subscription_consumer.py,
         src/sentry/spans/**/*.py,
+        src/sentry/search/events/base.py,
+        src/sentry/search/events/builder.py,
+        src/sentry/search/events/types.py,
         src/sentry/tasks/app_store_connect.py,
         src/sentry/tasks/low_priority_symbolication.py,
         src/sentry/tasks/store.py,

+ 3 - 3
src/sentry/search/events/base.py

@@ -1,4 +1,4 @@
-from typing import Dict, List, Mapping, Optional, Set
+from typing import Dict, List, Mapping, Optional, Set, cast
 
 from django.utils.functional import cached_property
 from snuba_sdk.aliased_expression import AliasedExpression
@@ -37,9 +37,9 @@ class QueryBase:
 
         self.resolve_column_name = resolve_column(self.dataset)
 
-    @cached_property
+    @cached_property  # type: ignore
     def project_slugs(self) -> Mapping[str, int]:
-        project_ids = self.params.get("project_id", [])
+        project_ids = cast(List[int], self.params.get("project_id", []))
 
         if len(project_ids) > 0:
             project_slugs = Project.objects.filter(id__in=project_ids)

+ 22 - 19
src/sentry/search/events/builder.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple, cast
 
 from snuba_sdk.aliased_expression import AliasedExpression
 from snuba_sdk.column import Column
@@ -16,7 +16,7 @@ from sentry.search.events.types import ParamsType, SelectType, WhereType
 from sentry.utils.snuba import Dataset
 
 
-class QueryBuilder(QueryFilter):
+class QueryBuilder(QueryFilter):  # type: ignore
     """Builds a snql query"""
 
     def __init__(
@@ -35,7 +35,7 @@ class QueryBuilder(QueryFilter):
         limit: Optional[int] = 50,
         offset: Optional[int] = 0,
         limitby: Optional[Tuple[str, int]] = None,
-        turbo: Optional[bool] = False,
+        turbo: bool = False,
         sample_rate: Optional[float] = None,
     ):
         super().__init__(dataset, params, auto_fields, functions_acl)
@@ -86,7 +86,7 @@ class QueryBuilder(QueryFilter):
         else:
             return []
 
-    def validate_aggregate_arguments(self):
+    def validate_aggregate_arguments(self) -> None:
         for column in self.columns:
             if column in self.aggregates:
                 continue
@@ -149,7 +149,7 @@ class QueryBuilder(QueryFilter):
         )
 
 
-class TimeseriesQueryBuilder(QueryFilter):
+class TimeseriesQueryBuilder(QueryFilter):  # type: ignore
     time_column = Column("time")
 
     def __init__(
@@ -185,7 +185,8 @@ class TimeseriesQueryBuilder(QueryFilter):
     def select(self) -> List[SelectType]:
         if not self.aggregates:
             raise InvalidSearchQuery("Cannot query a timeseries without a Y-Axis")
-        return self.aggregates
+        # Casting for now since QueryFields/QueryFilter are only partially typed
+        return cast(List[SelectType], self.aggregates)
 
     def get_snql_query(self) -> Query:
         return Query(
@@ -247,14 +248,16 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
         params: ParamsType,
         granularity: int,
         top_events: List[Dict[str, Any]],
-        other: Optional[bool] = False,
+        other: bool = False,
         query: Optional[str] = None,
         selected_columns: Optional[List[str]] = None,
         timeseries_columns: Optional[List[str]] = None,
         equations: Optional[List[str]] = None,
         limit: Optional[int] = 10000,
     ):
-        timeseries_equations, timeseries_functions = categorize_columns(timeseries_columns)
+        timeseries_equations, timeseries_functions = categorize_columns(
+            timeseries_columns if timeseries_columns is not None else []
+        )
         super().__init__(
             dataset,
             params,
@@ -265,7 +268,7 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
             limit=limit,
         )
 
-        self.fields = selected_columns
+        self.fields: List[str] = selected_columns if selected_columns is not None else []
 
         if (conditions := self.resolve_top_event_conditions(top_events, other)) is not None:
             self.where.append(conditions)
@@ -290,7 +293,7 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
         return sorted(translated)
 
     def resolve_top_event_conditions(
-        self, top_events: Optional[Dict[str, Any]], other: bool
+        self, top_events: List[Dict[str, Any]], other: bool
     ) -> Optional[WhereType]:
         """Given a list of top events construct the conditions"""
         conditions = []
@@ -314,7 +317,7 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
 
             resolved_field = self.resolve_column(field)
 
-            values = set()
+            values: Set[Any] = set()
             for event in top_events:
                 if field in event:
                     alias = field
@@ -328,8 +331,8 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
                     continue
                 else:
                     values.add(event.get(alias))
-            values = list(values)
-            if values:
+            values_list = list(values)
+            if values_list:
                 if field == "timestamp" or field.startswith("timestamp.to_"):
                     if not other:
                         # timestamp fields needs special handling, creating a big OR instead
@@ -337,20 +340,20 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
                     else:
                         # Needs to be a big AND when negated
                         function, operator = And, Op.NEQ
-                    if len(values) > 1:
+                    if len(values_list) > 1:
                         conditions.append(
                             function(
                                 conditions=[
                                     Condition(resolved_field, operator, value)
-                                    for value in sorted(values)
+                                    for value in sorted(values_list)
                                 ]
                             )
                         )
                     else:
-                        conditions.append(Condition(resolved_field, operator, values[0]))
-                elif None in values:
+                        conditions.append(Condition(resolved_field, operator, values_list[0]))
+                elif None in values_list:
                     # one of the values was null, but we can't do an in with null values, so split into two conditions
-                    non_none_values = [value for value in values if value is not None]
+                    non_none_values = [value for value in values_list if value is not None]
                     null_condition = Condition(
                         Function("isNull", [resolved_field]), Op.EQ if not other else Op.NEQ, 1
                     )
@@ -366,7 +369,7 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
                         conditions.append(null_condition)
                 else:
                     conditions.append(
-                        Condition(resolved_field, Op.IN if not other else Op.NOT_IN, values)
+                        Condition(resolved_field, Op.IN if not other else Op.NOT_IN, values_list)
                     )
         if len(conditions) > 1:
             final_function = And if not other else Or

+ 8 - 6
src/sentry/search/events/fields.py

@@ -2279,12 +2279,12 @@ def normalize_percentile_alias(args: Mapping[str, str]) -> str:
 
 
 class SnQLFunction(DiscoverFunction):
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kwargs) -> None:
         self.snql_aggregate = kwargs.pop("snql_aggregate", None)
         self.snql_column = kwargs.pop("snql_column", None)
         super().__init__(*args, **kwargs)
 
-    def validate(self):
+    def validate(self) -> None:
         # assert that all optional args have defaults available
         for i, arg in enumerate(self.optional_args):
             assert (
@@ -2989,7 +2989,7 @@ class QueryFields(QueryBase):
         return function in self.function_converter
 
     def resolve_function(
-        self, function: str, match: Optional[Match[str]] = None, resolve_only=False
+        self, function: str, match: Optional[Match[str]] = None, resolve_only: bool = False
     ) -> SelectType:
         """Given a public function, resolve to the corresponding Snql function
 
@@ -3007,7 +3007,7 @@ class QueryFields(QueryBase):
         if function in self.params.get("aliases", {}):
             raise NotImplementedError("Aggregate aliases not implemented in snql field parsing yet")
 
-        name, combinator_name, arguments, alias = self.parse_function(match)
+        name, combinator_name, parsed_arguments, alias = self.parse_function(match)
         snql_function = self.function_converter[name]
 
         combinator = snql_function.find_combinator(combinator_name)
@@ -3022,7 +3022,9 @@ class QueryFields(QueryBase):
 
         combinator_applied = False
 
-        arguments = snql_function.format_as_arguments(name, arguments, self.params, combinator)
+        arguments = snql_function.format_as_arguments(
+            name, parsed_arguments, self.params, combinator
+        )
 
         self.function_alias_map[alias] = FunctionDetails(function, snql_function, arguments.copy())
 
@@ -3384,7 +3386,7 @@ class QueryFields(QueryBase):
         self,
         args: Mapping[str, Union[str, Column, SelectType, int, float]],
         alias: str,
-        fixed_percentile: float = None,
+        fixed_percentile: Optional[float] = None,
     ) -> SelectType:
         return (
             Function(

+ 19 - 16
src/sentry/search/events/filter.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from datetime import datetime
 from functools import reduce
 from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
@@ -67,7 +69,7 @@ def is_condition(term):
     return isinstance(term, (tuple, list)) and len(term) == 3 and term[1] in OPERATOR_TO_FUNCTION
 
 
-def translate_transaction_status(val):
+def translate_transaction_status(val: str) -> str:
     if val not in SPAN_STATUS_NAME_TO_CODE:
         raise InvalidSearchQuery(
             f"Invalid value {val} for transaction.status condition. Accepted "
@@ -503,7 +505,7 @@ def _semver_build_filter_converter(
     return ["release", "IN", versions]
 
 
-def handle_operator_negation(operator):
+def handle_operator_negation(operator: str) -> Tuple[str, bool]:
     negated = False
     if operator == "!=":
         negated = True
@@ -1178,7 +1180,9 @@ class QueryFilter(QueryFields):
 
         return where, having
 
-    def _combine_conditions(self, lhs, rhs, operator):
+    def _combine_conditions(
+        self, lhs: List[WhereType], rhs: List[WhereType], operator: And | Or
+    ) -> List[WhereType]:
         combined_conditions = [
             conditions[0] if len(conditions) == 1 else And(conditions=conditions)
             for conditions in [lhs, rhs]
@@ -1441,9 +1445,9 @@ class QueryFilter(QueryFields):
         # conditions added to env_conditions can be OR'ed
         env_conditions = []
         value = search_filter.value.value
-        values = set(value if isinstance(value, (list, tuple)) else [value])
+        values_set = set(value if isinstance(value, (list, tuple)) else [value])
         # sorted for consistency
-        values = sorted(f"{value}" for value in values)
+        values = sorted(f"{value}" for value in values_set)
         environment = self.column("environment")
         # the "no environment" environment is null in snuba
         if "" in values:
@@ -1564,12 +1568,11 @@ class QueryFilter(QueryFields):
                 self.resolve_field(search_filter.key.name),
                 Op.IS_NULL if search_filter.operator == "=" else Op.IS_NOT_NULL,
             )
-        if search_filter.is_in_filter:
-            internal_value = [
-                translate_transaction_status(val) for val in search_filter.value.raw_value
-            ]
-        else:
-            internal_value = translate_transaction_status(search_filter.value.raw_value)
+        internal_value = (
+            [translate_transaction_status(val) for val in search_filter.value.raw_value]
+            if search_filter.is_in_filter
+            else translate_transaction_status(search_filter.value.raw_value)
+        )
         return Condition(
             self.resolve_field(search_filter.key.name),
             Op(search_filter.operator),
@@ -1661,8 +1664,8 @@ class QueryFilter(QueryFields):
             raise ValueError("organization_id is a required param")
 
         organization_id: int = self.params["organization_id"]
-        project_ids: Optional[list[int]] = self.params.get("project_id")
-        environments: Optional[list[Environment]] = self.params.get("environment_objects", [])
+        project_ids: Optional[List[int]] = self.params.get("project_id")
+        environments: Optional[List[Environment]] = self.params.get("environment_objects", [])
         qs = (
             Release.objects.filter_by_stage(
                 organization_id,
@@ -1729,7 +1732,7 @@ class QueryFilter(QueryFields):
             raise ValueError("organization_id is a required param")
 
         organization_id: int = self.params["organization_id"]
-        project_ids: Optional[list[int]] = self.params.get("project_id")
+        project_ids: Optional[List[int]] = self.params.get("project_id")
         # We explicitly use `raw_value` here to avoid converting wildcards to shell values
         version: str = search_filter.value.raw_value
         operator: str = search_filter.operator
@@ -1787,7 +1790,7 @@ class QueryFilter(QueryFields):
             raise ValueError("organization_id is a required param")
 
         organization_id: int = self.params["organization_id"]
-        project_ids: Optional[list[int]] = self.params.get("project_id")
+        project_ids: Optional[List[int]] = self.params.get("project_id")
         package: str = search_filter.value.raw_value
 
         versions = list(
@@ -1813,7 +1816,7 @@ class QueryFilter(QueryFields):
             raise ValueError("organization_id is a required param")
 
         organization_id: int = self.params["organization_id"]
-        project_ids: Optional[list[int]] = self.params.get("project_id")
+        project_ids: Optional[List[int]] = self.params.get("project_id")
         build: str = search_filter.value.raw_value
 
         operator, negated = handle_operator_negation(search_filter.operator)