|
@@ -4,8 +4,11 @@ import logging
|
|
|
import threading
|
|
|
import weakref
|
|
|
from contextlib import contextmanager
|
|
|
+from enum import IntEnum, auto
|
|
|
from typing import (
|
|
|
Any,
|
|
|
+ Callable,
|
|
|
+ Dict,
|
|
|
Generator,
|
|
|
Generic,
|
|
|
Iterable,
|
|
@@ -14,6 +17,7 @@ from typing import (
|
|
|
Optional,
|
|
|
Sequence,
|
|
|
Tuple,
|
|
|
+ Type,
|
|
|
)
|
|
|
|
|
|
from django.conf import settings
|
|
@@ -36,6 +40,15 @@ _local_cache_generation = 0
|
|
|
_local_cache_enabled = False
|
|
|
|
|
|
|
|
|
+class ModelManagerTriggerCondition(IntEnum):
|
|
|
+ QUERY = auto()
|
|
|
+ SAVE = auto()
|
|
|
+ DELETE = auto()
|
|
|
+
|
|
|
+
|
|
|
+ModelManagerTriggerAction = Callable[[Type[Model]], None]
|
|
|
+
|
|
|
+
|
|
|
class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]): # type: ignore
|
|
|
lookup_handlers = {"iexact": lambda x: x.upper()}
|
|
|
use_for_related_fields = True
|
|
@@ -53,6 +66,10 @@ class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]): #
|
|
|
self.cache_ttl = kwargs.pop("cache_ttl", 60 * 5)
|
|
|
self._cache_version: Optional[str] = kwargs.pop("cache_version", None)
|
|
|
self.__local_cache = threading.local()
|
|
|
+
|
|
|
+ self._triggers: Dict[
|
|
|
+ object, Tuple[ModelManagerTriggerCondition, ModelManagerTriggerAction]
|
|
|
+ ] = {}
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
@staticmethod
|
|
@@ -195,6 +212,8 @@ class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]): #
|
|
|
|
|
|
self.__cache_state(instance)
|
|
|
|
|
|
+ self._execute_triggers(ModelManagerTriggerCondition.SAVE)
|
|
|
+
|
|
|
def __post_delete(self, instance: M, **kwargs: Any) -> None:
|
|
|
"""
|
|
|
Drops instance from all cache storages.
|
|
@@ -213,6 +232,8 @@ class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]): #
|
|
|
key=self.__get_lookup_cache_key(**{pk_name: instance.pk}), version=self.cache_version
|
|
|
)
|
|
|
|
|
|
+ self._execute_triggers(ModelManagerTriggerCondition.DELETE)
|
|
|
+
|
|
|
def __get_lookup_cache_key(self, **kwargs: Any) -> str:
|
|
|
return make_key(self.model, "modelcache", kwargs)
|
|
|
|
|
@@ -448,10 +469,44 @@ class BaseManager(DjangoBaseManager.from_queryset(BaseQuerySet), Generic[M]): #
|
|
|
Returns a new QuerySet object. Subclasses can override this method to
|
|
|
easily customize the behavior of the Manager.
|
|
|
"""
|
|
|
+
|
|
|
+ # TODO: This is a quick-and-dirty place to put the trigger hook that won't
|
|
|
+ # work for all model classes, because some custom managers override
|
|
|
+ # get_queryset without a `super` call.
|
|
|
+ self._execute_triggers(ModelManagerTriggerCondition.QUERY)
|
|
|
+
|
|
|
if hasattr(self, "_hints"):
|
|
|
return self._queryset_class(self.model, using=self._db, hints=self._hints)
|
|
|
return self._queryset_class(self.model, using=self._db)
|
|
|
|
|
|
+ @contextmanager
|
|
|
+ def register_trigger(
|
|
|
+ self, condition: ModelManagerTriggerCondition, action: ModelManagerTriggerAction
|
|
|
+ ) -> Generator[None, None, None]:
|
|
|
+ """Register a callback for when an operation is executed inside the context.
|
|
|
+
|
|
|
+ There is no guarantee whether the action will be called before or after the
|
|
|
+ triggering operation is executed, nor whether it will or will not be called
|
|
|
+ if the triggering operation raises an exception.
|
|
|
+
|
|
|
+ Both the registration of the trigger and the execution of the action are NOT
|
|
|
+ THREADSAFE. This is intended for offline use in single-threaded contexts such
|
|
|
+ as pytest. We must add synchronization if we intend to adapt it for
|
|
|
+ production use.
|
|
|
+ """
|
|
|
+
|
|
|
+ key = object()
|
|
|
+ self._triggers[key] = (condition, action)
|
|
|
+ try:
|
|
|
+ yield
|
|
|
+ finally:
|
|
|
+ del self._triggers[key]
|
|
|
+
|
|
|
+ def _execute_triggers(self, condition: ModelManagerTriggerCondition) -> None:
|
|
|
+ for (next_condition, next_action) in self._triggers.values():
|
|
|
+ if condition == next_condition:
|
|
|
+ next_action(self.model)
|
|
|
+
|
|
|
def create_silo_limited_copy(
|
|
|
self, limit: SiloLimit, read_modes: Iterable[SiloMode]
|
|
|
) -> BaseManager[M]:
|