Browse Source

refs(api): Add type hints to cursors.py and any consuming files (#34391)

This adds type hinting to cursors. I'll probably also follow up with adding type hints to the
paginators as well.
Dan Fuller 2 years ago
parent
commit
b085e7f07d

+ 1 - 0
mypy.ini

@@ -89,6 +89,7 @@ files = src/sentry/analytics/,
         src/sentry/utils/avatar.py,
         src/sentry/utils/codecs.py,
         src/sentry/utils/committers.py,
+        src/sentry/utils/cursors.py,
         src/sentry/utils/dates.py,
         src/sentry/utils/email/,
         src/sentry/utils/jwt.py,

+ 4 - 2
src/sentry/api/helpers/group_index/index.py

@@ -185,11 +185,13 @@ def prep_search(
     request: Request,
     project: "Project",
     extra_query_kwargs: Optional[Mapping[str, Any]] = None,
-) -> Tuple[CursorResult, Mapping[str, Any]]:
+) -> Tuple[CursorResult[Group], Mapping[str, Any]]:
     try:
         environment = cls._get_environment_from_request(request, project.organization_id)
     except Environment.DoesNotExist:
-        result = CursorResult([], None, None, hits=0, max_hits=SEARCH_MAX_HITS)
+        result = CursorResult[Group](
+            [], Cursor(0, 0, 0), Cursor(0, 0, 0), hits=0, max_hits=SEARCH_MAX_HITS
+        )
         query_kwargs: MutableMapping[str, Any] = {}
     else:
         environments = [environment] if environment is not None else environment

+ 11 - 6
src/sentry/rules/history/backends/postgres.py

@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Sequence, TypedDict
+from typing import TYPE_CHECKING, Sequence, TypedDict, cast
 
 import pytz
 from django.db.models import Count, Max
@@ -10,10 +10,11 @@ from django.db.models.functions import TruncHour
 from sentry.api.paginator import OffsetPaginator
 from sentry.models import Group, RuleFireHistory
 from sentry.rules.history.base import RuleGroupHistory, RuleHistoryBackend, TimeSeriesValue
+from sentry.utils.cursors import CursorResult
 
 if TYPE_CHECKING:
     from sentry.models import Rule
-    from sentry.utils.cursors import Cursor, CursorResult
+    from sentry.utils.cursors import Cursor
 
 
 class _Result(TypedDict):
@@ -40,7 +41,7 @@ class PostgresRuleHistoryBackend(RuleHistoryBackend):
         end: datetime,
         cursor: Cursor | None = None,
         per_page: int = 25,
