server_payments.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. package server
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/stripe/stripe-go/v74"
  8. portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
  9. "github.com/stripe/stripe-go/v74/checkout/session"
  10. "github.com/stripe/stripe-go/v74/customer"
  11. "github.com/stripe/stripe-go/v74/price"
  12. "github.com/stripe/stripe-go/v74/subscription"
  13. "github.com/stripe/stripe-go/v74/webhook"
  14. "heckel.io/ntfy/log"
  15. "heckel.io/ntfy/user"
  16. "heckel.io/ntfy/util"
  17. "io"
  18. "net/http"
  19. "net/netip"
  20. "time"
  21. )
  22. var (
  23. errNotAPaidTier = errors.New("tier does not have billing price identifier")
  24. errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
  25. errNoBillingSubscription = errors.New("user does not have an active billing subscription")
  26. )
  27. // Payments in ntfy are done via Stripe.
  28. //
  29. // Pretty much all payments related things are in this file. The following processes
  30. // handle payments:
  31. //
  32. // - Checkout:
  33. // Creating a Stripe customer and subscription via the Checkout flow. This flow is only used if the
  34. // ntfy user is not already a Stripe customer. This requires redirecting to the Stripe checkout page.
  35. // It is implemented in handleAccountBillingSubscriptionCreate and the success callback
  36. // handleAccountBillingSubscriptionCreateSuccess.
  37. // - Update subscription:
  38. // Switching between Stripe subscriptions (upgrade/downgrade) is handled via
  39. // handleAccountBillingSubscriptionUpdate. This also handles proration.
  40. // - Cancel subscription (at period end):
  41. // Users can cancel the Stripe subscription via the web app at the end of the billing period. This
  42. // simply updates the subscription and Stripe will cancel it. Users cannot immediately cancel the
  43. // subscription.
  44. // - Webhooks:
  45. // Whenever a subscription changes (updated, deleted), Stripe sends us a request via a webhook.
  46. // This is used to keep the local user database fields up to date. Stripe is the source of truth.
  47. // What Stripe says is mirrored and not questioned.
  48. // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
  49. // in the UI. Note that this endpoint does NOT have a user context (no v.user!).
  50. func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  51. tiers, err := s.userManager.Tiers()
  52. if err != nil {
  53. return err
  54. }
  55. freeTier := defaultVisitorLimits(s.config)
  56. response := []*apiAccountBillingTier{
  57. {
  58. // This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
  59. Limits: &apiAccountLimits{
  60. Messages: freeTier.MessagesLimit,
  61. MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
  62. Emails: freeTier.EmailsLimit,
  63. Reservations: freeTier.ReservationsLimit,
  64. AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
  65. AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
  66. AttachmentExpiryDuration: int64(freeTier.AttachmentExpiryDuration.Seconds()),
  67. },
  68. },
  69. }
  70. prices, err := s.priceCache.Value()
  71. if err != nil {
  72. return err
  73. }
  74. for _, tier := range tiers {
  75. priceStr, ok := prices[tier.StripePriceID]
  76. if tier.StripePriceID == "" || !ok {
  77. continue
  78. }
  79. response = append(response, &apiAccountBillingTier{
  80. Code: tier.Code,
  81. Name: tier.Name,
  82. Price: priceStr,
  83. Limits: &apiAccountLimits{
  84. Messages: tier.MessagesLimit,
  85. MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()),
  86. Emails: tier.EmailsLimit,
  87. Reservations: tier.ReservationsLimit,
  88. AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
  89. AttachmentFileSize: tier.AttachmentFileSizeLimit,
  90. AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
  91. },
  92. })
  93. }
  94. return s.writeJSON(w, response)
  95. }
  96. // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
  97. // will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
  98. func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  99. if v.user.Billing.StripeSubscriptionID != "" {
  100. return errHTTPBadRequestBillingSubscriptionExists
  101. }
  102. req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
  103. if err != nil {
  104. return err
  105. }
  106. tier, err := s.userManager.Tier(req.Tier)
  107. if err != nil {
  108. return err
  109. } else if tier.StripePriceID == "" {
  110. return errNotAPaidTier
  111. }
  112. log.Info("Stripe: No existing subscription, creating checkout flow")
  113. var stripeCustomerID *string
  114. if v.user.Billing.StripeCustomerID != "" {
  115. stripeCustomerID = &v.user.Billing.StripeCustomerID
  116. stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
  117. if err != nil {
  118. return err
  119. } else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
  120. return errMultipleBillingSubscriptions
  121. }
  122. }
  123. successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
  124. params := &stripe.CheckoutSessionParams{
  125. Customer: stripeCustomerID, // A user may have previously deleted their subscription
  126. ClientReferenceID: &v.user.Name,
  127. SuccessURL: &successURL,
  128. Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
  129. AllowPromotionCodes: stripe.Bool(true),
  130. LineItems: []*stripe.CheckoutSessionLineItemParams{
  131. {
  132. Price: stripe.String(tier.StripePriceID),
  133. Quantity: stripe.Int64(1),
  134. },
  135. },
  136. /*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
  137. Enabled: stripe.Bool(true),
  138. },*/
  139. }
  140. sess, err := s.stripe.NewCheckoutSession(params)
  141. if err != nil {
  142. return err
  143. }
  144. response := &apiAccountBillingSubscriptionCreateResponse{
  145. RedirectURL: sess.URL,
  146. }
  147. return s.writeJSON(w, response)
  148. }
  149. // handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
  150. // the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
  151. // and only time we can map the local username with the Stripe customer ID.
  152. func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  153. // We don't have a v.user in this endpoint, only a userManager!
  154. matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
  155. if len(matches) != 2 {
  156. return errHTTPInternalErrorInvalidPath
  157. }
  158. sessionID := matches[1]
  159. sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
  160. if err != nil {
  161. return err
  162. } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
  163. return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
  164. }
  165. sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
  166. if err != nil {
  167. return err
  168. } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
  169. return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
  170. }
  171. tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
  172. if err != nil {
  173. return err
  174. }
  175. u, err := s.userManager.User(sess.ClientReferenceID)
  176. if err != nil {
  177. return err
  178. }
  179. if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
  180. return err
  181. }
  182. http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
  183. return nil
  184. }
  185. // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
  186. // a user's tier accordingly. This endpoint only works if there is an existing subscription.
  187. func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  188. if v.user.Billing.StripeSubscriptionID == "" {
  189. return errNoBillingSubscription
  190. }
  191. req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
  192. if err != nil {
  193. return err
  194. }
  195. tier, err := s.userManager.Tier(req.Tier)
  196. if err != nil {
  197. return err
  198. }
  199. log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
  200. sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
  201. if err != nil {
  202. return err
  203. }
  204. params := &stripe.SubscriptionParams{
  205. CancelAtPeriodEnd: stripe.Bool(false),
  206. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
  207. Items: []*stripe.SubscriptionItemsParams{
  208. {
  209. ID: stripe.String(sub.Items.Data[0].ID),
  210. Price: stripe.String(tier.StripePriceID),
  211. },
  212. },
  213. }
  214. _, err = s.stripe.UpdateSubscription(sub.ID, params)
  215. if err != nil {
  216. return err
  217. }
  218. return s.writeJSON(w, newSuccessResponse())
  219. }
  220. // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
  221. // and cancelling the Stripe subscription entirely
  222. func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
  223. if v.user.Billing.StripeSubscriptionID != "" {
  224. params := &stripe.SubscriptionParams{
  225. CancelAtPeriodEnd: stripe.Bool(true),
  226. }
  227. _, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
  228. if err != nil {
  229. return err
  230. }
  231. }
  232. return s.writeJSON(w, newSuccessResponse())
  233. }
  234. // handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
  235. // redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
  236. func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  237. if v.user.Billing.StripeCustomerID == "" {
  238. return errHTTPBadRequestNotAPaidUser
  239. }
  240. params := &stripe.BillingPortalSessionParams{
  241. Customer: stripe.String(v.user.Billing.StripeCustomerID),
  242. ReturnURL: stripe.String(s.config.BaseURL),
  243. }
  244. ps, err := s.stripe.NewPortalSession(params)
  245. if err != nil {
  246. return err
  247. }
  248. response := &apiAccountBillingPortalRedirectResponse{
  249. RedirectURL: ps.URL,
  250. }
  251. return s.writeJSON(w, response)
  252. }
  253. // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
  254. // with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
  255. // visitor (v) in this endpoint is the Stripe API, so we don't have v.user available.
  256. func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  257. stripeSignature := r.Header.Get("Stripe-Signature")
  258. if stripeSignature == "" {
  259. return errHTTPBadRequestBillingRequestInvalid
  260. }
  261. body, err := util.Peek(r.Body, jsonBodyBytesLimit)
  262. if err != nil {
  263. return err
  264. } else if body.LimitReached {
  265. return errHTTPEntityTooLargeJSONBody
  266. }
  267. event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
  268. if err != nil {
  269. return err
  270. } else if event.Data == nil || event.Data.Raw == nil {
  271. return errHTTPBadRequestBillingRequestInvalid
  272. }
  273. log.Info("Stripe: webhook event %s received", event.Type)
  274. switch event.Type {
  275. case "customer.subscription.updated":
  276. return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
  277. case "customer.subscription.deleted":
  278. return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
  279. default:
  280. return nil
  281. }
  282. }
  283. func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
  284. r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
  285. if err != nil {
  286. return err
  287. } else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
  288. return errHTTPBadRequestBillingRequestInvalid
  289. }
  290. subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
  291. log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
  292. u, err := s.userManager.UserByStripeCustomer(r.Customer)
  293. if err != nil {
  294. return err
  295. }
  296. tier, err := s.userManager.TierByStripePrice(priceID)
  297. if err != nil {
  298. return err
  299. }
  300. if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
  301. return err
  302. }
  303. s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
  304. return nil
  305. }
  306. func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
  307. r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
  308. if err != nil {
  309. return err
  310. } else if r.Customer == "" {
  311. return errHTTPBadRequestBillingRequestInvalid
  312. }
  313. log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
  314. u, err := s.userManager.UserByStripeCustomer(r.Customer)
  315. if err != nil {
  316. return err
  317. }
  318. if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
  319. return err
  320. }
  321. s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
  322. return nil
  323. }
  324. func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
  325. // Remove excess reservations (if too many for tier), and mark associated messages deleted
  326. reservations, err := s.userManager.Reservations(u.Name)
  327. if err != nil {
  328. return err
  329. }
  330. reservationsLimit := visitorDefaultReservationsLimit
  331. if tier != nil {
  332. reservationsLimit = tier.ReservationsLimit
  333. }
  334. if int64(len(reservations)) > reservationsLimit {
  335. topics := make([]string, 0)
  336. for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
  337. topics = append(topics, reservations[i].Topic)
  338. }
  339. if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
  340. return err
  341. }
  342. if err := s.messageCache.ExpireMessages(topics...); err != nil {
  343. return err
  344. }
  345. }
  346. // Change or remove tier
  347. if tier == nil {
  348. if err := s.userManager.ResetTier(u.Name); err != nil {
  349. return err
  350. }
  351. } else {
  352. if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
  353. return err
  354. }
  355. }
  356. // Update billing fields
  357. billing := &user.Billing{
  358. StripeCustomerID: customerID,
  359. StripeSubscriptionID: subscriptionID,
  360. StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
  361. StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
  362. StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
  363. }
  364. if err := s.userManager.ChangeBilling(u.Name, billing); err != nil {
  365. return err
  366. }
  367. return nil
  368. }
  369. // fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
  370. // in memory, and ultimately for the web app to display the price table.
  371. func (s *Server) fetchStripePrices() (map[string]string, error) {
  372. log.Debug("Caching prices from Stripe API")
  373. priceMap := make(map[string]string)
  374. prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
  375. if err != nil {
  376. log.Warn("Fetching Stripe prices failed: %s", err.Error())
  377. return nil, err
  378. }
  379. for _, p := range prices {
  380. if p.UnitAmount%100 == 0 {
  381. priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
  382. } else {
  383. priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
  384. }
  385. log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
  386. }
  387. return priceMap, nil
  388. }
  389. // stripeAPI is a small interface to facilitate mocking of the Stripe API
  390. type stripeAPI interface {
  391. NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error)
  392. NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error)
  393. ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error)
  394. GetCustomer(id string) (*stripe.Customer, error)
  395. GetSession(id string) (*stripe.CheckoutSession, error)
  396. GetSubscription(id string) (*stripe.Subscription, error)
  397. UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error)
  398. CancelSubscription(id string) (*stripe.Subscription, error)
  399. ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error)
  400. }
  401. // realStripeAPI is a thin shim around the Stripe functions to facilitate mocking
  402. type realStripeAPI struct{}
  403. var _ stripeAPI = (*realStripeAPI)(nil)
  404. func newStripeAPI() stripeAPI {
  405. return &realStripeAPI{}
  406. }
  407. func (s *realStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  408. return session.New(params)
  409. }
  410. func (s *realStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  411. return portalsession.New(params)
  412. }
  413. func (s *realStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  414. prices := make([]*stripe.Price, 0)
  415. iter := price.List(params)
  416. for iter.Next() {
  417. prices = append(prices, iter.Price())
  418. }
  419. if iter.Err() != nil {
  420. return nil, iter.Err()
  421. }
  422. return prices, nil
  423. }
  424. func (s *realStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  425. return customer.Get(id, nil)
  426. }
  427. func (s *realStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  428. return session.Get(id, nil)
  429. }
  430. func (s *realStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  431. return subscription.Get(id, nil)
  432. }
  433. func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  434. return subscription.Update(id, params)
  435. }
  436. func (s *realStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  437. return subscription.Cancel(id, nil)
  438. }
  439. func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  440. return webhook.ConstructEvent(payload, header, secret)
  441. }