models.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import logging
  2. from django.db import models
  3. from django.db.models.expressions import OuterRef, Subquery
  4. from apps.organizations_ext.models import Organization
  5. from .client import list_prices, list_products, list_subscriptions
  6. from .utils import unix_to_datetime
  7. logger = logging.getLogger(__name__)
  8. class StripeModel(models.Model):
  9. stripe_id = models.CharField(primary_key=True, max_length=30)
  10. class Meta:
  11. abstract = True
  12. class StripeProduct(StripeModel):
  13. name = models.CharField()
  14. description = models.TextField()
  15. default_price = models.ForeignKey(
  16. "StripePrice", on_delete=models.CASCADE, blank=True, null=True
  17. )
  18. events = models.PositiveBigIntegerField()
  19. is_public = models.BooleanField()
  20. def __str__(self):
  21. return f"{self.name} {self.stripe_id}"
  22. @classmethod
  23. async def sync_from_stripe(cls):
  24. stripe_ids = set()
  25. async for products_page in list_products():
  26. logger.info(f"Found {len(products_page)} products in Stripe")
  27. products_page = [
  28. product for product in products_page if "events" in product.metadata
  29. ]
  30. products = [
  31. StripeProduct(
  32. stripe_id=product.id,
  33. name=product.name,
  34. description=product.description if product.description else "",
  35. events=product.metadata["events"],
  36. is_public=product.metadata.get("is_public") == "true",
  37. )
  38. for product in products_page
  39. ]
  40. prices = [
  41. StripePrice(
  42. stripe_id=product.default_price.id,
  43. price=product.default_price.unit_amount / 100,
  44. nickname=product.default_price.nickname or "",
  45. product_id=product.id,
  46. )
  47. for product in products_page
  48. if product.default_price
  49. and product.default_price.unit_amount is not None
  50. ]
  51. product_updated = await StripeProduct.objects.abulk_create(
  52. products,
  53. update_conflicts=True,
  54. update_fields=["name", "description", "events", "is_public"],
  55. unique_fields=["stripe_id"],
  56. )
  57. logger.info(f"Created/updated {len(product_updated)} products in Django")
  58. price_updated = await StripePrice.objects.abulk_create(
  59. prices,
  60. update_conflicts=True,
  61. update_fields=["price", "nickname", "product_id"],
  62. unique_fields=["stripe_id"],
  63. )
  64. logger.info(f"Created/updated {len(price_updated)} prices in Django")
  65. for product in product_updated:
  66. for price in price_updated:
  67. if (
  68. price.product_id == product.stripe_id
  69. and product.default_price_id != price.stripe_id
  70. ):
  71. product.default_price_id = price.stripe_id
  72. await product.asave(update_fields=["default_price_id"])
  73. for obj in product_updated:
  74. stripe_ids.add(obj.stripe_id)
  75. result = await StripeProduct.objects.exclude(stripe_id__in=stripe_ids).adelete()
  76. if result[0]:
  77. logger.info(f"Deleted {result[0]} products in Django")
  78. class StripePrice(StripeModel):
  79. price = models.DecimalField(max_digits=10, decimal_places=2)
  80. nickname = models.CharField(max_length=255)
  81. product = models.ForeignKey(StripeProduct, on_delete=models.CASCADE)
  82. def __str__(self):
  83. return f"{self.nickname} {self.price} {self.stripe_id}"
  84. @classmethod
  85. async def sync_from_stripe(cls):
  86. async for prices_page in list_prices():
  87. product_ids = {price.product for price in prices_page}
  88. products = StripeProduct.objects.filter(stripe_id__in=product_ids)
  89. known_product_ids = set()
  90. async for product in products:
  91. known_product_ids.add(product.stripe_id)
  92. prices = [
  93. StripePrice(
  94. stripe_id=price.id,
  95. price=price.unit_amount / 100,
  96. nickname=price.nickname or "",
  97. product_id=price.product,
  98. )
  99. for price in prices_page
  100. if price.unit_amount is not None and price.product in known_product_ids
  101. ]
  102. await StripePrice.objects.abulk_create(
  103. prices,
  104. update_conflicts=True,
  105. update_fields=["price", "nickname", "product_id"],
  106. unique_fields=["stripe_id"],
  107. )
  108. class StripeSubscription(StripeModel):
  109. is_active = models.BooleanField()
  110. created = models.DateTimeField()
  111. current_period_start = models.DateTimeField()
  112. current_period_end = models.DateTimeField()
  113. product = models.ForeignKey(StripeProduct, on_delete=models.CASCADE)
  114. organization = models.ForeignKey(
  115. "organizations_ext.Organization", on_delete=models.SET_NULL, null=True
  116. )
  117. def __str__(self):
  118. return f"{self.stripe_id}"
  119. @classmethod
  120. async def get_primary_subscription(cls, organization: Organization):
  121. return (
  122. await cls.objects.filter(organization=organization, is_active=True)
  123. .order_by("-product__events", "-created")
  124. .afirst()
  125. )
  126. @classmethod
  127. async def set_primary_subscriptions_for_organizations(
  128. cls, organization_ids: set[int]
  129. ):
  130. # This subquery finds the primary subscription ID for each organization.
  131. primary_subscription_subquery = (
  132. cls.objects.filter(organization_id=OuterRef("pk"), is_active=True)
  133. .order_by("-product__events", "-created")
  134. .values("pk")[:1]
  135. )
  136. org_updates = []
  137. async for org in Organization.objects.filter(id__in=organization_ids).annotate(
  138. primary_subscription_id=Subquery(primary_subscription_subquery)
  139. ):
  140. if (
  141. org.primary_subscription_id
  142. and org.primary_subscription_id != org.stripe_primary_subscription_id
  143. ):
  144. org.stripe_primary_subscription_id = org.primary_subscription_id
  145. org_updates.append(org)
  146. if org_updates:
  147. await Organization.objects.abulk_update(
  148. org_updates, ["stripe_primary_subscription_id"]
  149. )
  150. @classmethod
  151. async def sync_from_stripe(cls):
  152. organization_ids = set()
  153. active_organization_ids = set()
  154. async for subscriptions in list_subscriptions():
  155. logger.info(f"Found {len(subscriptions)} subcriptions in Stripe")
  156. subscription_objects = []
  157. for subscription in subscriptions:
  158. org_metadata = subscription.customer.metadata
  159. try:
  160. organization_id = int(
  161. org_metadata.get(
  162. "organization_id", org_metadata.get("djstripe_subscriber")
  163. )
  164. )
  165. except (ValueError, KeyError):
  166. continue # Skip if no organization ID in metadata
  167. items = subscription.items.data
  168. if not items or not items[0].get("price", {}).get("product"):
  169. continue # Skip
  170. product_id = items[0]["price"]["product"]
  171. # If unseen organization id, check if it exists
  172. if organization_id not in organization_ids:
  173. organization_ids.add(organization_id)
  174. organization = await Organization.objects.filter(
  175. id=organization_id
  176. ).afirst()
  177. if organization:
  178. active_organization_ids.add(organization_id)
  179. if not organization.stripe_customer_id:
  180. organization.stripe_customer_id = subscription.customer.id
  181. await organization.asave(
  182. update_fields=["stripe_customer_id"]
  183. )
  184. # Only save subscriptions with organizations that exist
  185. if organization_id in active_organization_ids:
  186. subscription_objects.append(
  187. StripeSubscription(
  188. stripe_id=subscription.id,
  189. created=unix_to_datetime(subscription.created),
  190. current_period_start=unix_to_datetime(
  191. subscription.current_period_start
  192. ),
  193. current_period_end=unix_to_datetime(
  194. subscription.current_period_end
  195. ),
  196. product_id=product_id,
  197. organization_id=organization_id,
  198. is_active=subscription.status == "active",
  199. )
  200. )
  201. stripe_subscriptions = await StripeSubscription.objects.abulk_create(
  202. subscription_objects,
  203. update_conflicts=True,
  204. update_fields=[
  205. "created",
  206. "current_period_start",
  207. "current_period_end",
  208. "product_id",
  209. "organization_id",
  210. "is_active",
  211. ],
  212. unique_fields=["stripe_id"],
  213. )
  214. logger.info(
  215. f"Created/updated {len(stripe_subscriptions)} subscriptions in Django"
  216. )
  217. await cls.set_primary_subscriptions_for_organizations(active_organization_ids)