Browse Source

chore(subscriptions): Add Typing to the SubscriptionProcessor (#46614)

Add typing to the subscription processor file. In a follow up PR I'll
add typing to [this
file](https://github.com/getsentry/sentry/blob/master/src/sentry/incidents/tasks.py)
(which calls it) so we can make performance improvements with less risk.
Colleen O'Rourke 1 year ago
parent
commit
57ddf46dc3
2 changed files with 95 additions and 46 deletions
  1. 1 0
      mypy.ini
  2. 94 46
      src/sentry/incidents/subscription_processor.py

+ 1 - 0
mypy.ini

@@ -71,6 +71,7 @@ files = fixtures/mypy-stubs,
         src/sentry/grouping/strategies/template.py,
         src/sentry/grouping/strategies/utils.py,
         src/sentry/incidents/charts.py,
+        src/sentry/incidents/subscription_processor.py,
         src/sentry/ingest/billing_metrics_consumer.py,
         src/sentry/ingest/transaction_clusterer/,
         src/sentry/integrations/base.py,

+ 94 - 46
src/sentry/incidents/subscription_processor.py

@@ -1,8 +1,10 @@
+from __future__ import annotations
+
 import logging
 import operator
 from copy import deepcopy
-from datetime import timedelta
-from typing import Optional
+from datetime import datetime, timedelta
+from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, TypeVar, cast
 
 from django.conf import settings
 from django.db import transaction
@@ -37,9 +39,11 @@ from sentry.snuba.entity_subscription import (
     get_entity_key_from_query_builder,
     get_entity_subscription_from_snuba_query,
 )
+from sentry.snuba.models import QuerySubscription
 from sentry.snuba.tasks import build_query_builder
 from sentry.utils import metrics, redis
 from sentry.utils.dates import to_datetime, to_timestamp
+from sentry.utils.redis import RetryingRedisCluster
 
 logger = logging.getLogger(__name__)
 REDIS_TTL = int(timedelta(days=7).total_seconds())
@@ -55,6 +59,17 @@ ALERT_RULE_TRIGGER_STAT_KEYS = ("alert_triggered", "resolve_triggered")
 #  functionality, then maybe we should move this to constants
 CRASH_RATE_ALERT_MINIMUM_THRESHOLD: Optional[int] = None
 
+T = TypeVar("T")
+
+
+class SubscriptionUpdate(TypedDict):
+    subscription_id: int
+    values: Dict[str, List[Any]]
+    timestamp: datetime
+    interval: int
+    partition: int
+    offset: int
+
 
 class SubscriptionProcessor:
     """
@@ -70,7 +85,7 @@ class SubscriptionProcessor:
         AlertRuleThresholdType.BELOW: (operator.lt, operator.gt),
     }
 
-    def __init__(self, subscription):
+    def __init__(self, subscription: QuerySubscription) -> None:
         self.subscription = subscription
         try:
             self.alert_rule = AlertRule.objects.get_for_subscription(subscription)
@@ -89,7 +104,7 @@ class SubscriptionProcessor:
         self.orig_trigger_resolve_counts = deepcopy(self.trigger_resolve_counts)
 
     @property
-    def active_incident(self):
+    def active_incident(self) -> Incident:
         if not hasattr(self, "_active_incident"):
             self._active_incident = Incident.objects.get_active_incident(
                 self.alert_rule, self.subscription.project
@@ -97,11 +112,11 @@ class SubscriptionProcessor:
         return self._active_incident
 
     @active_incident.setter
-    def active_incident(self, active_incident):
+    def active_incident(self, active_incident: Incident) -> None:
         self._active_incident = active_incident
 
     @property
-    def incident_triggers(self):
+    def incident_triggers(self) -> Dict[int, IncidentTrigger]:
         if not hasattr(self, "_incident_triggers"):
             incident = self.active_incident
             incident_triggers = {}
@@ -114,7 +129,7 @@ class SubscriptionProcessor:
             self._incident_triggers = incident_triggers
         return self._incident_triggers
 
-    def check_trigger_status(self, trigger, status):
+    def check_trigger_status(self, trigger: AlertRuleTrigger, status: IncidentStatus) -> bool:
         """
         Determines whether a trigger is currently at the specified status
         :param trigger: An `AlertRuleTrigger`
@@ -124,7 +139,7 @@ class SubscriptionProcessor:
         incident_trigger = self.incident_triggers.get(trigger.id)
         return incident_trigger is not None and incident_trigger.status == status.value
 
-    def reset_trigger_counts(self):
+    def reset_trigger_counts(self) -> None:
         """
         Helper method that clears both the trigger alert and the trigger resolve counts
         """
@@ -134,7 +149,7 @@ class SubscriptionProcessor:
             self.trigger_resolve_counts[trigger_id] = 0
         self.update_alert_rule_stats()
 
-    def calculate_resolve_threshold(self, trigger):
+    def calculate_resolve_threshold(self, trigger: IncidentTrigger) -> float:
         """
         Determine the resolve threshold for a trigger. First checks whether an
         explicit resolve threshold has been set on the rule, and whether this trigger is
@@ -148,7 +163,8 @@ class SubscriptionProcessor:
             len(self.triggers) == 1
             or trigger.label == WARNING_TRIGGER_LABEL
         ):
-            return self.alert_rule.resolve_threshold
+            resolve_threshold: float = self.alert_rule.resolve_threshold
+            return resolve_threshold
 
         # Since we only support gt/lt thresholds we have an off-by-one with auto
         # resolve. If we have an alert threshold of > 0, and no resolve threshold, then
@@ -163,9 +179,12 @@ class SubscriptionProcessor:
         else:
             resolve_add = -0.000001
 
-        return trigger.alert_threshold + resolve_add
+        threshold: float = trigger.alert_threshold + resolve_add
+        return threshold
 
-    def get_comparison_aggregation_value(self, subscription_update, aggregation_value):
+    def get_comparison_aggregation_value(
+        self, subscription_update: SubscriptionUpdate, aggregation_value: float
+    ) -> Optional[float]:
         # For comparison alerts run a query over the comparison period and use it to calculate the
         # % change.
         delta = timedelta(seconds=self.alert_rule.comparison_delta)
@@ -204,15 +223,18 @@ class SubscriptionProcessor:
 
         except Exception:
             logger.exception("Failed to run comparison query")
-            return
+            return None
 
         if not comparison_aggregate:
             metrics.incr("incidents.alert_rules.skipping_update_comparison_value_invalid")
-            return
+            return None
 
-        return (aggregation_value / comparison_aggregate) * 100
+        result: float = (aggregation_value / comparison_aggregate) * 100
+        return result
 
-    def get_crash_rate_alert_aggregation_value(self, subscription_update):
+    def get_crash_rate_alert_aggregation_value(
+        self, subscription_update: SubscriptionUpdate
+    ) -> Optional[float]:
         """
         Handles validation and extraction of Crash Rate Alerts subscription updates values.
         The subscription update looks like
@@ -234,7 +256,7 @@ class SubscriptionProcessor:
         if aggregation_value is None:
             self.reset_trigger_counts()
             metrics.incr("incidents.alert_rules.ignore_update_no_session_data")
-            return
+            return None
 
         try:
             total_count = subscription_update["values"]["data"][0][
@@ -247,7 +269,7 @@ class SubscriptionProcessor:
                     metrics.incr(
                         "incidents.alert_rules.ignore_update_count_lower_than_min_threshold"
                     )
-                    return
+                    return None
         except KeyError:
             # If for whatever reason total session count was not sent in the update,
             # ignore the minimum threshold comparison and continue along with processing the
@@ -259,10 +281,12 @@ class SubscriptionProcessor:
         # The subscription aggregation for crash rate alerts uses the Discover percentage
         # function, which would technically return a ratio of sessions_crashed/sessions and
         # so we need to calculate the crash free percentage out of that returned value
-        aggregation_value = round((1 - aggregation_value) * 100, 3)
-        return aggregation_value
+        aggregation_value_result: int = round((1 - aggregation_value) * 100, 3)
+        return aggregation_value_result
 
-    def get_crash_rate_alert_metrics_aggregation_value(self, subscription_update):
+    def get_crash_rate_alert_metrics_aggregation_value(
+        self, subscription_update: SubscriptionUpdate
+    ) -> Optional[float]:
         """Handle both update formats. Once all subscriptions have been updated
         to v2, we can remove v1 and replace this function with current v2.
         """
@@ -281,7 +305,9 @@ class SubscriptionProcessor:
         )
         return result
 
-    def _get_crash_rate_alert_metrics_aggregation_value_v1(self, subscription_update):
+    def _get_crash_rate_alert_metrics_aggregation_value_v1(
+        self, subscription_update: SubscriptionUpdate
+    ) -> Optional[float]:
         """
         Handles validation and extraction of Crash Rate Alerts subscription updates values over
         metrics dataset.
@@ -315,20 +341,22 @@ class SubscriptionProcessor:
         if total_session_count == 0:
             self.reset_trigger_counts()
             metrics.incr("incidents.alert_rules.ignore_update_no_session_data")
-            return
+            return None
 
         if CRASH_RATE_ALERT_MINIMUM_THRESHOLD is not None:
             min_threshold = int(CRASH_RATE_ALERT_MINIMUM_THRESHOLD)
             if total_session_count < min_threshold:
                 self.reset_trigger_counts()
                 metrics.incr("incidents.alert_rules.ignore_update_count_lower_than_min_threshold")
-                return
+                return None
 
         aggregation_value = round((1 - crash_count / total_session_count) * 100, 3)
 
         return aggregation_value
 
-    def _get_crash_rate_alert_metrics_aggregation_value_v2(self, subscription_update):
+    def _get_crash_rate_alert_metrics_aggregation_value_v2(
+        self, subscription_update: SubscriptionUpdate
+    ) -> Optional[float]:
         """
         Handles validation and extraction of Crash Rate Alerts subscription updates values over
         metrics dataset.
@@ -351,20 +379,20 @@ class SubscriptionProcessor:
         if total_session_count == 0:
             self.reset_trigger_counts()
             metrics.incr("incidents.alert_rules.ignore_update_no_session_data")
-            return
+            return None
 
         if CRASH_RATE_ALERT_MINIMUM_THRESHOLD is not None:
             min_threshold = int(CRASH_RATE_ALERT_MINIMUM_THRESHOLD)
             if total_session_count < min_threshold:
                 self.reset_trigger_counts()
                 metrics.incr("incidents.alert_rules.ignore_update_count_lower_than_min_threshold")
-                return
+                return None
 
-        aggregation_value = round((1 - crash_count / total_session_count) * 100, 3)
+        aggregation_value: int = round((1 - crash_count / total_session_count) * 100, 3)
 
         return aggregation_value
 
-    def get_aggregation_value(self, subscription_update):
+    def get_aggregation_value(self, subscription_update: SubscriptionUpdate) -> Optional[float]:
         if self.subscription.snuba_query.dataset == Dataset.Sessions.value:
             aggregation_value = self.get_crash_rate_alert_aggregation_value(subscription_update)
         elif self.subscription.snuba_query.dataset == Dataset.Metrics.value:
@@ -386,7 +414,7 @@ class SubscriptionProcessor:
                 )
         return aggregation_value
 
-    def process_update(self, subscription_update):
+    def process_update(self, subscription_update: SubscriptionUpdate) -> None:
         dataset = self.subscription.snuba_query.dataset
         try:
             # Check that the project exists
@@ -437,6 +465,7 @@ class SubscriptionProcessor:
             )
 
         aggregation_value = self.get_aggregation_value(subscription_update)
