123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- import hmac
- import logging
- import time
- from django.conf import settings
- from django.http import (
- HttpRequest,
- HttpResponse,
- HttpResponseForbidden,
- HttpResponseServerError,
- )
- from django.views.decorators.csrf import csrf_exempt
- from django.views.decorators.http import require_POST
- from pydantic import ValidationError
- from apps.organizations_ext.models import Organization
- from .client import stripe_get
- from .models import StripePrice, StripeProduct, StripeSubscription
- from .schema import Customer, Price, Product, StripeEvent, Subscription
- from .utils import unix_to_datetime
- logger = logging.getLogger(__name__)
- async def update_product(product: Product):
- metadata = product.metadata
- if "events" not in metadata:
- return
- await StripeProduct.objects.aupdate_or_create(
- stripe_id=product.id,
- defaults={
- "name": product.name,
- "description": product.description,
- "events": metadata["events"],
- "is_public": metadata.get("is_public") == "true",
- },
- )
- async def update_price(price: Price):
- if (
- not price.unit_amount
- or not await StripeProduct.objects.filter(stripe_id=price.product).aexists()
- ):
- return
- await StripePrice.objects.aupdate_or_create(
- stripe_id=price.id,
- defaults={
- "product_id": price.product,
- "nickname": price.nickname,
- "price": price.unit_amount / 100,
- },
- )
- async def update_subscription(subscription: Subscription):
- customer_obj = Customer.model_validate_json(
- await stripe_get(f"customers/{subscription.customer}")
- )
- customer_metadata = customer_obj.metadata
- organization_id = int(
- customer_metadata.get(
- "organization_id", customer_metadata["djstripe_subscriber"]
- )
- )
- if not organization_id:
- return
- organization = await Organization.objects.filter(id=organization_id).afirst()
- if not organization:
- return
- if (
- product_id := subscription.items.data
- and subscription.items.data[0].get("plan", {}).get("product")
- ) is None:
- return
- stripe_subscription, created = await StripeSubscription.objects.aupdate_or_create(
- stripe_id=subscription.id,
- defaults={
- "created": unix_to_datetime(subscription.created),
- "current_period_start": unix_to_datetime(subscription.current_period_start),
- "current_period_end": unix_to_datetime(subscription.current_period_end),
- "product_id": product_id,
- "organization_id": organization.id,
- "is_active": subscription.status == "active",
- },
- )
- if stripe_subscription.is_active:
- primary_subscription = await StripeSubscription.get_primary_subscription(
- organization
- )
- if (
- primary_subscription
- and primary_subscription.stripe_id
- != organization.stripe_primary_subscription_id
- ):
- organization.stripe_primary_subscription = primary_subscription
- await organization.asave(update_fields=["stripe_primary_subscription"])
- @csrf_exempt
- @require_POST
- async def stripe_webhook_view(request: HttpRequest):
- """
- Handles Stripe webhook events.
- This view verifies the webhook signature using the raw request body and the
- Stripe webhook secret. It then processes the event based on its type.
- """
- payload = request.body
- sig_header = request.META.get("HTTP_STRIPE_SIGNATURE")
- if not sig_header:
- logger.warning("Stripe webhook received without signature header.")
- return HttpResponseForbidden("Missing signature header")
- try:
- if not verify_stripe_signature(payload, sig_header):
- logger.warning("Stripe webhook signature verification failed.")
- return HttpResponseForbidden("Invalid signature")
- except ValueError as e:
- logger.error(f"Error during signature verification: {e}")
- return HttpResponseForbidden("Invalid payload")
- except Exception as e:
- logger.exception(
- f"Unexpected error verifying signature: {e}"
- ) # Catch unexpected exceptions
- return HttpResponseServerError("Internal Server Error")
- try:
- event = StripeEvent.model_validate_json(payload)
- except ValidationError:
- logger.warning("Invalid JSON payload in Stripe webhook.")
- return HttpResponse(status=200)
- if event.type in ["product.updated", "product.created"]:
- await update_product(event.data.object)
- elif event.type in [
- "customer.subscription.updated",
- "customer.subscription.created",
- ]:
- await update_subscription(event.data.object)
- elif event.type in ["price.updated", "price.created"]:
- await update_price(event.data.object)
- else:
- logger.info(f"Unhandled Stripe event type: {event.type}")
- return HttpResponse(status=200)
- def verify_stripe_signature(payload, sig_header):
- """Verifies the Stripe webhook signature.
- Args:
- payload: The raw request body (bytes).
- sig_header: The value of the Stripe-Signature header.
- Returns:
- True if the signature is valid, False otherwise.
- Raises:
- ValueError: if the signature header is malformed.
- """
- webhook_secret = getattr(settings, "STRIPE_WEBHOOK_SECRET", None)
- if not webhook_secret:
- logger.error("STRIPE_WEBHOOK_SECRET not configured in settings.")
- # Return False/raise exception based on desired behavior (security vs. failing fast).
- # Returning False is generally safer.
- return False
- try:
- parts = {}
- for part in sig_header.split(","):
- key, value = part.strip().split("=", 1)
- parts[key.strip()] = value.strip()
- timestamp = int(parts.get("t"))
- signature = parts.get("v1") # Or 'v0' depending on your webhook setting
- if not timestamp or not signature:
- raise ValueError("Missing timestamp or signature")
- # Check timestamp tolerance (prevent replay attacks)
- tolerance = getattr(
- settings, "STRIPE_WEBHOOK_TOLERANCE", 300
- ) # Default: 5 minutes
- if (time.time() - timestamp) > tolerance:
- logger.warning(
- f"Stripe Webhook timestamp outside of tolerance: {timestamp}"
- )
- return False
- # Construct the signed payload string.
- signed_payload = (
- f"{timestamp}.{payload.decode('utf-8')}" # Important: decode to string
- )
- # Compute the expected signature.
- expected_signature = hmac.new(
- webhook_secret.encode("utf-8"),
- signed_payload.encode("utf-8"),
- digestmod="sha256",
- ).hexdigest()
- # Compare signatures. Use hmac.compare_digest for constant-time comparison.
- return hmac.compare_digest(signature, expected_signature)
- except ValueError:
- raise
- except Exception as e:
- logger.exception(f"Error in verify_stripe_signature: {e}")
- return False
|