Browse Source

Port dj rest auth things to work with allauth

David Burke 9 months ago
parent
commit
2f4a7c6359

+ 10 - 7
apps/users/api.py

@@ -23,24 +23,27 @@ GET /organizations/burke-software/users/ (Not implemented)
 """
 """
 
 
 
 
-def get_user_queryset(user_id: int):
-    return User.objects.filter(id=user_id)
+def get_user_queryset(user_id: int, add_details=False):
+    qs = User.objects.filter(id=user_id)
+    if add_details:
+        qs = qs.prefetch_related("socialaccount_set")
+    return qs
 
 
 
 
-@router.get("/users/", response=list[UserSchema])
+@router.get("/users/", response=list[UserSchema], by_alias=True)
 @paginate
 @paginate
 async def list_users(request: AuthHttpRequest, response: HttpResponse):
 async def list_users(request: AuthHttpRequest, response: HttpResponse):
     """
     """
     Exists in Sentry OSS, unsure what the use case is
     Exists in Sentry OSS, unsure what the use case is
     We make it only list the current user
     We make it only list the current user
     """
     """
-    return get_user_queryset(user_id=request.auth.user_id)
+    return get_user_queryset(user_id=request.auth.user_id, add_details=True)
 
 
 
 
-@router.get("/users/{slug:user_id}/", response=UserSchema)
+@router.get("/users/{slug:user_id}/", response=UserSchema, by_alias=True)
 async def get_user(request: AuthHttpRequest, user_id: MeID):
 async def get_user(request: AuthHttpRequest, user_id: MeID):
     user_id = request.auth.user_id
     user_id = request.auth.user_id
-    return await aget_object_or_404(get_user_queryset(user_id))
+    return await aget_object_or_404(get_user_queryset(user_id, add_details=True))
 
 
 
 
 @router.delete("/users/{slug:user_id}/", response={204: None})
 @router.delete("/users/{slug:user_id}/", response={204: None})
@@ -72,7 +75,7 @@ async def update_user(request: AuthHttpRequest, user_id: MeID, payload: UserIn):
     if user_id != request.auth.user_id and user_id != "me":
     if user_id != request.auth.user_id and user_id != "me":
         raise Http404
         raise Http404
     user_id = request.auth.user_id
     user_id = request.auth.user_id
-    user = await aget_object_or_404(get_user_queryset(user_id))
+    user = await aget_object_or_404(get_user_queryset(user_id, add_details=True))
 
 
     for attr, value in payload.dict().items():
     for attr, value in payload.dict().items():
         setattr(user, attr, value)
         setattr(user, attr, value)

+ 31 - 0
apps/users/migrations/0011_alter_user_email.py

@@ -0,0 +1,31 @@
+# Generated by Django 5.0.6 on 2024-05-31 17:11
+
+from django.contrib.postgres.operations import CreateCollation
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+    dependencies = [
+        ("users", "0010_allauth_oidc_from_env_var"),
+    ]
+
+    operations = [
+        migrations.AlterField(
+            model_name="user",
+            name="email",
+            field=models.EmailField(max_length=254),
+        ),
+        CreateCollation(
+            "case_insensitive",
+            provider="icu",
+            locale="und-u-ks-level2",
+            deterministic=False,
+        ),
+        migrations.AlterField(
+            model_name="user",
+            name="email",
+            field=models.EmailField(
+                db_collation="case_insensitive", max_length=254, unique=True
+            ),
+        ),
+    ]

+ 1 - 1
apps/users/models.py

@@ -67,7 +67,7 @@ class UserManager(BaseUserManager):
 
 
 
 
 class User(AbstractBaseUser, PermissionsMixin):
 class User(AbstractBaseUser, PermissionsMixin):
