social.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from allauth.account.adapter import DefaultAccountAdapter
  2. from allauth.account.auth_backends import AuthenticationBackend
  3. from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
  4. from allauth.socialaccount.providers.gitlab.views import GitLabOAuth2Adapter
  5. from allauth.socialaccount.providers.microsoft.views import MicrosoftGraphOAuth2Adapter
  6. from allauth.socialaccount.providers.oauth2.client import OAuth2Client
  7. from dj_rest_auth.registration.serializers import (
  8. SocialLoginSerializer as BaseSocialLoginSerializer,
  9. )
  10. from dj_rest_auth.registration.views import SocialConnectView, SocialLoginView
  11. from django.conf import settings
  12. from django.contrib.auth import get_backends
  13. from django_rest_mfa.helpers import has_mfa
  14. from rest_framework import serializers
  15. from rest_framework.response import Response
  16. from .constants import SOCIAL_ADAPTER_MAP
  17. DOMAIN = settings.GLITCHTIP_URL.geturl()
  18. class MFAAccountAdapter(DefaultAccountAdapter):
  19. """
  20. If user requires MFA, do not actually log in
  21. """
  22. def login(self, request, user):
  23. """Extend to check for MFA status, backend hack is copied from super method"""
  24. if not hasattr(user, "backend"):
  25. backends = get_backends()
  26. backend = None
  27. for b in backends: # pylint: disable=invalid-name
  28. if isinstance(b, AuthenticationBackend):
  29. # prefer our own backend
  30. backend = b
  31. break
  32. elif not backend and hasattr(b, "get_user"):
  33. # Pick the first valid one
  34. backend = b
  35. backend_path = ".".join([backend.__module__, backend.__class__.__name__])
  36. user.backend = backend_path
  37. if has_mfa(request, user):
  38. user.mfa = True # Store for later, to avoid multiple DB checks
  39. else:
  40. super().login(request, user)
  41. def get_login_redirect_url(self, request):
  42. """Ignore login redirect when not logged in"""
  43. try:
  44. return super().get_login_redirect_url(request)
  45. except AssertionError:
  46. pass
  47. class SocialLoginSerializer(BaseSocialLoginSerializer):
  48. tags = serializers.CharField(
  49. allow_blank=True, required=False, allow_null=True, write_only=True
  50. )
  51. class GenericMFAMixin:
  52. client_class = OAuth2Client # Needed for Github. Would this ever break a provider?
  53. @property
  54. def callback_url(self):
  55. provider_id = self.adapter_class.provider_id
  56. return DOMAIN + "/auth/" + provider_id
  57. @property
  58. def adapter_class(self):
  59. provider = self.kwargs.get("provider")
  60. return SOCIAL_ADAPTER_MAP[provider]
  61. class GlitchTipSocialConnectView(GenericMFAMixin, SocialConnectView):
  62. pass
  63. class MFASocialLoginView(GenericMFAMixin, SocialLoginView):
  64. serializer_class = SocialLoginSerializer
  65. def process_login(self):
  66. tags = self.serializer.validated_data.get("tags")
  67. if tags and self.user.analytics is None:
  68. self.user.set_register_analytics_tags(tags)
  69. self.user.save(update_fields=["analytics"])
  70. if not getattr(self.user, "mfa", False):
  71. super().process_login()
  72. def get_response(self):
  73. if getattr(self.user, "mfa", False):
  74. user_key_types = (
  75. self.user.userkey_set.all()
  76. .values_list("key_type", flat=True)
  77. .distinct()
  78. )
  79. return Response({"requires_mfa": True, "valid_auth": user_key_types})
  80. return super().get_response()