Browse Source

typing(discover): Adding some of the files for QueryBuilder to mypy (#29851)

- This adds a few files to mypy, didn't include search/events/filter or
  fields yet
William Mak 3 years ago
parent
commit
f5ecc4ae0a

+ 3 - 0
mypy.ini

@@ -56,6 +56,9 @@ files = src/sentry/api/bases/external_actor.py,
         src/sentry/snuba/outcomes.py,
         src/sentry/snuba/outcomes.py,
         src/sentry/snuba/query_subscription_consumer.py,
         src/sentry/snuba/query_subscription_consumer.py,
         src/sentry/spans/**/*.py,
         src/sentry/spans/**/*.py,
+        src/sentry/search/events/base.py,
+        src/sentry/search/events/builder.py,
+        src/sentry/search/events/types.py,
         src/sentry/tasks/app_store_connect.py,
         src/sentry/tasks/app_store_connect.py,
         src/sentry/tasks/low_priority_symbolication.py,
         src/sentry/tasks/low_priority_symbolication.py,
         src/sentry/tasks/store.py,
         src/sentry/tasks/store.py,

+ 3 - 3
src/sentry/search/events/base.py

@@ -1,4 +1,4 @@
-from typing import Dict, List, Mapping, Optional, Set
+from typing import Dict, List, Mapping, Optional, Set, cast
 
 
 from django.utils.functional import cached_property
 from django.utils.functional import cached_property
 from snuba_sdk.aliased_expression import AliasedExpression
 from snuba_sdk.aliased_expression import AliasedExpression
@@ -37,9 +37,9 @@ class QueryBase:
 
 
         self.resolve_column_name = resolve_column(self.dataset)
         self.resolve_column_name = resolve_column(self.dataset)
 
 
-    @cached_property
+    @cached_property  # type: ignore
     def project_slugs(self) -> Mapping[str, int]:
     def project_slugs(self) -> Mapping[str, int]:
-        project_ids = self.params.get("project_id", [])
+        project_ids = cast(List[int], self.params.get("project_id", []))
 
 
         if len(project_ids) > 0:
         if len(project_ids) > 0:
             project_slugs = Project.objects.filter(id__in=project_ids)
             project_slugs = Project.objects.filter(id__in=project_ids)

+ 22 - 19
src/sentry/search/events/builder.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple, cast
 
 
 from snuba_sdk.aliased_expression import AliasedExpression
 from snuba_sdk.aliased_expression import AliasedExpression
 from snuba_sdk.column import Column
 from snuba_sdk.column import Column
@@ -16,7 +16,7 @@ from sentry.search.events.types import ParamsType, SelectType, WhereType
 from sentry.utils.snuba import Dataset
 from sentry.utils.snuba import Dataset
 
 
 
 
-class QueryBuilder(QueryFilter):
+class QueryBuilder(QueryFilter):  # type: ignore
     """Builds a snql query"""
     """Builds a snql query"""
 
 
     def __init__(
     def __init__(
@@ -35,7 +35,7 @@ class QueryBuilder(QueryFilter):
         limit: Optional[int] = 50,
         limit: Optional[int] = 50,
         offset: Optional[int] = 0,
         offset: Optional[int] = 0,
         limitby: Optional[Tuple[str, int]] = None,
         limitby: Optional[Tuple[str, int]] = None,
-        turbo: Optional[bool] = False,
+        turbo: bool = False,
         sample_rate: Optional[float] = None,
         sample_rate: Optional[float] = None,
     ):
     ):
         super().__init__(dataset, params, auto_fields, functions_acl)
         super().__init__(dataset, params, auto_fields, functions_acl)
@@ -86,7 +86,7 @@ class QueryBuilder(QueryFilter):
         else:
         else:
             return []
             return []
 
 
-    def validate_aggregate_arguments(self):
+    def validate_aggregate_arguments(self) -> None:
         for column in self.columns:
         for column in self.columns:
             if column in self.aggregates:
             if column in self.aggregates:
                 continue
                 continue
@@ -149,7 +149,7 @@ class QueryBuilder(QueryFilter):
         )
         )
 
 
 
 