-    email = models.EmailField(unique=True)
+    email = models.EmailField(unique=True, db_collation="case_insensitive")
     name = models.CharField(_("name"), max_length=255, blank=True)
     name = models.CharField(_("name"), max_length=255, blank=True)
     is_staff = models.BooleanField(
     is_staff = models.BooleanField(
         _("staff status"),
         _("staff status"),

+ 37 - 9
apps/users/schema.py

@@ -1,36 +1,64 @@
-from ninja import ModelSchema
+from datetime import datetime
+from typing import Optional
+
+from allauth.socialaccount.models import SocialAccount
+from ninja import Field, ModelSchema
 
 
 from glitchtip.schema import CamelSchema
 from glitchtip.schema import CamelSchema
 
 
 from .models import User
 from .models import User
 
 
 
 
+class SocialAccountSchema(CamelSchema, ModelSchema):
+    email: Optional[str]
+    username: Optional[str]
+
+    class Meta:
+        model = SocialAccount
+        fields = (
+            "id",
+            "provider",
+            "uid",
+            "last_login",
+            "date_joined",
+        )
+
+    @staticmethod
+    def resolve_email(obj):
+        if obj.extra_data:
+            if "email" in obj.extra_data:
+                return obj.extra_data.get("email")
+            return obj.extra_data.get("userPrincipalName")  # MS oauth uses this
+
+    @staticmethod
+    def resolve_username(obj):
+        if obj.extra_data:
+            return obj.extra_data.get("username")
+
+
 class UserIn(CamelSchema, ModelSchema):
 class UserIn(CamelSchema, ModelSchema):
     class Meta:
     class Meta:
         model = User
         model = User
         fields = [
         fields = [
-            # "username",
-            # "emails",
-            # "identities",
             "name",
             "name",
-            # "email",
             "options",
             "options",
         ]
         ]
 
 
 
 
 class UserSchema(CamelSchema, ModelSchema):
 class UserSchema(CamelSchema, ModelSchema):
+    username: str = Field(validation_alias="email")
+    created: datetime = Field(serialization_alias="dateJoined")
+    has_password_auth: bool = Field(validation_alias="has_usable_password")
+    identities: list[SocialAccountSchema] = Field(validation_alias="socialaccount_set")
+
     class Meta(UserIn.Meta):
     class Meta(UserIn.Meta):
         fields = [
         fields = [
-            # "username",
             "last_login",
             "last_login",
             "is_superuser",
             "is_superuser",
             # "emails",
             # "emails",
-            # "identities",
             "id",
             "id",
             "is_active",
             "is_active",
             "name",
             "name",
-            # "dateJoined",
-            # "hasPasswordAuth",
             "email",
             "email",
             "options",
             "options",
         ]
         ]

+ 6 - 2
apps/users/tests/test_api.py

@@ -159,14 +159,18 @@ class UsersTestCase(GlitchTipTestCase):
 
 
     def test_emails_create_dupe_email(self):
     def test_emails_create_dupe_email(self):
         url = reverse("user-emails-list", args=["me"])
         url = reverse("user-emails-list", args=["me"])
-        email_address = baker.make("account.EmailAddress", user=self.user)
+        email_address = baker.make(
+            "account.EmailAddress",
+            user=self.user,
+            email="something@example.com",
+        )
         data = {"email": email_address.email}
         data = {"email": email_address.email}
         res = self.client.post(url, data)
         res = self.client.post(url, data)
         self.assertContains(res, "this account", status_code=400)
         self.assertContains(res, "this account", status_code=400)
 
 
     def test_emails_create_dupe_email_other_user(self):
     def test_emails_create_dupe_email_other_user(self):
         url = reverse("user-emails-list", args=["me"])
         url = reverse("user-emails-list", args=["me"])
-        email_address = baker.make("account.EmailAddress")
+        email_address = baker.make("account.EmailAddress", email="a@example.com")
         data = {"email": email_address.email}
         data = {"email": email_address.email}
         res = self.client.post(url, data)
         res = self.client.post(url, data)
         self.assertContains(res, "another account", status_code=400)
         self.assertContains(res, "another account", status_code=400)

+ 8 - 7
glitchtip/social.py

@@ -153,6 +153,14 @@ class SocialLoginSerializer(BaseSocialLoginSerializer):
                 )
                 )
             else:
             else:
                 login = self.get_social_login(adapter, app, social_token, token)
                 login = self.get_social_login(adapter, app, social_token, token)
