|
@@ -58,6 +58,10 @@ def _model_silo_limit(t: type[Model]) -> ModelSiloLimit:
|
|
|
return silo_limit
|
|
|
|
|
|
|
|
|
+class AncestorAlreadySiloDecoratedException(Exception):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
class SiloModeTestDecorator:
|
|
|
"""Decorate a test case that is expected to work in a given silo mode.
|
|
|
|
|
@@ -199,6 +203,11 @@ class SiloModeTestDecorator:
|
|
|
if not (is_test_case_class or is_function):
|
|
|
raise ValueError("@SiloModeTest must decorate a function or TestCase class")
|
|
|
|
|
|
+ if is_test_case_class:
|
|
|
+ self._validate_that_no_ancestor_is_silo_decorated(decorated_obj)
|
|
|
+ # _silo_modes is used to mark the class as silo decorated in the above validation
|
|
|
+ decorated_obj._silo_modes = self.silo_modes
|
|
|
+
|
|
|
# Only run non monolith tests when they are marked stable or we are explicitly running for that mode.
|
|
|
if not (stable or settings.FORCE_SILOED_TESTS):
|
|
|
# In this case, simply force the current silo mode (monolith)
|
|
@@ -209,6 +218,20 @@ class SiloModeTestDecorator:
|
|
|
|
|
|
return self._mark_parameterized_by_silo_mode(decorated_obj, regions)
|
|
|
|
|
|
+ def _validate_that_no_ancestor_is_silo_decorated(self, object_to_validate: Any):
|
|
|
+ class_queue = [object_to_validate]
|
|
|
+
|
|
|
+ # Do a breadth-first traversal of all base classes to ensure that the
|
|
|
+ # object does not inherit from a class which has already been decorated,
|
|
|
+ # even in multi-inheritance scenarios.
|
|
|
+ while len(class_queue) > 0:
|
|
|
+ current_class = class_queue.pop(0)
|
|
|
+ if getattr(current_class, "_silo_modes", None):
|
|
|
+ raise AncestorAlreadySiloDecoratedException(
|
|
|
+ f"Cannot decorate class '{object_to_validate.__name__}', which inherits from a silo decorated class"
|
|
|
+ )
|
|
|
+ class_queue.extend(current_class.__bases__)
|
|
|
+
|
|
|
|
|
|
all_silo_test = SiloModeTestDecorator(SiloMode.CONTROL, SiloMode.REGION)
|
|
|
"""
|