util.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "heckel.io/ntfy/v2/util"
  7. "io"
  8. "mime"
  9. "net/http"
  10. "net/netip"
  11. "regexp"
  12. "strings"
  13. )
  14. var (
  15. mimeDecoder mime.WordDecoder
  16. priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
  17. )
  18. func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
  19. value := strings.ToLower(readParam(r, names...))
  20. if value == "" {
  21. return defaultValue
  22. }
  23. return toBool(value)
  24. }
  25. func isBoolValue(value string) bool {
  26. return value == "1" || value == "yes" || value == "true" || value == "0" || value == "no" || value == "false"
  27. }
  28. func toBool(value string) bool {
  29. return value == "1" || value == "yes" || value == "true"
  30. }
  31. func readCommaSeparatedParam(r *http.Request, names ...string) (params []string) {
  32. paramStr := readParam(r, names...)
  33. if paramStr != "" {
  34. params = make([]string, 0)
  35. for _, s := range util.SplitNoEmpty(paramStr, ",") {
  36. params = append(params, strings.TrimSpace(s))
  37. }
  38. }
  39. return params
  40. }
  41. func readParam(r *http.Request, names ...string) string {
  42. value := readHeaderParam(r, names...)
  43. if value != "" {
  44. return value
  45. }
  46. return readQueryParam(r, names...)
  47. }
  48. func readHeaderParam(r *http.Request, names ...string) string {
  49. for _, name := range names {
  50. value := strings.TrimSpace(maybeDecodeHeader(name, r.Header.Get(name)))
  51. if value != "" {
  52. return value
  53. }
  54. }
  55. return ""
  56. }
  57. func readQueryParam(r *http.Request, names ...string) string {
  58. for _, name := range names {
  59. value := r.URL.Query().Get(strings.ToLower(name))
  60. if value != "" {
  61. return strings.TrimSpace(value)
  62. }
  63. }
  64. return ""
  65. }
  66. func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
  67. remoteAddr := r.RemoteAddr
  68. addrPort, err := netip.ParseAddrPort(remoteAddr)
  69. ip := addrPort.Addr()
  70. if err != nil {
  71. // This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified
  72. ip, err = netip.ParseAddr(remoteAddr)
  73. if err != nil {
  74. ip = netip.IPv4Unspecified()
  75. if remoteAddr != "@" || !behindProxy { // RemoteAddr is @ when unix socket is used
  76. logr(r).Err(err).Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created", remoteAddr)
  77. }
  78. }
  79. }
  80. if behindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
  81. // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
  82. // only the right-most address can be trusted (as this is the one added by our proxy server).
  83. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
  84. ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
  85. realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
  86. if err != nil {
  87. logr(r).Err(err).Error("invalid IP address %s received in X-Forwarded-For header", ip)
  88. // Fall back to regular remote address if X-Forwarded-For is damaged
  89. } else {
  90. ip = realIP
  91. }
  92. }
  93. return ip
  94. }
  95. func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T, error) {
  96. obj, err := util.UnmarshalJSONWithLimit[T](r, limit, allowEmpty)
  97. if errors.Is(err, util.ErrUnmarshalJSON) {
  98. return nil, errHTTPBadRequestJSONInvalid
  99. } else if errors.Is(err, util.ErrTooLargeJSON) {
  100. return nil, errHTTPEntityTooLargeJSONBody
  101. } else if err != nil {
  102. return nil, err
  103. }
  104. return obj, nil
  105. }
  106. func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
  107. c := r.Context()
  108. for k, v := range ctx {
  109. c = context.WithValue(c, k, v)
  110. }
  111. return r.WithContext(c)
  112. }
  113. func fromContext[T any](r *http.Request, key contextKey) (T, error) {
  114. t, ok := r.Context().Value(key).(T)
  115. if !ok {
  116. return t, fmt.Errorf("cannot find key %v in request context", key)
  117. }
  118. return t, nil
  119. }
  120. // maybeDecodeHeader decodes the given header value if it is MIME encoded, e.g. "=?utf-8?q?Hello_World?=",
  121. // or returns the original header value if it is not MIME encoded. It also calls maybeIgnoreSpecialHeader
  122. // to ignore new HTTP "Priority" header.
  123. func maybeDecodeHeader(name, value string) string {
  124. decoded, err := mimeDecoder.DecodeHeader(value)
  125. if err != nil {
  126. return maybeIgnoreSpecialHeader(name, value)
  127. }
  128. return maybeIgnoreSpecialHeader(name, decoded)
  129. }
  130. // maybeIgnoreSpecialHeader ignores new HTTP "Priority" header (see https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-priority)
  131. //
  132. // Cloudflare (and potentially other providers) add this to requests when forwarding to the backend (ntfy),
  133. // so we just ignore it. If the "Priority" header is set to "u=*, i" or "u=*" (by Cloudflare), the header will be ignored.
  134. // Returning an empty string will allow the rest of the logic to continue searching for another header (x-priority, prio, p),
  135. // or in the Query parameters.
  136. func maybeIgnoreSpecialHeader(name, value string) string {
  137. if strings.ToLower(name) == "priority" && priorityHeaderIgnoreRegex.MatchString(strings.TrimSpace(value)) {
  138. return ""
  139. }
  140. return value
  141. }