models.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from __future__ import annotations
  2. import re
  3. from typing import Any
  4. from django.apps import apps
  5. from django.conf import settings
  6. from django.db import models
  7. from sentry.db.models import control_silo_only_model
  8. from .fields import JSONField
  9. from .utils import setting
  10. AUTH_USER_MODEL = settings.AUTH_USER_MODEL
  11. UID_LENGTH = setting("SOCIAL_AUTH_UID_LENGTH", 255)
  12. NONCE_SERVER_URL_LENGTH = setting("SOCIAL_AUTH_NONCE_SERVER_URL_LENGTH", 255)
  13. ASSOCIATION_SERVER_URL_LENGTH = setting("SOCIAL_AUTH_ASSOCIATION_SERVER_URL_LENGTH", 255)
  14. ASSOCIATION_HANDLE_LENGTH = setting("SOCIAL_AUTH_ASSOCIATION_HANDLE_LENGTH", 255)
  15. CLEAN_USERNAME_REGEX = re.compile(r"[^\w.@+-_]+", re.UNICODE)
  16. @control_silo_only_model
  17. class UserSocialAuth(models.Model):
  18. """Social Auth association model"""
  19. id = models.AutoField(primary_key=True)
  20. user = models.ForeignKey(AUTH_USER_MODEL, related_name="social_auth", on_delete=models.CASCADE)
  21. provider = models.CharField(max_length=32)
  22. uid = models.CharField(max_length=UID_LENGTH)
  23. extra_data: models.Field[dict[str, Any], dict[str, Any]] = JSONField(default="{}")
  24. class Meta:
  25. """Meta data"""
  26. unique_together = ("provider", "uid", "user")
  27. app_label = "social_auth"
  28. def __str__(self):
  29. """Return associated user unicode representation"""
  30. return f"{self.user} - {self.provider.title()}"
  31. def get_backend(self):
  32. from .utils import get_backend
  33. return get_backend(instance=self)
  34. @property
  35. def tokens(self):
  36. from .utils import tokens
  37. return tokens(instance=self)
  38. def expiration_datetime(self):
  39. from .utils import expiration_datetime
  40. return expiration_datetime(instance=self)
  41. def revoke_token(self, drop_token=True):
  42. """Attempts to revoke permissions for provider."""
  43. if "access_token" in self.tokens:
  44. success = self.get_backend().revoke_token(self.tokens["access_token"], self.uid)
  45. if success and drop_token:
  46. self.extra_data.pop("access_token", None)
  47. self.save()
  48. def refresh_token(self):
  49. refresh_token = self.extra_data.get("refresh_token")
  50. if refresh_token:
  51. backend = self.get_backend()
  52. if hasattr(backend, "refresh_token"):
  53. response = backend.refresh_token(refresh_token, self.provider)
  54. new_access_token = response.get("access_token")
  55. # We have not got a new access token, so don't lose the
  56. # existing one.
  57. if not new_access_token:
  58. return
  59. self.extra_data["access_token"] = new_access_token
  60. # New refresh token might be given.
  61. new_refresh_token = response.get("refresh_token")
  62. if new_refresh_token:
  63. self.extra_data["refresh_token"] = new_refresh_token
  64. self.save()
  65. @classmethod
  66. def clean_username(cls, value):
  67. return CLEAN_USERNAME_REGEX.sub("", value)
  68. @classmethod
  69. def user_username(cls, user):
  70. if hasattr(user, "USERNAME_FIELD"):
  71. # Django 1.5 custom user model, 'username' is just for internal
  72. # use, doesn't imply that the model should have an username field
  73. field_name = user.USERNAME_FIELD
  74. else:
  75. field_name = "username"
  76. return getattr(user, field_name)
  77. @classmethod
  78. def username_field(cls, values):
  79. user_model = cls.user_model()
  80. if hasattr(user_model, "USERNAME_FIELD"):
  81. # Django 1.5 custom user model, 'username' is just for internal
  82. # use, doesn't imply that the model should have an username field
  83. values[user_model.USERNAME_FIELD] = values.pop("username")
  84. return values
  85. @classmethod
  86. def simple_user_exists(cls, *args, **kwargs):
  87. """
  88. Return True/False if a User instance exists with the given arguments.
  89. Arguments are directly passed to filter() manager method.
  90. TODO: consider how to ensure case-insensitive email matching
  91. """
  92. kwargs = cls.username_field(kwargs)
  93. return cls.user_model().objects.filter(*args, **kwargs).exists()
  94. @classmethod
  95. def create_user(cls, *args, **kwargs):
  96. kwargs = cls.username_field(kwargs)
  97. return cls.user_model().objects.create_user(*args, **kwargs)
  98. @classmethod
  99. def get_user(cls, pk):
  100. try:
  101. return cls.user_model().objects.get(pk=pk)
  102. except cls.user_model().DoesNotExist:
  103. return None
  104. @classmethod
  105. def get_user_by_email(cls, email):
  106. """Case insensitive search"""
  107. # Do case-insensitive match, since real-world email address is
  108. # case-insensitive.
  109. return cls.user_model().objects.get(email__iexact=email)
  110. @classmethod
  111. def resolve_user_or_id(cls, user_or_id):
  112. if isinstance(user_or_id, cls.user_model()):
  113. return user_or_id
  114. return cls.user_model().objects.get(pk=user_or_id)
  115. @classmethod
  116. def get_social_auth_for_user(cls, user):
  117. return user.social_auth.all()
  118. @classmethod
  119. def create_social_auth(cls, user, uid, provider):
  120. if not isinstance(uid, str):
  121. uid = str(uid)
  122. return cls.objects.create(user_id=user.id, uid=uid, provider=provider)
  123. @classmethod
  124. def get_social_auth(cls, provider, uid, user):
  125. try:
  126. instance = cls.objects.get(provider=provider, uid=uid, user_id=user.id)
  127. return instance
  128. except UserSocialAuth.DoesNotExist:
  129. return None
  130. @classmethod
  131. def username_max_length(cls):
  132. return cls._field_length("USERNAME_FIELD", "username")
  133. @classmethod
  134. def email_max_length(cls):
  135. return cls._field_length("EMAIL_FIELD", "email")
  136. @classmethod
  137. def _field_length(self, setting_name, default_name):
  138. model = UserSocialAuth.user_model()
  139. field_name = getattr(model, setting_name, default_name)
  140. return model._meta.get_field(field_name).max_length
  141. @classmethod
  142. def user_model(cls):
  143. db, name = AUTH_USER_MODEL.split(".")
  144. return apps.get_model(db, name)