social.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. from allauth.account.adapter import DefaultAccountAdapter
  2. from allauth.account import app_settings as allauth_settings
  3. from allauth.account.auth_backends import AuthenticationBackend
  4. from allauth.socialaccount.adapter import DefaultSocialAccountAdapter, get_adapter
  5. from allauth.socialaccount.helpers import complete_social_login
  6. from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
  7. from allauth.socialaccount.providers.gitlab.views import GitLabOAuth2Adapter
  8. from allauth.socialaccount.providers.microsoft.views import MicrosoftGraphOAuth2Adapter
  9. from allauth.socialaccount.providers.oauth2.client import OAuth2Client
  10. from dj_rest_auth.registration.serializers import (
  11. SocialLoginSerializer as BaseSocialLoginSerializer,
  12. )
  13. from dj_rest_auth.registration.views import SocialConnectView, SocialLoginView
  14. from django.conf import settings
  15. from django.contrib.auth import get_backends, get_user_model
  16. from django.http import HttpResponseBadRequest
  17. from django.utils.translation import gettext_lazy as _
  18. from django_rest_mfa.helpers import has_mfa
  19. from requests.exceptions import HTTPError
  20. from rest_framework import serializers
  21. from rest_framework.response import Response
  22. from users.utils import is_user_registration_open
  23. from .constants import SOCIAL_ADAPTER_MAP
  24. DOMAIN = settings.GLITCHTIP_URL.geturl()
  25. class MFAAccountAdapter(DefaultAccountAdapter):
  26. """
  27. If user requires MFA, do not actually log in
  28. """
  29. def login(self, request, user):
  30. """Extend to check for MFA status, backend hack is copied from super method"""
  31. if not hasattr(user, "backend"):
  32. backends = get_backends()
  33. backend = None
  34. for b in backends: # pylint: disable=invalid-name
  35. if isinstance(b, AuthenticationBackend):
  36. # prefer our own backend
  37. backend = b
  38. break
  39. elif not backend and hasattr(b, "get_user"):
  40. # Pick the first valid one
  41. backend = b
  42. backend_path = ".".join([backend.__module__, backend.__class__.__name__])
  43. user.backend = backend_path
  44. if has_mfa(request, user):
  45. user.mfa = True # Store for later, to avoid multiple DB checks
  46. else:
  47. super().login(request, user)
  48. def get_login_redirect_url(self, request):
  49. """Ignore login redirect when not logged in"""
  50. try:
  51. return super().get_login_redirect_url(request)
  52. except AssertionError:
  53. pass
  54. class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
  55. def is_open_for_signup(self, request, sociallogin):
  56. return is_user_registration_open()
  57. class SocialLoginSerializer(BaseSocialLoginSerializer):
  58. tags = serializers.CharField(
  59. allow_blank=True, required=False, allow_null=True, write_only=True
  60. )
  61. # Overriding to add check for django-allauth's is_open_for_signup() at end, to prevent
  62. # creation of new user on first-time social auth login
  63. # https://github.com/iMerica/dj-rest-auth/blob/master/dj_rest_auth/registration/serializers.py#L79
  64. def validate(self, attrs):
  65. view = self.context.get("view")
  66. request = self._get_request()
  67. if not view:
  68. raise serializers.ValidationError(
  69. _("View is not defined, pass it as a context variable"),
  70. )
  71. adapter_class = getattr(view, "adapter_class", None)
  72. if not adapter_class:
  73. raise serializers.ValidationError(_("Define adapter_class in view"))
  74. adapter = adapter_class(request)
  75. app = adapter.get_provider().get_app(request)
  76. access_token = attrs.get("access_token")
  77. code = attrs.get("code")
  78. if access_token:
  79. tokens_to_parse = {"access_token": access_token}
  80. token = access_token
  81. id_token = attrs.get("id_token")
  82. if id_token:
  83. tokens_to_parse["id_token"] = id_token
  84. elif code:
  85. self.set_callback_url(view=view, adapter_class=adapter_class)
  86. self.client_class = getattr(view, "client_class", None)
  87. if not self.client_class:
  88. raise serializers.ValidationError(
  89. _("Define client_class in view"),
  90. )
  91. provider = adapter.get_provider()
  92. scope = provider.get_scope(request)
  93. client = self.client_class(
  94. request,
  95. app.client_id,
  96. app.secret,
  97. adapter.access_token_method,
  98. adapter.access_token_url,
  99. self.callback_url,
  100. scope,
  101. scope_delimiter=adapter.scope_delimiter,
  102. headers=adapter.headers,
  103. basic_auth=adapter.basic_auth,
  104. )
  105. token = client.get_access_token(code)
  106. access_token = token["access_token"]
  107. tokens_to_parse = {"access_token": access_token}
  108. for key in ["refresh_token", "id_token", adapter.expires_in_key]:
  109. if key in token:
  110. tokens_to_parse[key] = token[key]
  111. else:
  112. raise serializers.ValidationError(
  113. _("Incorrect input. access_token or code is required."),
  114. )
  115. social_token = adapter.parse_token(tokens_to_parse)
  116. social_token.app = app
  117. try:
  118. login = self.get_social_login(adapter, app, social_token, token)
  119. ret = complete_social_login(request, login)
  120. except HTTPError:
  121. raise serializers.ValidationError(_("Incorrect value"))
  122. if isinstance(ret, HttpResponseBadRequest):
  123. raise serializers.ValidationError(ret.content)
  124. if not login.is_existing:
  125. if allauth_settings.UNIQUE_EMAIL:
  126. account_exists = (
  127. get_user_model()
  128. .objects.filter(
  129. email=login.user.email,
  130. )
  131. .exists()
  132. )
  133. if account_exists:
  134. raise serializers.ValidationError(
  135. _("User is already registered with this e-mail address."),
  136. )
  137. # Added check for open signup
  138. if not get_adapter(request).is_open_for_signup(request, login):
  139. raise serializers.ValidationError(_("User registration is closed."))
  140. else:
  141. login.lookup()
  142. login.save(request, connect=True)
  143. attrs["user"] = login.account.user
  144. return attrs
  145. class GenericMFAMixin:
  146. client_class = OAuth2Client # Needed for Github. Would this ever break a provider?
  147. @property
  148. def callback_url(self):
  149. provider_id = self.adapter_class.provider_id
  150. return DOMAIN + "/auth/" + provider_id
  151. @property
  152. def adapter_class(self):
  153. provider = self.kwargs.get("provider")
  154. return SOCIAL_ADAPTER_MAP[provider]
  155. class GlitchTipSocialConnectView(GenericMFAMixin, SocialConnectView):
  156. pass
  157. class MFASocialLoginView(GenericMFAMixin, SocialLoginView):
  158. serializer_class = SocialLoginSerializer
  159. def process_login(self):
  160. tags = self.serializer.validated_data.get("tags")
  161. if tags and self.user.analytics is None:
  162. self.user.set_register_analytics_tags(tags)
  163. self.user.save(update_fields=["analytics"])
  164. if not getattr(self.user, "mfa", False):
  165. super().process_login()
  166. def get_response(self):
  167. if getattr(self.user, "mfa", False):
  168. user_key_types = (
  169. self.user.userkey_set.all()
  170. .values_list("key_type", flat=True)
  171. .distinct()
  172. )
  173. return Response({"requires_mfa": True, "valid_auth": user_key_types})
  174. return super().get_response()