123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- import logging
- from django.db import models
- from django.db.models.expressions import OuterRef, Subquery
- from apps.organizations_ext.models import Organization
- from .client import list_prices, list_products, list_subscriptions
- from .utils import unix_to_datetime
- logger = logging.getLogger(__name__)
- class StripeModel(models.Model):
- stripe_id = models.CharField(primary_key=True, max_length=30)
- class Meta:
- abstract = True
- class StripeProduct(StripeModel):
- name = models.CharField()
- description = models.TextField()
- default_price = models.ForeignKey(
- "StripePrice", on_delete=models.CASCADE, blank=True, null=True
- )
- events = models.PositiveBigIntegerField()
- is_public = models.BooleanField()
- def __str__(self):
- return f"{self.name} {self.stripe_id}"
- @classmethod
- async def sync_from_stripe(cls):
- stripe_ids = set()
- async for products_page in list_products():
- logger.info(f"Found {len(products_page)} products in Stripe")
- products_page = [
- product for product in products_page if "events" in product.metadata
- ]
- products = [
- StripeProduct(
- stripe_id=product.id,
- name=product.name,
- description=product.description if product.description else "",
- events=product.metadata["events"],
- is_public=product.metadata.get("is_public") == "true",
- )
- for product in products_page
- ]
- prices = [
- StripePrice(
- stripe_id=product.default_price.id,
- price=product.default_price.unit_amount / 100,
- nickname=product.default_price.nickname or "",
- product_id=product.id,
- )
- for product in products_page
- if product.default_price
- and product.default_price.unit_amount is not None
- ]
- product_updated = await StripeProduct.objects.abulk_create(
- products,
- update_conflicts=True,
- update_fields=["name", "description", "events", "is_public"],
- unique_fields=["stripe_id"],
- )
- logger.info(f"Created/updated {len(product_updated)} products in Django")
- price_updated = await StripePrice.objects.abulk_create(
- prices,
- update_conflicts=True,
- update_fields=["price", "nickname", "product_id"],
- unique_fields=["stripe_id"],
- )
- logger.info(f"Created/updated {len(price_updated)} prices in Django")
- for product in product_updated:
- for price in price_updated:
- if (
- price.product_id == product.stripe_id
- and product.default_price_id != price.stripe_id
- ):
- product.default_price_id = price.stripe_id
- await product.asave(update_fields=["default_price_id"])
- for obj in product_updated:
- stripe_ids.add(obj.stripe_id)
- result = await StripeProduct.objects.exclude(stripe_id__in=stripe_ids).adelete()
- if result[0]:
- logger.info(f"Deleted {result[0]} products in Django")
- class StripePrice(StripeModel):
- price = models.DecimalField(max_digits=10, decimal_places=2)
- nickname = models.CharField(max_length=255)
- product = models.ForeignKey(StripeProduct, on_delete=models.CASCADE)
- def __str__(self):
- return f"{self.nickname} {self.price} {self.stripe_id}"
- @classmethod
- async def sync_from_stripe(cls):
- async for prices_page in list_prices():
- product_ids = {price.product for price in prices_page}
- products = StripeProduct.objects.filter(stripe_id__in=product_ids)
- known_product_ids = set()
- async for product in products:
- known_product_ids.add(product.stripe_id)
- prices = [
- StripePrice(
- stripe_id=price.id,
- price=price.unit_amount / 100,
- nickname=price.nickname or "",
- product_id=price.product,
- )
- for price in prices_page
- if price.unit_amount is not None and price.product in known_product_ids
- ]
- await StripePrice.objects.abulk_create(
- prices,
- update_conflicts=True,
- update_fields=["price", "nickname", "product_id"],
- unique_fields=["stripe_id"],
- )
- class StripeSubscription(StripeModel):
- is_active = models.BooleanField()
- created = models.DateTimeField()
- current_period_start = models.DateTimeField()
- current_period_end = models.DateTimeField()
- product = models.ForeignKey(StripeProduct, on_delete=models.CASCADE)
- organization = models.ForeignKey(
- "organizations_ext.Organization", on_delete=models.SET_NULL, null=True
- )
- def __str__(self):
- return f"{self.stripe_id}"
- @classmethod
- async def get_primary_subscription(cls, organization: Organization):
- return (
- await cls.objects.filter(organization=organization, is_active=True)
- .order_by("-product__events", "-created")
- .afirst()
- )
- @classmethod
- async def set_primary_subscriptions_for_organizations(
- cls, organization_ids: set[int]
- ):
- # This subquery finds the primary subscription ID for each organization.
- primary_subscription_subquery = (
- cls.objects.filter(organization_id=OuterRef("pk"), is_active=True)
- .order_by("-product__events", "-created")
- .values("pk")[:1]
- )
- org_updates = []
- async for org in Organization.objects.filter(id__in=organization_ids).annotate(
- primary_subscription_id=Subquery(primary_subscription_subquery)
- ):
- if (
- org.primary_subscription_id
- and org.primary_subscription_id != org.stripe_primary_subscription_id
- ):
- org.stripe_primary_subscription_id = org.primary_subscription_id
- org_updates.append(org)
- if org_updates:
- await Organization.objects.abulk_update(
- org_updates, ["stripe_primary_subscription_id"]
- )
- @classmethod
- async def sync_from_stripe(cls):
- organization_ids = set()
- active_organization_ids = set()
- async for subscriptions in list_subscriptions():
- logger.info(f"Found {len(subscriptions)} subcriptions in Stripe")
- subscription_objects = []
- for subscription in subscriptions:
- org_metadata = subscription.customer.metadata
- try:
- organization_id = int(
- org_metadata.get(
- "organization_id", org_metadata.get("djstripe_subscriber")
- )
- )
- except (ValueError, KeyError):
- continue # Skip if no organization ID in metadata
- items = subscription.items.data
- if not items or not items[0].get("price", {}).get("product"):
- continue # Skip
- product_id = items[0]["price"]["product"]
- # If unseen organization id, check if it exists
- if organization_id not in organization_ids:
- organization_ids.add(organization_id)
- organization = await Organization.objects.filter(
- id=organization_id
- ).afirst()
- if organization:
- active_organization_ids.add(organization_id)
- if not organization.stripe_customer_id:
- organization.stripe_customer_id = subscription.customer.id
- await organization.asave(
- update_fields=["stripe_customer_id"]
- )
- # Only save subscriptions with organizations that exist
- if organization_id in active_organization_ids:
- subscription_objects.append(
- StripeSubscription(
- stripe_id=subscription.id,
- created=unix_to_datetime(subscription.created),
- current_period_start=unix_to_datetime(
- subscription.current_period_start
- ),
- current_period_end=unix_to_datetime(
- subscription.current_period_end
- ),
- product_id=product_id,
- organization_id=organization_id,
- is_active=subscription.status == "active",
- )
- )
- stripe_subscriptions = await StripeSubscription.objects.abulk_create(
- subscription_objects,
- update_conflicts=True,
- update_fields=[
- "created",
- "current_period_start",
- "current_period_end",
- "product_id",
- "organization_id",
- "is_active",
- ],
- unique_fields=["stripe_id"],
- )
- logger.info(
- f"Created/updated {len(stripe_subscriptions)} subscriptions in Django"
- )
- await cls.set_primary_subscriptions_for_organizations(active_organization_ids)
|