Browse Source

feat(sessions): Filter by session.status in sessions API (#33906)

Add the ability to filter by session.status in the metrics
implementation of the sessions API.

Examples:
session.status:healthy
session.status:[crashed,abnormal]
!session.status:healthy

Not in this PR: Parsing complex queries with OR
or AND expressions containing session.status.
Joris Bayer 2 years ago
parent
commit
5abc4e9ee4

+ 7 - 1
src/sentry/api/endpoints/organization_sessions.py

@@ -44,7 +44,13 @@ class OrganizationSessionsEndpoint(OrganizationEventsEndpointBase):
         else:
             allowed_resolution = AllowedResolution.one_hour
 
-        return QueryDefinition(query_params, params, allowed_resolution=allowed_resolution)
+        return QueryDefinition(
+            query_params,
+            params,
+            allowed_resolution=allowed_resolution,
+            # FIXME: This won't work with duplex backend
+            allow_session_status_query=release_health.is_metrics_based(),
+        )
 
     @contextmanager
     def handle_query_errors(self):

+ 278 - 129
src/sentry/release_health/metrics_sessions_v2.py

@@ -10,18 +10,31 @@ from dataclasses import dataclass, replace
 from enum import Enum
 from typing import (
     Any,
+    FrozenSet,
     Iterable,
     List,
     Mapping,
     MutableMapping,
     Optional,
     Sequence,
+    Tuple,
+    Type,
     TypedDict,
     Union,
     cast,
 )
 
-from snuba_sdk import Condition, Direction, Granularity, Limit
+from snuba_sdk import (
+    BooleanCondition,
+    Column,
+    Condition,
+    Direction,
+    Function,
+    Granularity,
+    Limit,
+    Op,
+)
+from snuba_sdk.conditions import ConditionGroup
 from snuba_sdk.legacy import json_to_snql
 
 from sentry.api.utils import InvalidParams as UtilsInvalidParams
@@ -35,6 +48,8 @@ from sentry.snuba.dataset import EntityKey
 from sentry.snuba.metrics.datasource import get_series
 from sentry.snuba.metrics.naming_layer.public import SessionMetricKey
 from sentry.snuba.metrics.query import MetricField, OrderBy
+from sentry.snuba.metrics.query import QueryDefinition as MetricsQuery
+from sentry.snuba.metrics.utils import OrderByNotSupportedOverCompositeEntityException
 from sentry.snuba.sessions_v2 import (
     SNUBA_LIMIT,
     InvalidParams,
@@ -46,6 +61,8 @@ from sentry.snuba.sessions_v2 import (
 
 logger = logging.getLogger(__name__)
 
+Scalar = Union[int, float, None]
+
 #: Group key as featured in output format
 GroupKeyDict = TypedDict(
     "GroupKeyDict",
@@ -68,6 +85,13 @@ class SessionStatus(Enum):
     HEALTHY = "healthy"
 
 
+ALL_STATUSES = frozenset(iter(SessionStatus))
+
+
+#: Used to filter results by session.status
+StatusFilter = Optional[FrozenSet[SessionStatus]]
+
+
 @dataclass(frozen=True)
 class GroupKey:
     """Hashable version of group key dict"""
@@ -75,7 +99,7 @@ class GroupKey:
     project: Optional[int] = None
     release: Optional[str] = None
     environment: Optional[str] = None
-    session_status: Optional[str] = None
+    session_status: Optional[SessionStatus] = None
 
     @staticmethod
     def from_input_dict(dct: MetricsGroupKeyDict) -> "GroupKey":
@@ -95,7 +119,7 @@ class GroupKey:
         if self.environment:
             dct["environment"] = self.environment
         if self.session_status:
-            dct["session.status"] = self.session_status
+            dct["session.status"] = self.session_status.value
 
         return dct
 
@@ -113,28 +137,35 @@ GroupedData = Mapping[GroupKey, Any]
 
 
 class Field(ABC):
-
-    name: str
+    def __init__(
+        self,
+        name: str,
+        raw_groupby: Sequence[str],
+        status_filter: StatusFilter,
+    ):
+        self.name = name
+        self._raw_groupby = raw_groupby
+        self._status_filter = status_filter
+        self.metric_fields = self._get_metric_fields(raw_groupby, status_filter)
 
     @abstractmethod
-    def get_session_status(
-        self, metric_field: MetricField, raw_groupby: Sequence[str]
-    ) -> Optional[str]:
+    def _get_session_status(self, metric_field: MetricField) -> Optional[SessionStatus]:
         ...
 
     @abstractmethod
-    def get_metric_fields(self, raw_groupby: Sequence[str]) -> Sequence[MetricField]:
+    def _get_metric_fields(
+        self, raw_groupby: Sequence[str], status_filter: StatusFilter
+    ) -> Sequence[MetricField]:
         ...
 
     def extract_values(
         self,
-        raw_groupby: Sequence[str],
         input_groups: GroupedData,
         output_groups: GroupedData,
     ) -> None:
-        is_groupby_status = "session.status" in raw_groupby
-        for metric_field in self.get_metric_fields(raw_groupby):
-            session_status = self.get_session_status(metric_field, raw_groupby)
+        is_groupby_status = "session.status" in self._raw_groupby
+        for metric_field in self.metric_fields:
+            session_status = self._get_session_status(metric_field)
             if is_groupby_status and session_status is None:
                 # We fetched this only to be consistent with the sort order
                 # in the original implementation, don't add it to output data
@@ -145,26 +176,31 @@ class Field(ABC):
                 else metric_field.metric_name
             )
             for input_group_key, group in input_groups.items():
-                if session_status:
+                if session_status and not self._status_filter:
                     self.ensure_status_groups(input_group_key, output_groups)
                 group_key = replace(input_group_key, session_status=session_status)
                 for subgroup in ("totals", "series"):
+                    target = output_groups[group_key][subgroup]
+                    previous_value = target[self.name]
                     value = group[subgroup][field_name]
                     if isinstance(value, list):
-                        value = [self.normalize(x) for x in value]
+                        value = [
+                            self.accumulate(prev, self.normalize(x))
+                            for prev, x in zip(previous_value, value)
+                        ]
                     else:
-                        value = self.normalize(value)
-                    output_groups[group_key][subgroup][self.name] = value
+                        value = self.accumulate(previous_value, self.normalize(value))
+                    target[self.name] = value
 
     def ensure_status_groups(self, input_group_key: GroupKey, output_groups: GroupedData) -> None:
         # To be consistent with original sessions implementation,
         # always create defaults for all session status groups
         for session_status in SessionStatus:
-            group_key = replace(input_group_key, session_status=session_status.value)
+            group_key = replace(input_group_key, session_status=session_status)
             output_groups[group_key]  # creates entry in defaultdict
 
-    def get_groupby(self, raw_groupby: Sequence[str]) -> Iterable[str]:
-        for groupby in raw_groupby:
+    def get_groupby(self) -> Iterable[str]:
+        for groupby in self._raw_groupby:
             if groupby == "session.status":
                 continue
             elif groupby == "project":
@@ -172,114 +208,136 @@ class Field(ABC):
             else:
                 yield groupby
 
-    def normalize(self, value: Union[int, float, None]) -> Union[int, float, None]:
-        return cast(Union[int, float, None], finite_or_none(value))
+    def normalize(self, value: Scalar) -> Scalar:
+        return cast(Scalar, finite_or_none(value))
 
-
-class IntegerField(Field):
-    def normalize(self, value: Union[int, float, None]) -> Union[int, float, None]:
-        value = super().normalize(value)
-        if isinstance(value, float):
-            return int(value)
-        return value
+    def accumulate(self, old_value: Scalar, new_value: Scalar) -> Scalar:
+        """Combine two numbers for the same target.
+        Default is the new value"""
+        return new_value
 
 
-class SessionsField(IntegerField):
+class CountField(Field):
+    """Base class for sum(sessions) and count_unique(user)"""
 
-    name = "sum(session)"
+    status_to_metric_field: Mapping[Optional[SessionStatus], MetricField] = {}
 
-    metric_field_to_session_status = {
-        MetricField(None, SessionMetricKey.HEALTHY.value): "healthy",
-        MetricField(None, SessionMetricKey.ABNORMAL.value): "abnormal",
-        MetricField(None, SessionMetricKey.CRASHED.value): "crashed",
-        MetricField(None, SessionMetricKey.ERRORED.value): "errored",
-        MetricField(None, SessionMetricKey.ALL.value): None,
-    }
+    def get_all_field(self) -> MetricField:
+        return self.status_to_metric_field[None]
 
-    def get_session_status(
-        self, metric_field: MetricField, raw_groupby: Sequence[str]
-    ) -> Optional[str]:
-        return self.metric_field_to_session_status[metric_field]
+    def _get_metric_fields(
+        self, raw_groupby: Sequence[str], status_filter: StatusFilter
+    ) -> Sequence[MetricField]:
+        if status_filter:
+            # Restrict fields to the included ones
+            return [self.status_to_metric_field[status] for status in status_filter]
 
-    def get_metric_fields(self, raw_groupby: Sequence[str]) -> Sequence[MetricField]:
         if "session.status" in raw_groupby:
             return [
                 # Always also get ALL, because this is what we sort by
                 # in the sessions implementation, with which we want to be consistent
-                MetricField(None, SessionMetricKey.ALL.value),
+                self.get_all_field(),
                 # These are the fields we actually need:
-                MetricField(None, SessionMetricKey.HEALTHY.value),
-                MetricField(None, SessionMetricKey.ABNORMAL.value),
-                MetricField(None, SessionMetricKey.CRASHED.value),
-                MetricField(None, SessionMetricKey.ERRORED.value),
+                self.status_to_metric_field[SessionStatus.HEALTHY],
+                self.status_to_metric_field[SessionStatus.ABNORMAL],
+                self.status_to_metric_field[SessionStatus.CRASHED],
+                self.status_to_metric_field[SessionStatus.ERRORED],
             ]
-        return [MetricField(None, SessionMetricKey.ALL.value)]
+        return [self.get_all_field()]
+
+    def _get_session_status(self, metric_field: MetricField) -> Optional[SessionStatus]:
+        if "session.status" in self._raw_groupby:
+            reverse_lookup = {v: k for k, v in self.status_to_metric_field.items()}
+            return reverse_lookup[metric_field]
+        return None
 
+    def normalize(self, value: Scalar) -> Scalar:
+        value = super().normalize(value)
+        # In the sessions API, sum() and count_unique() return integers
+        if isinstance(value, float):
+            return int(value)
+        return value
 
-class UsersField(IntegerField):
-    name = "count_unique(user)"
 
-    metric_field_to_session_status = {
-        MetricField(None, SessionMetricKey.HEALTHY_USER.value): "healthy",
-        MetricField(None, SessionMetricKey.ABNORMAL_USER.value): "abnormal",
-        MetricField(None, SessionMetricKey.CRASHED_USER.value): "crashed",
-        MetricField(None, SessionMetricKey.ERRORED_USER.value): "errored",
-        MetricField(None, SessionMetricKey.ALL_USER.value): None,
+class SumSessionField(CountField):
+    status_to_metric_field = {
+        SessionStatus.HEALTHY: MetricField(None, SessionMetricKey.HEALTHY.value),
+        SessionStatus.ABNORMAL: MetricField(None, SessionMetricKey.ABNORMAL.value),
+        SessionStatus.CRASHED: MetricField(None, SessionMetricKey.CRASHED.value),
+        SessionStatus.ERRORED: MetricField(None, SessionMetricKey.ERRORED.value),
+        None: MetricField(None, SessionMetricKey.ALL.value),
     }
 
-    def get_session_status(
-        self, metric_field: MetricField, raw_groupby: Sequence[str]
-    ) -> Optional[str]:
-        return self.metric_field_to_session_status[metric_field]
+    def accumulate(self, old_value: Scalar, new_value: Scalar) -> Scalar:
+        # This is only needed for a single specific scenario:
+        # When we filter by more than one session.status (e.g. crashed and abnormal),
+        # but do *not* group by session.status, we want to sum up the values from the different metrics,
+        # e.g. session.crashed + session.abnormal
+        assert isinstance(old_value, int)
+        assert isinstance(new_value, int)
+        return old_value + new_value
 
-    def get_metric_fields(self, raw_groupby: Sequence[str]) -> Sequence[MetricField]:
-        if "session.status" in raw_groupby:
-            return [
-                # Always also get ALL, because this is what we sort by
-                # in the sessions implementation, with which we want to be consistent
-                MetricField(None, SessionMetricKey.ALL_USER.value),
-                # These are the fields we actually need:
-                MetricField(None, SessionMetricKey.HEALTHY_USER.value),
-                MetricField(None, SessionMetricKey.ABNORMAL_USER.value),
-                MetricField(None, SessionMetricKey.CRASHED_USER.value),
-                MetricField(None, SessionMetricKey.ERRORED_USER.value),
-            ]
-        return [MetricField(None, SessionMetricKey.ALL_USER.value)]
+
+class CountUniqueUser(CountField):
+    def __init__(
+        self,
+        name: str,
+        raw_groupby: Sequence[str],
+        status_filter: StatusFilter,
+    ):
+        # We cannot do set arithmetic outside of the metrics API:
+        if status_filter and len(status_filter) > 1 and "session.status" not in raw_groupby:
+            raise InvalidParams(
+                "Cannot filter count_unique by multiple session.status unless it is in groupBy"
+            )
+
+        super().__init__(name, raw_groupby, status_filter)
+
+    status_to_metric_field = {
+        SessionStatus.HEALTHY: MetricField(None, SessionMetricKey.HEALTHY_USER.value),
+        SessionStatus.ABNORMAL: MetricField(None, SessionMetricKey.ABNORMAL_USER.value),
+        SessionStatus.CRASHED: MetricField(None, SessionMetricKey.CRASHED_USER.value),
+        SessionStatus.ERRORED: MetricField(None, SessionMetricKey.ERRORED_USER.value),
+        None: MetricField(None, SessionMetricKey.ALL_USER.value),
+    }
 
 
 class DurationField(Field):
-    def __init__(self, name: SessionsQueryFunction):
-        self.name = name
+    def __init__(self, name: str, raw_groupby: Sequence[str], status_filter: StatusFilter):
         self.op = name[:3]  # That this works is just a lucky coincidence
+        super().__init__(name, raw_groupby, status_filter)
 
-    def get_session_status(
-        self, metric_field: MetricField, raw_groupby: Sequence[str]
-    ) -> Optional[str]:
+    def _get_session_status(self, metric_field: MetricField) -> Optional[SessionStatus]:
         assert metric_field == MetricField(self.op, SessionMetricKey.DURATION.value)
-        if "session.status" in raw_groupby:
-            return "healthy"
+        if "session.status" in self._raw_groupby:
+            return SessionStatus.HEALTHY
         return None
 
-    def get_metric_fields(self, raw_groupby: Sequence[str]) -> Sequence[MetricField]:
-        return [MetricField(self.op, SessionMetricKey.DURATION.value)]
+    def _get_metric_fields(
+        self, raw_groupby: Sequence[str], status_filter: StatusFilter
+    ) -> Sequence[MetricField]:
+        if status_filter is None or SessionStatus.HEALTHY in status_filter:
+            return [MetricField(self.op, SessionMetricKey.DURATION.value)]
 
-    def normalize(self, value: Union[int, float, None]) -> Union[int, float, None]:
+        return []  # TODO: test if we can handle zero fields
+
+    def normalize(self, value: Scalar) -> Scalar:
         value = finite_or_none(value)
         if value is not None:
             value *= 1000
         return value
 
 
-COLUMN_MAP = {
-    SessionsField.name: SessionsField(),
-    UsersField.name: UsersField(),
-    "avg(session.duration)": DurationField("avg(session.duration)"),
-    "p50(session.duration)": DurationField("p50(session.duration)"),
-    "p75(session.duration)": DurationField("p75(session.duration)"),
-    "p90(session.duration)": DurationField("p90(session.duration)"),
-    "p95(session.duration)": DurationField("p95(session.duration)"),
-    "p99(session.duration)": DurationField("p99(session.duration)"),
-    "max(session.duration)": DurationField("max(session.duration)"),
+FIELD_MAP: Mapping[SessionsQueryFunction, Type[Field]] = {
+    "sum(session)": SumSessionField,
+    "count_unique(user)": CountUniqueUser,
+    "avg(session.duration)": DurationField,
+    "p50(session.duration)": DurationField,
+    "p75(session.duration)": DurationField,
+    "p90(session.duration)": DurationField,
+    "p95(session.duration)": DurationField,
+    "p99(session.duration)": DurationField,
+    "max(session.duration)": DurationField,
 }
 
 
@@ -293,35 +351,49 @@ def run_sessions_query(
     # backend runs
     query = deepcopy(query)
 
-    from sentry.snuba.metrics.query import QueryDefinition as MetricsQuery
+    intervals = get_timestamps(query)
+
+    conditions = _get_filter_conditions(query.conditions)
+    where, status_filter = _extract_status_filter_from_conditions(conditions)
+    if status_filter == frozenset():
+        # There was a condition that cannot be met, such as 'session:status:foo'
+        # no need to query metrics, just return empty groups.
+        return _empty_result(query)
+
+    fields = {
+        field_name: FIELD_MAP[field_name](field_name, query.raw_groupby, status_filter)
+        for field_name in query.raw_fields
+    }
+
+    # Remove fields that do not query anything:
+    fields = {field_name: field for field_name, field in fields.items() if field.metric_fields}
 
-    fields = [COLUMN_MAP[field_name] for field_name in query.raw_fields]
+    if not fields:
+        return _empty_result(query)
 
     filter_keys = query.filter_keys.copy()
     project_ids = filter_keys.pop("project_id")
     assert not filter_keys
 
-    orderby = _parse_orderby(query)
+    orderby = _parse_orderby(query, fields)
     if orderby is None:
         # We only return the top-N groups, based on the first field that is being
         # queried, assuming that those are the most relevant to the user.
         # In a future iteration we might expose an `orderBy` query parameter.
-        primary_metric_field = _get_primary_field(fields, query.raw_groupby)
+        primary_metric_field = _get_primary_field(list(fields.values()), query.raw_groupby)
         orderby = OrderBy(primary_metric_field, Direction.DESC)
 
-    max_groups = SNUBA_LIMIT // len(get_timestamps(query))
+    max_groups = SNUBA_LIMIT // len(intervals)
 
     metrics_query = MetricsQuery(
         org_id,
         project_ids,
-        list({column for field in fields for column in field.get_metric_fields(query.raw_groupby)}),
+        list({column for field in fields.values() for column in field.metric_fields}),
         query.start,
         query.end,
         Granularity(query.rollup),
-        where=_get_filter_conditions(query.conditions),
-        groupby=list(
-            {column for field in fields for column in field.get_groupby(query.raw_groupby)}
-        ),
+        where=where,
+        groupby=list({column for field in fields.values() for column in field.get_groupby()}),
         orderby=orderby,
         limit=Limit(max_groups),
     )
@@ -329,24 +401,28 @@ def run_sessions_query(
     # TODO: Stop passing project IDs everywhere
     projects = Project.objects.get_many_from_cache(project_ids)
     try:
-        results = get_series(projects, metrics_query)
+        metrics_results = get_series(projects, metrics_query)
+    except OrderByNotSupportedOverCompositeEntityException:
+        raise InvalidParams(f"Cannot order by {query.raw_orderby[0]} with the current filters")
     except UtilsInvalidParams as e:
         raise InvalidParams(e)
 
-    input_groups = {GroupKey.from_input_dict(group["by"]): group for group in results["groups"]}
+    input_groups = {
+        GroupKey.from_input_dict(group["by"]): group for group in metrics_results["groups"]
+    }
 
     output_groups: MutableMapping[GroupKey, Group] = defaultdict(
         lambda: {
             "totals": {field: default_for(field) for field in query.raw_fields},
             "series": {
-                field: len(results["intervals"]) * [default_for(field)]
+                field: len(metrics_results["intervals"]) * [default_for(field)]
                 for field in query.raw_fields
             },
         }
     )
 
-    for field in fields:
-        field.extract_values(query.raw_groupby, input_groups, output_groups)
+    for field in fields.values():
+        field.extract_values(input_groups, output_groups)
 
     if not output_groups:
         # Generate default groups to be consistent with original sessions_v2
@@ -357,23 +433,33 @@ def run_sessions_query(
         elif ["session.status"] == query.raw_groupby:
             for status in SessionStatus:
                 # Create entry in default dict:
-                output_groups[GroupKey(session_status=status.value)]
-
-    # Convert group keys back to dictionaries:
-    results["groups"] = [
-        {"by": group_key.to_output_dict(), **group} for group_key, group in output_groups.items()  # type: ignore
-    ]
+                output_groups[GroupKey(session_status=status)]
+
+    return {
+        "groups": [
+            # Convert group keys back to dictionaries:
+            {"by": group_key.to_output_dict(), **group}  # type: ignore
+            for group_key, group in output_groups.items()
+        ],
+        "start": isoformat_z(metrics_results["start"]),
+        "end": isoformat_z(metrics_results["end"]),
+        "intervals": [isoformat_z(ts) for ts in metrics_results["intervals"]],
+        "query": query.query,
+    }
 
-    # Finally, serialize timestamps:
-    results["start"] = isoformat_z(results["start"])
-    results["end"] = isoformat_z(results["end"])
-    results["intervals"] = [isoformat_z(ts) for ts in results["intervals"]]
-    results["query"] = results.get("query", "")
 
-    return cast(SessionsQueryResult, results)
+def _empty_result(query: QueryDefinition) -> SessionsQueryResult:
+    intervals = get_timestamps(query)
+    return {
+        "groups": [],
+        "start": intervals[0],
+        "end": intervals[-1],
+        "intervals": intervals,
+        "query": query.query,
+    }
 
 
-def _get_filter_conditions(conditions: Sequence[Condition]) -> Any:
+def _get_filter_conditions(conditions: Any) -> ConditionGroup:
     """Translate given conditions to snql"""
     dummy_entity = EntityKey.MetricsSets.value
     return json_to_snql(
@@ -381,7 +467,70 @@ def _get_filter_conditions(conditions: Sequence[Condition]) -> Any:
     ).where
 
 
-def _parse_orderby(query: QueryDefinition) -> Optional[OrderBy]:
+def _extract_status_filter_from_conditions(
+    conditions: ConditionGroup,
+) -> Tuple[ConditionGroup, StatusFilter]:
+    """Split conditions into metrics conditions and a filter on session.status"""
+    if not conditions:
+        return conditions, None
+    where, status_filters = zip(*map(_transform_single_condition, conditions))
+    where = [condition for condition in where if condition is not None]
+    status_filters = [f for f in status_filters if f is not None]
+    if status_filters:
+        status_filters = frozenset.intersection(*status_filters)
+    else:
+        status_filters = None
+    return where, status_filters
+
+
+def _transform_single_condition(
+    condition: Union[Condition, BooleanCondition]
+) -> Tuple[Optional[Union[Condition, BooleanCondition]], StatusFilter]:
+    if isinstance(condition, Condition):
+        if condition.lhs == Function("ifNull", parameters=[Column("session.status"), ""]):
+            # HACK: metrics tags are never null. We should really
+            # write our own parser for this.
+            condition = replace(condition, lhs=Column("session.status"))
+
+        if condition.lhs == Column("session.status"):
+            if condition.op == Op.EQ:
+                return None, _parse_session_status(condition.rhs)
+            if condition.op == Op.NEQ:
+                return None, ALL_STATUSES - _parse_session_status(condition.rhs)
+            if condition.op == Op.IN:
+                return None, frozenset.union(
+                    *[_parse_session_status(status) for status in condition.rhs]
+                )
+            if condition.op == Op.NOT_IN:
+                return None, ALL_STATUSES - frozenset.union(
+                    *[_parse_session_status(status) for status in condition.rhs]
+                )
+            raise InvalidParams("Unable to resolve session.status filter")
+
+    if "session.status" in str(condition):
+        # Anything not handled by the code above cannot be parsed for now,
+        # for two reasons:
+        # 1) Queries like session.status:healthy OR release:foo are hard to
+        #    translate, because they would require different conditions on the separate
+        #    metric fields.
+        # 2) AND and OR conditions come in the form `Condition(Function("or", [...]), Op.EQ, 1)`
+        #    where [...] can again contain any condition encoded as a Function. For this, we would
+        #    have to replicate the translation code above.
+        raise InvalidParams("Unable to parse condition with session.status")
+
+    return condition, None
+
+
+def _parse_session_status(status: Any) -> FrozenSet[SessionStatus]:
+    try:
+        return frozenset([SessionStatus(status)])
+    except ValueError:
+        return frozenset()
+
+
+def _parse_orderby(
+    query: QueryDefinition, fields: Mapping[SessionsQueryFunction, Field]
+) -> Optional[OrderBy]:
     orderbys = query.raw_orderby
     if orderbys == []:
         return None
@@ -400,14 +549,14 @@ def _parse_orderby(query: QueryDefinition) -> Optional[OrderBy]:
     assert query.raw_fields
     if orderby not in query.raw_fields:
         raise InvalidParams("'orderBy' must be one of the provided 'fields'")
-    field = COLUMN_MAP[orderby]
 
-    metric_fields = field.get_metric_fields(query.raw_groupby)
+    field = fields[orderby]
 
-    # Because we excluded groupBy session status, we should have a one-to-one mapping now
-    assert len(metric_fields) == 1
+    if len(field.metric_fields) != 1:
+        # This can still happen when we filter by session.status
+        raise InvalidParams(f"Cannot order by {field.name} with the current filters")
 
-    return OrderBy(metric_fields[0], direction)
+    return OrderBy(field.metric_fields[0], direction)
 
 
 def _get_primary_field(fields: Sequence[Field], raw_groupby: Sequence[str]) -> MetricField:
@@ -415,7 +564,7 @@ def _get_primary_field(fields: Sequence[Field], raw_groupby: Sequence[str]) -> M
     primary_metric_field = None
     for i, field in enumerate(fields):
         if i == 0 or field.name == "sum(session)":
-            primary_metric_field = field.get_metric_fields(raw_groupby)[0]
+            primary_metric_field = field.metric_fields[0]
 
     assert primary_metric_field
     return primary_metric_field

+ 2 - 1
src/sentry/snuba/metrics/fields/base.py

@@ -75,6 +75,7 @@ from sentry.snuba.metrics.utils import (
     MetricOperationType,
     MetricType,
     NotSupportedOverCompositeEntityException,
+    OrderByNotSupportedOverCompositeEntityException,
     combine_dictionary_of_list_values,
 )
 from sentry.utils.snuba import raw_snql_query
@@ -704,7 +705,7 @@ class CompositeEntityDerivedMetric(DerivedMetricExpression):
         projects: Sequence[Project],
         query_definition: QueryDefinition,
     ) -> List[OrderBy]:
-        raise NotSupportedOverCompositeEntityException(
+        raise OrderByNotSupportedOverCompositeEntityException(
             f"It is not possible to orderBy field "
             f"{get_public_name_from_mri(self.metric_mri)} as it does not "
             f"have a direct mapping to a query alias"

+ 5 - 0
src/sentry/snuba/metrics/utils.py

@@ -26,6 +26,7 @@ __all__ = (
     "MetricDoesNotExistException",
     "MetricDoesNotExistInIndexer",
     "NotSupportedOverCompositeEntityException",
+    "OrderByNotSupportedOverCompositeEntityException",
     "MetricEntity",
     "UNALLOWED_TAGS",
     "combine_dictionary_of_list_values",
@@ -206,3 +207,7 @@ class DerivedMetricParseException(DerivedMetricException):
 
 class NotSupportedOverCompositeEntityException(DerivedMetricException):
     ...
+
+
+class OrderByNotSupportedOverCompositeEntityException(NotSupportedOverCompositeEntityException):
+    ...

+ 18 - 4
src/sentry/snuba/sessions_v2.py

@@ -224,8 +224,9 @@ CONDITION_COLUMNS = ["project", "environment", "release"]
 FILTER_KEY_COLUMNS = ["project_id"]
 
 
-def resolve_column(col):
-    if col in CONDITION_COLUMNS:
+def resolve_column(col, extra_columns=None):
+    condition_columns = CONDITION_COLUMNS + (extra_columns or [])
+    if col in condition_columns:
         return col
     raise InvalidField(f'Invalid query field: "{col}"')
 
@@ -253,7 +254,13 @@ class QueryDefinition:
     `fields` and `groupby` definitions as [`ColumnDefinition`] objects.
     """
 
-    def __init__(self, query, params, allowed_resolution: AllowedResolution):
+    def __init__(
+        self,
+        query,
+        params,
+        allowed_resolution: AllowedResolution,
+        allow_session_status_query: bool = False,
+    ):
         self.query = query.get("query", "")
         self.raw_fields = raw_fields = query.getlist("field", [])
         self.raw_groupby = raw_groupby = query.getlist("groupBy", [])
@@ -303,7 +310,14 @@ class QueryDefinition:
 
         # this makes sure that literals in complex queries are properly quoted,
         # and unknown fields are raised as errors
-        conditions = [resolve_condition(c, resolve_column) for c in snuba_filter.conditions]
+        if allow_session_status_query:
+            # NOTE: "''" is added because we use the event search parser, which
+            # resolves "session.status" to ifNull(..., "''")
+            column_resolver = lambda col: resolve_column(col, ["session.status", "''"])
+        else:
+            column_resolver = resolve_column
+
+        conditions = [resolve_condition(c, column_resolver) for c in snuba_filter.conditions]
         filter_keys = {
             resolve_filter_key(key): value for key, value in snuba_filter.filter_keys.items()
         }

+ 87 - 0
tests/sentry/release_health/test_metrics_sessions_v2.py

@@ -2,11 +2,18 @@ from itertools import chain, combinations
 from typing import Iterable, List
 from unittest.mock import patch
 
+import pytest
 from django.urls import reverse
+from snuba_sdk import Column, Condition, Function, Op
 
 from sentry.release_health.duplex import compare_results
 from sentry.release_health.metrics import MetricsReleaseHealthBackend
+from sentry.release_health.metrics_sessions_v2 import (
+    SessionStatus,
+    _extract_status_filter_from_conditions,
+)
 from sentry.release_health.sessions import SessionsReleaseHealthBackend
+from sentry.snuba.sessions_v2 import InvalidParams
 from sentry.testutils.cases import APITestCase, SnubaTestCase
 from tests.snuba.api.endpoints.test_organization_sessions import result_sorted
 
@@ -85,3 +92,83 @@ class MetricsSessionsV2Test(APITestCase, SnubaTestCase):
 def _session_groupby_powerset() -> Iterable[str]:
     keys = ["project", "release", "environment", "session.status"]
     return chain.from_iterable((combinations(keys, size)) for size in range(len(keys) + 1))
+
+
+@pytest.mark.parametrize(
+    "input, expected_output, expected_status_filter",
+    [
+        (
+            [
+                Condition(Column("release"), Op.EQ, "foo"),
+                Condition(Column("session.status"), Op.IN, ["abnormal", "errored"]),
+            ],
+            [Condition(Column("release"), Op.EQ, "foo")],
+            {SessionStatus.ABNORMAL, SessionStatus.ERRORED},
+        ),
+        (
+            [
+                Condition(Column("release"), Op.EQ, "foo"),
+                Condition(Column("session.status"), Op.EQ, "bogus"),
+            ],
+            [Condition(Column("release"), Op.EQ, "foo")],
+            frozenset(),
+        ),
+        (
+            [
+                Condition(Column("release"), Op.EQ, "foo"),
+                Condition(Column("session.status"), Op.NEQ, "abnormal"),
+            ],
+            [Condition(Column("release"), Op.EQ, "foo")],
+            {SessionStatus.HEALTHY, SessionStatus.ERRORED, SessionStatus.CRASHED},
+        ),
+        (
+            [
+                Condition(Column("release"), Op.EQ, "foo"),
+                Condition(Column("session.status"), Op.NOT_IN, ["abnormal", "bogus"]),
+            ],
+            [Condition(Column("release"), Op.EQ, "foo")],
+            {SessionStatus.HEALTHY, SessionStatus.ERRORED, SessionStatus.CRASHED},
+        ),
+        (
+            [
+                Condition(Column("session.status"), Op.EQ, "abnormal"),
+                Condition(Column("session.status"), Op.EQ, "errored"),
+            ],
+            [],
+            frozenset(),
+        ),
+    ],
+)
+def test_transform_conditions(input, expected_output, expected_status_filter):
+    output, status_filter = _extract_status_filter_from_conditions(input)
+    assert output == expected_output
+    assert status_filter == expected_status_filter
+
+
+@pytest.mark.parametrize("input", [[Condition(Column("release"), Op.EQ, "foo")]])
+def test_transform_conditions_nochange(input):
+    output, status_filter = _extract_status_filter_from_conditions(input)
+    assert input == output
+    assert status_filter is None
+
+
+@pytest.mark.parametrize(
+    "input",
+    [
+        [
+            Condition(
+                Function(
+                    "or",
+                    [
+                        Function("equals", ["release", "foo"]),
+                        Function("equals", ["session.status", "foo"]),
+                    ],
+                ),
+                Op.EQ,
+                1,
+            )
+        ],
+    ],
+)
+def test_transform_conditions_illegal(input):
+    pytest.raises(InvalidParams, _extract_status_filter_from_conditions, input)

+ 97 - 1
tests/snuba/api/endpoints/test_organization_sessions.py

@@ -969,7 +969,6 @@ class OrganizationSessionsEndpointMetricsTest(
 
     @freeze_time(MOCK_DATETIME)
     def test_orderby(self):
-
         response = self.do_request(
             {
                 "project": [-1],
@@ -1120,3 +1119,100 @@ class OrganizationSessionsEndpointMetricsTest(
                 ).status_code
                 == 200
             )
+
+    @freeze_time(MOCK_DATETIME)
+    def test_filter_by_session_status(self):
+        default_request = {
+            "project": [-1],
+            "statsPeriod": "1d",
+            "interval": "1d",
+        }
+
+        def req(**kwargs):
+            return self.do_request(dict(default_request, **kwargs))
+
+        response = req(field=["sum(session)"], query="session.status:bogus")
+        assert response.status_code == 200, response.content
+        assert result_sorted(response.data)["groups"] == []
+
+        response = req(field=["sum(session)"], query="!session.status:healthy")
+        assert response.status_code == 200, response.content
+        assert result_sorted(response.data)["groups"] == [
+            {"by": {}, "series": {"sum(session)": [3]}, "totals": {"sum(session)": 3}}
+        ]
+
+        # sum(session) filtered by multiple statuses adds them
+        response = req(field=["sum(session)"], query="session.status:[healthy, errored]")
+        assert response.status_code == 200, response.content
+        assert result_sorted(response.data)["groups"] == [
+            {"by": {}, "series": {"sum(session)": [8]}, "totals": {"sum(session)": 8}}
+        ]
+
+        response = req(
+            field=["sum(session)"],
+            query="session.status:[healthy, errored]",
+            groupBy="session.status",
+        )
+        assert response.status_code == 200, response.content
+        assert result_sorted(response.data)["groups"] == [
+            {
+                "by": {"session.status": "errored"},
+                "totals": {"sum(session)": 2},
+                "series": {"sum(session)": [2]},
+            },
+            {
+                "by": {"session.status": "healthy"},
+                "totals": {"sum(session)": 6},
+                "series": {"sum(session)": [6]},
+            },
+        ]
+
+        response = req(field=["sum(session)"], query="session.status:healthy release:foo@1.1.0")
+        assert response.status_code == 200, response.content
+        assert result_sorted(response.data)["groups"] == [
+            {"by": {}, "series": {"sum(session)": [1]}, "totals": {"sum(session)": 1}}
+        ]
+
+        response = req(field=["sum(session)"], query="session.status:healthy OR release:foo@1.1.0")
+        assert response.status_code == 400, response.data
+        assert response.data == {"detail": "Unable to parse condition with session.status"}
+
+        # count_unique(user) does not work with multiple session statuses selected
+        response = req(field=["count_unique(user)"], query="session.status:[healthy, errored]")
+        assert response.status_code == 400, response.data
+        assert response.data == {
+            "detail": "Cannot filter count_unique by multiple session.status unless it is in groupBy"
+        }
+
+        response = req(field=["p95(session.duration)"], query="session.status:abnormal")
+        assert response.status_code == 200, response.content
+        assert result_sorted(response.data)["groups"] == []
+
+    @freeze_time(MOCK_DATETIME)
+    def test_filter_by_session_status_with_orderby(self):
+        default_request = {
+            "project": [-1],
+            "statsPeriod": "1d",
+            "interval": "1d",
+        }
+
+        def req(**kwargs):
+            return self.do_request(dict(default_request, **kwargs))
+
+        response = req(
+            field=["sum(session)"],
+            query="session.status:[abnormal,crashed]",
+            groupBy="release",
+            orderBy="sum(session)",
+        )
+        assert response.status_code == 400, response.content
+        assert response.data == {"detail": "Cannot order by sum(session) with the current filters"}
+
+        response = req(
+            field=["sum(session)"],
+            query="session.status:healthy",
+            groupBy="release",
+            orderBy="sum(session)",
+        )
+        assert response.status_code == 400, response.content
+        assert response.data == {"detail": "Cannot order by sum(session) with the current filters"}