server_middleware.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package server
  2. import (
  3. "net/http"
  4. "heckel.io/ntfy/util"
  5. )
  6. type contextKey int
  7. const (
  8. contextRateVisitor contextKey = iota + 2586
  9. contextTopic
  10. contextMatrixPushKey
  11. )
  12. func (s *Server) limitRequests(next handleFunc) handleFunc {
  13. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  14. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  15. return next(w, r, v)
  16. } else if !v.RequestAllowed() {
  17. return errHTTPTooManyRequestsLimitRequests
  18. }
  19. return next(w, r, v)
  20. }
  21. }
  22. // limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context
  23. func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
  24. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  25. t, err := s.topicFromPath(r.URL.Path)
  26. if err != nil {
  27. return err
  28. }
  29. vrate := v
  30. if rateVisitor := t.RateVisitor(); rateVisitor != nil {
  31. vrate = rateVisitor
  32. }
  33. r = withContext(r, map[contextKey]any{
  34. contextRateVisitor: vrate,
  35. contextTopic: t,
  36. })
  37. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  38. return next(w, r, v)
  39. } else if !vrate.RequestAllowed() {
  40. return errHTTPTooManyRequestsLimitRequests
  41. }
  42. return next(w, r, v)
  43. }
  44. }
  45. func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
  46. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  47. if !s.config.EnableWeb {
  48. return errHTTPNotFound
  49. }
  50. return next(w, r, v)
  51. }
  52. }
  53. func (s *Server) ensureUserManager(next handleFunc) handleFunc {
  54. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  55. if s.userManager == nil {
  56. return errHTTPNotFound
  57. }
  58. return next(w, r, v)
  59. }
  60. }
  61. func (s *Server) ensureUser(next handleFunc) handleFunc {
  62. return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  63. if v.User() == nil {
  64. return errHTTPUnauthorized
  65. }
  66. return next(w, r, v)
  67. })
  68. }
  69. func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
  70. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  71. if s.config.StripeSecretKey == "" || s.stripe == nil {
  72. return errHTTPNotFound
  73. }
  74. return next(w, r, v)
  75. }
  76. }
  77. func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
  78. return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  79. if v.User().Billing.StripeCustomerID == "" {
  80. return errHTTPBadRequestNotAPaidUser
  81. }
  82. return next(w, r, v)
  83. })
  84. }
  85. func (s *Server) withAccountSync(next handleFunc) handleFunc {
  86. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  87. err := next(w, r, v)
  88. if err == nil {
  89. s.publishSyncEventAsync(v)
  90. }
  91. return err
  92. }
  93. }