Browse Source

ref: fix QuerySet soundness of event_frequency (#75150)

<!-- Describe your PR here. -->
anthony sottile 7 months ago
parent
commit
2e4a0d493c
1 changed files with 23 additions and 11 deletions
  1. 23 11
      src/sentry/rules/conditions/event_frequency.py

+ 23 - 11
src/sentry/rules/conditions/event_frequency.py

@@ -6,7 +6,7 @@ import logging
 from collections import defaultdict
 from collections.abc import Callable, Mapping
 from datetime import datetime, timedelta
-from typing import Any, Literal, NotRequired
+from typing import Any, Literal, NotRequired, TypedDict
 
 from django import forms
 from django.core.cache import cache
@@ -120,6 +120,13 @@ class EventFrequencyForm(forms.Form):
         return cleaned_data
 
 
+class _QSTypedDict(TypedDict):
+    id: int
+    type: int
+    project_id: int
+    project__organization_id: int
+
+
 class BaseEventFrequencyCondition(EventCondition, abc.ABC):
     intervals = STANDARD_INTERVALS
     form_cls = EventFrequencyForm
@@ -377,7 +384,8 @@ class BaseEventFrequencyCondition(EventCondition, abc.ABC):
         return batch_totals
 
     def get_error_and_generic_group_ids(
-        self, groups: list[QuerySet]
+        self,
+        groups: QuerySet[Group, _QSTypedDict],
     ) -> tuple[list[int], list[int]]:
         """
         Separate group ids into error group ids and generic group ids
@@ -386,14 +394,18 @@ class BaseEventFrequencyCondition(EventCondition, abc.ABC):
         error_issue_ids = []
 
         for group in groups:
-            issue_type = get_group_type_by_type_id(group.get("type"))
+            issue_type = get_group_type_by_type_id(group["type"])
             if GroupCategory(issue_type.category) == GroupCategory.ERROR:
-                error_issue_ids.append(group.get("id"))
+                error_issue_ids.append(group["id"])
             else:
-                generic_issue_ids.append(group.get("id"))
+                generic_issue_ids.append(group["id"])
         return (error_issue_ids, generic_issue_ids)
 
-    def get_value_from_groups(self, groups: list[QuerySet] | None, value: str) -> int | None:
+    def get_value_from_groups(
+        self,
+        groups: QuerySet[Group, _QSTypedDict] | None,
+        value: Literal["id", "project_id", "project__organization_id"],
+    ) -> int | None:
         result = None
         if groups:
             group = groups[0]
@@ -426,7 +438,7 @@ class EventFrequencyCondition(BaseEventFrequencyCondition):
     ) -> dict[int, int]:
         batch_sums: dict[int, int] = defaultdict(int)
         groups = Group.objects.filter(id__in=group_ids).values(
-            "id", "type", "project__organization_id"
+            "id", "type", "project_id", "project__organization_id"
         )
         error_issue_ids, generic_issue_ids = self.get_error_and_generic_group_ids(groups)
         organization_id = self.get_value_from_groups(groups, "project__organization_id")
@@ -489,7 +501,7 @@ class EventUniqueUserFrequencyCondition(BaseEventFrequencyCondition):
     ) -> dict[int, int]:
         batch_totals: dict[int, int] = defaultdict(int)
         groups = Group.objects.filter(id__in=group_ids).values(
-            "id", "type", "project__organization_id"
+            "id", "type", "project_id", "project__organization_id"
         )
         error_issue_ids, generic_issue_ids = self.get_error_and_generic_group_ids(groups)
         organization_id = self.get_value_from_groups(groups, "project__organization_id")
@@ -664,7 +676,7 @@ class EventFrequencyPercentCondition(BaseEventFrequencyCondition):
         project_id = self.get_value_from_groups(groups, "project_id")
 
         if not project_id:
-            return {group.get("id"): 0 for group in groups}
+            return {group["id"]: 0 for group in groups}
 
         session_count_last_hour = self.get_session_count(project_id, environment_id, start, end)
         avg_sessions_in_interval = self.get_session_interval(
@@ -672,13 +684,13 @@ class EventFrequencyPercentCondition(BaseEventFrequencyCondition):
         )
 
         if not avg_sessions_in_interval:
-            return {group.get("id"): 0 for group in groups}
+            return {group["id"]: 0 for group in groups}
 
         error_issue_ids, generic_issue_ids = self.get_error_and_generic_group_ids(groups)
         organization_id = self.get_value_from_groups(groups, "project__organization_id")
 
         if not (error_issue_ids and organization_id):
-            return {group.get("id"): 0 for group in groups}
+            return {group["id"]: 0 for group in groups}
 
         error_issue_count = self.get_chunked_result(
             tsdb_function=self.tsdb.get_sums,