Browse Source

feat(discover): Add failure_count function snql (#27530)

Added failure count to SnQL fields and some
patterns to add functions in general.
Shruthi 3 years ago
parent
commit
6426972103

+ 9 - 3
src/sentry/search/events/builder.py

@@ -19,13 +19,18 @@ class QueryBuilder(QueryFilter):
         query: Optional[str] = None,
         selected_columns: Optional[List[str]] = None,
         orderby: Optional[List[str]] = None,
+        use_aggregate_conditions: bool = False,
         limit: int = 50,
     ):
         super().__init__(dataset, params)
 
         self.limit = Limit(limit)
 
-        self.where = self.resolve_where(query)
+        parsed_terms = self.parse_query(query)
+        self.where = self.resolve_where(parsed_terms)
+        self.having = self.resolve_having(
+            parsed_terms, use_aggregate_conditions=use_aggregate_conditions
+        )
 
         # params depends on get_filter since there may be projects in the query
         self.where += self.resolve_params()
@@ -35,12 +40,12 @@ class QueryBuilder(QueryFilter):
 
     @property
     def select(self) -> Optional[List[SelectType]]:
-        return [*self.aggregates, *self.columns]
+        return self.columns
 
     @property
     def groupby(self) -> Optional[List[SelectType]]:
         if self.aggregates:
-            return self.columns
+            return [c for c in self.columns if c not in self.aggregates]
         else:
             return []
 
@@ -50,6 +55,7 @@ class QueryBuilder(QueryFilter):
             match=Entity(self.dataset.value),
             select=self.select,
             where=self.where,
+            having=self.having,
             groupby=self.groupby,
             orderby=self.orderby,
             limit=self.limit,

+ 147 - 2
src/sentry/search/events/fields.py

@@ -2,7 +2,7 @@ import re
 from collections import defaultdict, namedtuple
 from copy import deepcopy
 from datetime import datetime
-from typing import Callable, List, Mapping, Optional, Union
+from typing import Callable, List, Mapping, Match, Optional, Tuple, Union
 
 import sentry_sdk
 from sentry_relay.consts import SPAN_STATUS_NAME_TO_CODE
@@ -2104,6 +2104,31 @@ for alias, name in FUNCTION_ALIASES.items():
 FUNCTION_ALIAS_PATTERN = re.compile(r"^({}).*".format("|".join(list(FUNCTIONS.keys()))))
 
 
+class SnQLFunction(DiscoverFunction):
+    def __init__(self, *args, **kwargs):
+        self.snql_aggregate = kwargs.pop("snql_aggregate", None)
+        super().__init__(*args, **kwargs)
+
+    def validate(self):
+        # assert that all optional args have defaults available
+        for i, arg in enumerate(self.optional_args):
+            assert (
+                arg.has_default
+            ), f"{self.name}: optional argument at index {i} does not have default"
+
+        assert self.snql_aggregate is not None
+
+        # assert that no duplicate argument names are used
+        names = set()
+        for arg in self.args:
+            assert (
+                arg.name not in names
+            ), f"{self.name}: argument {arg.name} specified more than once"
+            names.add(arg.name)
+
+        self.validate_result_type(self.default_result_type)
+
+
 class QueryFields(QueryBase):
     """Field logic for a snql query"""
 
@@ -2129,6 +2154,82 @@ class QueryFields(QueryBase):
             TEAM_KEY_TRANSACTION_ALIAS: self._resolve_unimplemented_alias,
         }
 
