Browse Source

fix(hc): Change RPC schemas to be represented as OpenAPI vs (#70876)

Mulligan on #68345 that resolves the import issue that required a
revert.
Mark Story 10 months ago
parent
commit
bad3cf781f

+ 26 - 4
.github/workflows/openapi-diff.yml

@@ -19,10 +19,6 @@ jobs:
     timeout-minutes: 90
     steps:
       - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
-        with:
-          # Avoid codecov error message related to SHA resolution:
-          # https://github.com/codecov/codecov-bash/blob/7100762afbc822b91806a6574658129fe0d23a7d/codecov#L891
-          fetch-depth: '2'
 
       - name: Check for python file changes
         uses: dorny/paths-filter@0bc4621a3135347011ad047f9ecf449bf72ce2bd # v3.0.0
@@ -58,3 +54,29 @@ jobs:
         if: steps.changes.outputs.api_docs == 'true'
         run: |
           npx json-diff@0.5.4 --color sentry-api-schema/openapi-derefed.json tests/apidocs/openapi-derefed.json
+
+      # Do RPC compatibility checks. Combined with public API audits
+      # to avoid workflow/job setup overhead.
+      - name: Dump RPC schema for new version
+        id: newschema
+        if: steps.changes.outputs.api_docs == 'true'
+        run: |
+          mkdir schemas
+          sentry rpcschema --partial > schemas/new_schema.json
+
+      - name: Dump RPC schema for master
+        id: oldschema
+        if: steps.changes.outputs.api_docs == 'true'
+        run: |
+          git fetch origin
+          git switch master
+          sentry rpcschema --partial > schemas/old_schema.json
+
+      - name: Invoke oasdiff on RPC schemas
+        uses: oasdiff/oasdiff-action/breaking@fc826b9f0d21b85b085842521c7a8cc445412c08 # v0.0.19
+        if: steps.changes.outputs.api_docs == 'true'
+        with:
+          base: ./schemas/old_schema.json
+          revision: ./schemas/new_schema.json
+          format: 'text'
+          fail-on-diff: true

+ 1 - 0
requirements-dev-frozen.txt

@@ -106,6 +106,7 @@ nodeenv==1.8.0
 oauthlib==3.1.0
 openai==1.3.5
 openapi-core==0.18.2
+openapi-pydantic==0.4.0
 openapi-schema-validator==0.6.2
 openapi-spec-validator==0.7.1
 orjson==3.10.0

+ 1 - 0
requirements-dev.txt

@@ -7,6 +7,7 @@ docker>=6
 time-machine>=2.13.0
 honcho>=1.1.0
 openapi-core>=0.18.2
+openapi-pydantic>=0.4.0
 pytest>=8
 pytest-cov>=4.0.0
 pytest-django>=4.8.0

+ 126 - 0
src/sentry/runner/commands/rpcschema.py

@@ -0,0 +1,126 @@
+import sys
+import traceback
+from collections.abc import Iterable
+from dataclasses import dataclass
+from typing import Any
+
+import click
+from django.urls import reverse
+
+from sentry.runner.decorators import configuration
+from sentry.utils import json
+
+
+@click.command("rpcschema")
+@click.option(
+    "--partial",
+    is_flag=True,
+    default=False,
+    help="Ignore RPC methods that produce errors.",
+)
+@click.option(
+    "--diagnose",
+    is_flag=True,
+    default=False,
+    help="List RPC methods that produce errors and suppress all other output.",
+)
+@configuration
+def rpcschema(diagnose: bool, partial: bool) -> None:
+    # Defered imports because openapi_pydantic is only installed as a dev dependency
+    try:
+        from openapi_pydantic import OpenAPI
+        from openapi_pydantic.util import PydanticSchema, construct_open_api_with_schema_class
+    except ImportError:
+        click.echo(
+            "Could not import openapi_pydantic. You must install requirements-dev to run this command"
+        )
+        return
+
+    from sentry.services.hybrid_cloud.rpc import (
+        RpcMethodSignature,
+        list_all_service_method_signatures,
+    )
+
+    @dataclass
+    class RpcSchemaEntry:
+        sig: RpcMethodSignature
+
+        @property
+        def api_path(self) -> str:
+            return reverse(
+                "sentry-api-0-rpc-service", args=(self.sig.service_key, self.sig.method_name)
+            )
+
+        def build_api_entry(self) -> dict[str, Any]:
+            param_schema, return_schema = self.sig.get_schemas()
+            return {
+                "post": {
+                    "description": "Execute an RPC",
+                    "requestBody": {
+                        "content": {
+                            "application/json": {
+                                "schema": PydanticSchema(schema_class=param_schema)
+                            }
+                        },
+                    },
+                    "responses": {
+                        "200": {
+                            "description": "Success",
+                            "content": {
+                                "application/json": {
+                                    "schema": PydanticSchema(schema_class=return_schema)
+                                }
+                            },
+                        }
+                    },
+                }
+            }
+
+    def create_spec(signatures: Iterable[RpcMethodSignature]) -> dict[str, Any]:
+        entries = [RpcSchemaEntry(sig) for sig in signatures]
+        path_dict = {entry.api_path: entry.build_api_entry() for entry in entries}
+
+        spec = OpenAPI.parse_obj(
+            dict(
+                info=dict(
+                    title="Sentry Internal RPC APIs",
+                    version="0.0.1",
+                ),
+                servers=[dict(url="https://sentry.io/")],  # TODO: Generify with setting value
+                paths=path_dict,
+            )
+        )
+        spec = construct_open_api_with_schema_class(spec)
+        return spec.dict(by_alias=True, exclude_none=True)
+
+    def create_partial_spec(
+        signatures: Iterable[RpcMethodSignature],
+    ) -> tuple[dict[str, Any], list[str]]:
+        stable_signatures: list[RpcMethodSignature] = []
+        error_reports: list[str] = []
+        for sig in signatures:
+            try:
+                create_spec([sig])
+            except Exception as e:
+                last_line = str(e).split("\n")[-1].strip()
+                error_reports.append(f"{sig!s}: {last_line}")
+                if not diagnose:
+                    traceback.print_exc()
+            else:
+                stable_signatures.append(sig)
+
+        return create_spec(stable_signatures), error_reports
+
+    all_signatures = list_all_service_method_signatures()
+
+    if diagnose or partial:
+        spec, error_reports = create_partial_spec(all_signatures)
+        if diagnose:
+            click.echo(f"Error count: {len(error_reports)}")  # noqa
+            for bad_sig in error_reports:
+                click.echo("- " + bad_sig)  # noqa
+    else:
+        spec = create_spec(all_signatures)
+
+    if not diagnose:
+        json.dump(spec, sys.stdout)

+ 1 - 0
src/sentry/runner/main.py

@@ -57,6 +57,7 @@ for cmd in map(
         "sentry.runner.commands.plugins.plugins",
         "sentry.runner.commands.queues.queues",
         "sentry.runner.commands.repair.repair",
+        "sentry.runner.commands.rpcschema.rpcschema",
         "sentry.runner.commands.run.run",
         "sentry.runner.commands.start.start",
         "sentry.runner.commands.tsdb.tsdb",

+ 44 - 6
src/sentry/services/hybrid_cloud/rpc.py

@@ -5,6 +5,7 @@ import hashlib
 import hmac
 import inspect
 import logging
+import pkgutil
 from abc import abstractmethod
 from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping, Sequence
 from contextlib import contextmanager
@@ -67,8 +68,26 @@ class RpcMethodSignature(SerializableFunctionSignature):
             self.base_service_cls.__name__, self.base_function.__name__, message
         )
 
+    @property
+    def service_key(self) -> str:
+        return self.base_service_cls.key
+
+    @property
+    def service_name(self) -> str:
+        return self.base_service_cls.__name__
+
+    @property
+    def method_name(self) -> str:
+        return self.base_function.__name__
+
+    def __repr__(self) -> str:
+        return f"{type(self).__name__}({self.service_name!r}, {self.method_name!r})"
+
+    def __str__(self) -> str:
+        return f"{self.service_name}.{self.method_name}"
+
     def get_name_segments(self) -> Sequence[str]:
-        return (self.base_service_cls.__name__, self.base_function.__name__)
+        return self.service_name, self.method_name
 
     def _extract_region_resolution(self) -> RegionResolutionStrategy | None:
         region_resolution = getattr(self.base_function, _REGION_RESOLUTION_ATTR, None)
@@ -123,6 +142,9 @@ class DelegatingRpcService(DelegatedBySiloMode["RpcService"]):
     def local_mode(self) -> SiloMode:
         return self._base_service_cls.local_mode
 
+    def __repr__(self):
+        return f"{type(self).__name__}({self._base_service_cls.__name__})"
+
     def deserialize_rpc_arguments(
         self, method_name: str, serial_arguments: ArgumentDict
     ) -> pydantic.BaseModel:
@@ -133,6 +155,9 @@ class DelegatingRpcService(DelegatedBySiloMode["RpcService"]):
         signature = self._signatures[method_name]
         return signature.deserialize_return_value(serial_response)
 
+    def get_all_signatures(self) -> Iterable[RpcMethodSignature]:
+        return self._signatures.values()
+
 
 def rpc_method(method: Callable[..., _T]) -> Callable[..., _T]:
     """Decorate methods to be exposed as part of the RPC interface.
@@ -200,7 +225,7 @@ class RpcService(abc.ABC):
                 raise RpcServiceSetupException(
                     cls.key, None, "`local_mode` class attribute (SiloMode) is required"
                 )
-        cls._signatures = cls._create_signatures()
+        cls._signatures = {sig.method_name: sig for sig in cls._create_signatures()}
 
     @classmethod
     def _get_all_rpc_methods(cls) -> Iterator[Callable[..., Any]]:
@@ -236,8 +261,7 @@ class RpcService(abc.ABC):
         raise NotImplementedError
 
     @classmethod
-    def _create_signatures(cls) -> Mapping[str, RpcMethodSignature]:
-        model_table = {}
+    def _create_signatures(cls) -> Iterable[RpcMethodSignature]:
         for base_method in cls._get_all_rpc_methods():
             try:
                 signature = RpcMethodSignature(cls, base_method)
@@ -246,8 +270,7 @@ class RpcService(abc.ABC):
                     cls.key, base_method.__name__, "Error on parameter model"
                 ) from e
             else:
-                model_table[base_method.__name__] = signature
-        return model_table
+                yield signature
 
     @classmethod
     def _get_and_validate_local_implementation(cls) -> RpcService:
@@ -350,6 +373,21 @@ class RpcService(abc.ABC):
         return service  # type: ignore[return-value]
 
 
+def list_all_service_method_signatures() -> Iterable[RpcMethodSignature]:
+    """List signatures of all RPC methods in the global registry."""
+
+    from sentry.services import hybrid_cloud as hybrid_cloud_service_pkg
+
+    # Forcibly import all service packages to ensure the global registry is fully populated
+    for _, name, _ in pkgutil.walk_packages(
+        hybrid_cloud_service_pkg.__path__, prefix=f"{hybrid_cloud_service_pkg.__name__}."
+    ):
+        __import__(name)
+
+    for service_obj in _global_service_registry.values():
+        yield from service_obj.get_all_signatures()
+
+
 class RpcResolutionException(Exception):
     """Indicate that an RPC service or method name could not be resolved."""
 

+ 10 - 0
src/sentry/services/hybrid_cloud/sig.py

@@ -138,3 +138,13 @@ class SerializableFunctionSignature:
     def deserialize_return_value(self, value: Any) -> Any:
         parsed = self._return_model.parse_obj({self._RETURN_MODEL_ATTR: value})
         return getattr(parsed, self._RETURN_MODEL_ATTR)
+
+    def get_schemas(self) -> tuple[type[pydantic.BaseModel], type[pydantic.BaseModel]]:
+        """Access the schema representations directly.
+
+        This generally should be needed only for reflective operations such as
+        checking for cross-version compatibility. Routine operations on the parameter
+        and return values should be done through the "serialize" and "deserialize"
+        methods.
+        """
+        return self._parameter_model, self._return_model

+ 30 - 0
tests/sentry/hybridcloud/test_sig.py

@@ -0,0 +1,30 @@
+import pydantic
+
+from sentry.services.hybrid_cloud.sig import SerializableFunctionSignature
+from sentry.testutils.cases import TestCase
+
+
+class SerializableFunctionSignatureTest(TestCase):
+    def test_signature(self):
+        class AnObject(pydantic.BaseModel):
+            a: int
+            b: str
+
+        def a_function(arg1: AnObject, arg2: AnObject) -> AnObject:
+            return AnObject(a=arg1.a + arg2.a, b=".".join((arg1.b, arg2.b)))
+
+        sig = SerializableFunctionSignature(a_function)
+        arg_values = dict(arg1=AnObject(a=1, b="foo"), arg2=AnObject(a=2, b="bar"))
+        serialized_arguments = sig.serialize_arguments(arg_values)
+        assert serialized_arguments == {"arg1": {"a": 1, "b": "foo"}, "arg2": {"a": 2, "b": "bar"}}
+
+        deserialized_arguments = sig.deserialize_arguments(serialized_arguments)
+        assert isinstance(deserialized_arguments, pydantic.BaseModel)
+        assert set(deserialized_arguments.__dict__.keys()) == {"arg1", "arg2"}
+        assert hasattr(deserialized_arguments, "arg1")
+        assert deserialized_arguments.arg1 == AnObject(a=1, b="foo")
+        assert hasattr(deserialized_arguments, "arg2")
+        assert deserialized_arguments.arg2 == AnObject(a=2, b="bar")
+
+        deserialized_return_value = sig.deserialize_return_value(dict(a=3, b="qux"))
+        assert deserialized_return_value == AnObject(a=3, b="qux")