+                # In allauth 0.53, we need this here instead
+                account_exists = (
+                    get_user_model()
+                    .objects.filter(
+                        email=login.user.email,
+                    )
+                    .exists()
+                )
             ret = complete_social_login(request, login)
             ret = complete_social_login(request, login)
         except HTTPError:
         except HTTPError:
             raise serializers.ValidationError(_("Incorrect value"))
             raise serializers.ValidationError(_("Incorrect value"))
@@ -162,13 +170,6 @@ class SocialLoginSerializer(BaseSocialLoginSerializer):
 
 
         if not login.is_existing:
         if not login.is_existing:
             if allauth_account_settings.UNIQUE_EMAIL:
             if allauth_account_settings.UNIQUE_EMAIL:
-                account_exists = (
-                    get_user_model()
-                    .objects.filter(
-                        email=login.user.email,
-                    )
-                    .exists()
-                )
                 if account_exists:
                 if account_exists:
                     raise serializers.ValidationError(
                     raise serializers.ValidationError(
                         _("User is already registered with this e-mail address."),
                         _("User is already registered with this e-mail address."),

+ 6 - 6
poetry.lock

@@ -841,20 +841,20 @@ dev = ["attribution (==1.6.2)", "black (==23.3.0)", "flit (==3.8.0)", "mypy (==1
 
 
 [[package]]
 [[package]]
 name = "dj-rest-auth"
 name = "dj-rest-auth"
-version = "5.0.2"
+version = "6.0.0"
 description = "Authentication and Registration in Django Rest Framework"
 description = "Authentication and Registration in Django Rest Framework"
 optional = false
 optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
 files = [
 files = [
-    {file = "dj-rest-auth-5.0.2.tar.gz", hash = "sha256:aad7d912476169e9991547bf98645344d3939be2d7052098048d819524c115d9"},
+    {file = "dj-rest-auth-6.0.0.tar.gz", hash = "sha256:760b45f3a07cd6182e6a20fe07d0c55230c5f950167df724d7914d0dd8c50133"},
 ]
 ]
 
 
 [package.dependencies]
 [package.dependencies]
-Django = ">=3.2"
+Django = ">=3.2,<6.0"
 djangorestframework = ">=3.13.0"
 djangorestframework = ">=3.13.0"
 
 
 [package.extras]
 [package.extras]
-with-social = ["django-allauth (>=0.56.0,<0.58.0)"]
+with-social = ["django-allauth (>=0.56.0,<0.62.0)"]
 
 
 [[package]]
 [[package]]
 name = "dj-stripe"
 name = "dj-stripe"
@@ -4277,4 +4277,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
 [metadata]
 [metadata]
 lock-version = "2.0"
 lock-version = "2.0"
 python-versions = "^3.10"
 python-versions = "^3.10"
-content-hash = "e139021c47e2dc32959675513d72261641de0f63dcc11b1cc545583f8ed29900"
+content-hash = "c353b30172c059ce7b8f04f8fdb8482def04eeda77b27bac9c9f0ab9bb027970"

+ 1 - 1
pyproject.toml

@@ -24,7 +24,7 @@ celery = {version = "~5.3.0", extras = ["redis"]}
 django-csp = "^3.6"
 django-csp = "^3.6"
 dj-stripe = "~2.8.0"
 dj-stripe = "~2.8.0"
 django-anymail = "^10.2"
 django-anymail = "^10.2"
-dj-rest-auth = "~5.0.0"
+dj-rest-auth = "~6.0.0"
 user-agents = "^2.1"
 user-agents = "^2.1"
 django-ipware = "^7.0.0"
 django-ipware = "^7.0.0"
 anonymizeip = "^1.0.0"
 anonymizeip = "^1.0.0"