Browse Source

ref(replays): refactor count to use helper function (#53679)

refactors our replay_count endpoint to separate the specific endpoint
logic from the counting, so that way we can use this function internally
as well. slight changes to make type friendly, but behavior should be
the same i think.

---------

Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
Josh Ferge 1 year ago
parent
commit
7a42a3ea56

+ 0 - 1
pyproject.toml

@@ -717,7 +717,6 @@ module = [
     "sentry.release_health.metrics_sessions_v2",
     "sentry.release_health.sessions",
     "sentry.release_health.tasks",
-    "sentry.replays.endpoints.organization_replay_count",
     "sentry.replays.endpoints.project_replay_clicks_index",
     "sentry.replays.endpoints.project_replay_recording_segment_details",
     "sentry.replays.endpoints.project_replay_recording_segment_index",

+ 5 - 101
src/sentry/replays/endpoints/organization_replay_count.py

@@ -1,8 +1,5 @@
 from __future__ import annotations
 
-from collections import defaultdict
-from typing import Any
-
 from rest_framework import status
 from rest_framework.response import Response
 from snuba_sdk import Request
@@ -11,13 +8,9 @@ from sentry import features
 from sentry.api.base import region_silo_endpoint
 from sentry.api.bases import NoProjects
 from sentry.api.bases.organization_events import OrganizationEventsV2EndpointBase
-from sentry.api.event_search import parse_search_query
 from sentry.exceptions import InvalidSearchQuery
 from sentry.models import Organization
-from sentry.replays.query import query_replays_count
-from sentry.search.events.builder import QueryBuilder
-from sentry.search.events.types import ParamsType, SnubaParams
-from sentry.snuba.dataset import Dataset
+from sentry.replays.usecases.replay_counts import get_replay_counts
 from sentry.types.ratelimit import RateLimit, RateLimitCategory
 
 MAX_REPLAY_COUNT = 51
@@ -58,99 +51,10 @@ class OrganizationReplayCountEndpoint(OrganizationEventsV2EndpointBase):
             return Response({})
 
         try:
-            replay_ids_mapping = get_replay_id_mappings(request, params, snuba_params)
+            replay_counts = get_replay_counts(
+                snuba_params, params, request.GET.get("query"), request.GET.get("returnIds")
+            )
         except (InvalidSearchQuery, ValueError) as e:
             return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
 
-        replay_results = query_replays_count(
-            project_ids=[p.id for p in snuba_params.projects],
-            start=snuba_params.start,
-            end=snuba_params.end,
-            replay_ids=list(replay_ids_mapping.keys()),
-            tenant_ids={"organization_id": organization.id},
-        )
-
-        if request.GET.get("returnIds"):
-            return self.respond(get_replay_ids(replay_results, replay_ids_mapping))
-        else:
-            return self.respond(get_counts(replay_results, replay_ids_mapping))
-
-
-def get_counts(replay_results: Any, replay_ids_mapping: dict[str, list[str]]) -> dict[str, int]:
-    ret: dict[str, int] = defaultdict(int)
-    for row in replay_results["data"]:
-        identifiers = replay_ids_mapping[row["replay_id"]]
-        for identifier in identifiers:
-            ret[identifier] = min(ret[identifier] + 1, MAX_REPLAY_COUNT)
-    return ret
-
-
-def get_replay_ids(
-    replay_results: Any, replay_ids_mapping: dict[str, list[str]]
-) -> dict[str, list[str]]:
-    ret: dict[str, list[str]] = defaultdict(list)
-    for row in replay_results["data"]:
-        identifiers = replay_ids_mapping[row["replay_id"]]
-        for identifier in identifiers:
-            if len(ret[identifier]) < MAX_REPLAY_COUNT:
-                ret[identifier].append(row["replay_id"])
-    return ret
-
-
-def get_replay_id_mappings(
-    request: Request, params: ParamsType, snuba_params: SnubaParams
-) -> dict[str, list[str]]:
-
-    select_column, value = get_select_column(request.GET.get("query"))
-    query = request.GET.get("query")
-
-    query = query + FILTER_HAS_A_REPLAY
-
-    if select_column == "replay_id":
-        # just return a mapping of replay_id:replay_id instead of hitting discover
-        # if we want to validate list of replay_ids existence
-        return {v: [v] for v in value}
-
-    builder = QueryBuilder(
-        dataset=Dataset.Discover,
-        params=params,
-        snuba_params=snuba_params,
-        selected_columns=["group_uniq_array(100,replayId)", select_column],
-        query=request.GET.get("query"),
-        limit=25,
-        offset=0,
-        functions_acl=["group_uniq_array"],
-    )
-
-    discover_results = builder.run_query(
-        referrer="api.organization-issue-replay-count", use_cache=True
-    )
-
-    replay_id_to_issue_map = defaultdict(list)
-
-    for row in discover_results["data"]:
-        for replay_id in row["group_uniq_array_100_replayId"]:
-            replay_id_to_issue_map[replay_id].append(row[select_column])
-
-    return replay_id_to_issue_map
-
-
-def get_select_column(query: str) -> tuple[str, list[Any]]:
-    parsed_query = parse_search_query(query)
-
-    select_column_conditions = [
-        cond for cond in parsed_query if cond.key.name in ["issue.id", "transaction", "replay_id"]
-    ]
-
-    if len(select_column_conditions) > 1:
-        raise ValueError("Must provide only one of: issue.id, transaction, replay_id")
-
-    if len(select_column_conditions) == 0:
-        raise ValueError("Must provide at least one issue.id, transaction, or replay_id")
-
-    condition = select_column_conditions[0]
-
-    if len(condition.value.raw_value) > MAX_VALS_PROVIDED[condition.key.name]:
-        raise ValueError("Too many values provided")
-
-    return condition.key.name, condition.value.raw_value
+        return self.respond(replay_counts)

+ 125 - 0
src/sentry/replays/usecases/replay_counts.py

@@ -0,0 +1,125 @@
+from __future__ import annotations
+
+from collections import defaultdict
+from collections.abc import Sequence
+from typing import Any
+
+from sentry.api.event_search import parse_search_query
+from sentry.replays.query import query_replays_count
+from sentry.search.events.builder import QueryBuilder
+from sentry.search.events.types import ParamsType, SnubaParams
+from sentry.snuba.dataset import Dataset
+
+MAX_REPLAY_COUNT = 51
+MAX_VALS_PROVIDED = {
+    "issue.id": 25,
+    "transaction": 25,
+    "replay_id": 100,
+}
+
+FILTER_HAS_A_REPLAY = " AND !replayId:''"
+
+
+def get_replay_counts(
+    snuba_params: SnubaParams, params: ParamsType, query, return_ids
+) -> dict[str, Any]:
+    if snuba_params.start is None or snuba_params.end is None or snuba_params.organization is None:
+        raise ValueError("Must provide start and end")
+
+    replay_ids_mapping = get_replay_id_mappings(query, params, snuba_params)
+
+    replay_results = query_replays_count(
+        project_ids=[p.id for p in snuba_params.projects],
+        start=snuba_params.start,
+        end=snuba_params.end,
+        replay_ids=list(replay_ids_mapping.keys()),
+        tenant_ids={"organization_id": snuba_params.organization.id},
+    )
+
+    if return_ids:
+        return get_replay_ids(replay_results, replay_ids_mapping)
+    else:
+        return get_counts(replay_results, replay_ids_mapping)
+
+
+def get_replay_id_mappings(query, params, snuba_params) -> dict[str, list[str]]:
+
+    select_column, value = get_select_column(query)
+    # query = query + FILTER_HAS_A_REPLAY
+    # XXX(jferge): TODO: the above line didn't actually do anything in the previous code.
+    # will re-enable this line after the refactor is merged.
+
+    if select_column == "replay_id":
+        # just return a mapping of replay_id:replay_id instead of hitting discover
+        # if we want to validate list of replay_ids existence
+        return {v: [v] for v in value}
+
+    builder = QueryBuilder(
+        dataset=Dataset.Discover,
+        params=params,
+        snuba_params=snuba_params,
+        selected_columns=["group_uniq_array(100,replayId)", select_column],
+        query=query,
+        limit=25,
+        offset=0,
+        functions_acl=["group_uniq_array"],
+    )
+
+    discover_results = builder.run_query(
+        referrer="api.organization-issue-replay-count", use_cache=True
+    )
+
+    replay_id_to_issue_map = defaultdict(list)
+
+    for row in discover_results["data"]:
+        for replay_id in row["group_uniq_array_100_replayId"]:
+            replay_id_to_issue_map[replay_id].append(row[select_column])
+
+    return replay_id_to_issue_map
+
+
+def get_counts(replay_results: Any, replay_ids_mapping: dict[str, list[str]]) -> dict[str, int]:
+    ret: dict[str, int] = defaultdict(int)
+    for row in replay_results["data"]:
+        identifiers = replay_ids_mapping[row["replay_id"]]
+        for identifier in identifiers:
+            ret[identifier] = min(ret[identifier] + 1, MAX_REPLAY_COUNT)
+    return ret
+
+
+def get_replay_ids(
+    replay_results: Any, replay_ids_mapping: dict[str, list[str]]
+) -> dict[str, list[str]]:
+    ret: dict[str, list[str]] = defaultdict(list)
+    for row in replay_results["data"]:
+        identifiers = replay_ids_mapping[row["replay_id"]]
+        for identifier in identifiers:
+            if len(ret[identifier]) < MAX_REPLAY_COUNT:
+                ret[identifier].append(row["replay_id"])
+    return ret
+
+
+def get_select_column(query: str) -> tuple[str, Sequence[Any]]:
+    parsed_query = parse_search_query(query)
+
+    select_column_conditions = [
+        cond for cond in parsed_query if cond.key.name in ["issue.id", "transaction", "replay_id"]
+    ]
+
+    if len(select_column_conditions) > 1:
+        raise ValueError("Must provide only one of: issue.id, transaction, replay_id")
+
+    if len(select_column_conditions) == 0:
+        raise ValueError("Must provide at least one issue.id, transaction, or replay_id")
+
+    condition = select_column_conditions[0]
+
+    if not isinstance(condition.value.raw_value, Sequence) or isinstance(
+        condition.value.raw_value, str
+    ):
+        raise ValueError("Condition value must be a list of strings")
+
+    if len(condition.value.raw_value) > MAX_VALS_PROVIDED[condition.key.name]:
+        raise ValueError("Too many values provided")
+
+    return condition.key.name, condition.value.raw_value