|
@@ -1,11 +1,12 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
-from typing import Any, Callable, Iterable, Mapping, Optional, Tuple, Type, TypeVar
|
|
|
+from typing import Any, Callable, Iterable, Mapping, TypeVar
|
|
|
|
|
|
from django.apps.config import AppConfig
|
|
|
from django.db import models
|
|
|
from django.db.models import signals
|
|
|
from django.utils import timezone
|
|
|
+from typing_extensions import Self
|
|
|
|
|
|
from sentry.backup.dependencies import ImportKind, PrimaryKeyMap, dependencies, get_model_name
|
|
|
from sentry.backup.helpers import ImportFlags
|
|
@@ -13,7 +14,7 @@ from sentry.backup.scopes import ImportScope, RelocationScope
|
|
|
from sentry.silo import SiloLimit, SiloMode
|
|
|
|
|
|
from .fields.bounded import BoundedBigAutoField
|
|
|
-from .manager import BaseManager, M
|
|
|
+from .manager import BaseManager
|
|
|
from .manager.base import create_silo_limited_copy
|
|
|
from .query import update
|
|
|
|
|
@@ -48,7 +49,7 @@ class BaseModel(models.Model):
|
|
|
|
|
|
__relocation_scope__: RelocationScope | set[RelocationScope]
|
|
|
|
|
|
- objects = BaseManager[M]() # type: ignore
|
|
|
+ objects: BaseManager[Self] = BaseManager()
|
|
|
|
|
|
update = update
|
|
|
|
|
@@ -67,20 +68,28 @@ class BaseModel(models.Model):
|
|
|
|
|
|
def __reduce__(
|
|
|
self,
|
|
|
- ) -> Tuple[Callable[[int], models.Model], Tuple[Tuple[str, str]], Mapping[str, Any]]:
|
|
|
- (model_unpickle, stuff, _) = super().__reduce__()
|
|
|
+ ) -> tuple[Callable[[int], models.Model], tuple[tuple[str, str]], Mapping[str, Any]]:
|
|
|
+ reduced = super().__reduce__()
|
|
|
+ assert isinstance(reduced, tuple), reduced
|
|
|
+ (model_unpickle, stuff, _) = reduced
|
|
|
return (model_unpickle, stuff, self.__getstate__())
|
|
|
|
|
|
def __setstate__(self, state: Mapping[str, Any]) -> None:
|
|
|
self.__dict__.update(state)
|
|
|
|
|
|
+ def _get_relational_field(self, field_name: str) -> models.ForeignKey:
|
|
|
+ ret = self._meta.get_field(field_name)
|
|
|
+ if not isinstance(ret, models.ForeignKey):
|
|
|
+ raise TypeError(f"expected {field_name=} to be ForeignKey")
|
|
|
+ return ret
|
|
|
+
|
|
|
def set_cached_field_value(self, field_name: str, value: Any) -> None:
|
|
|
# Explicitly set a field's cached value.
|
|
|
# This only works for relational fields, and is useful when
|
|
|
# you already have the value and can therefore use this
|
|
|
# to populate Django's cache before accessing the attribute
|
|
|
# and triggering a duplicate, unnecessary query.
|
|
|
- self._meta.get_field(field_name).set_cached_value(self, value)
|
|
|
+ self._get_relational_field(field_name).set_cached_value(self, value)
|
|
|
|
|
|
def get_cached_field_value(self, field_name: str) -> Any:
|
|
|
# Get a relational field's cached value.
|
|
@@ -88,18 +97,18 @@ class BaseModel(models.Model):
|
|
|
# for when you would like to inspect the cache.
|
|
|
# In production, you should guard `model.field` with an
|
|
|
# `if model.is_field_cached`.
|
|
|
- name = self._meta.get_field(field_name).get_cache_name()
|
|
|
- return self._state.fields_cache.get(name, None)
|
|
|
+ name = self._get_relational_field(field_name).get_cache_name()
|
|
|
+ return self._state.fields_cache.get(name, None) # type: ignore[attr-defined] # typeddjango/django-stubs#1743
|
|
|
|
|
|
def delete_cached_field_value(self, field_name: str) -> None:
|
|
|
- name = self._meta.get_field(field_name).get_cache_name()
|
|
|
- if name in self._state.fields_cache:
|
|
|
- del self._state.fields_cache[name]
|
|
|
+ name = self._get_relational_field(field_name).get_cache_name()
|
|
|
+ if name in self._state.fields_cache: # type: ignore[operator] # typeddjango/django-stubs#1743
|
|
|
+ del self._state.fields_cache[name] # type: ignore[attr-defined] # typeddjango/django-stubs#1743
|
|
|
|
|
|
def is_field_cached(self, field_name: str) -> bool:
|
|
|
# Ask if a relational field has a cached value.
|
|
|
- name = self._meta.get_field(field_name).get_cache_name()
|
|
|
- return name in self._state.fields_cache
|
|
|
+ name = self._get_relational_field(field_name).get_cache_name()
|
|
|
+ return name in self._state.fields_cache # type: ignore[operator] # typeddjango/django-stubs#1743
|
|
|
|
|
|
def get_relocation_scope(self) -> RelocationScope:
|
|
|
"""
|
|
@@ -114,7 +123,7 @@ class BaseModel(models.Model):
|
|
|
return self.__relocation_scope__
|
|
|
|
|
|
@classmethod
|
|
|
- def get_possible_relocation_scopes(cls) -> RelocationScope:
|
|
|
+ def get_possible_relocation_scopes(cls) -> set[RelocationScope]:
|
|
|
"""
|
|
|
Retrieves the `RelocationScope` for a `Model` subclass. It always returns a set, to account for models that support multiple scopes on a situational, per-instance basis.
|
|
|
"""
|
|
@@ -127,7 +136,7 @@ class BaseModel(models.Model):
|
|
|
|
|
|
def normalize_before_relocation_import(
|
|
|
self, pk_map: PrimaryKeyMap, _s: ImportScope, _f: ImportFlags
|
|
|
- ) -> Optional[int]:
|
|
|
+ ) -> int | None:
|
|
|
"""
|
|
|
A helper function that normalizes a deserialized model. Note that this modifies the model in
|
|
|
place, so it should generally be done immediately prior to a companion
|
|
@@ -165,7 +174,7 @@ class BaseModel(models.Model):
|
|
|
|
|
|
def write_relocation_import(
|
|
|
self, _s: ImportScope, _f: ImportFlags
|
|
|
- ) -> Optional[Tuple[int, ImportKind]]:
|
|
|
+ ) -> tuple[int, ImportKind] | None:
|
|
|
"""
|
|
|
Writes a deserialized model to the database. If this write is successful, this method will
|
|
|
return a tuple of the new `pk` and the `ImportKind` (ie, whether we created a new model or
|
|
@@ -246,7 +255,7 @@ signals.post_save.connect(__model_post_save)
|
|
|
signals.class_prepared.connect(__model_class_prepared)
|
|
|
|
|
|
|
|
|
-def get_model_if_available(app_config: AppConfig, model_name: str) -> Type[models.Model] | None:
|
|
|
+def get_model_if_available(app_config: AppConfig, model_name: str) -> type[models.Model] | None:
|
|
|
"""Get a named model class if it exists and is available in this silo mode."""
|
|
|
try:
|
|
|
model = app_config.get_model(model_name)
|
|
@@ -303,10 +312,14 @@ class ModelSiloLimit(SiloLimit):
|
|
|
|
|
|
return handle
|
|
|
|
|
|
- def __call__(self, model_class: Type[ModelClass]) -> Type[ModelClass]:
|
|
|
+ def _check_type(self, model_class: object) -> None:
|
|
|
+ # split out so mypy doesn't erroneously narrow
|
|
|
if not (isinstance(model_class, type) and issubclass(model_class, models.Model)):
|
|
|
raise TypeError("`@ModelSiloLimit ` must decorate a Model class")
|
|
|
|
|
|
+ def __call__(self, model_class: type[ModelClass]) -> type[ModelClass]:
|
|
|
+ self._check_type(model_class)
|
|
|
+
|
|
|
setattr(
|
|
|
model_class,
|
|
|
"objects",
|