Browse Source

fix(eap): Numeric tags (#80387)

- This fixes how the search grammar handles numeric tags, its currently
thinking its constantly a string
- Changes argument_type to a list since most functions accept both
duration and numbers -
William Mak 4 months ago
parent
commit
968d7cede2

+ 35 - 8
src/sentry/api/event_search.py

@@ -137,15 +137,16 @@ text_filter = negation? text_key sep operator? search_value
 key                    = ~r"[a-zA-Z0-9_.-]+"
 quoted_key             = '"' ~r"[a-zA-Z0-9_.:-]+" '"'
 explicit_tag_key       = "tags" open_bracket search_key closed_bracket
-explicit_typed_tag_key = "tags" open_bracket search_key spaces comma spaces search_type closed_bracket
+explicit_string_tag_key = "tags" open_bracket search_key spaces comma spaces "string" closed_bracket
+explicit_number_tag_key = "tags" open_bracket search_key spaces comma spaces "number" closed_bracket
 aggregate_key          = key open_paren spaces function_args? spaces closed_paren
 function_args          = aggregate_param (spaces comma spaces !comma aggregate_param?)*
 aggregate_param        = quoted_aggregate_param / raw_aggregate_param
 raw_aggregate_param    = ~r"[^()\t\n, \"]+"
 quoted_aggregate_param = '"' ('\\"' / ~r'[^\t\n\"]')* '"'
-search_key             = key / quoted_key
+search_key             = explicit_number_tag_key / key / quoted_key
 search_type            = "number" / "string"
-text_key               = explicit_tag_key / explicit_typed_tag_key / search_key
+text_key               = explicit_tag_key / explicit_string_tag_key / search_key
 value                  = ~r"[^()\t\n ]*"
 quoted_value           = '"' ('\\"' / ~r'[^"]')* '"'
 in_value               = (&in_value_termination in_value_char)+
@@ -571,7 +572,14 @@ class SearchConfig:
 class SearchVisitor(NodeVisitor):
     unwrapped_exceptions = (InvalidSearchQuery,)
 
-    def __init__(self, config=None, params=None, builder=None, get_field_type=None):
+    def __init__(
+        self,
+        config=None,
+        params=None,
+        builder=None,
+        get_field_type=None,
+        get_function_result_type=None,
+    ):
         super().__init__()
 
         if config is None:
@@ -595,6 +603,10 @@ class SearchVisitor(NodeVisitor):
             self.get_field_type = self.builder.get_field_type
         else:
             self.get_field_type = get_field_type
+        if get_function_result_type is None:
+            self.get_function_result_type = self.builder.get_function_result_type
+        else:
+            self.get_function_result_type = get_function_result_type
 
     @cached_property
     def key_mappings_lookup(self):
@@ -1053,8 +1065,11 @@ class SearchVisitor(NodeVisitor):
     def visit_explicit_tag_key(self, node, children):
         return SearchKey(f"tags[{children[2].name}]")
 
-    def visit_explicit_typed_tag_key(self, node, children):
-        return SearchKey(f"tags[{children[2].name},{children[6]}]")
+    def visit_explicit_string_tag_key(self, node, children):
+        return SearchKey(f"tags[{children[2].name},string]")
+
+    def visit_explicit_number_tag_key(self, node, children):
+        return SearchKey(f"tags[{children[2].name},number]")
 
     def visit_search_type(self, node, children):
         return node.text
@@ -1095,6 +1110,8 @@ class SearchVisitor(NodeVisitor):
             or key in self.config.blocked_keys
         ):
             raise InvalidSearchQuery(f"Invalid key for this search: {key}")
+        if isinstance(key, SearchKey):
+            return key
         return SearchKey(self.key_mappings_lookup.get(key, key))
 
     def visit_text_key(self, node, children):
@@ -1246,7 +1263,13 @@ QueryToken = Union[SearchFilter, QueryOp, ParenExpression]
 
 
 def parse_search_query(
-    query, config=None, params=None, builder=None, config_overrides=None, get_field_type=None
+    query,
+    config=None,
+    params=None,
+    builder=None,
+    config_overrides=None,
+    get_field_type=None,
+    get_function_result_type=None,
 ) -> list[
     SearchFilter
 ]:  # TODO: use the `Sequence[QueryToken]` type and update the code that fails type checking.