-    ) -> CursorResult:
+    ) -> CursorResult[Group]:
         qs = (
             RuleFireHistory.objects.filter(
                 rule=rule,
@@ -51,9 +52,13 @@ class PostgresRuleHistoryBackend(RuleHistoryBackend):
             .values("group")
             .annotate(count=Count("id"), last_triggered=Max("date_added"))
         )
-        return OffsetPaginator(
-            qs, order_by=("-count", "-last_triggered"), on_results=convert_results
-        ).get_result(per_page, cursor)
+        # TODO: Add types to paginators and remove this
+        return cast(
+            CursorResult[Group],
+            OffsetPaginator(
+                qs, order_by=("-count", "-last_triggered"), on_results=convert_results
+            ).get_result(per_page, cursor),
+        )
 
     def fetch_rule_hourly_stats(
         self, rule: Rule, start: datetime, end: datetime

+ 1 - 1
src/sentry/rules/history/base.py

@@ -39,7 +39,7 @@ class RuleHistoryBackend(Service):
 
     def fetch_rule_groups_paginated(
         self, rule: Rule, start: datetime, end: datetime, cursor: Cursor, per_page: int
-    ) -> CursorResult:
+    ) -> CursorResult[Group]:
         """
         Fetches groups that triggered a rule within a given timeframe, ordered by number of
         times each group fired.

+ 2 - 2
src/sentry/search/base.py

@@ -9,7 +9,7 @@ ANY = object()
 
 if TYPE_CHECKING:
     from sentry.api.event_search import SearchFilter
-    from sentry.models import Environment, Project
+    from sentry.models import Environment, Group, Project
     from sentry.utils.cursors import Cursor, CursorResult
 
 
@@ -34,5 +34,5 @@ class SearchBackend(Service):
         date_from: Optional[datetime] = None,
         date_to: Optional[datetime] = None,
         max_hits: Optional[int] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Group]:
         raise NotImplementedError

+ 1 - 1
src/sentry/search/snuba/backend.py

@@ -357,7 +357,7 @@ class SnubaSearchBackendBase(SearchBackend, metaclass=ABCMeta):
         date_from: Optional[datetime] = None,
         date_to: Optional[datetime] = None,
         max_hits: Optional[int] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Group]:
         search_filters = search_filters if search_filters is not None else []
 
         # ensure projects are from same org

+ 19 - 13
src/sentry/search/snuba/executors.py

@@ -6,7 +6,7 @@ from abc import ABCMeta, abstractmethod
 from dataclasses import replace
 from datetime import datetime, timedelta
 from hashlib import md5
-from typing import Any, List, Mapping, Sequence, Set, Tuple
+from typing import Any, List, Mapping, Sequence, Set, Tuple, cast
 
 import sentry_sdk
 from django.utils import timezone
@@ -77,8 +77,9 @@ class AbstractQueryExecutor(metaclass=ABCMeta):
         raise NotImplementedError
 
     @property
-    def empty_result(self) -> CursorResult:
-        return Paginator(Group.objects.none()).get_result()
+    def empty_result(self) -> CursorResult[Group]:
+        # TODO: Add types to paginators and remove this
+        return cast(CursorResult[Group], Paginator(Group.objects.none()).get_result())
 
     @property
     @abstractmethod
@@ -105,14 +106,14 @@ class AbstractQueryExecutor(metaclass=ABCMeta):
         environments: Optional[Sequence[Environment]],
         sort_by: str,
         limit: int,
-        cursor: Cursor,
+        cursor: Cursor | None,
         count_hits: bool,
         paginator_options: Optional[Mapping[str, Any]],
         search_filters: Optional[Sequence[SearchFilter]],
         date_from: Optional[datetime],
         date_to: Optional[datetime],
         max_hits: Optional[int] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Group]:
         """This function runs your actual query and returns the results
         We usually return a paginator object, which contains the results and the number of hits"""
         raise NotImplementedError
@@ -329,14 +330,14 @@ class PostgresSnubaQueryExecutor(AbstractQueryExecutor):
         environments: Optional[Sequence[Environment]],
         sort_by: str,
         limit: int,
-        cursor: Cursor,
+        cursor: Cursor | None,
         count_hits: bool,
         paginator_options: Optional[Mapping[str, Any]],
         search_filters: Optional[Sequence[SearchFilter]],
         date_from: Optional[datetime],
         date_to: Optional[datetime],
         max_hits: Optional[int] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Group]:
 
         now = timezone.now()
         end = None
@@ -367,7 +368,12 @@ class PostgresSnubaQueryExecutor(AbstractQueryExecutor):
                 group_queryset = group_queryset.order_by("-last_seen")
                 paginator = DateTimePaginator(group_queryset, "-last_seen", **paginator_options)
                 # When its a simple django-only search, we count_hits like normal
-                return paginator.get_result(limit, cursor, count_hits=count_hits, max_hits=max_hits)
+
+                # TODO: Add types to paginators and remove this
+                return cast(
+                    CursorResult[Group],
+                    paginator.get_result(limit, cursor, count_hits=count_hits, max_hits=max_hits),
+                )
 
         # TODO: Presumably we only want to search back to the project's max
         # retention date, which may be closer than 90 days in the past, but
@@ -572,7 +578,7 @@ class PostgresSnubaQueryExecutor(AbstractQueryExecutor):
         environments: Sequence[Environment],
         sort_by: str,
         limit: int,
-        cursor: Cursor,
+        cursor: Cursor | None,
         count_hits: bool,
         paginator_options: Mapping[str, Any],
         search_filters: Sequence[SearchFilter],
@@ -649,8 +655,6 @@ class PostgresSnubaQueryExecutor(AbstractQueryExecutor):
                 hits = int(hit_ratio * snuba_total)
                 return hits
 
-        return None
-
 
 class InvalidQueryForExecutor(Exception):
     pass
@@ -771,7 +775,7 @@ class CdcPostgresSnubaQueryExecutor(PostgresSnubaQueryExecutor):
         date_from: Optional[datetime],
         date_to: Optional[datetime],
         max_hits: Optional[int] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Group]:
 
         if not validate_cdc_search_filters(search_filters):
             raise InvalidQueryForExecutor("Search filters invalid for this query executor")
@@ -868,4 +872,6 @@ class CdcPostgresSnubaQueryExecutor(PostgresSnubaQueryExecutor):
         # probably not be noticeable to the user, so holding off for now to reduce complexity.
         groups = group_queryset.in_bulk(paginator_results.results)
         paginator_results.results = [groups[k] for k in paginator_results.results if k in groups]
-        return paginator_results
+
+        # TODO: Add types to paginators and remove this
+        return cast(CursorResult[Group], paginator_results)

+ 76 - 39
src/sentry/utils/cursors.py

@@ -1,23 +1,44 @@
-from collections.abc import Sequence
+from __future__ import annotations
+
+from typing import Any, Callable, Iterator, Protocol, Sequence, Tuple, TypeVar, Union
+
+from sentry.utils.json import JSONData
+
+T = TypeVar("T")
+CursorValue = Union[float, int, str]
+
+
+class KeyCallable(Protocol):
+    def __call__(self, value: T, for_prev: bool = ...) -> CursorValue:
+        ...
+
+
+OnResultCallable = Callable[[Sequence[T]], JSONData]
 
 
 class Cursor:
-    def __init__(self, value, offset=0, is_prev=False, has_results=None):
-        self.value = value
+    def __init__(
+        self,
+        value: CursorValue,
+        offset: int = 0,
+        is_prev: bool | int = False,
+        has_results: bool | None = None,
+    ):
+        self.value: CursorValue = value
         self.offset = int(offset)
         self.is_prev = bool(is_prev)
         self.has_results = has_results
 
-    def __str__(self):
+    def __str__(self) -> str:
         return f"{self.value}:{self.offset}:{int(self.is_prev)}"
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         return all(
             getattr(self, attr) == getattr(other, attr)
             for attr in ("value", "offset", "is_prev", "has_results")
         )
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<{}: value={} offset={} is_prev={}>".format(
             type(self).__name__,
             self.value,
@@ -25,68 +46,75 @@ class Cursor:
             int(self.is_prev),
         )
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return bool(self.has_results)
 
     # python2 compatibility
     __nonzero__ = __bool__
 
     @classmethod
-    def from_string(cls, value):
-        bits = value.split(":")
+    def from_string(cls, cursor_str: str) -> Cursor:
+        bits = cursor_str.split(":")
         if len(bits) != 3:
             raise ValueError
         try:
             value = float(bits[0]) if "." in bits[0] else int(bits[0])
-            bits = value, int(bits[1]), int(bits[2])
+            return Cursor(value, int(bits[1]), int(bits[2]))
         except (TypeError, ValueError):
             raise ValueError
-        return cls(*bits)
 
 
 class SCIMCursor(Cursor):
     @classmethod
-    def from_string(cls, value):
+    def from_string(cls, value: str) -> SCIMCursor:
         # SCIM cursors are 1 indexed
-        return cls(0, int(value) - 1, 0)
+        return SCIMCursor(0, int(value) - 1, 0)
 
 
 class StringCursor(Cursor):
     @classmethod
-    def from_string(cls, value):
-        bits = value.rsplit(":", 2)
+    def from_string(cls, cursor_str: str) -> StringCursor:
+        bits = cursor_str.rsplit(":", 2)
         if len(bits) != 3:
             raise ValueError
         try:
             value = bits[0]
-            bits = value, int(bits[1]), int(bits[2])
+            return StringCursor(value, int(bits[1]), int(bits[2]))
         except (TypeError, ValueError):
             raise ValueError
-        return cls(*bits)
 
 
-class CursorResult(Sequence):
-    def __init__(self, results, next, prev, hits=None, max_hits=None):
+class CursorResult(Sequence[T]):
+    def __init__(
+        self,
+        results: Sequence[T],
+        next: Cursor,
+        prev: Cursor,
+        hits: int | None = None,
+        max_hits: int | None = None,
+    ):
         self.results = results
         self.next = next
         self.prev = prev
         self.hits = hits
         self.max_hits = max_hits
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.results)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[T]:
         return iter(self.results)
 
-    def __getitem__(self, key):
-        return self.results[key]
-
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"<{type(self).__name__}: results={len(self.results)}>"
 
+    def __getitem__(self, key: Any) -> Any:
+        return self.results[key]
+
 
-def _build_next_values(cursor, results, key, limit, is_desc):
+def _build_next_values(
+    cursor: Cursor, results: Sequence[T], key: KeyCallable, limit: int, is_desc: bool
+) -> Tuple[CursorValue, int, bool]:
     value = cursor.value
     offset = cursor.offset
     is_prev = cursor.is_prev
@@ -114,7 +142,7 @@ def _build_next_values(cursor, results, key, limit, is_desc):
     # value has not changed, page forward by adjusting the offset
     if next_value == value:
         next_offset = offset + limit
-        return (next_value, next_offset, has_next)
+        return next_value, next_offset, has_next
 
     # We have an absolute value to page from. If any of the items in
     # the current result set come *after* or *before* (depending on the
@@ -134,18 +162,20 @@ def _build_next_values(cursor, results, key, limit, is_desc):
     for result in result_iter:
         result_value = key(result)
 
-        is_larger = result_value >= next_value
-        is_smaller = result_value <= next_value
+        is_larger = result_value >= next_value  # type: ignore
+        is_smaller = result_value <= next_value  # type: ignore
 
         if (is_desc and is_smaller) or (not is_desc and is_larger):
             next_offset += 1
         else:
             break
 
-    return (next_value, next_offset, has_next)
+    return next_value, next_offset, has_next
 
 
-def _build_prev_values(cursor, results, key, limit, is_desc):
+def _build_prev_values(
+    cursor: Cursor, results: Sequence[T], key: KeyCallable, limit: int, is_desc: bool
+) -> Tuple[CursorValue, int, bool]:
     value = cursor.value
     offset = cursor.offset
     is_prev = cursor.is_prev
@@ -157,7 +187,7 @@ def _build_prev_values(cursor, results, key, limit, is_desc):
     else:
         # It's likely that there's a previous page if they passed us either
         # offset values
-        has_prev = value or offset
+        has_prev = bool(value or offset)
 
     # If the cursor contains previous results, the first item is the item that
     # indicates if we have more items later, and is *not* the first item in the
@@ -174,12 +204,12 @@ def _build_prev_values(cursor, results, key, limit, is_desc):
     prev_offset = offset if is_prev else 0
 
     if not (is_prev and num_results):
-        return (prev_value, prev_offset, has_prev)
+        return prev_value, prev_offset, has_prev
 
     # Value has not changed, page back by adjusting the offset
     if prev_value == value:
         prev_offset = offset + limit
-        return (prev_value, prev_offset, has_prev)
+        return prev_value, prev_offset, has_prev
 
     # Just as in the next cursor builder, we may need to add an offset
     # if any of the results at the beginning are *before* or *after*
@@ -201,8 +231,8 @@ def _build_prev_values(cursor, results, key, limit, is_desc):
     for result in result_iter:
         result_value = key(result, for_prev=True)
 
-        is_larger = result_value >= prev_value
-        is_smaller = result_value <= prev_value
+        is_larger = result_value >= prev_value  # type: ignore
+        is_smaller = result_value <= prev_value  # type: ignore
 
         # Note that the checks are reversed here as a prev query has
         # it's ordering reversed.
@@ -211,12 +241,19 @@ def _build_prev_values(cursor, results, key, limit, is_desc):
         else:
             break
 
-    return (prev_value, prev_offset, has_prev)
+    return prev_value, prev_offset, has_prev
 
 
 def build_cursor(
-    results, key, limit=100, is_desc=False, cursor=None, hits=None, max_hits=None, on_results=None
-):
+    results: Sequence[T],
+    key: KeyCallable,
+    limit: int = 100,
+    is_desc: bool = False,
+    cursor: Cursor | None = None,
+    hits: int | None = None,
+    max_hits: int | None = None,
+    on_results: None | OnResultCallable[T] = None,
+) -> CursorResult[T | JSONData]:
     if cursor is None:
         cursor = Cursor(0, 0, 0)