+
         if self.subscription.snuba_query.dataset == Dataset.Sessions.value:
             try:
                 # Temporarily logging results from session updates for comparison with data from metric
@@ -497,7 +526,7 @@ class SubscriptionProcessor:
         # before the next one then we might alert twice.
         self.update_alert_rule_stats()
 
-    def calculate_event_date_from_update_date(self, update_date):
+    def calculate_event_date_from_update_date(self, update_date: datetime) -> datetime:
         """
         Calculates the date that an event actually happened based on the date that we
         received the update. This takes into account time window and threshold period.
@@ -516,7 +545,9 @@ class SubscriptionProcessor:
             )
         )
 
-    def trigger_alert_threshold(self, trigger, metric_value):
+    def trigger_alert_threshold(
+        self, trigger: AlertRuleTrigger, metric_value: float
+    ) -> IncidentTrigger:
         """
         Called when a subscription update exceeds the value defined in the
         `trigger.alert_threshold`, and the trigger hasn't already been activated.
@@ -565,7 +596,7 @@ class SubscriptionProcessor:
             self.trigger_alert_counts[trigger.id] = 0
             return incident_trigger
 
-    def check_triggers_resolved(self):
+    def check_triggers_resolved(self) -> bool:
         """
         Determines whether all triggers associated with the active incident are
         resolved. A trigger is considered resolved if it is in the
@@ -577,7 +608,9 @@ class SubscriptionProcessor:
                 return False
         return True
 
-    def trigger_resolve_threshold(self, trigger, metric_value):
+    def trigger_resolve_threshold(
+        self, trigger: AlertRuleTrigger, metric_value: float
+    ) -> IncidentTrigger:
         """
         Called when a subscription update exceeds the trigger resolve threshold and the
         trigger is currently ACTIVE.