+        self.function_converter: Mapping[str, SnQLFunction] = {
+            function.name: function
+            for function in [
+                SnQLFunction(
+                    "failure_count",
+                    snql_aggregate=lambda _, alias: Function(
+                        "countIf",
+                        [
+                            Function(
+                                "notIn",
+                                [
+                                    self.column("transaction.status"),
+                                    (
+                                        SPAN_STATUS_NAME_TO_CODE["ok"],
+                                        SPAN_STATUS_NAME_TO_CODE["cancelled"],
+                                        SPAN_STATUS_NAME_TO_CODE["unknown"],
+                                    ),
+                                ],
+                            )
+                        ],
+                        alias,
+                    ),
+                    default_result_type="integer",
+                ),
+                # TODO: implement these
+                SnQLFunction("percentile", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("p50", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("p75", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("p95", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("p99", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("p100", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("eps", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("epm", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("last_seen", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("latest_event", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("apdex", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction(
+                    "count_miserable", snql_aggregate=self._resolve_unimplemented_function
+                ),
+                SnQLFunction("user_misery", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("failure_rate", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("array_join", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("histogram", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("count_unique", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("count", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("count_at_least", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("min", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("max", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("avg", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("var", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("stddev", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("cov", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("corr", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("sum", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("any", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("absolute_delta", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction(
+                    "percentile_range", snql_aggregate=self._resolve_unimplemented_function
+                ),
+                SnQLFunction("avg_range", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("variance_range", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("count_range", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("percentage", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("t_test", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction("minus", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction(
+                    "absolute_correlation", snql_aggregate=self._resolve_unimplemented_function
+                ),
+                SnQLFunction("count_if", snql_aggregate=self._resolve_unimplemented_function),
+                SnQLFunction(
+                    "compare_numeric_aggregate", snql_aggregate=self._resolve_unimplemented_function
+                ),
+                SnQLFunction("to_other", snql_aggregate=self._resolve_unimplemented_function),
+            ]
+        }
+
     def resolve_select(self, selected_columns: Optional[List[str]]) -> List[SelectType]:
         if selected_columns is None:
             return []
@@ -2147,7 +2248,7 @@ class QueryFields(QueryBase):
     def resolve_field(self, field: str) -> SelectType:
         match = is_function(field)
         if match:
-            raise NotImplementedError(f"{field} not implemented in snql field parsing yet")
+            return self.resolve_function(field, match)
 
         if self.is_field_alias(field):
             return self.resolve_field_alias(field)
@@ -2212,6 +2313,40 @@ class QueryFields(QueryBase):
             raise NotImplementedError(f"{alias} not implemented in snql field parsing yet")
         return converter(alias)
 
+    def is_function(self, function: str) -> bool:
+        return function in self.function_converter
+
+    def resolve_function(self, function: str, match: Optional[Match[str]] = None) -> SelectType:
+        if match is None:
+            match = is_function(function)
+
+        if not match:
+            raise InvalidSearchQuery(f"Invalid characters in field {function}")
+
+        if function in self.params.get("aliases", {}):
+            raise NotImplementedError("Aggregate aliases not implemented in snql field parsing yet")
+
+        name, arguments, alias = self.parse_function(match)
+        snql_function = self.function_converter.get(name)
+        if snql_function.snql_aggregate is not None:
+            self.aggregates.append(snql_function.snql_aggregate(arguments, alias))
+        return snql_function.snql_aggregate(arguments, alias)
+
+    def parse_function(self, match: Match[str]) -> Tuple[str, List[str], str]:
+        function = match.group("function")
+        if not self.is_function(function):
+            raise InvalidSearchQuery(f"{function} is not a valid function")
+
+        arguments = match.group("columns")
+        arguments = parse_arguments(function, arguments)
+        alias = match.group("alias")
+
+        if alias is None:
+            alias = get_function_alias_with_columns(function, arguments)
+
+        return (function, arguments, alias)
+
+    # Field Aliases
     def _resolve_issue_id_alias(self, _: str) -> SelectType:
         """The state of having no issues is represented differently on transactions vs
         other events. On the transactions table, it is represented by 0 whereas it is
@@ -2411,3 +2546,13 @@ class QueryFields(QueryBase):
         Can be deleted once all field aliases have been implemented.
         """
         raise NotImplementedError(f"{alias} not implemented in snql field parsing yet")
+
+    def _resolve_unimplemented_function(
+        self,
+        _: List[str],
+        alias: str,
+    ) -> SelectType:
+        """Used in the interim as a stub for ones that have not be implemented in SnQL yet.
+        Can be deleted once all functions have been implemented.
+        """
+        raise NotImplementedError(f"{alias} not implemented in snql field parsing yet")

+ 56 - 7
src/sentry/search/events/filter.py

@@ -1025,7 +1025,7 @@ class QueryFilter(QueryFields):
             ERROR_UNHANDLED_ALIAS: self._error_unhandled_filter_converter,
         }
 
-    def resolve_where(self, query: Optional[str]) -> List[WhereType]:
+    def parse_query(self, query: Optional[str]) -> Optional[Sequence[SearchFilter]]:
         if query is None:
             return []
 
@@ -1034,12 +1034,38 @@ class QueryFilter(QueryFields):
         except ParseError as e:
             raise InvalidSearchQuery(f"Parse error: {e.expr.name} (column {e.column():d})")
 
-        conditions = [
-            self.format_search_filter(term)
-            for term in parsed_terms
-            if isinstance(term, SearchFilter)
-        ]
-        return [condition for condition in conditions if condition]
+        return parsed_terms
+
+    def resolve_where(self, parsed_terms: Optional[Sequence[SearchFilter]]) -> List[WhereType]:
+        if not parsed_terms:
+            return []
+
+        where_conditions: List[WhereType] = []
+        for term in parsed_terms:
+            if isinstance(term, SearchFilter):
+                condition = self.format_search_filter(term)
+                if condition:
+                    where_conditions.append(condition)
+
+        return where_conditions
+
+    def resolve_having(
+        self, parsed_terms: Optional[Sequence[SearchFilter]], use_aggregate_conditions: bool = False
+    ) -> List[WhereType]:
+        if not parsed_terms:
+            return []
+
+        if not use_aggregate_conditions:
+            return []
+
+        having_conditions: List[WhereType] = []
+        for term in parsed_terms:
+            if isinstance(term, AggregateFilter):
+                condition = self.convert_aggregate_filter_to_condition(term)
+                if condition:
+                    having_conditions.append(condition)
+
+        return having_conditions
 
     def resolve_params(self) -> List[WhereType]:
         """Keys included as url params take precedent if same key is included in search
@@ -1102,6 +1128,29 @@ class QueryFilter(QueryFields):
         )
         return converted_filter if converted_filter else None
 
+    def convert_aggregate_filter_to_condition(
+        self, aggregate_filter: AggregateFilter
+    ) -> Optional[WhereType]:
+        name = aggregate_filter.key.name
+        value = aggregate_filter.value.value
+
+        if name in self.params.get("aliases", {}):
+            raise NotImplementedError("Aggregate aliases not implemented in snql field parsing yet")
+
+        value = (
+            int(to_timestamp(value))
+            if isinstance(value, datetime) and name != "timestamp"
+            else value
+        )
+
+        if aggregate_filter.operator in {"=", "!="} and value == "":
+            operator = Op.IS_NULL if aggregate_filter.operator == "=" else Op.IS_NOT_NULL
+            return Condition(name, operator)
+
+        function = self.resolve_function(name)
+
+        return Condition(function, Op(aggregate_filter.operator), value)
+
     def convert_search_filter_to_condition(
         self,
         search_filter: SearchFilter,

+ 10 - 1
src/sentry/snuba/discover.py

@@ -198,8 +198,17 @@ def wip_snql_query(
     Replacement API for query using snql, this function is still a work in
     progress and is not ready for use in production
     """
-    builder = QueryBuilder(Dataset.Discover, params, query, selected_columns, orderby, limit)
+    builder = QueryBuilder(
+        Dataset.Discover,
+        params,
+        query=query,
+        selected_columns=selected_columns,
+        orderby=orderby,
+        use_aggregate_conditions=use_aggregate_conditions,
+        limit=limit,
+    )
     snql_query = builder.get_snql_query()
+
     results = raw_snql_query(snql_query, referrer)
     return results
 

+ 48 - 0
tests/sentry/snuba/test_discover.py

@@ -521,6 +521,54 @@ class QueryIntegrationTest(SnubaTestCase, TestCase):
                 r[1] for r in expected_project_threshold_config
             ]
 
+    def test_failure_count_function(self):
+        project = self.create_project()
+
+        data = load_data("transaction", timestamp=before_now(minutes=5))
+        data["transaction"] = "/failure_count/success"
+        self.store_event(data, project_id=project.id)
+
+        data = load_data("transaction", timestamp=before_now(minutes=5))
+        data["transaction"] = "/failure_count/unknown"
+        data["contexts"]["trace"]["status"] = "unknown_error"
+        self.store_event(data, project_id=project.id)
+
+        for i in range(6):
+            data = load_data("transaction", timestamp=before_now(minutes=5))
+            data["transaction"] = f"/failure_count/{i}"
+            data["contexts"]["trace"]["status"] = "unauthenticated"
+            self.store_event(data, project_id=project.id)
+
+        data = load_data("transaction", timestamp=before_now(minutes=5))
+        data["transaction"] = "/failure_count/0"
+        data["contexts"]["trace"]["status"] = "unauthenticated"
+        self.store_event(data, project_id=project.id)
+
+        queries = [
+            ("", 8, True),
+            ("failure_count():>0", 6, True),
+            ("failure_count():>0", 8, False),
+        ]
+
+        for query, expected_length, use_aggregate_conditions in queries:
+            for query_fn in [discover.query, discover.wip_snql_query]:
+                result = query_fn(
+                    selected_columns=["transaction", "failure_count()"],
+                    query=query,
+                    orderby="transaction",
+                    params={
+                        "start": before_now(minutes=10),
+                        "end": before_now(minutes=2),
+                        "project_id": [project.id],
+                    },
+                    use_aggregate_conditions=use_aggregate_conditions,
+                )
+                data = result["data"]
+
+                assert len(data) == expected_length
+                assert data[0]["failure_count"] == 2
+                assert data[1]["failure_count"] == 1
+
     def test_transaction_status(self):
         data = load_data("transaction", timestamp=before_now(minutes=1))
         data["transaction"] = "/test_transaction/success"