server_middleware.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. package server
  2. import (
  3. "net/http"
  4. "heckel.io/ntfy/v2/util"
  5. )
  6. type contextKey int
  7. const (
  8. contextRateVisitor contextKey = iota + 2586
  9. contextTopic
  10. contextMatrixPushKey
  11. )
  12. func (s *Server) limitRequests(next handleFunc) handleFunc {
  13. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  14. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  15. return next(w, r, v)
  16. } else if !v.RequestAllowed() {
  17. return errHTTPTooManyRequestsLimitRequests
  18. }
  19. return next(w, r, v)
  20. }
  21. }
  22. // limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context
  23. func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
  24. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  25. t, err := s.topicFromPath(r.URL.Path)
  26. if err != nil {
  27. return err
  28. }
  29. vrate := v
  30. if rateVisitor := t.RateVisitor(); rateVisitor != nil {
  31. vrate = rateVisitor
  32. }
  33. r = withContext(r, map[contextKey]any{
  34. contextRateVisitor: vrate,
  35. contextTopic: t,
  36. })
  37. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  38. return next(w, r, v)
  39. } else if !vrate.RequestAllowed() {
  40. return errHTTPTooManyRequestsLimitRequests
  41. }
  42. return next(w, r, v)
  43. }
  44. }
  45. func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
  46. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  47. if s.config.WebRoot == "" {
  48. return errHTTPNotFound
  49. }
  50. return next(w, r, v)
  51. }
  52. }
  53. func (s *Server) ensureWebPushEnabled(next handleFunc) handleFunc {
  54. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  55. if s.config.WebRoot == "" || s.config.WebPushPublicKey == "" {
  56. return errHTTPNotFound
  57. }
  58. return next(w, r, v)
  59. }
  60. }
  61. func (s *Server) ensureUserManager(next handleFunc) handleFunc {
  62. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  63. if s.userManager == nil {
  64. return errHTTPNotFound
  65. }
  66. return next(w, r, v)
  67. }
  68. }
  69. func (s *Server) ensureUser(next handleFunc) handleFunc {
  70. return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  71. if v.User() == nil {
  72. return errHTTPUnauthorized
  73. }
  74. return next(w, r, v)
  75. })
  76. }
  77. func (s *Server) ensureAdmin(next handleFunc) handleFunc {
  78. return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  79. if !v.User().IsAdmin() {
  80. return errHTTPUnauthorized
  81. }
  82. return next(w, r, v)
  83. })
  84. }
  85. func (s *Server) ensureCallsEnabled(next handleFunc) handleFunc {
  86. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  87. if s.config.TwilioAccount == "" || s.userManager == nil {
  88. return errHTTPNotFound
  89. }
  90. return next(w, r, v)
  91. }
  92. }
  93. func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
  94. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  95. if s.config.StripeSecretKey == "" || s.stripe == nil {
  96. return errHTTPNotFound
  97. }
  98. return next(w, r, v)
  99. }
  100. }
  101. func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
  102. return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  103. if v.User().Billing.StripeCustomerID == "" {
  104. return errHTTPBadRequestNotAPaidUser
  105. }
  106. return next(w, r, v)
  107. })
  108. }
  109. func (s *Server) withAccountSync(next handleFunc) handleFunc {
  110. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  111. err := next(w, r, v)
  112. if err == nil {
  113. s.publishSyncEventAsync(v)
  114. }
  115. return err
  116. }
  117. }