123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- package server
- import (
- "net/http"
- "heckel.io/ntfy/v2/util"
- )
- type contextKey int
- const (
- contextRateVisitor contextKey = iota + 2586
- contextTopic
- contextMatrixPushKey
- )
- func (s *Server) limitRequests(next handleFunc) handleFunc {
- return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
- return next(w, r, v)
- } else if !v.RequestAllowed() {
- return errHTTPTooManyRequestsLimitRequests
- }
- return next(w, r, v)
- }
- }
- // limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context
- func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
- return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- t, err := s.topicFromPath(r.URL.Path)
- if err != nil {
- return err
- }
- vrate := v
- if rateVisitor := t.RateVisitor(); rateVisitor != nil {
- vrate = rateVisitor
- }
- r = withContext(r, map[contextKey]any{
- contextRateVisitor: vrate,
- contextTopic: t,
- })
- if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
- return next(w, r, v)
- } else if !vrate.RequestAllowed() {
- return errHTTPTooManyRequestsLimitRequests
- }
- return next(w, r, v)
- }
- }
- func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
- return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if s.config.WebRoot == "" {
- return errHTTPNotFound
- }
- return next(w, r, v)
- }
- }
- func (s *Server) ensureWebPushEnabled(next handleFunc) handleFunc {
- return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if s.config.WebRoot == "" || s.config.WebPushPublicKey == "" {
- return errHTTPNotFound
- }
- return next(w, r, v)
- }
- }
- func (s *Server) ensureUserManager(next handleFunc) handleFunc {
- return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if s.userManager == nil {
- return errHTTPNotFound
- }
- return next(w, r, v)
- }
- }
- func (s *Server) ensureUser(next handleFunc) handleFunc {
- return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if v.User() == nil {
- return errHTTPUnauthorized
- }
- return next(w, r, v)
- })
- }
- func (s *Server) ensureAdmin(next handleFunc) handleFunc {
- return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if !v.User().IsAdmin() {
- return errHTTPUnauthorized
- }
- return next(w, r, v)
- })
- }
- func (s *Server) ensureCallsEnabled(next handleFunc) handleFunc {
- return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if s.config.TwilioAccount == "" || s.userManager == nil {
- return errHTTPNotFound
- }
- return next(w, r, v)
- }
- }
- func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
- return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if s.config.StripeSecretKey == "" || s.stripe == nil {
- return errHTTPNotFound
- }
- return next(w, r, v)
- }
- }
- func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
- return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- if v.User().Billing.StripeCustomerID == "" {
- return errHTTPBadRequestNotAPaidUser
- }
- return next(w, r, v)
- })
- }
- func (s *Server) withAccountSync(next handleFunc) handleFunc {
- return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
- err := next(w, r, v)
- if err == nil {
- s.publishSyncEventAsync(v)
- }
- return err
- }
- }
|