Browse Source

ref: fix typing in sentry.testutils.factories (#67046)

<!-- Describe your PR here. -->
anthony sottile 1 year ago
parent
commit
f6cb4753ae

+ 0 - 1
pyproject.toml

@@ -485,7 +485,6 @@ module = [
     "sentry.templatetags.sentry_plugins",
     "sentry.testutils.asserts",
     "sentry.testutils.cases",
-    "sentry.testutils.factories",
     "sentry.testutils.fixtures",
     "sentry.testutils.helpers.features",
     "sentry.testutils.helpers.notifications",

+ 1 - 1
src/sentry/models/integrations/integration.py

@@ -119,7 +119,7 @@ class Integration(DefaultFieldsModel):
         """
         Add an organization to this integration.
 
-        Returns False if the OrganizationIntegration was not created
+        Returns None if the OrganizationIntegration was not created
         """
         from sentry.models.integrations.organization_integration import OrganizationIntegration
 

+ 33 - 29
src/sentry/testutils/factories.py

@@ -4,6 +4,7 @@ import contextlib
 import io
 import os
 import random
+import zipfile
 from base64 import b64encode
 from binascii import hexlify
 from collections.abc import Mapping, Sequence
@@ -304,22 +305,22 @@ class Factories:
         if not name:
             name = petname.generate(2, " ", letters=10).title()
 
-        if region is None or SiloMode.get_current_mode() == SiloMode.MONOLITH:
-            region_name = get_local_region().name
-            org_creation_context = contextlib.nullcontext()
-        else:
-            if isinstance(region, Region):
-                region_name = region.name
+        with contextlib.ExitStack() as ctx:
+            if region is None or SiloMode.get_current_mode() == SiloMode.MONOLITH:
+                region_name = get_local_region().name
             else:
-                region_obj = get_region_by_name(region)  # Verify it exists
-                region_name = region_obj.name
-            org_creation_context = override_settings(
-                SILO_MODE=SiloMode.REGION, SENTRY_REGION=region_name
-            )
+                if isinstance(region, Region):
+                    region_name = region.name
+                else:
+                    region_obj = get_region_by_name(region)  # Verify it exists
+                    region_name = region_obj.name
+
+                ctx.enter_context(
+                    override_settings(SILO_MODE=SiloMode.REGION, SENTRY_REGION=region_name)
+                )
 
-        with org_creation_context:
             with outbox_context(flush=False):
-                org: Organization = Organization.objects.create(name=name, **kwargs)
+                org = Organization.objects.create(name=name, **kwargs)
 
             with assume_test_silo_mode(SiloMode.CONTROL):
                 # Organization mapping creation relies on having a matching org slug reservation
@@ -592,7 +593,7 @@ class Factories:
             ReleaseEnvironment.objects.create(
                 organization=project.organization, release=release, environment=environment
             )
-            for project in [project] + additional_projects:
+            for project in [project, *additional_projects]:
                 ReleaseProjectEnvironment.objects.create(
                     project=project,
                     release=release,
@@ -626,6 +627,7 @@ class Factories:
 
         return release
 
+    @staticmethod
     def create_group_release(project: Project, group: Group, release: Release) -> GroupRelease:
         return GroupRelease.objects.create(
             project_id=project.id,
@@ -662,13 +664,11 @@ class Factories:
     def create_artifact_bundle_zip(
         org=None, release=None, project=None, extra_files=None, fixture_path="artifact_bundle"
     ):
-        import zipfile
-
         bundle = io.BytesIO()
         bundle_dir = get_fixture_path(fixture_path)
-        with zipfile.ZipFile(bundle, "w", zipfile.ZIP_DEFLATED) as zipfile:
+        with zipfile.ZipFile(bundle, "w", zipfile.ZIP_DEFLATED) as zipf:
             for path, content in (extra_files or {}).items():
-                zipfile.writestr(path, content)
+                zipf.writestr(path, content)
             for path, _, files in os.walk(bundle_dir):
                 for filename in files:
                     fullpath = os.path.join(path, filename)
@@ -677,9 +677,9 @@ class Factories:
                         manifest = _patch_artifact_manifest(
                             fullpath, org, release, project, extra_files
                         )
-                        zipfile.writestr(relpath, manifest)
+                        zipf.writestr(relpath, manifest)
                     else:
-                        zipfile.write(fullpath, relpath)
+                        zipf.write(fullpath, relpath)
 
         return bundle.getvalue()
 
@@ -687,10 +687,10 @@ class Factories:
     @assume_test_silo_mode(SiloMode.REGION)
     def create_release_archive(cls, org, release: str, project=None, dist=None):
         bundle = cls.create_artifact_bundle_zip(org, release, project)
-        file_ = File.objects.create(name="release-artifacts.zip")
-        file_.putfile(ContentFile(bundle))
-        release = Release.objects.get(organization__slug=org, version=release)
-        return update_artifact_index(release, dist, file_)
+        file = File.objects.create(name="release-artifacts.zip")
+        file.putfile(ContentFile(bundle))
+        release_obj = Release.objects.get(organization__slug=org, version=release)
+        return update_artifact_index(release_obj, dist, file)
 
     @classmethod
     @assume_test_silo_mode(SiloMode.REGION)
@@ -854,7 +854,7 @@ class Factories:
         user: User,
         provider: str | None = None,
         uid: str | None = None,
-        extra_data: Mapping[str, Any] | None = None,
+        extra_data: dict[str, Any] | None = None,
     ):
         if not provider:
             provider = "asana"
@@ -885,7 +885,7 @@ class Factories:
                         type=group_type,
                         parent_span_ids=None,
                         cause_span_ids=None,
-                        offender_span_ids=None,
+                        offender_span_ids=[],
                         evidence_data={},
                         evidence_display=[],
                     )
@@ -1075,16 +1075,18 @@ class Factories:
     ) -> ApiToken:
         if internal_integration and install:
             raise ValueError("Only one of internal_integration or install arg can be provided")
-        if internal_integration is None and install is None:
+        elif internal_integration is None and install is None:
             raise ValueError("Must pass in either internal_integration or install arg")
 
-        if install is None:
+        if internal_integration is not None and install is None:
             # Fetch install from provided or created internal integration
             with assume_test_silo_mode(SiloMode.CONTROL):
                 install = SentryAppInstallation.objects.get(
                     sentry_app=internal_integration.id,
                     organization_id=internal_integration.owner_id,
                 )
+        elif install is None:
+            raise AssertionError("unreachable")
 
         return SentryAppInstallationTokenCreator(sentry_app_installation=install).run(
             user=user, request=request
@@ -1138,6 +1140,7 @@ class Factories:
             if not prevent_token_exchange and (
                 install.sentry_app.status != SentryAppStatus.INTERNAL
             ):
+                assert install.api_grant is not None
                 GrantExchanger.run(
                     install=rpc_install,
                     code=install.api_grant.code,
@@ -1678,6 +1681,7 @@ class Factories:
         organization_integration = integration.add_organization(
             organization_id=organization.id, user=user, default_auth_id=identity.id
         )
+        assert organization_integration is not None
         return integration, organization_integration, identity, identity_provider
 
     @staticmethod
@@ -1820,7 +1824,7 @@ class Factories:
         return UserOption.objects.create(*args, **kwargs)
 
     @staticmethod
-    def create_basic_auth_header(username: str, password: str = "") -> str:
+    def create_basic_auth_header(username: str, password: str = "") -> bytes:
         return b"Basic " + b64encode(f"{username}:{password}".encode())
 
     @staticmethod

+ 1 - 1
src/sentry/testutils/fixtures.py

@@ -560,7 +560,7 @@ class Fixtures:
     def create_organization_mapping(self, *args, **kwargs):
         return Factories.create_org_mapping(*args, **kwargs)
 
-    def create_basic_auth_header(self, *args, **kwargs):
+    def create_basic_auth_header(self, *args, **kwargs) -> bytes:
         return Factories.create_basic_auth_header(*args, **kwargs)
 
     def snooze_rule(self, *args, **kwargs):

+ 1 - 1
src/social_auth/models.py

@@ -30,7 +30,7 @@ class UserSocialAuth(models.Model):
     user = models.ForeignKey(AUTH_USER_MODEL, related_name="social_auth", on_delete=models.CASCADE)
     provider = models.CharField(max_length=32)
     uid = models.CharField(max_length=UID_LENGTH)
-    extra_data: models.Field[dict[str, Any], dict[str, Any]] = JSONField(default="{}")
+    extra_data: models.Field[dict[str, Any] | None, dict[str, Any]] = JSONField(default="{}")
 
     class Meta:
         """Meta data"""