util.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "heckel.io/ntfy/util"
  6. "io"
  7. "mime"
  8. "net/http"
  9. "net/netip"
  10. "regexp"
  11. "strings"
  12. )
  13. var (
  14. mimeDecoder mime.WordDecoder
  15. priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
  16. )
  17. func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
  18. value := strings.ToLower(readParam(r, names...))
  19. if value == "" {
  20. return defaultValue
  21. }
  22. return toBool(value)
  23. }
  24. func isBoolValue(value string) bool {
  25. return value == "1" || value == "yes" || value == "true" || value == "0" || value == "no" || value == "false"
  26. }
  27. func toBool(value string) bool {
  28. return value == "1" || value == "yes" || value == "true"
  29. }
  30. func readCommaSeparatedParam(r *http.Request, names ...string) (params []string) {
  31. paramStr := readParam(r, names...)
  32. if paramStr != "" {
  33. params = make([]string, 0)
  34. for _, s := range util.SplitNoEmpty(paramStr, ",") {
  35. params = append(params, strings.TrimSpace(s))
  36. }
  37. }
  38. return params
  39. }
  40. func readParam(r *http.Request, names ...string) string {
  41. value := readHeaderParam(r, names...)
  42. if value != "" {
  43. return value
  44. }
  45. return readQueryParam(r, names...)
  46. }
  47. func readHeaderParam(r *http.Request, names ...string) string {
  48. for _, name := range names {
  49. value := strings.TrimSpace(maybeDecodeHeader(name, r.Header.Get(name)))
  50. if value != "" {
  51. return value
  52. }
  53. }
  54. return ""
  55. }
  56. func readQueryParam(r *http.Request, names ...string) string {
  57. for _, name := range names {
  58. value := r.URL.Query().Get(strings.ToLower(name))
  59. if value != "" {
  60. return strings.TrimSpace(value)
  61. }
  62. }
  63. return ""
  64. }
  65. func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
  66. remoteAddr := r.RemoteAddr
  67. addrPort, err := netip.ParseAddrPort(remoteAddr)
  68. ip := addrPort.Addr()
  69. if err != nil {
  70. // This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified
  71. ip, err = netip.ParseAddr(remoteAddr)
  72. if err != nil {
  73. ip = netip.IPv4Unspecified()
  74. if remoteAddr != "@" || !behindProxy { // RemoteAddr is @ when unix socket is used
  75. logr(r).Err(err).Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created", remoteAddr)
  76. }
  77. }
  78. }
  79. if behindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
  80. // X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
  81. // only the right-most address can be trusted (as this is the one added by our proxy server).
  82. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
  83. ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
  84. realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
  85. if err != nil {
  86. logr(r).Err(err).Error("invalid IP address %s received in X-Forwarded-For header", ip)
  87. // Fall back to regular remote address if X-Forwarded-For is damaged
  88. } else {
  89. ip = realIP
  90. }
  91. }
  92. return ip
  93. }
  94. func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T, error) {
  95. obj, err := util.UnmarshalJSONWithLimit[T](r, limit, allowEmpty)
  96. if err == util.ErrUnmarshalJSON {
  97. return nil, errHTTPBadRequestJSONInvalid
  98. } else if err == util.ErrTooLargeJSON {
  99. return nil, errHTTPEntityTooLargeJSONBody
  100. } else if err != nil {
  101. return nil, err
  102. }
  103. return obj, nil
  104. }
  105. func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
  106. c := r.Context()
  107. for k, v := range ctx {
  108. c = context.WithValue(c, k, v)
  109. }
  110. return r.WithContext(c)
  111. }
  112. func fromContext[T any](r *http.Request, key contextKey) (T, error) {
  113. t, ok := r.Context().Value(key).(T)
  114. if !ok {
  115. return t, fmt.Errorf("cannot find key %v in request context", key)
  116. }
  117. return t, nil
  118. }
  119. // maybeDecodeHeader decodes the given header value if it is MIME encoded, e.g. "=?utf-8?q?Hello_World?=",
  120. // or returns the original header value if it is not MIME encoded. It also calls maybeIgnoreSpecialHeader
  121. // to ignore new HTTP "Priority" header.
  122. func maybeDecodeHeader(name, value string) string {
  123. decoded, err := mimeDecoder.DecodeHeader(value)
  124. if err != nil {
  125. return maybeIgnoreSpecialHeader(name, value)
  126. }
  127. return maybeIgnoreSpecialHeader(name, decoded)
  128. }
  129. // maybeIgnoreSpecialHeader ignores new HTTP "Priority" header (see https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-priority)
  130. //
  131. // Cloudflare (and potentially other providers) add this to requests when forwarding to the backend (ntfy),
  132. // so we just ignore it. If the "Priority" header is set to "u=*, i" or "u=*" (by Cloudflare), the header will be ignored.
  133. // Returning an empty string will allow the rest of the logic to continue searching for another header (x-priority, prio, p),
  134. // or in the Query parameters.
  135. func maybeIgnoreSpecialHeader(name, value string) string {
  136. if strings.ToLower(name) == "priority" && priorityHeaderIgnoreRegex.MatchString(strings.TrimSpace(value)) {
  137. return ""
  138. }
  139. return value
  140. }