-class TimeseriesQueryBuilder(QueryFilter):
+class TimeseriesQueryBuilder(QueryFilter):  # type: ignore
     time_column = Column("time")
     time_column = Column("time")
 
 
     def __init__(
     def __init__(
@@ -185,7 +185,8 @@ class TimeseriesQueryBuilder(QueryFilter):
     def select(self) -> List[SelectType]:
     def select(self) -> List[SelectType]:
         if not self.aggregates:
         if not self.aggregates:
             raise InvalidSearchQuery("Cannot query a timeseries without a Y-Axis")
             raise InvalidSearchQuery("Cannot query a timeseries without a Y-Axis")
-        return self.aggregates
+        # Casting for now since QueryFields/QueryFilter are only partially typed
+        return cast(List[SelectType], self.aggregates)
 
 
     def get_snql_query(self) -> Query:
     def get_snql_query(self) -> Query:
         return Query(
         return Query(
@@ -247,14 +248,16 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
         params: ParamsType,
         params: ParamsType,
         granularity: int,
         granularity: int,
         top_events: List[Dict[str, Any]],
         top_events: List[Dict[str, Any]],
-        other: Optional[bool] = False,
+        other: bool = False,
         query: Optional[str] = None,
         query: Optional[str] = None,
         selected_columns: Optional[List[str]] = None,
         selected_columns: Optional[List[str]] = None,
         timeseries_columns: Optional[List[str]] = None,
         timeseries_columns: Optional[List[str]] = None,
         equations: Optional[List[str]] = None,
         equations: Optional[List[str]] = None,
         limit: Optional[int] = 10000,
         limit: Optional[int] = 10000,
     ):
     ):
-        timeseries_equations, timeseries_functions = categorize_columns(timeseries_columns)
+        timeseries_equations, timeseries_functions = categorize_columns(
+            timeseries_columns if timeseries_columns is not None else []
+        )
         super().__init__(
         super().__init__(
             dataset,
             dataset,
             params,
             params,
@@ -265,7 +268,7 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
             limit=limit,
             limit=limit,
         )
         )
 
 
-        self.fields = selected_columns
+        self.fields: List[str] = selected_columns if selected_columns is not None else []
 
 
         if (conditions := self.resolve_top_event_conditions(top_events, other)) is not None:
         if (conditions := self.resolve_top_event_conditions(top_events, other)) is not None:
             self.where.append(conditions)
             self.where.append(conditions)
@@ -290,7 +293,7 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
         return sorted(translated)
         return sorted(translated)
 
 
     def resolve_top_event_conditions(
     def resolve_top_event_conditions(
-        self, top_events: Optional[Dict[str, Any]], other: bool
+        self, top_events: List[Dict[str, Any]], other: bool
     ) -> Optional[WhereType]:
     ) -> Optional[WhereType]:
         """Given a list of top events construct the conditions"""
         """Given a list of top events construct the conditions"""
         conditions = []
         conditions = []
@@ -314,7 +317,7 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
 
 
             resolved_field = self.resolve_column(field)
             resolved_field = self.resolve_column(field)
 
 
-            values = set()
+            values: Set[Any] = set()
             for event in top_events:
             for event in top_events:
                 if field in event:
                 if field in event:
                     alias = field
                     alias = field
@@ -328,8 +331,8 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
                     continue
                     continue
                 else:
                 else:
                     values.add(event.get(alias))
                     values.add(event.get(alias))
-            values = list(values)
-            if values:
+            values_list = list(values)
+            if values_list:
                 if field == "timestamp" or field.startswith("timestamp.to_"):
                 if field == "timestamp" or field.startswith("timestamp.to_"):
                     if not other:
                     if not other:
                         # timestamp fields needs special handling, creating a big OR instead
                         # timestamp fields needs special handling, creating a big OR instead
@@ -337,20 +340,20 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
                     else:
                     else:
                         # Needs to be a big AND when negated
                         # Needs to be a big AND when negated
                         function, operator = And, Op.NEQ
                         function, operator = And, Op.NEQ
-                    if len(values) > 1:
+                    if len(values_list) > 1:
                         conditions.append(
                         conditions.append(
                             function(
                             function(
                                 conditions=[
                                 conditions=[
                                     Condition(resolved_field, operator, value)
                                     Condition(resolved_field, operator, value)
-                                    for value in sorted(values)
+                                    for value in sorted(values_list)
                                 ]
                                 ]
                             )
                             )
                         )
                         )
                     else:
                     else:
-                        conditions.append(Condition(resolved_field, operator, values[0]))
-                elif None in values:
+                        conditions.append(Condition(resolved_field, operator, values_list[0]))
+                elif None in values_list:
                     # one of the values was null, but we can't do an in with null values, so split into two conditions
                     # one of the values was null, but we can't do an in with null values, so split into two conditions
-                    non_none_values = [value for value in values if value is not None]
+                    non_none_values = [value for value in values_list if value is not None]
                     null_condition = Condition(
                     null_condition = Condition(
                         Function("isNull", [resolved_field]), Op.EQ if not other else Op.NEQ, 1
                         Function("isNull", [resolved_field]), Op.EQ if not other else Op.NEQ, 1
                     )
                     )
@@ -366,7 +369,7 @@ class TopEventsQueryBuilder(TimeseriesQueryBuilder):
                         conditions.append(null_condition)
                         conditions.append(null_condition)
                 else:
                 else:
                     conditions.append(
                     conditions.append(
-                        Condition(resolved_field, Op.IN if not other else Op.NOT_IN, values)
+                        Condition(resolved_field, Op.IN if not other else Op.NOT_IN, values_list)
                     )
                     )
         if len(conditions) > 1:
         if len(conditions) > 1:
             final_function = And if not other else Or
             final_function = And if not other else Or

+ 8 - 6
src/sentry/search/events/fields.py

@@ -2279,12 +2279,12 @@ def normalize_percentile_alias(args: Mapping[str, str]) -> str:
 
 
 
 
 class SnQLFunction(DiscoverFunction):
 class SnQLFunction(DiscoverFunction):
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kwargs) -> None:
         self.snql_aggregate = kwargs.pop("snql_aggregate", None)
         self.snql_aggregate = kwargs.pop("snql_aggregate", None)
         self.snql_column = kwargs.pop("snql_column", None)
         self.snql_column = kwargs.pop("snql_column", None)
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
-    def validate(self):
+    def validate(self) -> None:
         # assert that all optional args have defaults available
         # assert that all optional args have defaults available
         for i, arg in enumerate(self.optional_args):
         for i, arg in enumerate(self.optional_args):
             assert (
             assert (
@@ -2989,7 +2989,7 @@ class QueryFields(QueryBase):
         return function in self.function_converter
         return function in self.function_converter
 
 
     def resolve_function(
     def resolve_function(
-        self, function: str, match: Optional[Match[str]] = None, resolve_only=False
+        self, function: str, match: Optional[Match[str]] = None, resolve_only: bool = False
     ) -> SelectType:
     ) -> SelectType:
         """Given a public function, resolve to the corresponding Snql function
         """Given a public function, resolve to the corresponding Snql function
 
 
@@ -3007,7 +3007,7 @@ class QueryFields(QueryBase):
         if function in self.params.get("aliases", {}):
         if function in self.params.get("aliases", {}):
             raise NotImplementedError("Aggregate aliases not implemented in snql field parsing yet")
             raise NotImplementedError("Aggregate aliases not implemented in snql field parsing yet")
 
 
-        name, combinator_name, arguments, alias = self.parse_function(match)
+        name, combinator_name, parsed_arguments, alias = self.parse_function(match)
         snql_function = self.function_converter[name]
         snql_function = self.function_converter[name]
 
 
         combinator = snql_function.find_combinator(combinator_name)
         combinator = snql_function.find_combinator(combinator_name)
@@ -3022,7 +3022,9 @@ class QueryFields(QueryBase):
 
 
         combinator_applied = False
         combinator_applied = False
 
 
-        arguments = snql_function.format_as_arguments(name, arguments, self.params, combinator)
+        arguments = snql_function.format_as_arguments(
+            name, parsed_arguments, self.params, combinator
+        )
 
 
         self.function_alias_map[alias] = FunctionDetails(function, snql_function, arguments.copy())
         self.function_alias_map[alias] = FunctionDetails(function, snql_function, arguments.copy())
 
 
@@ -3384,7 +3386,7 @@ class QueryFields(QueryBase):
         self,
         self,
         args: Mapping[str, Union[str, Column, SelectType, int, float]],
         args: Mapping[str, Union[str, Column, SelectType, int, float]],
         alias: str,
         alias: str,
-        fixed_percentile: float = None,
+        fixed_percentile: Optional[float] = None,
     ) -> SelectType:
     ) -> SelectType:
         return (
         return (
             Function(
             Function(

+ 19 - 16
src/sentry/search/events/filter.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from datetime import datetime
 from datetime import datetime
 from functools import reduce
 from functools import reduce
 from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
 from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
@@ -67,7 +69,7 @@ def is_condition(term):
     return isinstance(term, (tuple, list)) and len(term) == 3 and term[1] in OPERATOR_TO_FUNCTION
     return isinstance(term, (tuple, list)) and len(term) == 3 and term[1] in OPERATOR_TO_FUNCTION
 
 
 
 
-def translate_transaction_status(val):
+def translate_transaction_status(val: str) -> str:
     if val not in SPAN_STATUS_NAME_TO_CODE:
     if val not in SPAN_STATUS_NAME_TO_CODE:
         raise InvalidSearchQuery(
         raise InvalidSearchQuery(
             f"Invalid value {val} for transaction.status condition. Accepted "
             f"Invalid value {val} for transaction.status condition. Accepted "
@@ -503,7 +505,7 @@ def _semver_build_filter_converter(
     return ["release", "IN", versions]
     return ["release", "IN", versions]
 
 
 
 
-def handle_operator_negation(operator):
+def handle_operator_negation(operator: str) -> Tuple[str, bool]:
     negated = False
     negated = False
     if operator == "!=":
     if operator == "!=":
         negated = True
         negated = True
@@ -1178,7 +1180,9 @@ class QueryFilter(QueryFields):
 
 
         return where, having
         return where, having
 
 
-    def _combine_conditions(self, lhs, rhs, operator):
+    def _combine_conditions(
+        self, lhs: List[WhereType], rhs: List[WhereType], operator: And | Or
+    ) -> List[WhereType]:
         combined_conditions = [
         combined_conditions = [
             conditions[0] if len(conditions) == 1 else And(conditions=conditions)
             conditions[0] if len(conditions) == 1 else And(conditions=conditions)
             for conditions in [lhs, rhs]
             for conditions in [lhs, rhs]
@@ -1441,9 +1445,9 @@ class QueryFilter(QueryFields):
         # conditions added to env_conditions can be OR'ed
         # conditions added to env_conditions can be OR'ed
         env_conditions = []
         env_conditions = []
         value = search_filter.value.value
         value = search_filter.value.value
-        values = set(value if isinstance(value, (list, tuple)) else [value])
+        values_set = set(value if isinstance(value, (list, tuple)) else [value])
         # sorted for consistency
         # sorted for consistency
-        values = sorted(f"{value}" for value in values)
+        values = sorted(f"{value}" for value in values_set)
         environment = self.column("environment")
         environment = self.column("environment")
         # the "no environment" environment is null in snuba
         # the "no environment" environment is null in snuba
         if "" in values:
         if "" in values:
@@ -1564,12 +1568,11 @@ class QueryFilter(QueryFields):
                 self.resolve_field(search_filter.key.name),
                 self.resolve_field(search_filter.key.name),
                 Op.IS_NULL if search_filter.operator == "=" else Op.IS_NOT_NULL,
                 Op.IS_NULL if search_filter.operator == "=" else Op.IS_NOT_NULL,
             )
             )
-        if search_filter.is_in_filter:
-            internal_value = [
-                translate_transaction_status(val) for val in search_filter.value.raw_value
-            ]
-        else:
-            internal_value = translate_transaction_status(search_filter.value.raw_value)
+        internal_value = (
+            [translate_transaction_status(val) for val in search_filter.value.raw_value]
+            if search_filter.is_in_filter
+            else translate_transaction_status(search_filter.value.raw_value)
+        )
         return Condition(
         return Condition(
             self.resolve_field(search_filter.key.name),
             self.resolve_field(search_filter.key.name),
             Op(search_filter.operator),
             Op(search_filter.operator),
@@ -1661,8 +1664,8 @@ class QueryFilter(QueryFields):
             raise ValueError("organization_id is a required param")
             raise ValueError("organization_id is a required param")
 
 
         organization_id: int = self.params["organization_id"]
         organization_id: int = self.params["organization_id"]
-        project_ids: Optional[list[int]] = self.params.get("project_id")
-        environments: Optional[list[Environment]] = self.params.get("environment_objects", [])
+        project_ids: Optional[List[int]] = self.params.get("project_id")
+        environments: Optional[List[Environment]] = self.params.get("environment_objects", [])
         qs = (
         qs = (
             Release.objects.filter_by_stage(
             Release.objects.filter_by_stage(
                 organization_id,
                 organization_id,
@@ -1729,7 +1732,7 @@ class QueryFilter(QueryFields):
             raise ValueError("organization_id is a required param")
             raise ValueError("organization_id is a required param")
 
 
         organization_id: int = self.params["organization_id"]
         organization_id: int = self.params["organization_id"]
-        project_ids: Optional[list[int]] = self.params.get("project_id")
+        project_ids: Optional[List[int]] = self.params.get("project_id")
         # We explicitly use `raw_value` here to avoid converting wildcards to shell values
         # We explicitly use `raw_value` here to avoid converting wildcards to shell values
         version: str = search_filter.value.raw_value
         version: str = search_filter.value.raw_value
         operator: str = search_filter.operator
         operator: str = search_filter.operator
@@ -1787,7 +1790,7 @@ class QueryFilter(QueryFields):
             raise ValueError("organization_id is a required param")
             raise ValueError("organization_id is a required param")
 
 
         organization_id: int = self.params["organization_id"]
         organization_id: int = self.params["organization_id"]
-        project_ids: Optional[list[int]] = self.params.get("project_id")
+        project_ids: Optional[List[int]] = self.params.get("project_id")
         package: str = search_filter.value.raw_value
         package: str = search_filter.value.raw_value
 
 
         versions = list(
         versions = list(
@@ -1813,7 +1816,7 @@ class QueryFilter(QueryFields):
             raise ValueError("organization_id is a required param")
             raise ValueError("organization_id is a required param")
 
 
         organization_id: int = self.params["organization_id"]
         organization_id: int = self.params["organization_id"]
-        project_ids: Optional[list[int]] = self.params.get("project_id")
+        project_ids: Optional[List[int]] = self.params.get("project_id")
         build: str = search_filter.value.raw_value
         build: str = search_filter.value.raw_value
 
 
         operator, negated = handle_operator_negation(search_filter.operator)
         operator, negated = handle_operator_negation(search_filter.operator)