Browse Source

feat(backup): Add export checkpointer (#80711)

This feature mirrors what we do for importing, where we periodically
"save our work", so that if we experience an ephemeral failure (timeout,
pod restart, OOM, etc), we can "pick up where we left off". For imports,
we do this by saving `ImportChunk`s to the database every time we import
a few models, which allows us to check what we've already imported to
avoiding redoing work when retrying.

We use a similar strategy here for exporting. For every model kind, we
save a copy of the JSON of all instances of that model that we exported
to some durable media in specially-named "checkpoint" files. If there is
a failure midway through the export process, when we try again, we can
scan for these files to quickly re-use them, rather than doing very
expensive and resource intensive database queries again. While this does
assume that the model state has stayed relatively consistent between
runs, this is already an assumption we make for exporting in general (we
can't export a "single snapshot in time" of the database at once
anyway).

A follow-up PR will implement a subclass of `ExportCheckpointer` for
GCP, which is what we will use to checkpoint large SaaS->SaaS
relocations.
Alex Zaslavsky 3 months ago
parent
commit
cebe40191b

+ 20 - 14
src/sentry/backup/crypto.py

@@ -54,8 +54,6 @@ class Encryptor(ABC):
     A `IO[bytes]`-wrapper that contains relevant information and methods to encrypt some an in-memory JSON-ifiable dict.
     """
 
-    __fp: IO[bytes]
-
     @abstractmethod
     def get_public_key_pem(self) -> bytes:
         pass
@@ -67,10 +65,10 @@ class LocalFileEncryptor(Encryptor):
     """
 
     def __init__(self, fp: IO[bytes]):
-        self.__fp = fp
+        self.__key = fp.read()
 
     def get_public_key_pem(self) -> bytes:
-        return self.__fp.read()
+        return self.__key
 
 
 class GCPKMSEncryptor(Encryptor):
@@ -82,7 +80,7 @@ class GCPKMSEncryptor(Encryptor):
     crypto_key_version: CryptoKeyVersion | None = None
 
     def __init__(self, fp: IO[bytes]):
-        self.__fp = fp
+        self.__key = fp.read()
 
     @classmethod
     def from_crypto_key_version(cls, crypto_key_version: CryptoKeyVersion) -> GCPKMSEncryptor:
@@ -93,7 +91,7 @@ class GCPKMSEncryptor(Encryptor):
     def get_public_key_pem(self) -> bytes:
         if self.crypto_key_version is None:
             # Read the user supplied configuration into the proper format.
-            gcp_kms_config_json = orjson.loads(self.__fp.read())
+            gcp_kms_config_json = orjson.loads(self.__key)
             try:
                 self.crypto_key_version = CryptoKeyVersion(**gcp_kms_config_json)
             except TypeError:
@@ -217,8 +215,6 @@ class Decryptor(ABC):
     tarball.
     """
 
-    __fp: IO[bytes]
-
     @abstractmethod
     def read(self) -> bytes:
         pass
@@ -234,14 +230,14 @@ class LocalFileDecryptor(Decryptor):
     """
 
     def __init__(self, fp: IO[bytes]):
-        self.__fp = fp
+        self.__key = fp.read()
 
     @classmethod
     def from_bytes(cls, b: bytes) -> LocalFileDecryptor:
         return cls(io.BytesIO(b))
 
     def read(self) -> bytes:
-        return self.__fp.read()
+        return self.__key
 
     def decrypt_data_encryption_key(self, unwrapped: UnwrappedEncryptedExportTarball) -> bytes:
         """
@@ -249,7 +245,7 @@ class LocalFileDecryptor(Decryptor):
         """
 
         # Compare the public and private key, to ensure that they are a match.
-        private_key_pem = self.__fp.read()
+        private_key_pem = self.__key
         private_key = serialization.load_pem_private_key(
             private_key_pem,
             password=None,
@@ -286,17 +282,17 @@ class GCPKMSDecryptor(Decryptor):
     """
 
     def __init__(self, fp: IO[bytes]):
-        self.__fp = fp
+        self.__key = fp.read()
 
     @classmethod
     def from_bytes(cls, b: bytes) -> GCPKMSDecryptor:
         return cls(io.BytesIO(b))
 
     def read(self) -> bytes:
-        return self.__fp.read()
+        return self.__key
 
     def decrypt_data_encryption_key(self, unwrapped: UnwrappedEncryptedExportTarball) -> bytes:
-        gcp_kms_config_bytes = self.__fp.read()
+        gcp_kms_config_bytes = self.__key
 
         # Read the user supplied configuration into the proper format.
         gcp_kms_config_json = orjson.loads(gcp_kms_config_bytes)
@@ -345,3 +341,13 @@ def decrypt_encrypted_tarball(tarball: IO[bytes], decryptor: Decryptor) -> bytes
     decrypted_dek = decryptor.decrypt_data_encryption_key(unwrapped)
     fernet = Fernet(decrypted_dek)
     return fernet.decrypt(unwrapped.encrypted_json_blob)
+
+
+class EncryptorDecryptorPair:
+    """
+    An Encryptor and Decryptor that use paired public and private keys, respectively.
+    """
+
+    def __init__(self, encryptor: Encryptor, decryptor: Decryptor):
+        self.encryptor = encryptor
+        self.decryptor = decryptor

+ 94 - 9
src/sentry/backup/exports.py

@@ -5,12 +5,15 @@ import io
 # We have to use the default JSON interface to enable pretty-printing on export. When loading JSON,
 # we still use the one from `sentry.utils`, imported as `sentry_json` below.
 import json as builtin_json  # noqa: S003
+from abc import ABC, abstractmethod
 from typing import IO
 
 import orjson
 
-from sentry.backup.crypto import Encryptor, create_encrypted_export_tarball
+from sentry.backup.crypto import Encryptor, EncryptorDecryptorPair, create_encrypted_export_tarball
 from sentry.backup.dependencies import (
+    ImportKind,
+    NormalizedModelName,
     PrimaryKeyMap,
     dependencies,
     get_model_name,
@@ -20,6 +23,7 @@ from sentry.backup.helpers import Filter, Printer
 from sentry.backup.scopes import ExportScope
 from sentry.backup.services.import_export.model import (
     RpcExportError,
+    RpcExportOk,
     RpcExportScope,
     RpcFilter,
     RpcPrimaryKeyMap,
@@ -41,6 +45,69 @@ class ExportingError(Exception):
         self.context = context
 
 
+class ExportCheckpointerError(Exception):
+    pass
+
+
+class ExportCheckpointer(ABC):
+    """
+    For very large exports, the exporting environment may fall over half-way through the process:
+    the thread running it may hit some timeout, or it may OOM, or fail for some other ephemeral
+    reason. To help in such situations, we'd like an API for saving "checkpoints" during the export.
+
+    This class provides per-model checkpointing support for exports. Since there is a topologically
+    sorted order of models being exported, as we move through this list, we can save the exported
+    JSON for each kind of model in order to some stable media (disk, GCP, etc). If there is a
+    failure late in the export process, when it is retried, the exporter can check if that
+    particular model already exists in the checkpointer's cache, thereby avoiding redoing the work
+    of pulling the models from the database, processing them, etc. This ensures that in most retry
+    situations, we can quickly "re-ingest" already-exported models in memory and pick up where we
+    left off.
+    """
+
+    def _parse_cached_json(self, json_data: bytes) -> RpcExportOk | None:
+        max_pk = 0
+        pk_map = PrimaryKeyMap()
+        models = orjson.loads(json_data)
+        for model in models:
+            model_name = model.get("model", None)
+            pk = model.get("pk", None)
+            if model_name is None or pk is None:
+                raise ExportCheckpointerError("Improperly formatted entry")
+
+            pk_map.insert(model_name, pk, pk, ImportKind.Inserted)
+            if pk > max_pk:
+                max_pk = pk
+
+        return RpcExportOk(
+            mapped_pks=RpcPrimaryKeyMap.into_rpc(pk_map), max_pk=max_pk, json_data=json_data
+        )
+
+    @abstractmethod
+    def get(self, model_name: NormalizedModelName) -> RpcExportOk | None:
+        pass
+
+    @abstractmethod
+    def add(self, model_name: NormalizedModelName, json_data: str) -> None:
+        pass
+
+
+class NoopExportCheckpointer(ExportCheckpointer):
+    """
+    A noop checkpointer - that is, it doesn't write or read any checkpoints, always returning None.
+    This means that no checkpointing ever occurs.
+    """
+
+    def __init__(self, crypto: EncryptorDecryptorPair | None, printer: Printer):
+        pass
+
+    def get(self, model_name: NormalizedModelName) -> RpcExportOk | None:
+        return None
+
+    def add(self, model_name: NormalizedModelName, json_data: str) -> None:
+        return None
+
+
 def _export(
     dest: IO[bytes],
     scope: ExportScope,
@@ -49,6 +116,7 @@ def _export(
     indent: int = 2,
     filter_by: Filter | None = None,
     printer: Printer,
+    checkpointer: ExportCheckpointer | None = None,
 ):
     """
     Exports core data for the Sentry installation.
@@ -68,6 +136,7 @@ def _export(
         printer.echo(errText, err=True)
         raise RuntimeError(errText)
 
+    cache = checkpointer if checkpointer is not None else NoopExportCheckpointer(None, printer)
     json_export = []
     pk_map = PrimaryKeyMap()
     allowed_relocation_scopes = scope.value
@@ -119,13 +188,18 @@ def _export(
 
         dep_models = {get_model_name(d) for d in model_relations.get_dependencies_for_relocation()}
         export_by_model = ImportExportService.get_exporter_for_model(model)
-        result = export_by_model(
-            export_model_name=str(model_name),
-            scope=RpcExportScope.into_rpc(scope),
-            from_pk=0,
-            filter_by=[RpcFilter.into_rpc(f) for f in filters],
-            pk_map=RpcPrimaryKeyMap.into_rpc(pk_map.partition(dep_models)),
-            indent=indent,
+        cached_result = cache.get(model_name)
+        result = (
+            cached_result
+            if cached_result is not None
+            else export_by_model(
+                export_model_name=str(model_name),
+                scope=RpcExportScope.into_rpc(scope),
+                from_pk=0,
+                filter_by=[RpcFilter.into_rpc(f) for f in filters],
+                pk_map=RpcPrimaryKeyMap.into_rpc(pk_map.partition(dep_models)),
+                indent=indent,
+            )
         )
 
         if isinstance(result, RpcExportError):
@@ -133,11 +207,14 @@ def _export(
             raise ExportingError(result)
 
         pk_map.extend(result.mapped_pks.from_rpc())
+        json_models = orjson.loads(result.json_data)
+        if cached_result is None:
+            cache.add(model_name, json_models)
 
         # TODO(getsentry/team-ospo#190): Since the structure of this data is very predictable (an
         # array of serialized model objects), we could probably avoid re-ingesting the JSON string
         # as a future optimization.
-        for json_model in orjson.loads(result.json_data):
+        for json_model in json_models:
             json_export.append(json_model)
 
     # If no `encryptor` argument was passed in, this is an unencrypted export, so we can just dump
@@ -158,6 +235,7 @@ def export_in_user_scope(
     user_filter: set[str] | None = None,
     indent: int = 2,
     printer: Printer,
+    checkpointer: ExportCheckpointer | None = None,
 ):
     """
     Perform an export in the `User` scope, meaning that only models with `RelocationScope.User` will
@@ -174,6 +252,7 @@ def export_in_user_scope(
         filter_by=Filter(User, "username", user_filter) if user_filter is not None else None,
         indent=indent,
         printer=printer,
+        checkpointer=checkpointer,
     )
 
 
@@ -184,6 +263,7 @@ def export_in_organization_scope(
     org_filter: set[str] | None = None,
     indent: int = 2,
     printer: Printer,
+    checkpointer: ExportCheckpointer | None = None,
 ):
     """
     Perform an export in the `Organization` scope, meaning that only models with
@@ -201,6 +281,7 @@ def export_in_organization_scope(
         filter_by=Filter(Organization, "slug", org_filter) if org_filter is not None else None,
         indent=indent,
         printer=printer,
+        checkpointer=checkpointer,
     )
 
 
@@ -210,6 +291,7 @@ def export_in_config_scope(
     encryptor: Encryptor | None = None,
     indent: int = 2,
     printer: Printer,
+    checkpointer: ExportCheckpointer | None = None,
 ):
     """
     Perform an export in the `Config` scope, meaning that only models directly related to the global
@@ -226,6 +308,7 @@ def export_in_config_scope(
         filter_by=Filter(User, "pk", import_export_service.get_all_globally_privileged_users()),
         indent=indent,
         printer=printer,
+        checkpointer=checkpointer,
     )
 
 
@@ -235,6 +318,7 @@ def export_in_global_scope(
     encryptor: Encryptor | None = None,
     indent: int = 2,
     printer: Printer,
+    checkpointer: ExportCheckpointer | None = None,
 ):
     """
     Perform an export in the `Global` scope, meaning that all models will be exported from the
@@ -246,4 +330,5 @@ def export_in_global_scope(
         encryptor=encryptor,
         indent=indent,
         printer=printer,
+        checkpointer=checkpointer,
     )

+ 42 - 8
src/sentry/testutils/helpers/backups.py

@@ -31,6 +31,7 @@ from sentry.backup.dependencies import (
     sorted_dependencies,
 )
 from sentry.backup.exports import (
+    ExportCheckpointer,
     export_in_config_scope,
     export_in_global_scope,
     export_in_organization_scope,
@@ -147,7 +148,12 @@ class ValidationError(Exception):
         self.info = info
 
 
-def export_to_file(path: Path, scope: ExportScope, filter_by: set[str] | None = None) -> Any:
+def export_to_file(
+    path: Path,
+    scope: ExportScope,
+    filter_by: set[str] | None = None,
+    checkpointer: ExportCheckpointer | None = None,
+) -> Any:
     """
     Helper function that exports the current state of the database to the specified file.
     """
@@ -157,13 +163,31 @@ def export_to_file(path: Path, scope: ExportScope, filter_by: set[str] | None =
         # These functions are just thin wrappers, but its best to exercise them directly anyway in
         # case that ever changes.
         if scope == ExportScope.Global:
-            export_in_global_scope(tmp_file, printer=NOOP_PRINTER)
+            export_in_global_scope(
+                tmp_file,
+                printer=NOOP_PRINTER,
+                checkpointer=checkpointer,
+            )
         elif scope == ExportScope.Config:
-            export_in_config_scope(tmp_file, printer=NOOP_PRINTER)
+            export_in_config_scope(
+                tmp_file,
+                printer=NOOP_PRINTER,
+                checkpointer=checkpointer,
+            )
         elif scope == ExportScope.Organization:
-            export_in_organization_scope(tmp_file, org_filter=filter_by, printer=NOOP_PRINTER)
+            export_in_organization_scope(
+                tmp_file,
+                org_filter=filter_by,
+                printer=NOOP_PRINTER,
+                checkpointer=checkpointer,
+            )
         elif scope == ExportScope.User:
-            export_in_user_scope(tmp_file, user_filter=filter_by, printer=NOOP_PRINTER)
+            export_in_user_scope(
+                tmp_file,
+                user_filter=filter_by,
+                printer=NOOP_PRINTER,
+                checkpointer=checkpointer,
+            )
         else:
             raise AssertionError(f"Unknown `ExportScope`: `{scope.name}`")
 
@@ -193,7 +217,9 @@ def export_to_encrypted_tarball(
     path: Path,
     scope: ExportScope,
     *,
+    rsa_key_pair: tuple[bytes, bytes],
     filter_by: set[str] | None = None,
+    checkpointer: ExportCheckpointer | None = None,
 ) -> Any:
     """
     Helper function that exports the current state of the database to the specified encrypted
@@ -201,7 +227,7 @@ def export_to_encrypted_tarball(
     """
 
     # Generate a public-private key pair.
-    (private_key_pem, public_key_pem) = generate_rsa_key_pair()
+    (private_key_pem, public_key_pem) = rsa_key_pair
     public_key_fp = io.BytesIO(public_key_pem)
 
     # Run the appropriate `export_in_...` command with encryption enabled.
@@ -211,11 +237,17 @@ def export_to_encrypted_tarball(
         # case that ever changes.
         if scope == ExportScope.Global:
             export_in_global_scope(
-                tmp_file, encryptor=LocalFileEncryptor(public_key_fp), printer=NOOP_PRINTER
+                tmp_file,
+                encryptor=LocalFileEncryptor(public_key_fp),
+                printer=NOOP_PRINTER,
+                checkpointer=checkpointer,
             )
         elif scope == ExportScope.Config:
             export_in_config_scope(
-                tmp_file, encryptor=LocalFileEncryptor(public_key_fp), printer=NOOP_PRINTER
+                tmp_file,
+                encryptor=LocalFileEncryptor(public_key_fp),
+                printer=NOOP_PRINTER,
+                checkpointer=checkpointer,
             )
         elif scope == ExportScope.Organization:
             export_in_organization_scope(
@@ -223,6 +255,7 @@ def export_to_encrypted_tarball(
                 encryptor=LocalFileEncryptor(public_key_fp),
                 org_filter=filter_by,
                 printer=NOOP_PRINTER,
+                checkpointer=checkpointer,
             )
         elif scope == ExportScope.User:
             export_in_user_scope(
@@ -230,6 +263,7 @@ def export_to_encrypted_tarball(
                 encryptor=LocalFileEncryptor(public_key_fp),
                 user_filter=filter_by,
                 printer=NOOP_PRINTER,
+                checkpointer=checkpointer,
             )
         else:
             raise AssertionError(f"Unknown `ExportScope`: `{scope.name}`")

+ 0 - 6
tests/sentry/backup/test_exhaustive.py

@@ -29,12 +29,6 @@ class ExhaustiveTests(BackupTransactionTestCase):
     Ensure that a database with all exportable models filled out still works.
     """
 
-    def export_to_tmp_file_and_clear_database(self, tmp_dir, reset_pks) -> Path:
-        tmp_path = Path(tmp_dir).joinpath(f"{self._testMethodName}.expect.json")
-        export_to_file(tmp_path, ExportScope.Global)
-        clear_database(reset_pks=reset_pks)
-        return tmp_path
-
     # Note: the "clean_pks" version of this test lives in
     # `test_sanitize.py::SanitizationExhaustiveTests`. Because these tests are slow, we want to
     # reduce duplication, so we only use that one in that particular location.

+ 322 - 12
tests/sentry/backup/test_exports.py

@@ -1,21 +1,37 @@
 from __future__ import annotations
 
 from copy import deepcopy
+from io import BytesIO
 from pathlib import Path
 from tempfile import TemporaryDirectory
 from typing import Any
 
+from orjson import JSONDecodeError, dumps
+
+from sentry.backup.crypto import (
+    DecryptionError,
+    EncryptorDecryptorPair,
+    LocalFileDecryptor,
+    LocalFileEncryptor,
+    create_encrypted_export_tarball,
+    decrypt_encrypted_tarball,
+)
 from sentry.backup.dependencies import NormalizedModelName, get_model, get_model_name
+from sentry.backup.exports import ExportCheckpointer, ExportCheckpointerError
+from sentry.backup.helpers import Printer
 from sentry.backup.scopes import ExportScope
+from sentry.backup.services.import_export.model import RpcExportOk
 from sentry.db import models
 from sentry.models.options.option import Option
 from sentry.models.organization import Organization
 from sentry.models.organizationmember import OrganizationMember
 from sentry.models.orgauthtoken import OrgAuthToken
 from sentry.testutils.helpers.backups import (
+    NOOP_PRINTER,
     BackupTransactionTestCase,
     export_to_encrypted_tarball,
     export_to_file,
+    generate_rsa_key_pair,
 )
 from sentry.testutils.helpers.datetime import freeze_time
 from sentry.users.models.email import Email
@@ -26,6 +42,65 @@ from sentry.users.models.userrole import UserRole, UserRoleUser
 from tests.sentry.backup import get_matching_exportable_models
 
 
+class FakeExportCheckpointer(ExportCheckpointer):
+    cache_hits: int = 0
+    cache_misses: int = 0
+    cache_writes: int = 0
+
+    def __init__(
+        self,
+        crypto: EncryptorDecryptorPair | None,
+        printer: Printer,
+        tmp_dir: str,
+        test_method_name: str,
+    ):
+        self.__crypto = crypto
+        self.__printer = printer
+        self.__tmp_dir = tmp_dir
+        self.__test_method_name = test_method_name
+
+    def _get_file_name(self, model_name: NormalizedModelName) -> Path:
+        if self.__crypto is None:
+            return Path(self.__tmp_dir).joinpath(
+                f"_{self.__test_method_name}.checkpoint.{str(model_name)}.json"
+            )
+        else:
+            return Path(self.__tmp_dir).joinpath(
+                f"_{self.__test_method_name}.checkpoint.{str(model_name)}.enc.tar"
+            )
+
+    def get(self, model_name: NormalizedModelName) -> RpcExportOk | None:
+        file_name = self._get_file_name(model_name)
+        try:
+            with open(file_name, "rb") as fp:
+                json_data = (
+                    decrypt_encrypted_tarball(fp, self.__crypto.decryptor)
+                    if self.__crypto is not None
+                    else fp.read()
+                )
+                parsed_json = self._parse_cached_json(json_data)
+                if parsed_json is None:
+                    self.cache_misses += 1
+                else:
+                    self.cache_hits += 1
+
+                return parsed_json
+        except (FileNotFoundError, DecryptionError, JSONDecodeError, ExportCheckpointerError):
+            self.cache_misses += 1
+            return None
+
+    def add(self, model_name: NormalizedModelName, json_export: Any) -> None:
+        file_name = self._get_file_name(model_name)
+        with open(file_name, "wb") as fp:
+            out_bytes = (
+                create_encrypted_export_tarball(json_export, self.__crypto.encryptor).getvalue()
+                if self.__crypto is not None
+                else dumps(json_export)
+            )
+            fp.write(out_bytes)
+            self.cache_writes += 1
+
+
 class ExportTestCase(BackupTransactionTestCase):
     @staticmethod
     def count(data: Any, model: type[models.base.BaseModel]) -> int:
@@ -48,23 +123,32 @@ class ExportTestCase(BackupTransactionTestCase):
 
     def export(
         self,
-        tmp_dir,
+        tmp_dir: str,
         *,
         scope: ExportScope,
         filter_by: set[str] | None = None,
+        checkpointer: ExportCheckpointer | None = None,
     ) -> Any:
         tmp_path = Path(tmp_dir).joinpath(f"{self._testMethodName}.json")
-        return export_to_file(tmp_path, scope=scope, filter_by=filter_by)
+        return export_to_file(tmp_path, scope=scope, filter_by=filter_by, checkpointer=checkpointer)
 
     def export_and_encrypt(
         self,
-        tmp_dir,
+        tmp_dir: str,
         *,
         scope: ExportScope,
+        rsa_key_pair: tuple[bytes, bytes],
         filter_by: set[str] | None = None,
+        checkpointer: ExportCheckpointer | None = None,
     ) -> Any:
         tmp_path = Path(tmp_dir).joinpath(f"{self._testMethodName}.enc.tar")
-        return export_to_encrypted_tarball(tmp_path, scope=scope, filter_by=filter_by)
+        return export_to_encrypted_tarball(
+            tmp_path,
+            scope=scope,
+            filter_by=filter_by,
+            checkpointer=checkpointer,
+            rsa_key_pair=rsa_key_pair,
+        )
 
 
 class ScopingTests(ExportTestCase):
@@ -101,17 +185,131 @@ class ScopingTests(ExportTestCase):
     def test_user_export_scoping(self):
         self.create_exhaustive_instance(is_superadmin=True)
         with TemporaryDirectory() as tmp_dir:
-            unencrypted = self.export(tmp_dir, scope=ExportScope.User)
+            unencrypted_checkpointer = FakeExportCheckpointer(
+                crypto=None,
+                printer=NOOP_PRINTER,
+                tmp_dir=tmp_dir,
+                test_method_name=self._testMethodName,
+            )
+            unencrypted = self.export(
+                tmp_dir, scope=ExportScope.User, checkpointer=unencrypted_checkpointer
+            )
             self.verify_model_inclusion(unencrypted, ExportScope.User)
-            assert unencrypted == self.export_and_encrypt(tmp_dir, scope=ExportScope.User)
+
+            first_pass_cache_writes = unencrypted_checkpointer.cache_writes
+            assert unencrypted_checkpointer.cache_hits == 0
+            assert unencrypted_checkpointer.cache_misses > 0
+            assert unencrypted_checkpointer.cache_misses == first_pass_cache_writes
+
+            # The following re-run of `self.export` should only use the checkpoint cache. After
+            # completion, we should have no new writes, and no new misses; only hits should be
+            # incremented.
+            self.export(tmp_dir, scope=ExportScope.User, checkpointer=unencrypted_checkpointer)
+            assert unencrypted_checkpointer.cache_hits == first_pass_cache_writes
+            assert unencrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert unencrypted_checkpointer.cache_writes == first_pass_cache_writes
+
+            rsa_key_pair = generate_rsa_key_pair()
+            (private_key_pem, public_key_pem) = rsa_key_pair
+            encrypted_checkpointer = FakeExportCheckpointer(
+                crypto=EncryptorDecryptorPair(
+                    encryptor=LocalFileEncryptor(BytesIO(public_key_pem)),
+                    decryptor=LocalFileDecryptor(BytesIO(private_key_pem)),
+                ),
+                printer=NOOP_PRINTER,
+                tmp_dir=tmp_dir,
+                test_method_name=self._testMethodName,
+            )
+            encrypted = self.export_and_encrypt(
+                tmp_dir,
+                scope=ExportScope.User,
+                rsa_key_pair=rsa_key_pair,
+                checkpointer=encrypted_checkpointer,
+            )
+            assert unencrypted == encrypted
+
+            assert encrypted_checkpointer.cache_hits == 0
+            assert encrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_writes == first_pass_cache_writes
+
+            # The following re-run of `self.export_and_encrypt` should only use the checkpoint
+            # cache. After completion, we should have no new writes, and no new misses; only hits
+            # should be incremented.
+            self.export_and_encrypt(
+                tmp_dir,
+                scope=ExportScope.User,
+                rsa_key_pair=rsa_key_pair,
+                checkpointer=encrypted_checkpointer,
+            )
+            assert encrypted_checkpointer.cache_hits == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_writes == first_pass_cache_writes
 
     @freeze_time("2023-10-11 18:00:00")
     def test_organization_export_scoping(self):
         self.create_exhaustive_instance(is_superadmin=True)
         with TemporaryDirectory() as tmp_dir:
-            unencrypted = self.export(tmp_dir, scope=ExportScope.Organization)
+            unencrypted_checkpointer = FakeExportCheckpointer(
+                crypto=None,
+                printer=NOOP_PRINTER,
+                tmp_dir=tmp_dir,
+                test_method_name=self._testMethodName,
+            )
+            unencrypted = self.export(
+                tmp_dir, scope=ExportScope.Organization, checkpointer=unencrypted_checkpointer
+            )
             self.verify_model_inclusion(unencrypted, ExportScope.Organization)
-            assert unencrypted == self.export_and_encrypt(tmp_dir, scope=ExportScope.Organization)
+
+            first_pass_cache_writes = unencrypted_checkpointer.cache_writes
+            assert unencrypted_checkpointer.cache_hits == 0
+            assert unencrypted_checkpointer.cache_misses > 0
+            assert unencrypted_checkpointer.cache_misses == first_pass_cache_writes
+
+            # The following re-run of `self.export` should only use the checkpoint cache. After
+            # completion, we should have no new writes, and no new misses; only hits should be
+            # incremented.
+            self.export(
+                tmp_dir, scope=ExportScope.Organization, checkpointer=unencrypted_checkpointer
+            )
+            assert unencrypted_checkpointer.cache_hits == first_pass_cache_writes
+            assert unencrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert unencrypted_checkpointer.cache_writes == first_pass_cache_writes
+
+            rsa_key_pair = generate_rsa_key_pair()
+            (private_key_pem, public_key_pem) = rsa_key_pair
+            encrypted_checkpointer = FakeExportCheckpointer(
+                crypto=EncryptorDecryptorPair(
+                    encryptor=LocalFileEncryptor(BytesIO(public_key_pem)),
+                    decryptor=LocalFileDecryptor(BytesIO(private_key_pem)),
+                ),
+                printer=NOOP_PRINTER,
+                tmp_dir=tmp_dir,
+                test_method_name=self._testMethodName,
+            )
+            encrypted = self.export_and_encrypt(
+                tmp_dir,
+                scope=ExportScope.Organization,
+                rsa_key_pair=rsa_key_pair,
+                checkpointer=encrypted_checkpointer,
+            )
+            assert unencrypted == encrypted
+
+            assert encrypted_checkpointer.cache_hits == 0
+            assert encrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_writes == first_pass_cache_writes
+
+            # The following re-run of `self.export_and_encrypt` should only use the checkpoint
+            # cache. After completion, we should have no new writes, and no new misses; only hits
+            # should be incremented.
+            self.export_and_encrypt(
+                tmp_dir,
+                scope=ExportScope.Organization,
+                rsa_key_pair=rsa_key_pair,
+                checkpointer=encrypted_checkpointer,
+            )
+            assert encrypted_checkpointer.cache_hits == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_writes == first_pass_cache_writes
 
     @freeze_time("2023-10-11 18:00:00")
     def test_config_export_scoping(self):
@@ -123,17 +321,129 @@ class ScopingTests(ExportTestCase):
         self.create_exhaustive_api_keys_for_user(staff)
         self.create_exhaustive_api_keys_for_user(superuser)
         with TemporaryDirectory() as tmp_dir:
-            unencrypted = self.export(tmp_dir, scope=ExportScope.Config)
+            unencrypted_checkpointer = FakeExportCheckpointer(
+                crypto=None,
+                printer=NOOP_PRINTER,
+                tmp_dir=tmp_dir,
+                test_method_name=self._testMethodName,
+            )
+            unencrypted = self.export(
+                tmp_dir, scope=ExportScope.Config, checkpointer=unencrypted_checkpointer
+            )
             self.verify_model_inclusion(unencrypted, ExportScope.Config)
-            assert unencrypted == self.export_and_encrypt(tmp_dir, scope=ExportScope.Config)
+
+            first_pass_cache_writes = unencrypted_checkpointer.cache_writes
+            assert unencrypted_checkpointer.cache_hits == 0
+            assert unencrypted_checkpointer.cache_misses > 0
+            assert unencrypted_checkpointer.cache_misses == first_pass_cache_writes
+
+            # The following re-run of `self.export` should only use the checkpoint cache. After
+            # completion, we should have no new writes, and no new misses; only hits should be
+            # incremented.
+            self.export(tmp_dir, scope=ExportScope.Config, checkpointer=unencrypted_checkpointer)
+            assert unencrypted_checkpointer.cache_hits == first_pass_cache_writes
+            assert unencrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert unencrypted_checkpointer.cache_writes == first_pass_cache_writes
+
+            rsa_key_pair = generate_rsa_key_pair()
+            (private_key_pem, public_key_pem) = rsa_key_pair
+            encrypted_checkpointer = FakeExportCheckpointer(
+                crypto=EncryptorDecryptorPair(
+                    encryptor=LocalFileEncryptor(BytesIO(public_key_pem)),
+                    decryptor=LocalFileDecryptor(BytesIO(private_key_pem)),
+                ),
+                printer=NOOP_PRINTER,
+                tmp_dir=tmp_dir,
+                test_method_name=self._testMethodName,
+            )
+            encrypted = self.export_and_encrypt(
+                tmp_dir,
+                scope=ExportScope.Config,
+                rsa_key_pair=rsa_key_pair,
+                checkpointer=encrypted_checkpointer,
+            )
+            assert unencrypted == encrypted
+
+            assert encrypted_checkpointer.cache_hits == 0
+            assert encrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_writes == first_pass_cache_writes
+
+            # The following re-run of `self.export_and_encrypt` should only use the checkpoint
+            # cache. After completion, we should have no new writes, and no new misses; only hits
+            # should be incremented.
+            self.export_and_encrypt(
+                tmp_dir,
+                scope=ExportScope.Config,
+                rsa_key_pair=rsa_key_pair,
+                checkpointer=encrypted_checkpointer,
+            )
+            assert encrypted_checkpointer.cache_hits == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_writes == first_pass_cache_writes
 
     @freeze_time("2023-10-11 18:00:00")
     def test_global_export_scoping(self):
         self.create_exhaustive_instance(is_superadmin=True)
         with TemporaryDirectory() as tmp_dir:
-            unencrypted = self.export(tmp_dir, scope=ExportScope.Global)
+            unencrypted_checkpointer = FakeExportCheckpointer(
+                crypto=None,
+                printer=NOOP_PRINTER,
+                tmp_dir=tmp_dir,
+                test_method_name=self._testMethodName,
+            )
+            unencrypted = self.export(
+                tmp_dir, scope=ExportScope.Global, checkpointer=unencrypted_checkpointer
+            )
             self.verify_model_inclusion(unencrypted, ExportScope.Global)
-            assert unencrypted == self.export_and_encrypt(tmp_dir, scope=ExportScope.Global)
+
+            first_pass_cache_writes = unencrypted_checkpointer.cache_writes
+            assert unencrypted_checkpointer.cache_hits == 0
+            assert unencrypted_checkpointer.cache_misses > 0
+            assert unencrypted_checkpointer.cache_misses == first_pass_cache_writes
+
+            # The following re-run of `self.export` should only use the checkpoint cache. After
+            # completion, we should have no new writes, and no new misses; only hits should be
+            # incremented.
+            self.export(tmp_dir, scope=ExportScope.Global, checkpointer=unencrypted_checkpointer)
+            assert unencrypted_checkpointer.cache_hits == first_pass_cache_writes
+            assert unencrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert unencrypted_checkpointer.cache_writes == first_pass_cache_writes
+
+            rsa_key_pair = generate_rsa_key_pair()
+            (private_key_pem, public_key_pem) = rsa_key_pair
+            encrypted_checkpointer = FakeExportCheckpointer(
+                crypto=EncryptorDecryptorPair(
+                    encryptor=LocalFileEncryptor(BytesIO(public_key_pem)),
+                    decryptor=LocalFileDecryptor(BytesIO(private_key_pem)),
+                ),
+                printer=NOOP_PRINTER,
+                tmp_dir=tmp_dir,
+                test_method_name=self._testMethodName,
+            )
+            encrypted = self.export_and_encrypt(
+                tmp_dir,
+                scope=ExportScope.Global,
+                rsa_key_pair=rsa_key_pair,
+                checkpointer=encrypted_checkpointer,
+            )
+            assert unencrypted == encrypted
+
+            assert encrypted_checkpointer.cache_hits == 0
+            assert encrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_writes == first_pass_cache_writes
+
+            # The following re-run of `self.export_and_encrypt` should only use the checkpoint
+            # cache. After completion, we should have no new writes, and no new misses; only hits
+            # should be incremented.
+            self.export_and_encrypt(
+                tmp_dir,
+                scope=ExportScope.Global,
+                rsa_key_pair=rsa_key_pair,
+                checkpointer=encrypted_checkpointer,
+            )
+            assert encrypted_checkpointer.cache_hits == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_misses == first_pass_cache_writes
+            assert encrypted_checkpointer.cache_writes == first_pass_cache_writes
 
 
 # Filters should work identically in both silo and monolith modes, so no need to repeat the tests