@@ -605,7 +638,9 @@ class SubscriptionProcessor:
 
             return incident_trigger
 
-    def handle_trigger_actions(self, incident_triggers, metric_value):
+    def handle_trigger_actions(
+        self, incident_triggers: List[IncidentTrigger], metric_value: float
+    ) -> None:
         actions = deduplicate_trigger_actions(triggers=deepcopy(incident_triggers))
         # Grab the first trigger to get incident id (they are all the same)
         # All triggers should either be firing or resolving, so doesn't matter which we grab.
@@ -622,7 +657,7 @@ class SubscriptionProcessor:
                 ).delay
             )
 
-    def handle_incident_severity_update(self):
+    def handle_incident_severity_update(self) -> None:
         if self.active_incident:
             active_incident_triggers = IncidentTrigger.objects.filter(
                 incident=self.active_incident, status=TriggerStatus.ACTIVE.value
@@ -643,7 +678,7 @@ class SubscriptionProcessor:
                     status_method=IncidentStatusMethod.RULE_TRIGGERED,
                 )
 
-    def update_alert_rule_stats(self):
+    def update_alert_rule_stats(self) -> None:
         """
         Updates stats about the alert rule, if they're changed.
         :return:
@@ -668,7 +703,7 @@ class SubscriptionProcessor:
         )
 
 
-def build_alert_rule_stat_keys(alert_rule, subscription):
+def build_alert_rule_stat_keys(alert_rule: AlertRule, subscription: QuerySubscription) -> List[str]:
     """
     Builds keys for fetching stats about alert rules
     :return: A list containing the alert rule stat keys