@@ -1270,5 +1293,9 @@ def parse_search_query(
         config = SearchConfig.create_from(config, **config_overrides)
 
     return SearchVisitor(
-        config, params=params, builder=builder, get_field_type=get_field_type
+        config,
+        params=params,
+        builder=builder,
+        get_field_type=get_field_type,
+        get_function_result_type=get_function_result_type,
     ).visit(tree)

+ 36 - 11
src/sentry/search/eap/columns.py

@@ -82,7 +82,7 @@ class ResolvedColumn:
 
 @dataclass
 class ArgumentDefinition:
-    argument_type: str | None = None
+    argument_types: list[str] | None = None
     # The public alias for the default arg, the SearchResolver will resolve this value
     default_arg: str | None = None
     # Whether this argument is completely ignored, used for `count()`
@@ -291,54 +291,79 @@ SPAN_FUNCTION_DEFINITIONS = {
     "sum": FunctionDefinition(
         internal_function=Function.FUNCTION_SUM,
         search_type="duration",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "avg": FunctionDefinition(
         internal_function=Function.FUNCTION_AVERAGE,
         search_type="duration",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "count": FunctionDefinition(
         internal_function=Function.FUNCTION_COUNT,
         search_type="number",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "p50": FunctionDefinition(
         internal_function=Function.FUNCTION_P50,
         search_type="duration",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "p90": FunctionDefinition(
         internal_function=Function.FUNCTION_P90,
         search_type="duration",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "p95": FunctionDefinition(
         internal_function=Function.FUNCTION_P95,
         search_type="duration",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "p99": FunctionDefinition(
         internal_function=Function.FUNCTION_P99,
         search_type="duration",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
+    ),
+    "p100": FunctionDefinition(
+        internal_function=Function.FUNCTION_MAX,
+        search_type="duration",
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "max": FunctionDefinition(
         internal_function=Function.FUNCTION_MAX,
         search_type="duration",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "min": FunctionDefinition(
         internal_function=Function.FUNCTION_MIN,
         search_type="duration",
-        arguments=[ArgumentDefinition(argument_type="duration", default_arg="span.duration")],
+        arguments=[
+            ArgumentDefinition(argument_types=["duration", "number"], default_arg="span.duration")
+        ],
     ),
     "count_unique": FunctionDefinition(
         internal_function=Function.FUNCTION_UNIQ,
         search_type="number",
         arguments=[
             ArgumentDefinition(
-                argument_type="string",
+                argument_types=["string"],
             )
         ],
     ),

+ 6 - 4
src/sentry/search/eap/spans.py

@@ -66,6 +66,7 @@ class SearchResolver:
                 querystring,
                 params=self.params.filter_params,
                 get_field_type=self.get_field_type,
+                get_function_result_type=self.get_field_type,
             )
         except ParseError as e:
             if e.expr is not None:
@@ -196,7 +197,8 @@ class SearchResolver:
                 else:
                     raise NotImplementedError("Can't filter on aggregates yet")
             else:
-                raise NotImplementedError()
+                if self.config.use_aggregate_conditions:
+                    raise NotImplementedError("Can't filter on aggregates yet")
 
         if len(parsed_terms) > 1:
             return TraceItemFilter(and_filter=AndFilter(filters=parsed_terms))
@@ -428,11 +430,11 @@ class SearchResolver:
                 )
 
             if (
-                argument.argument_type is not None
-                and parsed_argument.search_type != argument.argument_type
+                argument.argument_types is not None
+                and parsed_argument.search_type not in argument.argument_types
             ):
                 raise InvalidSearchQuery(
-                    f"{argument} is invalid for {function}, its a {parsed_argument.search_type} type field but {function} expects a {argument.argument_type} type field"
+                    f"{argument} is invalid for {function}, its a {parsed_argument.search_type} type field but {function} expects a field that are one of these types: {argument.argument_types}"
                 )
             parsed_columns.append(parsed_argument)
 

+ 143 - 5
tests/snuba/api/endpoints/test_organization_events_span_indexed.py

@@ -530,12 +530,9 @@ class OrganizationEventsSpanIndexedEndpointTest(OrganizationEventsEndpointTestBa
         assert response.data["data"] == [{"foo": "", "count()": 1}]
 
 
-@pytest.mark.xfail(
-    reason="Snuba is not stable for the EAP dataset, xfailing since its prone to failure"
-)
 class OrganizationEventsEAPSpanEndpointTest(OrganizationEventsSpanIndexedEndpointTest):
     is_eap = True
-    use_rpc = True
+    use_rpc = False
 
     def test_simple(self):
         self.store_spans(
@@ -1005,7 +1002,7 @@ class OrganizationEventsEAPSpanEndpointTest(OrganizationEventsSpanIndexedEndpoin
         # How to read this; these results mean that the extrapolated count is
         # 500k, with a lower estimated bound of ~200k, and an upper bound of 800k
         assert lower_limit == pytest.approx(190_000, abs=5000)
-        assert extrapolated == pytest.approx(500_000)
+        assert extrapolated == pytest.approx(500_000, abs=5000)
         assert upper_limit == pytest.approx(810_000, abs=5000)
 
     def test_skip_aggregate_conditions_option(self):
@@ -1051,3 +1048,144 @@ class OrganizationEventsEAPRPCSpanEndpointTest(OrganizationEventsEAPSpanEndpoint
 
     is_eap = True
     use_rpc = True
+
+    def test_span_duration(self):
+        spans = [
+            self.create_span(
+                {"description": "bar", "sentry_tags": {"status": "invalid_argument"}},
+                start_ts=self.ten_mins_ago,
+            ),
+            self.create_span(
+                {"description": "foo", "sentry_tags": {"status": "success"}},
+                start_ts=self.ten_mins_ago,
+            ),
+        ]
+        self.store_spans(spans, is_eap=self.is_eap)
+        response = self.do_request(
+            {
+                "field": ["span.duration", "description"],
+                "query": "",
+                "orderby": "description",
+                "project": self.project.id,
+                "dataset": self.dataset,
+            }
+        )
+
+        assert response.status_code == 200, response.content
+        data = response.data["data"]
+        meta = response.data["meta"]
+        assert len(data) == 2
+        assert data == [
+            {
+                "span.duration": 1000.0,
+                "description": "bar",
+                "project.name": self.project.slug,
+                "id": spans[0]["span_id"],
+            },
+            {
+                "span.duration": 1000.0,
+                "description": "foo",
+                "project.name": self.project.slug,
+                "id": spans[1]["span_id"],
+            },
+        ]
+        assert meta["dataset"] == self.dataset
+
+    @pytest.mark.xfail(reason="extrapolation not implemented yet")
+    def test_aggregate_numeric_attr_weighted(self):
+        super().test_aggregate_numeric_attr_weighted()
+
+    @pytest.mark.xfail(reason="RPC failing because of aliasing")
+    def test_numeric_attr_without_space(self):
+        super().test_numeric_attr_without_space()
+
+    @pytest.mark.xfail(reason="RPC failing because of aliasing")
+    def test_numeric_attr_with_spaces(self):
+        super().test_numeric_attr_with_spaces()
+
+    @pytest.mark.xfail(reason="RPC failing because of aliasing")
+    def test_numeric_attr_filtering(self):
+        super().test_numeric_attr_filtering()
+
+    @pytest.mark.xfail(reason="RPC failing because of aliasing")
+    def test_numeric_attr_orderby(self):
+        super().test_numeric_attr_orderby()
+
+    def test_aggregate_numeric_attr(self):
+        self.store_spans(
+            [
+                self.create_span(
+                    {
+                        "description": "foo",
+                        "sentry_tags": {"status": "success"},
+                        "tags": {"bar": "bar1"},
+                    },
+                    start_ts=self.ten_mins_ago,
+                ),
+                self.create_span(
+                    {
+                        "description": "foo",
+                        "sentry_tags": {"status": "success"},
+                        "tags": {"bar": "bar2"},
+                    },
+                    measurements={"foo": {"value": 5}},
+                    start_ts=self.ten_mins_ago,
+                ),
+            ],
+            is_eap=self.is_eap,
+        )
+
+        response = self.do_request(
+            {
+                "field": [
+                    "description",
+                    "count_unique(bar)",
+                    "count_unique(tags[bar])",
+                    "count_unique(tags[bar,string])",
+                    "count()",
+                    "count(span.duration)",
+                    "count(tags[foo,     number])",
+                    "sum(tags[foo,number])",
+                    "avg(tags[foo,number])",
+                    "p50(tags[foo,number])",
+                    # TODO: bring p75 back once its added to the rpc
+                    # "p75(tags[foo,number])",
+                    "p95(tags[foo,number])",
+                    "p99(tags[foo,number])",
+                    "p100(tags[foo,number])",
+                    "min(tags[foo,number])",
+                    "max(tags[foo,number])",
+                ],
+                "query": "",
+                "orderby": "description",
+                "project": self.project.id,
+                "dataset": self.dataset,
+            }
+        )
+
+        assert response.status_code == 200, response.content
+        assert len(response.data["data"]) == 1
+        data = response.data["data"]
+        assert data[0] == {
+            "description": "foo",
+            "count_unique(bar)": 2,
+            "count_unique(tags[bar])": 2,
+            "count_unique(tags[bar,string])": 2,
+            "count()": 2,
+            "count(span.duration)": 2,
+            "count(tags[foo,     number])": 1,
+            "sum(tags[foo,number])": 5.0,
+            "avg(tags[foo,number])": 5.0,
+            "p50(tags[foo,number])": 5.0,
+            # TODO: bring p75 back once its added to the rpc
+            # "p75(tags[foo,number])": 5.0,
+            "p95(tags[foo,number])": 5.0,
+            "p99(tags[foo,number])": 5.0,
+            "p100(tags[foo,number])": 5.0,
+            "min(tags[foo,number])": 5.0,
+            "max(tags[foo,number])": 5.0,
+        }
+
+    @pytest.mark.xfail(reason="extrapolation not implemented yet")
+    def test_margin_of_error(self):
+        super().test_margin_of_error()