views.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import hmac
  2. import logging
  3. import time
  4. from django.conf import settings
  5. from django.http import (
  6. HttpRequest,
  7. HttpResponse,
  8. HttpResponseForbidden,
  9. HttpResponseServerError,
  10. )
  11. from django.views.decorators.csrf import csrf_exempt
  12. from django.views.decorators.http import require_POST
  13. from pydantic import ValidationError
  14. from apps.organizations_ext.models import Organization
  15. from .client import stripe_get
  16. from .models import StripePrice, StripeProduct, StripeSubscription
  17. from .schema import Customer, Price, Product, StripeEvent, Subscription
  18. from .utils import unix_to_datetime
  19. logger = logging.getLogger(__name__)
  20. async def update_product(product: Product):
  21. metadata = product.metadata
  22. if "events" not in metadata:
  23. return
  24. await StripeProduct.objects.aupdate_or_create(
  25. stripe_id=product.id,
  26. defaults={
  27. "name": product.name,
  28. "description": product.description,
  29. "events": metadata["events"],
  30. "is_public": metadata.get("is_public") == "true",
  31. },
  32. )
  33. async def update_price(price: Price):
  34. if (
  35. not price.unit_amount
  36. or not await StripeProduct.objects.filter(stripe_id=price.product).aexists()
  37. ):
  38. return
  39. await StripePrice.objects.aupdate_or_create(
  40. stripe_id=price.id,
  41. defaults={
  42. "product_id": price.product,
  43. "nickname": price.nickname,
  44. "price": price.unit_amount / 100,
  45. },
  46. )
  47. async def update_subscription(subscription: Subscription):
  48. customer_obj = Customer.model_validate_json(
  49. await stripe_get(f"customers/{subscription.customer}")
  50. )
  51. customer_metadata = customer_obj.metadata
  52. organization_id = int(
  53. customer_metadata.get(
  54. "organization_id", customer_metadata["djstripe_subscriber"]
  55. )
  56. )
  57. if not organization_id:
  58. return
  59. organization = await Organization.objects.filter(id=organization_id).afirst()
  60. if not organization:
  61. return
  62. if (
  63. product_id := subscription.items.data
  64. and subscription.items.data[0].get("plan", {}).get("product")
  65. ) is None:
  66. return
  67. stripe_subscription, created = await StripeSubscription.objects.aupdate_or_create(
  68. stripe_id=subscription.id,
  69. defaults={
  70. "created": unix_to_datetime(subscription.created),
  71. "current_period_start": unix_to_datetime(subscription.current_period_start),
  72. "current_period_end": unix_to_datetime(subscription.current_period_end),
  73. "product_id": product_id,
  74. "organization_id": organization.id,
  75. "is_active": subscription.status == "active",
  76. },
  77. )
  78. if stripe_subscription.is_active:
  79. primary_subscription = await StripeSubscription.get_primary_subscription(
  80. organization
  81. )
  82. if (
  83. primary_subscription
  84. and primary_subscription.stripe_id
  85. != organization.stripe_primary_subscription_id
  86. ):
  87. organization.stripe_primary_subscription = primary_subscription
  88. await organization.asave(update_fields=["stripe_primary_subscription"])
  89. @csrf_exempt
  90. @require_POST
  91. async def stripe_webhook_view(request: HttpRequest):
  92. """
  93. Handles Stripe webhook events.
  94. This view verifies the webhook signature using the raw request body and the
  95. Stripe webhook secret. It then processes the event based on its type.
  96. """
  97. payload = request.body
  98. sig_header = request.META.get("HTTP_STRIPE_SIGNATURE")
  99. if not sig_header:
  100. logger.warning("Stripe webhook received without signature header.")
  101. return HttpResponseForbidden("Missing signature header")
  102. try:
  103. if not verify_stripe_signature(payload, sig_header):
  104. logger.warning("Stripe webhook signature verification failed.")
  105. return HttpResponseForbidden("Invalid signature")
  106. except ValueError as e:
  107. logger.error(f"Error during signature verification: {e}")
  108. return HttpResponseForbidden("Invalid payload")
  109. except Exception as e:
  110. logger.exception(
  111. f"Unexpected error verifying signature: {e}"
  112. ) # Catch unexpected exceptions
  113. return HttpResponseServerError("Internal Server Error")
  114. try:
  115. event = StripeEvent.model_validate_json(payload)
  116. except ValidationError:
  117. logger.warning("Invalid JSON payload in Stripe webhook.")
  118. return HttpResponse(status=200)
  119. if event.type in ["product.updated", "product.created"]:
  120. await update_product(event.data.object)
  121. elif event.type in [
  122. "customer.subscription.updated",
  123. "customer.subscription.created",
  124. ]:
  125. await update_subscription(event.data.object)
  126. elif event.type in ["price.updated", "price.created"]:
  127. await update_price(event.data.object)
  128. else:
  129. logger.info(f"Unhandled Stripe event type: {event.type}")
  130. return HttpResponse(status=200)
  131. def verify_stripe_signature(payload, sig_header):
  132. """Verifies the Stripe webhook signature.
  133. Args:
  134. payload: The raw request body (bytes).
  135. sig_header: The value of the Stripe-Signature header.
  136. Returns:
  137. True if the signature is valid, False otherwise.
  138. Raises:
  139. ValueError: if the signature header is malformed.
  140. """
  141. webhook_secret = getattr(settings, "STRIPE_WEBHOOK_SECRET", None)
  142. if not webhook_secret:
  143. logger.error("STRIPE_WEBHOOK_SECRET not configured in settings.")
  144. # Return False/raise exception based on desired behavior (security vs. failing fast).
  145. # Returning False is generally safer.
  146. return False
  147. try:
  148. parts = {}
  149. for part in sig_header.split(","):
  150. key, value = part.strip().split("=", 1)
  151. parts[key.strip()] = value.strip()
  152. timestamp = int(parts.get("t"))
  153. signature = parts.get("v1") # Or 'v0' depending on your webhook setting
  154. if not timestamp or not signature:
  155. raise ValueError("Missing timestamp or signature")
  156. # Check timestamp tolerance (prevent replay attacks)
  157. tolerance = getattr(
  158. settings, "STRIPE_WEBHOOK_TOLERANCE", 300
  159. ) # Default: 5 minutes
  160. if (time.time() - timestamp) > tolerance:
  161. logger.warning(
  162. f"Stripe Webhook timestamp outside of tolerance: {timestamp}"
  163. )
  164. return False
  165. # Construct the signed payload string.
  166. signed_payload = (
  167. f"{timestamp}.{payload.decode('utf-8')}" # Important: decode to string
  168. )
  169. # Compute the expected signature.
  170. expected_signature = hmac.new(
  171. webhook_secret.encode("utf-8"),
  172. signed_payload.encode("utf-8"),
  173. digestmod="sha256",
  174. ).hexdigest()
  175. # Compare signatures. Use hmac.compare_digest for constant-time comparison.
  176. return hmac.compare_digest(signature, expected_signature)
  177. except ValueError:
  178. raise
  179. except Exception as e:
  180. logger.exception(f"Error in verify_stripe_signature: {e}")
  181. return False