@@ -677,7 +712,9 @@ def build_alert_rule_stat_keys(alert_rule, subscription):
     return [ALERT_RULE_BASE_STAT_KEY % (key_base, stat_key) for stat_key in ALERT_RULE_STAT_KEYS]
 
 
-def build_trigger_stat_keys(alert_rule, subscription, triggers):
+def build_trigger_stat_keys(
+    alert_rule: AlertRule, subscription: QuerySubscription, triggers: List[AlertRuleTrigger]
+) -> List[str]:
     """
     Builds keys for fetching stats about triggers
     :return: A list containing the alert rule trigger stat keys
@@ -691,23 +728,27 @@ def build_trigger_stat_keys(alert_rule, subscription, triggers):
     ]
 
 
-def build_alert_rule_trigger_stat_key(alert_rule_id, project_id, trigger_id, stat_key):
+def build_alert_rule_trigger_stat_key(
+    alert_rule_id: int, project_id: int, trigger_id: str, stat_key: str
+) -> str:
     key_base = ALERT_RULE_BASE_KEY % (alert_rule_id, project_id)
     return ALERT_RULE_BASE_TRIGGER_STAT_KEY % (key_base, trigger_id, stat_key)
 
 
-def partition(iterable, n):
+def partition(iterable: Sequence[T], n: int) -> Sequence[Sequence[T]]:
     """
     Partitions an iterable into tuples of size n. Expects the iterable length to be a
     multiple of n.
-    partition('ABCDEF', 3) --> [('ABC', 'DEF')]
+    partition('ABCDEF', 3) --> [('A', 'B', 'C'), ('D', 'E', 'F')]
     """
     assert len(iterable) % n == 0
     args = [iter(iterable)] * n
-    return zip(*args)
+    return cast(Sequence[Sequence[T]], zip(*args))
 
 
-def get_alert_rule_stats(alert_rule, subscription, triggers):
+def get_alert_rule_stats(
+    alert_rule: AlertRule, subscription: QuerySubscription, triggers: List[AlertRuleTrigger]
+) -> Tuple[datetime, Dict[str, int], Dict[str, int]]:
     """
     Fetches stats about the alert rule, specific to the current subscription
     :return: A tuple containing the stats about the alert rule and subscription.
@@ -727,6 +768,7 @@ def get_alert_rule_stats(alert_rule, subscription, triggers):
     trigger_results = results[1:]
     trigger_alert_counts = {}
     trigger_resolve_counts = {}
+
     for trigger, trigger_result in zip(
         triggers, partition(trigger_results, len(ALERT_RULE_TRIGGER_STAT_KEYS))
     ):
@@ -736,7 +778,13 @@ def get_alert_rule_stats(alert_rule, subscription, triggers):
     return last_update, trigger_alert_counts, trigger_resolve_counts
 
 
-def update_alert_rule_stats(alert_rule, subscription, last_update, alert_counts, resolve_counts):
+def update_alert_rule_stats(
+    alert_rule: AlertRule,
+    subscription: QuerySubscription,
+    last_update: datetime,
+    alert_counts: Dict[str, int],
+    resolve_counts: Dict[str, int],
+) -> None:
     """
     Updates stats about the alert rule, subscription and triggers if they've changed.
     """
@@ -758,6 +806,6 @@ def update_alert_rule_stats(alert_rule, subscription, last_update, alert_counts,
     pipeline.execute()
 
 
-def get_redis_client():
+def get_redis_client() -> RetryingRedisCluster:
     cluster_key = getattr(settings, "SENTRY_INCIDENT_RULES_REDIS_CLUSTER", "default")
     return redis.redis_clusters.get(cluster_key)