social.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from allauth.account import app_settings as allauth_account_settings
  2. from allauth.account.adapter import DefaultAccountAdapter
  3. from allauth.account.auth_backends import AuthenticationBackend
  4. from allauth.socialaccount.adapter import DefaultSocialAccountAdapter, get_adapter
  5. from allauth.socialaccount.helpers import complete_social_login
  6. from allauth.socialaccount.providers.oauth2.client import OAuth2Client, OAuth2Error
  7. from allauth.socialaccount.providers.openid_connect.views import OpenIDConnectAdapter
  8. from dj_rest_auth.registration.serializers import (
  9. SocialLoginSerializer as BaseSocialLoginSerializer,
  10. )
  11. from dj_rest_auth.registration.views import SocialConnectView, SocialLoginView
  12. from django.conf import settings
  13. from django.contrib.auth import get_backends, get_user_model
  14. from django.http import HttpResponseBadRequest
  15. from django.utils.translation import gettext_lazy as _
  16. from django_rest_mfa.helpers import has_mfa
  17. from requests.exceptions import HTTPError
  18. from rest_framework import serializers
  19. from rest_framework.response import Response
  20. from apps.users.utils import is_user_registration_open
  21. from .constants import SOCIAL_ADAPTER_MAP
  22. DOMAIN = settings.GLITCHTIP_URL.geturl()
  23. class MFAAccountAdapter(DefaultAccountAdapter):
  24. """
  25. If user requires MFA, do not actually log in
  26. """
  27. def login(self, request, user):
  28. """Extend to check for MFA status, backend hack is copied from super method"""
  29. if not hasattr(user, "backend"):
  30. backends = get_backends()
  31. backend = None
  32. for b in backends: # pylint: disable=invalid-name
  33. if isinstance(b, AuthenticationBackend):
  34. # prefer our own backend
  35. backend = b
  36. break
  37. elif not backend and hasattr(b, "get_user"):
  38. # Pick the first valid one
  39. backend = b
  40. backend_path = ".".join([backend.__module__, backend.__class__.__name__])
  41. user.backend = backend_path
  42. if has_mfa(request, user):
  43. user.mfa = True # Store for later, to avoid multiple DB checks
  44. else:
  45. super().login(request, user)
  46. def get_login_redirect_url(self, request):
  47. """Ignore login redirect when not logged in"""
  48. try:
  49. return super().get_login_redirect_url(request)
  50. except AssertionError:
  51. pass
  52. class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
  53. def is_open_for_signup(self, request, sociallogin):
  54. return is_user_registration_open()
  55. class SocialLoginSerializer(BaseSocialLoginSerializer):
  56. tags = serializers.CharField(
  57. allow_blank=True, required=False, allow_null=True, write_only=True
  58. )
  59. # Overriding to add check for django-allauth's is_open_for_signup() at end, to prevent
  60. # creation of new user on first-time social auth login
  61. # https://github.com/iMerica/dj-rest-auth/blob/master/dj_rest_auth/registration/serializers.py#L79
  62. def validate(self, attrs):
  63. view = self.context.get("view")
  64. request = self._get_request()
  65. if not view:
  66. raise serializers.ValidationError(
  67. _("View is not defined, pass it as a context variable"),
  68. )
  69. adapter_class = getattr(view, "adapter_class", None)
  70. if not adapter_class:
  71. raise serializers.ValidationError(_("Define adapter_class in view"))
  72. # The OIDC provider has a dynamic provider id. Fetch it from the request.
  73. if adapter_class == OpenIDConnectAdapter:
  74. provider = request.resolver_match.captured_kwargs.get("provider")
  75. adapter = adapter_class(request, provider)
  76. else:
  77. adapter = adapter_class(request)
  78. app = adapter.get_provider().app
  79. access_token = attrs.get("access_token")
  80. code = attrs.get("code")
  81. if access_token:
  82. tokens_to_parse = {"access_token": access_token}
  83. token = access_token
  84. id_token = attrs.get("id_token")
  85. if id_token:
  86. tokens_to_parse["id_token"] = id_token
  87. elif code:
  88. self.set_callback_url(view=view, adapter_class=adapter_class)
  89. self.client_class = getattr(view, "client_class", None)
  90. if not self.client_class:
  91. raise serializers.ValidationError(
  92. _("Define client_class in view"),
  93. )
  94. provider = adapter.get_provider()
  95. scope = provider.get_scope(request)
  96. client = self.client_class(
  97. request,
  98. app.client_id,
  99. app.secret,
  100. adapter.access_token_method,
  101. adapter.access_token_url,
  102. self.callback_url,
  103. scope,
  104. scope_delimiter=adapter.scope_delimiter,
  105. headers=adapter.headers,
  106. basic_auth=adapter.basic_auth,
  107. )
  108. try:
  109. token = client.get_access_token(code)
  110. except OAuth2Error as ex:
  111. raise serializers.ValidationError(
  112. _("Failed to exchange code for access token")
  113. ) from ex
  114. access_token = token["access_token"]
  115. tokens_to_parse = {"access_token": access_token}
  116. for key in ["refresh_token", "id_token", adapter.expires_in_key]:
  117. if key in token:
  118. tokens_to_parse[key] = token[key]
  119. else:
  120. raise serializers.ValidationError(
  121. _("Incorrect input. access_token or code is required."),
  122. )
  123. social_token = adapter.parse_token(tokens_to_parse)
  124. social_token.app = app
  125. try:
  126. if adapter.provider_id == "google" and not code:
  127. login = self.get_social_login(
  128. adapter, app, social_token, response={"id_token": id_token}
  129. )
  130. else:
  131. login = self.get_social_login(adapter, app, social_token, token)
  132. ret = complete_social_login(request, login)
  133. except HTTPError:
  134. raise serializers.ValidationError(_("Incorrect value"))
  135. if isinstance(ret, HttpResponseBadRequest):
  136. raise serializers.ValidationError(ret.content)
  137. if not login.is_existing:
  138. if allauth_account_settings.UNIQUE_EMAIL:
  139. account_exists = (
  140. get_user_model()
  141. .objects.filter(
  142. email=login.user.email,
  143. )
  144. .exists()
  145. )
  146. if account_exists:
  147. raise serializers.ValidationError(
  148. _("User is already registered with this e-mail address."),
  149. )
  150. # Added check for open signup
  151. if not get_adapter(request).is_open_for_signup(request, login):
  152. raise serializers.ValidationError(_("User registration is closed."))
  153. else:
  154. login.lookup()
  155. login.save(request, connect=True)
  156. self.post_signup(login, attrs)
  157. attrs["user"] = login.account.user
  158. return attrs
  159. class GenericMFAMixin:
  160. client_class = OAuth2Client # Only OAuth2 client is supported
  161. @property
  162. def callback_url(self):
  163. # Set dynamic OIDC provider ID
  164. provider_id = self.kwargs.get("provider", self.adapter_class.provider_id)
  165. return DOMAIN + "/auth/" + provider_id
  166. @property
  167. def adapter_class(self):
  168. provider = self.kwargs.get("provider")
  169. adapter_class = SOCIAL_ADAPTER_MAP.get(
  170. provider, SOCIAL_ADAPTER_MAP["openid_connect"]
  171. )
  172. # Set dynamic OIDC provider ID
  173. adapter_class.provider_id = provider
  174. return adapter_class
  175. class GlitchTipSocialConnectView(GenericMFAMixin, SocialConnectView):
  176. pass
  177. class MFASocialLoginView(GenericMFAMixin, SocialLoginView):
  178. serializer_class = SocialLoginSerializer
  179. def process_login(self):
  180. tags = self.serializer.validated_data.get("tags")
  181. if tags and self.user.analytics is None:
  182. self.user.set_register_analytics_tags(tags)
  183. self.user.save(update_fields=["analytics"])
  184. if not getattr(self.user, "mfa", False):
  185. super().process_login()
  186. def get_response(self):
  187. if getattr(self.user, "mfa", False):
  188. user_key_types = (
  189. self.user.userkey_set.all()
  190. .values_list("key_type", flat=True)
  191. .distinct()
  192. )
  193. return Response({"requires_mfa": True, "valid_auth": user_key_types})
  194. return super().get_response()