Browse Source

feat(server): Add scripts for applying and checking mode limit decorators (#36227)

Ryan Skonnord 2 years ago
parent
commit
95a913ccfb
2 changed files with 348 additions and 0 deletions
  1. 157 0
      scripts/servermode/add_mode_limits.py
  2. 191 0
      scripts/servermode/audit_mode_limits.py

+ 157 - 0
scripts/servermode/add_mode_limits.py

@@ -0,0 +1,157 @@
+#!.venv/bin/python
+
+from __future__ import annotations
+
+import os
+import re
+import sys
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import Iterable
+
+from sentry.utils import json
+
+"""
+Instructions for use:
+
+1. Commit or stash any Git changes in progress.
+2. Scroll down to "Fill these predicates in..." and write what you want to do.
+3. From the Sentry project root, do
+     ./scripts/servermode/audit_mode_limits.py | ./scripts/servermode/add_mode_limits.py
+4. Do `git status` or `git diff` to observe the results. Commit if you're happy.
+"""
+
+
+class ClassCategory(Enum):
+    MODEL = auto()
+    VIEW = auto()
+
+
+@dataclass
+class LimitedClass:
+    package: str
+    name: str
+    category: ClassCategory
+    is_decorated: bool
+
+
+def parse_audit(audit) -> Iterable[LimitedClass]:
+    def split_qualname(value):
+        dot_index = value.rindex(".")
+        package = value[:dot_index]
+        name = value[dot_index + 1 :]
+        return package, name
+
+    def parse_group(category, dec_group):
+        is_decorated = dec_group["decorator"] is not None
+        for value in dec_group["values"]:
+            package, name = split_qualname(value)
+            yield LimitedClass(package, name, category, is_decorated)
+
+    for dec_group in audit["models"]["decorators"]:
+        yield from parse_group(ClassCategory.MODEL, dec_group)
+    for dec_group in audit["views"]["decorators"]:
+        yield from parse_group(ClassCategory.VIEW, dec_group)
+
+
+def read_audit():
+    pipe_input = sys.stdin.read()
+    brace_index = pipe_input.index("{")
+    pipe_input = pipe_input[brace_index:]  # strip leading junk
+    server_mode_audit = json.loads(pipe_input)
+    return list(parse_audit(server_mode_audit))
+
+
+def find_source_paths():
+    for (dirpath, dirnames, filenames) in os.walk("./src/sentry"):
+        for filename in filenames:
+            if filename.endswith(".py"):
+                yield os.path.join(dirpath, filename)
+
+
+def find_class_declarations():
+    for src_path in find_source_paths():
+        with open(src_path) as f:
+            src_code = f.read()
+        for match in re.findall(r"\nclass\s+(\w+)\(", src_code):
+            yield src_path, match
+
+
+def insert_import(src_code: str, import_stmt: str) -> str:
+    future_import = None
+    for future_import in re.finditer(r"from\s+__future__\s+.*\n+", src_code):
+        pass  # iterate to last match
+    if future_import:
+        start, end = future_import.span()
+        return src_code[:end] + import_stmt + "\n" + src_code[end:]
+    else:
+        return import_stmt + "\n" + src_code
+
+
+def apply_decorators(
+    decorator_name: str,
+    import_stmt: str,
+    target_classes: Iterable[LimitedClass],
+) -> None:
+    target_names = {c.name for c in target_classes if not c.is_decorated}
+    for src_path, class_name in find_class_declarations():
+        if class_name in target_names:
+            with open(src_path) as f:
+                src_code = f.read()
+            new_code = re.sub(
+                rf"\nclass\s+{class_name}\(",
+                f"\n@{decorator_name}\nclass {class_name}(",
+                src_code,
+            )
+            new_code = insert_import(new_code, import_stmt)
+            with open(src_path, mode="w") as f:
+                f.write(new_code)
+
+
+def main():
+    classes = read_audit()
+
+    def filter_classes(category, predicate):
+        return (c for c in classes if c.category == category and predicate(c))
+
+    ####################################################################
+    # Fill these predicates in with the logic you want to apply
+
+    def control_model_predicate(c: LimitedClass) -> bool:
+        return False
+
+    def customer_model_predicate(c: LimitedClass) -> bool:
+        return False
+
+    def control_endpoint_predicate(c: LimitedClass) -> bool:
+        return False
+
+    def customer_endpoint_predicate(c: LimitedClass) -> bool:
+        return False
+
+    ####################################################################
+
+    apply_decorators(
+        "control_silo_model",
+        "from sentry.db.models import control_silo_model",
+        filter_classes(ClassCategory.MODEL, control_model_predicate),
+    )
+    apply_decorators(
+        "customer_silo_model",
+        "from sentry.db.models import customer_silo_model",
+        filter_classes(ClassCategory.MODEL, customer_model_predicate),
+    )
+    apply_decorators(
+        "control_silo_endpoint",
+        "from sentry.api.base import control_silo_endpoint",
+        filter_classes(ClassCategory.VIEW, control_endpoint_predicate),
+    )
+    apply_decorators(
+        "customer_silo_endpoint",
+        "from sentry.api.base import customer_silo_endpoint",
+        filter_classes(ClassCategory.VIEW, customer_endpoint_predicate),
+    )
+
+
+if __name__ == "__main__":
+    main()

+ 191 - 0
scripts/servermode/audit_mode_limits.py

@@ -0,0 +1,191 @@
+#!/usr/bin/env sentry exec
+
+from __future__ import annotations
+
+import abc
+import json  # noqa - I want the `indent` param
+import sys
+from collections import defaultdict
+
+import django.apps
+import django.urls
+
+
+def audit_mode_limits(format="json"):
+    """Lists which classes have had server mode decorators applied."""
+
+    from sentry.runner import configure
+
+    configure()
+    model_table = create_model_table()
+    view_table = create_view_table()
+
+    if format == "json":
+        json_repr = {
+            "models": ModelPresentation().as_json_repr(model_table),
+            "views": ViewPresentation().as_json_repr(view_table),
+        }
+        json.dump(json_repr, sys.stdout, indent=4)
+    elif format == "markdown":
+        ModelPresentation().print_markdown(model_table)
+        ViewPresentation().print_markdown(view_table)
+    else:
+        raise ValueError
+
+
+def create_model_table():
+    table = defaultdict(list)
+    for model_class in django.apps.apps.get_models():
+        if model_class._meta.app_label != "sentry":
+            continue
+        limit = getattr(model_class._meta, "_ModelAvailableOn__mode_limit", None)
+        key = (limit.modes, limit.read_only) if limit else None
+        table[key].append(model_class)
+    return table
+
+
+def create_view_table():
+    from sentry.api.base import Endpoint
+
+    def is_endpoint(view_function, bindings):
+        view_class = getattr(view_function, "view_class", None)
+        return view_class and issubclass(view_class, Endpoint)
+
+    def get_view_classes():
+        url_mappings = list(django.urls.get_resolver().reverse_dict.items())
+        for (view_function, bindings) in url_mappings:
+            if is_endpoint(view_function, bindings):
+                yield view_function.view_class
+
+    table = defaultdict(list)
+    for view_class in get_view_classes():
+        limit = getattr(view_class, "__mode_limit", None)
+        key = limit.modes if limit else None
+        table[key].append(view_class)
+
+    return table
+
+
+class ConsolePresentation(abc.ABC):
+    @property
+    @abc.abstractmethod
+    def table_label(self):
+        raise NotImplementedError
+
+    @abc.abstractmethod
+    def order(self, group):
+        raise NotImplementedError
+
+    @abc.abstractmethod
+    def get_group_label(self, key):
+        raise NotImplementedError
+
+    @abc.abstractmethod
+    def get_key_repr(self, key):
+        raise NotImplementedError
+
+    @staticmethod
+    def format_mode_set(modes):
+        if modes is None:
+            return None
+        return sorted(str(x) for x in modes)
+
+    @staticmethod
+    def format_value(value):
+        return f"{value.__module__}.{value.__name__}"
+
+    def normalize_table(self, table):
+        return {
+            key: sorted({self.format_value(value) for value in group})
+            for (key, group) in (sorted(table.items(), key=self.order))
+        }
+
+    def as_json_repr(self, table):
+        table = self.normalize_table(table)
+        return {
+            "total_count": sum(len(group) for group in table.values()),
+            "decorators": [
+                {
+                    "decorator": self.get_key_repr(group_key),
+                    "count": len(group),
+                    "values": group,
+                }
+                for group_key, group in table.items()
+            ],
+        }
+
+    def print_markdown(self, table):
+        table = self.normalize_table(table)
+
+        total_count = sum(len(group) for group in table.values())
+        table_header = f"{self.table_label} ({total_count})"
+        print("\n" + table_header)  # noqa
+        print("=" * len(table_header), end="\n\n")  # noqa
+
+        for (group_key, group) in table.items():
+            group_label = self.get_group_label(group_key)
+            group_header = f"{group_label} ({len(group)})"
+            print(group_header)  # noqa
+            print("-" * len(group_header), end="\n\n")  # noqa
+
+            for value in group:
+                print("  - " + value)  # noqa
+            print()  # noqa
+
+
+class ModelPresentation(ConsolePresentation):
+    @property
+    def table_label(self):
+        return "MODELS"
+
+    def order(self, group):
+        group_key, _model_group = group
+        if group_key is None:
+            return ()
+        write_modes, read_modes = group_key
+        return (
+            len(write_modes),
+            len(read_modes),
+            self.format_mode_set(write_modes),
+            self.format_mode_set(read_modes),
+        )
+
+    def get_key_repr(self, key):
+        if key is None:
+            return None
+        write_modes, read_modes = key
+        return {
+            "write_modes": self.format_mode_set(write_modes),
+            "read_modes": self.format_mode_set(read_modes),
+        }
+
+    def get_group_label(self, key):
+        if key is None:
+            return "No decorator"
+        write_modes, read_modes = key
+        if read_modes:
+            return (
+                f"{self.format_mode_set(write_modes)}, read_only={self.format_mode_set(read_modes)}"
+            )
+        else:
+            return self.format_mode_set(write_modes)
+
+
+class ViewPresentation(ConsolePresentation):
+    @property
+    def table_label(self):
+        return "VIEWS"
+
+    def order(self, group):
+        mode_set, _view_group = group
+        return len(mode_set or ()), self.format_mode_set(mode_set)
+
+    def get_group_label(self, key):
+        return self.format_mode_set(key) if key else "No decorator"
+
+    def get_key_repr(self, key):
+        return self.format_mode_set(key)
+
+
+if __name__ == "__main__":
+    audit_mode_limits()