social.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from django.conf import settings
  2. from django.contrib.auth import get_backends
  3. from rest_framework.response import Response
  4. from dj_rest_auth.registration.views import SocialConnectView, SocialLoginView
  5. from django_rest_mfa.helpers import has_mfa
  6. from allauth.account.adapter import DefaultAccountAdapter
  7. from allauth.account.auth_backends import AuthenticationBackend
  8. from allauth.socialaccount.providers.gitlab.views import GitLabOAuth2Adapter
  9. from allauth.socialaccount.providers.google.views import GoogleOAuth2Adapter
  10. from allauth.socialaccount.providers.microsoft.views import MicrosoftGraphOAuth2Adapter
  11. from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
  12. from allauth.socialaccount.providers.oauth2.client import OAuth2Client
  13. DOMAIN = settings.GLITCHTIP_URL.geturl()
  14. class MFAAccountAdapter(DefaultAccountAdapter):
  15. """
  16. If user requires MFA, do not actually log in
  17. """
  18. def login(self, request, user):
  19. """Extend to check for MFA status, backend hack is copied from super method"""
  20. if not hasattr(user, "backend"):
  21. backends = get_backends()
  22. backend = None
  23. for b in backends:
  24. if isinstance(b, AuthenticationBackend):
  25. # prefer our own backend
  26. backend = b
  27. break
  28. elif not backend and hasattr(b, "get_user"):
  29. # Pick the first valid one
  30. backend = b
  31. backend_path = ".".join([backend.__module__, backend.__class__.__name__])
  32. user.backend = backend_path
  33. if has_mfa(request, user):
  34. user.mfa = True # Store for later, to avoid multiple DB checks
  35. else:
  36. super().login(request, user)
  37. def get_login_redirect_url(self, request):
  38. """Ignore login redirect when not logged in"""
  39. try:
  40. return super().get_login_redirect_url(request)
  41. except AssertionError:
  42. pass
  43. class MFASocialLoginView(SocialLoginView):
  44. def process_login(self):
  45. if not getattr(self.user, "mfa", False):
  46. super().process_login()
  47. def get_response(self):
  48. if getattr(self.user, "mfa", False):
  49. return Response({"requires_mfa": True})
  50. return super().get_response()
  51. class GitlabConnect(SocialConnectView):
  52. adapter_class = GitLabOAuth2Adapter
  53. class GitlabLogin(MFASocialLoginView):
  54. adapter_class = GitLabOAuth2Adapter
  55. class GithubConnect(SocialConnectView):
  56. adapter_class = GitHubOAuth2Adapter
  57. client_class = OAuth2Client
  58. callback_url = DOMAIN + "/auth/github"
  59. class GithubLogin(MFASocialLoginView):
  60. adapter_class = GitHubOAuth2Adapter
  61. client_class = OAuth2Client
  62. callback_url = DOMAIN + "/auth/github"
  63. class GoogleConnect(SocialConnectView):
  64. adapter_class = GoogleOAuth2Adapter
  65. class GoogleLogin(MFASocialLoginView):
  66. adapter_class = GoogleOAuth2Adapter
  67. class MicrosoftConnect(SocialConnectView):
  68. adapter_class = MicrosoftGraphOAuth2Adapter
  69. class MicrosoftLogin(MFASocialLoginView):
  70. adapter_class = MicrosoftGraphOAuth2Adapter