Browse Source

ref: type sentry.db.models.base (#57135)

<!-- Describe your PR here. -->
anthony sottile 1 year ago
parent
commit
5f9916dd61

+ 0 - 1
pyproject.toml

@@ -372,7 +372,6 @@ module = [
     "sentry.auth.system",
     "sentry.auth.view",
     "sentry.db.mixin",
-    "sentry.db.models.base",
     "sentry.db.models.manager.base",
     "sentry.db.models.paranoia",
     "sentry.db.models.query",

+ 1 - 1
src/sentry/backup/dependencies.py

@@ -98,7 +98,7 @@ class NormalizedModelName:
         return self.__model_name
 
 
-def get_model_name(model: Type[models.base.Model]) -> NormalizedModelName:
+def get_model_name(model: type[models.Model] | models.Model) -> NormalizedModelName:
     return NormalizedModelName(f"{model._meta.app_label}.{model._meta.object_name}")
 
 

+ 1 - 1
src/sentry/backup/exports.py

@@ -126,7 +126,7 @@ def _export(
                 continue
 
             possible_relocation_scopes = model.get_possible_relocation_scopes()
-            includable = possible_relocation_scopes & allowed_relocation_scopes  # type: ignore
+            includable = possible_relocation_scopes & allowed_relocation_scopes
             if not includable or model._meta.proxy:
                 continue
 

+ 31 - 18
src/sentry/db/models/base.py

@@ -1,11 +1,12 @@
 from __future__ import annotations
 
-from typing import Any, Callable, Iterable, Mapping, Optional, Tuple, Type, TypeVar
+from typing import Any, Callable, Iterable, Mapping, TypeVar
 
 from django.apps.config import AppConfig
 from django.db import models
 from django.db.models import signals
 from django.utils import timezone
+from typing_extensions import Self
 
 from sentry.backup.dependencies import ImportKind, PrimaryKeyMap, dependencies, get_model_name
 from sentry.backup.helpers import ImportFlags
@@ -13,7 +14,7 @@ from sentry.backup.scopes import ImportScope, RelocationScope
 from sentry.silo import SiloLimit, SiloMode
 
 from .fields.bounded import BoundedBigAutoField
-from .manager import BaseManager, M
+from .manager import BaseManager
 from .manager.base import create_silo_limited_copy
 from .query import update
 
@@ -48,7 +49,7 @@ class BaseModel(models.Model):
 
     __relocation_scope__: RelocationScope | set[RelocationScope]
 
-    objects = BaseManager[M]()  # type: ignore
+    objects: BaseManager[Self] = BaseManager()
 
     update = update
 
@@ -67,20 +68,28 @@ class BaseModel(models.Model):
 
     def __reduce__(
         self,
-    ) -> Tuple[Callable[[int], models.Model], Tuple[Tuple[str, str]], Mapping[str, Any]]:
-        (model_unpickle, stuff, _) = super().__reduce__()
+    ) -> tuple[Callable[[int], models.Model], tuple[tuple[str, str]], Mapping[str, Any]]:
+        reduced = super().__reduce__()
+        assert isinstance(reduced, tuple), reduced
+        (model_unpickle, stuff, _) = reduced
         return (model_unpickle, stuff, self.__getstate__())
 
     def __setstate__(self, state: Mapping[str, Any]) -> None:
         self.__dict__.update(state)
 
+    def _get_relational_field(self, field_name: str) -> models.ForeignKey:
+        ret = self._meta.get_field(field_name)
+        if not isinstance(ret, models.ForeignKey):
+            raise TypeError(f"expected {field_name=} to be ForeignKey")
+        return ret
+
     def set_cached_field_value(self, field_name: str, value: Any) -> None:
         # Explicitly set a field's cached value.
         # This only works for relational fields, and is useful when
         # you already have the value and can therefore use this
         # to populate Django's cache before accessing the attribute
         # and triggering a duplicate, unnecessary query.
-        self._meta.get_field(field_name).set_cached_value(self, value)
+        self._get_relational_field(field_name).set_cached_value(self, value)
 
     def get_cached_field_value(self, field_name: str) -> Any:
         # Get a relational field's cached value.
@@ -88,18 +97,18 @@ class BaseModel(models.Model):
         # for when you would like to inspect the cache.
         # In production, you should guard `model.field` with an
         # `if model.is_field_cached`.
-        name = self._meta.get_field(field_name).get_cache_name()
-        return self._state.fields_cache.get(name, None)
+        name = self._get_relational_field(field_name).get_cache_name()
+        return self._state.fields_cache.get(name, None)  # type: ignore[attr-defined]  # typeddjango/django-stubs#1743
 
     def delete_cached_field_value(self, field_name: str) -> None:
-        name = self._meta.get_field(field_name).get_cache_name()
-        if name in self._state.fields_cache:
-            del self._state.fields_cache[name]
+        name = self._get_relational_field(field_name).get_cache_name()
+        if name in self._state.fields_cache:  # type: ignore[operator]  # typeddjango/django-stubs#1743
+            del self._state.fields_cache[name]  # type: ignore[attr-defined]  # typeddjango/django-stubs#1743
 
     def is_field_cached(self, field_name: str) -> bool:
         # Ask if a relational field has a cached value.
-        name = self._meta.get_field(field_name).get_cache_name()
-        return name in self._state.fields_cache
+        name = self._get_relational_field(field_name).get_cache_name()
+        return name in self._state.fields_cache  # type: ignore[operator]  # typeddjango/django-stubs#1743
 
     def get_relocation_scope(self) -> RelocationScope:
         """
@@ -114,7 +123,7 @@ class BaseModel(models.Model):
         return self.__relocation_scope__
 
     @classmethod
-    def get_possible_relocation_scopes(cls) -> RelocationScope:
+    def get_possible_relocation_scopes(cls) -> set[RelocationScope]:
         """
         Retrieves the `RelocationScope` for a `Model` subclass. It always returns a set, to account for models that support multiple scopes on a situational, per-instance basis.
         """
@@ -127,7 +136,7 @@ class BaseModel(models.Model):
 
     def normalize_before_relocation_import(
         self, pk_map: PrimaryKeyMap, _s: ImportScope, _f: ImportFlags
-    ) -> Optional[int]:
+    ) -> int | None:
         """
         A helper function that normalizes a deserialized model. Note that this modifies the model in
         place, so it should generally be done immediately prior to a companion
@@ -165,7 +174,7 @@ class BaseModel(models.Model):
 
     def write_relocation_import(
         self, _s: ImportScope, _f: ImportFlags
-    ) -> Optional[Tuple[int, ImportKind]]:
+    ) -> tuple[int, ImportKind] | None:
         """
         Writes a deserialized model to the database. If this write is successful, this method will
         return a tuple of the new `pk` and the `ImportKind` (ie, whether we created a new model or
@@ -246,7 +255,7 @@ signals.post_save.connect(__model_post_save)
 signals.class_prepared.connect(__model_class_prepared)
 
 
-def get_model_if_available(app_config: AppConfig, model_name: str) -> Type[models.Model] | None:
+def get_model_if_available(app_config: AppConfig, model_name: str) -> type[models.Model] | None:
     """Get a named model class if it exists and is available in this silo mode."""
     try:
         model = app_config.get_model(model_name)
@@ -303,10 +312,14 @@ class ModelSiloLimit(SiloLimit):
 
         return handle
 
-    def __call__(self, model_class: Type[ModelClass]) -> Type[ModelClass]:
+    def _check_type(self, model_class: object) -> None:
+        # split out so mypy doesn't erroneously narrow
         if not (isinstance(model_class, type) and issubclass(model_class, models.Model)):
             raise TypeError("`@ModelSiloLimit ` must decorate a Model class")
 
+    def __call__(self, model_class: type[ModelClass]) -> type[ModelClass]:
+        self._check_type(model_class)
+
         setattr(
             model_class,
             "objects",

+ 1 - 1
src/sentry/models/user.py

@@ -335,7 +335,7 @@ class User(BaseModel, AbstractBaseUser):
                 organization_id=organization_id, from_user_id=from_user.id, to_user_id=to_user.id
             )
 
-        model_list = (
+        model_list: tuple[type[BaseModel], ...] = (
             Authenticator,
             Identity,
             UserAvatar,