server.go 75 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "embed"
  7. "encoding/base64"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "net"
  13. "net/http"
  14. "net/http/pprof"
  15. "net/netip"
  16. "net/url"
  17. "os"
  18. "path"
  19. "path/filepath"
  20. "regexp"
  21. "sort"
  22. "strconv"
  23. "strings"
  24. "sync"
  25. "time"
  26. "unicode/utf8"
  27. "github.com/emersion/go-smtp"
  28. "github.com/gorilla/websocket"
  29. "github.com/prometheus/client_golang/prometheus/promhttp"
  30. "golang.org/x/sync/errgroup"
  31. "heckel.io/ntfy/log"
  32. "heckel.io/ntfy/user"
  33. "heckel.io/ntfy/util"
  34. )
  35. // Server is the main server, providing the UI and API for ntfy
  36. type Server struct {
  37. config *Config
  38. httpServer *http.Server
  39. httpsServer *http.Server
  40. httpMetricsServer *http.Server
  41. httpProfileServer *http.Server
  42. unixListener net.Listener
  43. smtpServer *smtp.Server
  44. smtpServerBackend *smtpBackend
  45. smtpSender mailer
  46. topics map[string]*topic
  47. visitors map[string]*visitor // ip:<ip> or user:<user>
  48. firebaseClient *firebaseClient
  49. messages int64 // Total number of messages (persisted if messageCache enabled)
  50. messagesHistory []int64 // Last n values of the messages counter, used to determine rate
  51. userManager *user.Manager // Might be nil!
  52. messageCache *messageCache // Database that stores the messages
  53. webPush *webPushStore // Database that stores web push subscriptions
  54. fileCache *fileCache // File system based cache that stores attachments
  55. stripe stripeAPI // Stripe API, can be replaced with a mock
  56. priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
  57. metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
  58. closeChan chan bool
  59. mu sync.RWMutex
  60. }
  61. // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
  62. type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
  63. var (
  64. // If changed, don't forget to update Android App and auth_sqlite.go
  65. topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
  66. topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
  67. externalTopicPathRegex = regexp.MustCompile(`^/[^/]+\.[^/]+/[-_A-Za-z0-9]{1,64}$`) // Extended topic path, for web-app, e.g. /example.com/mytopic
  68. jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
  69. ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
  70. rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
  71. wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
  72. authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
  73. publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
  74. webConfigPath = "/config.js"
  75. webManifestPath = "/manifest.webmanifest"
  76. webRootHTMLPath = "/app.html"
  77. webServiceWorkerPath = "/sw.js"
  78. accountPath = "/account"
  79. matrixPushPath = "/_matrix/push/v1/notify"
  80. metricsPath = "/metrics"
  81. apiHealthPath = "/v1/health"
  82. apiStatsPath = "/v1/stats"
  83. apiWebPushPath = "/v1/webpush"
  84. apiTiersPath = "/v1/tiers"
  85. apiUsersPath = "/v1/users"
  86. apiUsersAccessPath = "/v1/users/access"
  87. apiAccountPath = "/v1/account"
  88. apiAccountTokenPath = "/v1/account/token"
  89. apiAccountPasswordPath = "/v1/account/password"
  90. apiAccountSettingsPath = "/v1/account/settings"
  91. apiAccountSubscriptionPath = "/v1/account/subscription"
  92. apiAccountReservationPath = "/v1/account/reservation"
  93. apiAccountPhonePath = "/v1/account/phone"
  94. apiAccountPhoneVerifyPath = "/v1/account/phone/verify"
  95. apiAccountBillingPortalPath = "/v1/account/billing/portal"
  96. apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
  97. apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
  98. apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
  99. apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
  100. apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
  101. staticRegex = regexp.MustCompile(`^/static/.+`)
  102. docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
  103. fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
  104. urlRegex = regexp.MustCompile(`^https?://`)
  105. phoneNumberRegex = regexp.MustCompile(`^\+\d{1,100}$`)
  106. //go:embed site
  107. webFs embed.FS
  108. webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
  109. webSiteDir = "/site"
  110. webAppIndex = "/app.html" // React app
  111. //go:embed docs
  112. docsStaticFs embed.FS
  113. docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
  114. )
  115. const (
  116. firebaseControlTopic = "~control" // See Android if changed
  117. firebasePollTopic = "~poll" // See iOS if changed
  118. emptyMessageBody = "triggered" // Used if message body is empty
  119. newMessageBody = "New message" // Used in poll requests as generic message
  120. defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
  121. encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
  122. jsonBodyBytesLimit = 16384 // Max number of bytes for a JSON request body
  123. unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
  124. unifiedPushTopicLength = 14 // Length of UnifiedPush topics, including the "up" part
  125. messagesHistoryMax = 10 // Number of message count values to keep in memory
  126. )
  127. // WebSocket constants
  128. const (
  129. wsWriteWait = 2 * time.Second
  130. wsBufferSize = 1024
  131. wsReadLimit = 64 // We only ever receive PINGs
  132. wsPongWait = 15 * time.Second
  133. )
  134. // New instantiates a new Server. It creates the cache and adds a Firebase
  135. // subscriber (if configured).
  136. func New(conf *Config) (*Server, error) {
  137. var mailer mailer
  138. if conf.SMTPSenderAddr != "" {
  139. mailer = &smtpSender{config: conf}
  140. }
  141. var stripe stripeAPI
  142. if conf.StripeSecretKey != "" {
  143. stripe = newStripeAPI()
  144. }
  145. messageCache, err := createMessageCache(conf)
  146. if err != nil {
  147. return nil, err
  148. }
  149. var webPush *webPushStore
  150. if conf.WebPushPublicKey != "" {
  151. webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
  152. if err != nil {
  153. return nil, err
  154. }
  155. }
  156. topics, err := messageCache.Topics()
  157. if err != nil {
  158. return nil, err
  159. }
  160. messages, err := messageCache.Stats()
  161. if err != nil {
  162. return nil, err
  163. }
  164. var fileCache *fileCache
  165. if conf.AttachmentCacheDir != "" {
  166. fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
  167. if err != nil {
  168. return nil, err
  169. }
  170. }
  171. var userManager *user.Manager
  172. if conf.AuthFile != "" {
  173. userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, conf.AuthStatsQueueWriterInterval)
  174. if err != nil {
  175. return nil, err
  176. }
  177. }
  178. var firebaseClient *firebaseClient
  179. if conf.FirebaseKeyFile != "" {
  180. sender, err := newFirebaseSender(conf.FirebaseKeyFile)
  181. if err != nil {
  182. return nil, err
  183. }
  184. // This awkward logic is required because Go is weird about nil types and interfaces.
  185. // See issue #641, and https://go.dev/play/p/uur1flrv1t3 for an example
  186. var auther user.Auther
  187. if userManager != nil {
  188. auther = userManager
  189. }
  190. firebaseClient = newFirebaseClient(sender, auther)
  191. }
  192. s := &Server{
  193. config: conf,
  194. messageCache: messageCache,
  195. webPush: webPush,
  196. fileCache: fileCache,
  197. firebaseClient: firebaseClient,
  198. smtpSender: mailer,
  199. topics: topics,
  200. userManager: userManager,
  201. messages: messages,
  202. messagesHistory: []int64{messages},
  203. visitors: make(map[string]*visitor),
  204. stripe: stripe,
  205. }
  206. s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
  207. return s, nil
  208. }
  209. func createMessageCache(conf *Config) (*messageCache, error) {
  210. if conf.CacheDuration == 0 {
  211. return newNopCache()
  212. } else if conf.CacheFile != "" {
  213. return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
  214. }
  215. return newMemCache()
  216. }
  217. // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
  218. // a manager go routine to print stats and prune messages.
  219. func (s *Server) Run() error {
  220. var listenStr string
  221. if s.config.ListenHTTP != "" {
  222. listenStr += fmt.Sprintf(" %s[http]", s.config.ListenHTTP)
  223. }
  224. if s.config.ListenHTTPS != "" {
  225. listenStr += fmt.Sprintf(" %s[https]", s.config.ListenHTTPS)
  226. }
  227. if s.config.ListenUnix != "" {
  228. listenStr += fmt.Sprintf(" %s[unix]", s.config.ListenUnix)
  229. }
  230. if s.config.SMTPServerListen != "" {
  231. listenStr += fmt.Sprintf(" %s[smtp]", s.config.SMTPServerListen)
  232. }
  233. if s.config.MetricsListenHTTP != "" {
  234. listenStr += fmt.Sprintf(" %s[http/metrics]", s.config.MetricsListenHTTP)
  235. }
  236. if s.config.ProfileListenHTTP != "" {
  237. listenStr += fmt.Sprintf(" %s[http/profile]", s.config.ProfileListenHTTP)
  238. }
  239. log.Tag(tagStartup).Info("Listening on%s, ntfy %s, log level is %s", listenStr, s.config.Version, log.CurrentLevel().String())
  240. if log.IsFile() {
  241. fmt.Fprintf(os.Stderr, "Listening on%s, ntfy %s\n", listenStr, s.config.Version)
  242. fmt.Fprintf(os.Stderr, "Logs are written to %s\n", log.File())
  243. }
  244. mux := http.NewServeMux()
  245. mux.HandleFunc("/", s.handle)
  246. errChan := make(chan error)
  247. s.mu.Lock()
  248. s.closeChan = make(chan bool)
  249. if s.config.ListenHTTP != "" {
  250. s.httpServer = &http.Server{Addr: s.config.ListenHTTP, Handler: mux}
  251. go func() {
  252. errChan <- s.httpServer.ListenAndServe()
  253. }()
  254. }
  255. if s.config.ListenHTTPS != "" {
  256. s.httpsServer = &http.Server{Addr: s.config.ListenHTTPS, Handler: mux}
  257. go func() {
  258. errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
  259. }()
  260. }
  261. if s.config.ListenUnix != "" {
  262. go func() {
  263. var err error
  264. s.mu.Lock()
  265. os.Remove(s.config.ListenUnix)
  266. s.unixListener, err = net.Listen("unix", s.config.ListenUnix)
  267. if err != nil {
  268. s.mu.Unlock()
  269. errChan <- err
  270. return
  271. }
  272. defer s.unixListener.Close()
  273. if s.config.ListenUnixMode > 0 {
  274. if err := os.Chmod(s.config.ListenUnix, s.config.ListenUnixMode); err != nil {
  275. s.mu.Unlock()
  276. errChan <- err
  277. return
  278. }
  279. }
  280. s.mu.Unlock()
  281. httpServer := &http.Server{Handler: mux}
  282. errChan <- httpServer.Serve(s.unixListener)
  283. }()
  284. }
  285. if s.config.MetricsListenHTTP != "" {
  286. initMetrics()
  287. s.httpMetricsServer = &http.Server{Addr: s.config.MetricsListenHTTP, Handler: promhttp.Handler()}
  288. go func() {
  289. errChan <- s.httpMetricsServer.ListenAndServe()
  290. }()
  291. } else if s.config.EnableMetrics {
  292. initMetrics()
  293. s.metricsHandler = promhttp.Handler()
  294. }
  295. if s.config.ProfileListenHTTP != "" {
  296. profileMux := http.NewServeMux()
  297. profileMux.HandleFunc("/debug/pprof/", pprof.Index)
  298. profileMux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
  299. profileMux.HandleFunc("/debug/pprof/profile", pprof.Profile)
  300. profileMux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
  301. profileMux.HandleFunc("/debug/pprof/trace", pprof.Trace)
  302. s.httpProfileServer = &http.Server{Addr: s.config.ProfileListenHTTP, Handler: profileMux}
  303. go func() {
  304. errChan <- s.httpProfileServer.ListenAndServe()
  305. }()
  306. }
  307. if s.config.SMTPServerListen != "" {
  308. go func() {
  309. errChan <- s.runSMTPServer()
  310. }()
  311. }
  312. s.mu.Unlock()
  313. go s.runManager()
  314. go s.runStatsResetter()
  315. go s.runDelayedSender()
  316. go s.runFirebaseKeepaliver()
  317. return <-errChan
  318. }
  319. // Stop stops HTTP (+HTTPS) server and all managers
  320. func (s *Server) Stop() {
  321. s.mu.Lock()
  322. defer s.mu.Unlock()
  323. if s.httpServer != nil {
  324. s.httpServer.Close()
  325. }
  326. if s.httpsServer != nil {
  327. s.httpsServer.Close()
  328. }
  329. if s.unixListener != nil {
  330. s.unixListener.Close()
  331. }
  332. if s.smtpServer != nil {
  333. s.smtpServer.Close()
  334. }
  335. s.closeDatabases()
  336. close(s.closeChan)
  337. }
  338. func (s *Server) closeDatabases() {
  339. if s.userManager != nil {
  340. s.userManager.Close()
  341. }
  342. s.messageCache.Close()
  343. if s.webPush != nil {
  344. s.webPush.Close()
  345. }
  346. }
  347. // handle is the main entry point for all HTTP requests
  348. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  349. v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
  350. if err != nil {
  351. s.handleError(w, r, v, err)
  352. return
  353. }
  354. ev := logvr(v, r)
  355. if ev.IsTrace() {
  356. ev.Field("http_request", renderHTTPRequest(r)).Trace("HTTP request started")
  357. } else if logvr(v, r).IsDebug() {
  358. ev.Debug("HTTP request started")
  359. }
  360. logvr(v, r).
  361. Timing(func() {
  362. if err := s.handleInternal(w, r, v); err != nil {
  363. s.handleError(w, r, v, err)
  364. return
  365. }
  366. if metricHTTPRequests != nil {
  367. metricHTTPRequests.WithLabelValues("200", "20000", r.Method).Inc()
  368. }
  369. }).
  370. Debug("HTTP request finished")
  371. }
  372. func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor, err error) {
  373. httpErr, ok := err.(*errHTTP)
  374. if !ok {
  375. httpErr = errHTTPInternalError
  376. }
  377. if metricHTTPRequests != nil {
  378. metricHTTPRequests.WithLabelValues(fmt.Sprintf("%d", httpErr.HTTPCode), fmt.Sprintf("%d", httpErr.Code), r.Method).Inc()
  379. }
  380. isRateLimiting := util.Contains(rateLimitingErrorCodes, httpErr.HTTPCode)
  381. isNormalError := strings.Contains(err.Error(), "i/o timeout") || util.Contains(normalErrorCodes, httpErr.HTTPCode)
  382. ev := logvr(v, r).Err(err)
  383. if websocket.IsWebSocketUpgrade(r) {
  384. ev.Tag(tagWebsocket).Fields(websocketErrorContext(err))
  385. if isNormalError {
  386. ev.Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
  387. } else {
  388. ev.Info("WebSocket error: %s", err.Error())
  389. }
  390. return // Do not attempt to write to upgraded connection
  391. }
  392. if isNormalError {
  393. ev.Debug("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
  394. } else {
  395. ev.Info("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
  396. }
  397. if isRateLimiting && s.config.StripeSecretKey != "" {
  398. u := v.User()
  399. if u == nil || u.Tier == nil {
  400. httpErr = httpErr.Wrap("increase your limits with a paid plan, see %s", s.config.BaseURL)
  401. }
  402. }
  403. w.Header().Set("Content-Type", "application/json")
  404. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  405. w.WriteHeader(httpErr.HTTPCode)
  406. io.WriteString(w, httpErr.JSON()+"\n")
  407. }
  408. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
  409. if r.Method == http.MethodGet && r.URL.Path == "/" && s.config.WebRoot == "/" {
  410. return s.ensureWebEnabled(s.handleRoot)(w, r, v)
  411. } else if r.Method == http.MethodHead && r.URL.Path == "/" {
  412. return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
  413. } else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
  414. return s.handleHealth(w, r, v)
  415. } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
  416. return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
  417. } else if r.Method == http.MethodGet && r.URL.Path == webManifestPath {
  418. return s.ensureWebPushEnabled(s.handleWebManifest)(w, r, v)
  419. } else if r.Method == http.MethodGet && r.URL.Path == apiUsersPath {
  420. return s.ensureAdmin(s.handleUsersGet)(w, r, v)
  421. } else if r.Method == http.MethodPut && r.URL.Path == apiUsersPath {
  422. return s.ensureAdmin(s.handleUsersAdd)(w, r, v)
  423. } else if r.Method == http.MethodDelete && r.URL.Path == apiUsersPath {
  424. return s.ensureAdmin(s.handleUsersDelete)(w, r, v)
  425. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == apiUsersAccessPath {
  426. return s.ensureAdmin(s.handleAccessAllow)(w, r, v)
  427. } else if r.Method == http.MethodDelete && r.URL.Path == apiUsersAccessPath {
  428. return s.ensureAdmin(s.handleAccessReset)(w, r, v)
  429. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
  430. return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
  431. } else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
  432. return s.handleAccountGet(w, r, v) // Allowed by anonymous
  433. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath {
  434. return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
  435. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath {
  436. return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
  437. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath {
  438. return s.ensureUser(s.withAccountSync(s.handleAccountTokenCreate))(w, r, v)
  439. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath {
  440. return s.ensureUser(s.withAccountSync(s.handleAccountTokenUpdate))(w, r, v)
  441. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath {
  442. return s.ensureUser(s.withAccountSync(s.handleAccountTokenDelete))(w, r, v)
  443. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath {
  444. return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
  445. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath {
  446. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
  447. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSubscriptionPath {
  448. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
  449. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountSubscriptionPath {
  450. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
  451. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath {
  452. return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
  453. } else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
  454. return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
  455. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
  456. return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
  457. } else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
  458. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
  459. } else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
  460. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
  461. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
  462. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
  463. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
  464. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
  465. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
  466. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
  467. } else if r.Method == http.MethodPut && r.URL.Path == apiAccountPhoneVerifyPath {
  468. return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberVerify)))(w, r, v)
  469. } else if r.Method == http.MethodPut && r.URL.Path == apiAccountPhonePath {
  470. return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberAdd)))(w, r, v)
  471. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPhonePath {
  472. return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberDelete)))(w, r, v)
  473. } else if r.Method == http.MethodPost && apiWebPushPath == r.URL.Path {
  474. return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushUpdate))(w, r, v)
  475. } else if r.Method == http.MethodDelete && apiWebPushPath == r.URL.Path {
  476. return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushDelete))(w, r, v)
  477. } else if r.Method == http.MethodGet && r.URL.Path == apiStatsPath {
  478. return s.handleStats(w, r, v)
  479. } else if r.Method == http.MethodGet && r.URL.Path == apiTiersPath {
  480. return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
  481. } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
  482. return s.handleMatrixDiscovery(w)
  483. } else if r.Method == http.MethodGet && r.URL.Path == metricsPath && s.metricsHandler != nil {
  484. return s.handleMetrics(w, r, v)
  485. } else if r.Method == http.MethodGet && (staticRegex.MatchString(r.URL.Path) || r.URL.Path == webServiceWorkerPath || r.URL.Path == webRootHTMLPath) {
  486. return s.ensureWebEnabled(s.handleStatic)(w, r, v)
  487. } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
  488. return s.ensureWebEnabled(s.handleDocs)(w, r, v)
  489. } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
  490. return s.limitRequests(s.handleFile)(w, r, v)
  491. } else if r.Method == http.MethodOptions {
  492. return s.limitRequests(s.handleOptions)(w, r, v) // Should work even if the web app is not enabled, see #598
  493. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
  494. return s.transformBodyJSON(s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
  495. } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
  496. return s.transformMatrixJSON(s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
  497. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
  498. return s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  499. } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
  500. return s.limitRequestsWithTopic(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  501. } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
  502. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
  503. } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
  504. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
  505. } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
  506. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
  507. } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
  508. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
  509. } else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
  510. return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
  511. } else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
  512. return s.ensureWebEnabled(s.handleTopic)(w, r, v)
  513. }
  514. return errHTTPNotFound
  515. }
  516. func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request, v *visitor) error {
  517. r.URL.Path = webAppIndex
  518. return s.handleStatic(w, r, v)
  519. }
  520. func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor) error {
  521. unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
  522. if unifiedpush {
  523. w.Header().Set("Content-Type", "application/json")
  524. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  525. _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
  526. return err
  527. }
  528. r.URL.Path = webAppIndex
  529. return s.handleStatic(w, r, v)
  530. }
  531. func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error {
  532. return nil
  533. }
  534. func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  535. return s.writeJSON(w, newSuccessResponse())
  536. }
  537. func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  538. response := &apiHealthResponse{
  539. Healthy: true,
  540. }
  541. return s.writeJSON(w, response)
  542. }
  543. func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  544. response := &apiConfigResponse{
  545. BaseURL: "", // Will translate to window.location.origin
  546. AppRoot: s.config.WebRoot,
  547. EnableLogin: s.config.EnableLogin,
  548. EnableSignup: s.config.EnableSignup,
  549. EnablePayments: s.config.StripeSecretKey != "",
  550. EnableCalls: s.config.TwilioAccount != "",
  551. EnableEmails: s.config.SMTPSenderFrom != "",
  552. EnableReservations: s.config.EnableReservations,
  553. EnableWebPush: s.config.WebPushPublicKey != "",
  554. BillingContact: s.config.BillingContact,
  555. WebPushPublicKey: s.config.WebPushPublicKey,
  556. DisallowedTopics: s.config.DisallowedTopics,
  557. }
  558. b, err := json.MarshalIndent(response, "", " ")
  559. if err != nil {
  560. return err
  561. }
  562. w.Header().Set("Content-Type", "text/javascript")
  563. _, err = io.WriteString(w, fmt.Sprintf("// Generated server configuration\nvar config = %s;\n", string(b)))
  564. return err
  565. }
  566. // handleWebManifest serves the web app manifest for the progressive web app (PWA)
  567. func (s *Server) handleWebManifest(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  568. response := &webManifestResponse{
  569. Name: "ntfy web",
  570. Description: "ntfy lets you send push notifications via scripts from any computer or phone",
  571. ShortName: "ntfy",
  572. Scope: "/",
  573. StartURL: s.config.WebRoot,
  574. Display: "standalone",
  575. BackgroundColor: "#ffffff",
  576. ThemeColor: "#317f6f",
  577. Icons: []*webManifestIcon{
  578. {SRC: "/static/images/pwa-192x192.png", Sizes: "192x192", Type: "image/png"},
  579. {SRC: "/static/images/pwa-512x512.png", Sizes: "512x512", Type: "image/png"},
  580. },
  581. }
  582. return s.writeJSONWithContentType(w, response, "application/manifest+json")
  583. }
  584. // handleMetrics returns Prometheus metrics. This endpoint is only called if enable-metrics is set,
  585. // and listen-metrics-http is not set.
  586. func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  587. s.metricsHandler.ServeHTTP(w, r)
  588. return nil
  589. }
  590. // handleStatic returns all static resources (excluding the docs), including the web app
  591. func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  592. r.URL.Path = webSiteDir + r.URL.Path
  593. util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
  594. return nil
  595. }
  596. // handleDocs returns static resources related to the docs
  597. func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  598. util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
  599. return nil
  600. }
  601. // handleStats returns the publicly available server stats
  602. func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  603. s.mu.RLock()
  604. messages, n, rate := s.messages, len(s.messagesHistory), float64(0)
  605. if n > 1 {
  606. rate = float64(s.messagesHistory[n-1]-s.messagesHistory[0]) / (float64(n-1) * s.config.ManagerInterval.Seconds())
  607. }
  608. s.mu.RUnlock()
  609. response := &apiStatsResponse{
  610. Messages: messages,
  611. MessagesRate: rate,
  612. }
  613. return s.writeJSON(w, response)
  614. }
  615. // handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file.
  616. // Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
  617. // can associate the download bandwidth with the uploader.
  618. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
  619. if s.config.AttachmentCacheDir == "" {
  620. return errHTTPInternalError
  621. }
  622. matches := fileRegex.FindStringSubmatch(r.URL.Path)
  623. if len(matches) != 2 {
  624. return errHTTPInternalErrorInvalidPath
  625. }
  626. messageID := matches[1]
  627. file := filepath.Join(s.config.AttachmentCacheDir, messageID)
  628. stat, err := os.Stat(file)
  629. if err != nil {
  630. return errHTTPNotFound.Fields(log.Context{
  631. "message_id": messageID,
  632. "error_context": "filesystem",
  633. })
  634. }
  635. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  636. w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
  637. if r.Method == http.MethodHead {
  638. return nil
  639. }
  640. // Find message in database, and associate bandwidth to the uploader user
  641. // This is an easy way to
  642. // - avoid abuse (e.g. 1 uploader, 1k downloaders)
  643. // - and also uses the higher bandwidth limits of a paying user
  644. m, err := s.messageCache.Message(messageID)
  645. if err == errMessageNotFound {
  646. if s.config.CacheBatchTimeout > 0 {
  647. // Strange edge case: If we immediately after upload request the file (the web app does this for images),
  648. // and messages are persisted asynchronously, retry fetching from the database
  649. m, err = util.Retry(func() (*message, error) {
  650. return s.messageCache.Message(messageID)
  651. }, s.config.CacheBatchTimeout, 100*time.Millisecond, 300*time.Millisecond, 600*time.Millisecond)
  652. }
  653. if err != nil {
  654. return errHTTPNotFound.Fields(log.Context{
  655. "message_id": messageID,
  656. "error_context": "message_cache",
  657. })
  658. }
  659. } else if err != nil {
  660. return err
  661. }
  662. bandwidthVisitor := v
  663. if s.userManager != nil && m.User != "" {
  664. u, err := s.userManager.UserByID(m.User)
  665. if err != nil {
  666. return err
  667. }
  668. bandwidthVisitor = s.visitor(v.IP(), u)
  669. } else if m.Sender.IsValid() {
  670. bandwidthVisitor = s.visitor(m.Sender, nil)
  671. }
  672. if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
  673. return errHTTPTooManyRequestsLimitAttachmentBandwidth.With(m)
  674. }
  675. // Actually send file
  676. f, err := os.Open(file)
  677. if err != nil {
  678. return err
  679. }
  680. defer f.Close()
  681. if m.Attachment.Name != "" {
  682. w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name))
  683. }
  684. _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
  685. return err
  686. }
  687. func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
  688. if s.config.BaseURL == "" {
  689. return errHTTPInternalErrorMissingBaseURL
  690. }
  691. return writeMatrixDiscoveryResponse(w)
  692. }
  693. func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
  694. start := time.Now()
  695. t, err := fromContext[*topic](r, contextTopic)
  696. if err != nil {
  697. return nil, err
  698. }
  699. vrate, err := fromContext[*visitor](r, contextRateVisitor)
  700. if err != nil {
  701. return nil, err
  702. }
  703. body, err := util.Peek(r.Body, s.config.MessageLimit)
  704. if err != nil {
  705. return nil, err
  706. }
  707. m := newDefaultMessage(t.ID, "")
  708. cache, firebase, email, call, unifiedpush, e := s.parsePublishParams(r, m)
  709. if e != nil {
  710. return nil, e.With(t)
  711. }
  712. if unifiedpush && s.config.VisitorSubscriberRateLimiting && t.RateVisitor() == nil {
  713. // UnifiedPush clients must subscribe before publishing to allow proper subscriber-based rate limiting (see
  714. // Rate-Topics header). The 5xx response is because some app servers (in particular Mastodon) will remove
  715. // the subscription as invalid if any 400-499 code (except 429/408) is returned.
  716. // See https://github.com/mastodon/mastodon/blob/730bb3e211a84a2f30e3e2bbeae3f77149824a68/app/workers/web/push_notification_worker.rb#L35-L46
  717. return nil, errHTTPInsufficientStorageUnifiedPush.With(t)
  718. } else if !util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) && !vrate.MessageAllowed() {
  719. return nil, errHTTPTooManyRequestsLimitMessages.With(t)
  720. } else if email != "" && !vrate.EmailAllowed() {
  721. return nil, errHTTPTooManyRequestsLimitEmails.With(t)
  722. } else if call != "" {
  723. var httpErr *errHTTP
  724. call, httpErr = s.convertPhoneNumber(v.User(), call)
  725. if httpErr != nil {
  726. return nil, httpErr.With(t)
  727. } else if !vrate.CallAllowed() {
  728. return nil, errHTTPTooManyRequestsLimitCalls.With(t)
  729. }
  730. }
  731. if m.PollID != "" {
  732. m = newPollRequestMessage(t.ID, m.PollID)
  733. }
  734. m.Sender = v.IP()
  735. m.User = v.MaybeUserID()
  736. if cache {
  737. m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
  738. }
  739. if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
  740. return nil, err
  741. }
  742. if m.Message == "" {
  743. m.Message = emptyMessageBody
  744. }
  745. delayed := m.Time > time.Now().Unix()
  746. ev := logvrm(v, r, m).
  747. Tag(tagPublish).
  748. With(t).
  749. Fields(log.Context{
  750. "message_delayed": delayed,
  751. "message_firebase": firebase,
  752. "message_unifiedpush": unifiedpush,
  753. "message_email": email,
  754. "message_call": call,
  755. })
  756. if ev.IsTrace() {
  757. ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message")
  758. } else if ev.IsDebug() {
  759. ev.Debug("Received message")
  760. }
  761. if !delayed {
  762. if err := t.Publish(v, m); err != nil {
  763. return nil, err
  764. }
  765. if s.firebaseClient != nil && firebase {
  766. go s.sendToFirebase(v, m)
  767. }
  768. if s.smtpSender != nil && email != "" {
  769. go s.sendEmail(v, m, email)
  770. }
  771. if s.config.TwilioAccount != "" && call != "" {
  772. go s.callPhone(v, r, m, call)
  773. }
  774. if s.config.UpstreamBaseURL != "" && !unifiedpush { // UP messages are not sent to upstream
  775. go s.forwardPollRequest(v, m)
  776. }
  777. if s.config.WebPushPublicKey != "" {
  778. go s.publishToWebPushEndpoints(v, m)
  779. }
  780. } else {
  781. logvrm(v, r, m).Tag(tagPublish).Debug("Message delayed, will process later")
  782. }
  783. if cache {
  784. logvrm(v, r, m).Tag(tagPublish).Debug("Adding message to cache")
  785. if err := s.messageCache.AddMessage(m); err != nil {
  786. return nil, err
  787. }
  788. }
  789. u := v.User()
  790. if s.userManager != nil && u != nil && u.Tier != nil {
  791. go s.userManager.EnqueueUserStats(u.ID, v.Stats())
  792. }
  793. s.mu.Lock()
  794. s.messages++
  795. s.mu.Unlock()
  796. if unifiedpush {
  797. minc(metricUnifiedPushPublishedSuccess)
  798. }
  799. mset(metricMessagePublishDurationMillis, time.Since(start).Milliseconds())
  800. return m, nil
  801. }
  802. func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
  803. m, err := s.handlePublishInternal(r, v)
  804. if err != nil {
  805. minc(metricMessagesPublishedFailure)
  806. return err
  807. }
  808. minc(metricMessagesPublishedSuccess)
  809. return s.writeJSON(w, m)
  810. }
  811. func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
  812. _, err := s.handlePublishInternal(r, v)
  813. if err != nil {
  814. minc(metricMessagesPublishedFailure)
  815. minc(metricMatrixPublishedFailure)
  816. if e, ok := err.(*errHTTP); ok && e.HTTPCode == errHTTPInsufficientStorageUnifiedPush.HTTPCode {
  817. topic, err := fromContext[*topic](r, contextTopic)
  818. if err != nil {
  819. return err
  820. }
  821. pushKey, err := fromContext[string](r, contextMatrixPushKey)
  822. if err != nil {
  823. return err
  824. }
  825. if time.Since(topic.LastAccess()) > matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter {
  826. return writeMatrixResponse(w, pushKey)
  827. }
  828. }
  829. return err
  830. }
  831. minc(metricMessagesPublishedSuccess)
  832. minc(metricMatrixPublishedSuccess)
  833. return writeMatrixSuccess(w)
  834. }
  835. func (s *Server) sendToFirebase(v *visitor, m *message) {
  836. logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase")
  837. if err := s.firebaseClient.Send(v, m); err != nil {
  838. minc(metricFirebasePublishedFailure)
  839. if err == errFirebaseTemporarilyBanned {
  840. logvm(v, m).Tag(tagFirebase).Err(err).Debug("Unable to publish to Firebase: %v", err.Error())
  841. } else {
  842. logvm(v, m).Tag(tagFirebase).Err(err).Warn("Unable to publish to Firebase: %v", err.Error())
  843. }
  844. return
  845. }
  846. minc(metricFirebasePublishedSuccess)
  847. }
  848. func (s *Server) sendEmail(v *visitor, m *message, email string) {
  849. logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email)
  850. if err := s.smtpSender.Send(v, m, email); err != nil {
  851. logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error())
  852. minc(metricEmailsPublishedFailure)
  853. return
  854. }
  855. minc(metricEmailsPublishedSuccess)
  856. }
  857. func (s *Server) forwardPollRequest(v *visitor, m *message) {
  858. topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
  859. topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
  860. forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
  861. logvm(v, m).Debug("Publishing poll request to %s", forwardURL)
  862. req, err := http.NewRequest("POST", forwardURL, strings.NewReader(""))
  863. if err != nil {
  864. logvm(v, m).Err(err).Warn("Unable to publish poll request")
  865. return
  866. }
  867. req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
  868. req.Header.Set("X-Poll-ID", m.ID)
  869. if s.config.UpstreamAccessToken != "" {
  870. req.Header.Set("Authorization", util.BearerAuth(s.config.UpstreamAccessToken))
  871. }
  872. var httpClient = &http.Client{
  873. Timeout: time.Second * 10,
  874. }
  875. response, err := httpClient.Do(req)
  876. if err != nil {
  877. logvm(v, m).Err(err).Warn("Unable to publish poll request")
  878. return
  879. } else if response.StatusCode != http.StatusOK {
  880. if response.StatusCode == http.StatusTooManyRequests {
  881. logvm(v, m).Err(err).Warn("Unable to publish poll request, the upstream server %s responded with HTTP %s; you may solve this by sending fewer daily messages, or by configuring upstream-access-token (assuming you have an account with higher rate limits) ", s.config.UpstreamBaseURL, response.Status)
  882. } else {
  883. logvm(v, m).Err(err).Warn("Unable to publish poll request, the upstream server %s responded with HTTP %s", s.config.UpstreamBaseURL, response.Status)
  884. }
  885. return
  886. }
  887. }
  888. func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, unifiedpush bool, err *errHTTP) {
  889. cache = readBoolParam(r, true, "x-cache", "cache")
  890. firebase = readBoolParam(r, true, "x-firebase", "firebase")
  891. m.Title = readParam(r, "x-title", "title", "t")
  892. m.Click = readParam(r, "x-click", "click")
  893. icon := readParam(r, "x-icon", "icon")
  894. filename := readParam(r, "x-filename", "filename", "file", "f")
  895. attach := readParam(r, "x-attach", "attach", "a")
  896. if attach != "" || filename != "" {
  897. m.Attachment = &attachment{}
  898. }
  899. if filename != "" {
  900. m.Attachment.Name = filename
  901. }
  902. if attach != "" {
  903. if !urlRegex.MatchString(attach) {
  904. return false, false, "", "", false, errHTTPBadRequestAttachmentURLInvalid
  905. }
  906. m.Attachment.URL = attach
  907. if m.Attachment.Name == "" {
  908. u, err := url.Parse(m.Attachment.URL)
  909. if err == nil {
  910. m.Attachment.Name = path.Base(u.Path)
  911. if m.Attachment.Name == "." || m.Attachment.Name == "/" {
  912. m.Attachment.Name = ""
  913. }
  914. }
  915. }
  916. if m.Attachment.Name == "" {
  917. m.Attachment.Name = "attachment"
  918. }
  919. }
  920. if icon != "" {
  921. if !urlRegex.MatchString(icon) {
  922. return false, false, "", "", false, errHTTPBadRequestIconURLInvalid
  923. }
  924. m.Icon = icon
  925. }
  926. email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
  927. if s.smtpSender == nil && email != "" {
  928. return false, false, "", "", false, errHTTPBadRequestEmailDisabled
  929. }
  930. call = readParam(r, "x-call", "call")
  931. if call != "" && (s.config.TwilioAccount == "" || s.userManager == nil) {
  932. return false, false, "", "", false, errHTTPBadRequestPhoneCallsDisabled
  933. } else if call != "" && !isBoolValue(call) && !phoneNumberRegex.MatchString(call) {
  934. return false, false, "", "", false, errHTTPBadRequestPhoneNumberInvalid
  935. }
  936. messageStr := strings.ReplaceAll(readParam(r, "x-message", "message", "m"), "\\n", "\n")
  937. if messageStr != "" {
  938. m.Message = messageStr
  939. }
  940. var e error
  941. m.Priority, e = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
  942. if e != nil {
  943. return false, false, "", "", false, errHTTPBadRequestPriorityInvalid
  944. }
  945. m.Tags = readCommaSeparatedParam(r, "x-tags", "tags", "tag", "ta")
  946. delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
  947. if delayStr != "" {
  948. if !cache {
  949. return false, false, "", "", false, errHTTPBadRequestDelayNoCache
  950. }
  951. if email != "" {
  952. return false, false, "", "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
  953. }
  954. if call != "" {
  955. return false, false, "", "", false, errHTTPBadRequestDelayNoCall // we cannot store the phone number (yet)
  956. }
  957. delay, err := util.ParseFutureTime(delayStr, time.Now())
  958. if err != nil {
  959. return false, false, "", "", false, errHTTPBadRequestDelayCannotParse
  960. } else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
  961. return false, false, "", "", false, errHTTPBadRequestDelayTooSmall
  962. } else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
  963. return false, false, "", "", false, errHTTPBadRequestDelayTooLarge
  964. }
  965. m.Time = delay.Unix()
  966. }
  967. actionsStr := readParam(r, "x-actions", "actions", "action")
  968. if actionsStr != "" {
  969. m.Actions, e = parseActions(actionsStr)
  970. if e != nil {
  971. return false, false, "", "", false, errHTTPBadRequestActionsInvalid.Wrap(e.Error())
  972. }
  973. }
  974. contentType, markdown := readParam(r, "content-type", "content_type"), readBoolParam(r, false, "x-markdown", "markdown", "md")
  975. if markdown || strings.ToLower(contentType) == "text/markdown" {
  976. m.ContentType = "text/markdown"
  977. }
  978. unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
  979. if unifiedpush {
  980. firebase = false
  981. unifiedpush = true
  982. }
  983. m.PollID = readParam(r, "x-poll-id", "poll-id")
  984. if m.PollID != "" {
  985. unifiedpush = false
  986. cache = false
  987. email = ""
  988. }
  989. return cache, firebase, email, call, unifiedpush, nil
  990. }
  991. // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
  992. //
  993. // 1. curl -X POST -H "Poll: 1234" ntfy.sh/...
  994. // If a message is flagged as poll request, the body does not matter and is discarded
  995. // 2. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1"
  996. // If body is binary, encode as base64, if not do not encode
  997. // 3. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
  998. // Body must be a message, because we attached an external URL
  999. // 4. curl -T short.txt -H "Filename: short.txt" ntfy.sh/mytopic
  1000. // Body must be attachment, because we passed a filename
  1001. // 5. curl -T file.txt ntfy.sh/mytopic
  1002. // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
  1003. // 6. curl -T file.txt ntfy.sh/mytopic
  1004. // If file.txt is > message limit, treat it as an attachment
  1005. func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
  1006. if m.Event == pollRequestEvent { // Case 1
  1007. return s.handleBodyDiscard(body)
  1008. } else if unifiedpush {
  1009. return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
  1010. } else if m.Attachment != nil && m.Attachment.URL != "" {
  1011. return s.handleBodyAsTextMessage(m, body) // Case 3
  1012. } else if m.Attachment != nil && m.Attachment.Name != "" {
  1013. return s.handleBodyAsAttachment(r, v, m, body) // Case 4
  1014. } else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
  1015. return s.handleBodyAsTextMessage(m, body) // Case 5
  1016. }
  1017. return s.handleBodyAsAttachment(r, v, m, body) // Case 6
  1018. }
  1019. func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
  1020. _, err := io.Copy(io.Discard, body)
  1021. _ = body.Close()
  1022. return err
  1023. }
  1024. func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
  1025. if utf8.Valid(body.PeekedBytes) {
  1026. m.Message = string(body.PeekedBytes) // Do not trim
  1027. } else {
  1028. m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes)
  1029. m.Encoding = encodingBase64
  1030. }
  1031. return nil
  1032. }
  1033. func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
  1034. if !utf8.Valid(body.PeekedBytes) {
  1035. return errHTTPBadRequestMessageNotUTF8.With(m)
  1036. }
  1037. if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
  1038. m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
  1039. }
  1040. if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
  1041. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  1042. }
  1043. return nil
  1044. }
  1045. func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
  1046. if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
  1047. return errHTTPBadRequestAttachmentsDisallowed.With(m)
  1048. }
  1049. vinfo, err := v.Info()
  1050. if err != nil {
  1051. return err
  1052. }
  1053. attachmentExpiry := time.Now().Add(vinfo.Limits.AttachmentExpiryDuration).Unix()
  1054. if m.Time > attachmentExpiry {
  1055. return errHTTPBadRequestAttachmentsExpiryBeforeDelivery.With(m)
  1056. }
  1057. contentLengthStr := r.Header.Get("Content-Length")
  1058. if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
  1059. contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
  1060. if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) {
  1061. return errHTTPEntityTooLargeAttachment.With(m).Fields(log.Context{
  1062. "message_content_length": contentLength,
  1063. "attachment_total_size_remaining": vinfo.Stats.AttachmentTotalSizeRemaining,
  1064. "attachment_file_size_limit": vinfo.Limits.AttachmentFileSizeLimit,
  1065. })
  1066. }
  1067. }
  1068. if m.Attachment == nil {
  1069. m.Attachment = &attachment{}
  1070. }
  1071. var ext string
  1072. m.Attachment.Expires = attachmentExpiry
  1073. m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
  1074. m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
  1075. if m.Attachment.Name == "" {
  1076. m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
  1077. }
  1078. if m.Message == "" {
  1079. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  1080. }
  1081. limiters := []util.Limiter{
  1082. v.BandwidthLimiter(),
  1083. util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
  1084. util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
  1085. }
  1086. m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
  1087. if err == util.ErrLimitReached {
  1088. return errHTTPEntityTooLargeAttachment.With(m)
  1089. } else if err != nil {
  1090. return err
  1091. }
  1092. return nil
  1093. }
  1094. func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1095. encoder := func(msg *message) (string, error) {
  1096. var buf bytes.Buffer
  1097. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  1098. return "", err
  1099. }
  1100. return buf.String(), nil
  1101. }
  1102. return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder)
  1103. }
  1104. func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1105. encoder := func(msg *message) (string, error) {
  1106. var buf bytes.Buffer
  1107. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  1108. return "", err
  1109. }
  1110. if msg.Event != messageEvent {
  1111. return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
  1112. }
  1113. return fmt.Sprintf("data: %s\n", buf.String()), nil
  1114. }
  1115. return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder)
  1116. }
  1117. func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1118. encoder := func(msg *message) (string, error) {
  1119. if msg.Event == messageEvent { // only handle default events
  1120. return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
  1121. }
  1122. return "\n", nil // "keepalive" and "open" events just send an empty line
  1123. }
  1124. return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder)
  1125. }
  1126. func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
  1127. logvr(v, r).Tag(tagSubscribe).Debug("HTTP stream connection opened")
  1128. defer logvr(v, r).Tag(tagSubscribe).Debug("HTTP stream connection closed")
  1129. if !v.SubscriptionAllowed() {
  1130. return errHTTPTooManyRequestsLimitSubscriptions
  1131. }
  1132. defer v.RemoveSubscription()
  1133. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  1134. if err != nil {
  1135. return err
  1136. }
  1137. poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r)
  1138. if err != nil {
  1139. return err
  1140. }
  1141. var wlock sync.Mutex
  1142. defer func() {
  1143. // Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time.
  1144. // It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER
  1145. // this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the
  1146. // data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889.
  1147. wlock.TryLock()
  1148. }()
  1149. sub := func(v *visitor, msg *message) error {
  1150. if !filters.Pass(msg) {
  1151. return nil
  1152. }
  1153. m, err := encoder(msg)
  1154. if err != nil {
  1155. return err
  1156. }
  1157. wlock.Lock()
  1158. defer wlock.Unlock()
  1159. if _, err := w.Write([]byte(m)); err != nil {
  1160. return err
  1161. }
  1162. if fl, ok := w.(http.Flusher); ok {
  1163. fl.Flush()
  1164. }
  1165. return nil
  1166. }
  1167. if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
  1168. return err
  1169. }
  1170. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1171. w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
  1172. if poll {
  1173. for _, t := range topics {
  1174. t.Keepalive()
  1175. }
  1176. return s.sendOldMessages(topics, since, scheduled, v, sub)
  1177. }
  1178. ctx, cancel := context.WithCancel(context.Background())
  1179. defer cancel()
  1180. subscriberIDs := make([]int, 0)
  1181. for _, t := range topics {
  1182. subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
  1183. }
  1184. defer func() {
  1185. for i, subscriberID := range subscriberIDs {
  1186. topics[i].Unsubscribe(subscriberID) // Order!
  1187. }
  1188. }()
  1189. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  1190. return err
  1191. }
  1192. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  1193. return err
  1194. }
  1195. for {
  1196. select {
  1197. case <-ctx.Done():
  1198. return nil
  1199. case <-r.Context().Done():
  1200. return nil
  1201. case <-time.After(s.config.KeepaliveInterval):
  1202. ev := logvr(v, r).Tag(tagSubscribe)
  1203. if len(topics) == 1 {
  1204. ev.With(topics[0]).Trace("Sending keepalive message to %s", topics[0].ID)
  1205. } else {
  1206. ev.Trace("Sending keepalive message to %d topics", len(topics))
  1207. }
  1208. v.Keepalive()
  1209. for _, t := range topics {
  1210. t.Keepalive()
  1211. }
  1212. if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
  1213. return err
  1214. }
  1215. }
  1216. }
  1217. }
  1218. func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1219. if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
  1220. return errHTTPBadRequestWebSocketsUpgradeHeaderMissing
  1221. }
  1222. if !v.SubscriptionAllowed() {
  1223. return errHTTPTooManyRequestsLimitSubscriptions
  1224. }
  1225. defer v.RemoveSubscription()
  1226. logvr(v, r).Tag(tagWebsocket).Debug("WebSocket connection opened")
  1227. defer logvr(v, r).Tag(tagWebsocket).Debug("WebSocket connection closed")
  1228. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  1229. if err != nil {
  1230. return err
  1231. }
  1232. poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r)
  1233. if err != nil {
  1234. return err
  1235. }
  1236. upgrader := &websocket.Upgrader{
  1237. ReadBufferSize: wsBufferSize,
  1238. WriteBufferSize: wsBufferSize,
  1239. CheckOrigin: func(r *http.Request) bool {
  1240. return true // We're open for business!
  1241. },
  1242. }
  1243. conn, err := upgrader.Upgrade(w, r, nil)
  1244. if err != nil {
  1245. return err
  1246. }
  1247. defer conn.Close()
  1248. // Subscription connections can be canceled externally, see topic.CancelSubscribersExceptUser
  1249. cancelCtx, cancel := context.WithCancel(context.Background())
  1250. defer cancel()
  1251. // Use errgroup to run WebSocket reader and writer in Go routines
  1252. var wlock sync.Mutex
  1253. g, gctx := errgroup.WithContext(cancelCtx)
  1254. g.Go(func() error {
  1255. pongWait := s.config.KeepaliveInterval + wsPongWait
  1256. conn.SetReadLimit(wsReadLimit)
  1257. if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
  1258. return err
  1259. }
  1260. conn.SetPongHandler(func(appData string) error {
  1261. logvr(v, r).Tag(tagWebsocket).Trace("Received WebSocket pong")
  1262. return conn.SetReadDeadline(time.Now().Add(pongWait))
  1263. })
  1264. for {
  1265. _, _, err := conn.NextReader()
  1266. if err != nil {
  1267. return err
  1268. }
  1269. select {
  1270. case <-gctx.Done():
  1271. return nil
  1272. default:
  1273. }
  1274. }
  1275. })
  1276. g.Go(func() error {
  1277. ping := func() error {
  1278. wlock.Lock()
  1279. defer wlock.Unlock()
  1280. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1281. return err
  1282. }
  1283. logvr(v, r).Tag(tagWebsocket).Trace("Sending WebSocket ping")
  1284. return conn.WriteMessage(websocket.PingMessage, nil)
  1285. }
  1286. for {
  1287. select {
  1288. case <-gctx.Done():
  1289. return nil
  1290. case <-cancelCtx.Done():
  1291. logvr(v, r).Tag(tagWebsocket).Trace("Cancel received, closing subscriber connection")
  1292. conn.Close()
  1293. return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"}
  1294. case <-time.After(s.config.KeepaliveInterval):
  1295. v.Keepalive()
  1296. for _, t := range topics {
  1297. t.Keepalive()
  1298. }
  1299. if err := ping(); err != nil {
  1300. return err
  1301. }
  1302. }
  1303. }
  1304. })
  1305. sub := func(v *visitor, msg *message) error {
  1306. if !filters.Pass(msg) {
  1307. return nil
  1308. }
  1309. wlock.Lock()
  1310. defer wlock.Unlock()
  1311. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1312. return err
  1313. }
  1314. return conn.WriteJSON(msg)
  1315. }
  1316. if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
  1317. return err
  1318. }
  1319. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1320. if poll {
  1321. for _, t := range topics {
  1322. t.Keepalive()
  1323. }
  1324. return s.sendOldMessages(topics, since, scheduled, v, sub)
  1325. }
  1326. subscriberIDs := make([]int, 0)
  1327. for _, t := range topics {
  1328. subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
  1329. }
  1330. defer func() {
  1331. for i, subscriberID := range subscriberIDs {
  1332. topics[i].Unsubscribe(subscriberID) // Order!
  1333. }
  1334. }()
  1335. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  1336. return err
  1337. }
  1338. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  1339. return err
  1340. }
  1341. err = g.Wait()
  1342. if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
  1343. logvr(v, r).Tag(tagWebsocket).Err(err).Fields(websocketErrorContext(err)).Trace("WebSocket connection closed")
  1344. return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
  1345. }
  1346. return err
  1347. }
  1348. func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, rateTopics []string, err error) {
  1349. poll = readBoolParam(r, false, "x-poll", "poll", "po")
  1350. scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
  1351. since, err = parseSince(r, poll)
  1352. if err != nil {
  1353. return
  1354. }
  1355. filters, err = parseQueryFilters(r)
  1356. if err != nil {
  1357. return
  1358. }
  1359. rateTopics = readCommaSeparatedParam(r, "x-rate-topics", "rate-topics")
  1360. return
  1361. }
  1362. // maybeSetRateVisitors sets the rate visitor on a topic (v.SetRateVisitor), indicating that all messages published
  1363. // to that topic will be rate limited against the rate visitor instead of the publishing visitor.
  1364. //
  1365. // Setting the rate visitor is ony allowed if the `visitor-subscriber-rate-limiting` setting is enabled, AND
  1366. // - auth-file is not set (everything is open by default)
  1367. // - or the topic is reserved, and v.user is the owner
  1368. // - or the topic is not reserved, and v.user has write access
  1369. //
  1370. // Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
  1371. // until the Android app will send the "Rate-Topics" header.
  1372. func (s *Server) maybeSetRateVisitors(r *http.Request, v *visitor, topics []*topic, rateTopics []string) error {
  1373. // Bail out if not enabled
  1374. if !s.config.VisitorSubscriberRateLimiting {
  1375. return nil
  1376. }
  1377. // Make a list of topics that we'll actually set the RateVisitor on
  1378. eligibleRateTopics := make([]*topic, 0)
  1379. for _, t := range topics {
  1380. if (strings.HasPrefix(t.ID, unifiedPushTopicPrefix) && len(t.ID) == unifiedPushTopicLength) || util.Contains(rateTopics, t.ID) {
  1381. eligibleRateTopics = append(eligibleRateTopics, t)
  1382. }
  1383. }
  1384. if len(eligibleRateTopics) == 0 {
  1385. return nil
  1386. }
  1387. // If access controls are turned off, v has access to everything, and we can set the rate visitor
  1388. if s.userManager == nil {
  1389. return s.setRateVisitors(r, v, eligibleRateTopics)
  1390. }
  1391. // If access controls are enabled, only set rate visitor if
  1392. // - topic is reserved, and v.user is the owner
  1393. // - topic is not reserved, and v.user has write access
  1394. writableRateTopics := make([]*topic, 0)
  1395. for _, t := range topics {
  1396. ownerUserID, err := s.userManager.ReservationOwner(t.ID)
  1397. if err != nil {
  1398. return err
  1399. }
  1400. if ownerUserID == "" {
  1401. if err := s.userManager.Authorize(v.User(), t.ID, user.PermissionWrite); err == nil {
  1402. writableRateTopics = append(writableRateTopics, t)
  1403. }
  1404. } else if ownerUserID == v.MaybeUserID() {
  1405. writableRateTopics = append(writableRateTopics, t)
  1406. }
  1407. }
  1408. return s.setRateVisitors(r, v, writableRateTopics)
  1409. }
  1410. func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topic) error {
  1411. for _, t := range rateTopics {
  1412. logvr(v, r).
  1413. Tag(tagSubscribe).
  1414. With(t).
  1415. Debug("Setting visitor as rate visitor for topic %s", t.ID)
  1416. t.SetRateVisitor(v)
  1417. }
  1418. return nil
  1419. }
  1420. // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
  1421. // marker, returning only messages that are newer than the marker.
  1422. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
  1423. if since.IsNone() {
  1424. return nil
  1425. }
  1426. messages := make([]*message, 0)
  1427. for _, t := range topics {
  1428. topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
  1429. if err != nil {
  1430. return err
  1431. }
  1432. messages = append(messages, topicMessages...)
  1433. }
  1434. sort.Slice(messages, func(i, j int) bool {
  1435. return messages[i].Time < messages[j].Time
  1436. })
  1437. for _, m := range messages {
  1438. if err := sub(v, m); err != nil {
  1439. return err
  1440. }
  1441. }
  1442. return nil
  1443. }
  1444. // parseSince returns a timestamp identifying the time span from which cached messages should be received.
  1445. //
  1446. // Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h), or
  1447. // "all" for all messages.
  1448. func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
  1449. since := readParam(r, "x-since", "since", "si")
  1450. // Easy cases (empty, all, none)
  1451. if since == "" {
  1452. if poll {
  1453. return sinceAllMessages, nil
  1454. }
  1455. return sinceNoMessages, nil
  1456. } else if since == "all" {
  1457. return sinceAllMessages, nil
  1458. } else if since == "none" {
  1459. return sinceNoMessages, nil
  1460. }
  1461. // ID, timestamp, duration
  1462. if validMessageID(since) {
  1463. return newSinceID(since), nil
  1464. } else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
  1465. return newSinceTime(s), nil
  1466. } else if d, err := time.ParseDuration(since); err == nil {
  1467. return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
  1468. }
  1469. return sinceNoMessages, errHTTPBadRequestSinceInvalid
  1470. }
  1471. func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  1472. w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
  1473. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1474. w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
  1475. return nil
  1476. }
  1477. // topicFromPath returns the topic from a root path (e.g. /mytopic), creating it if it doesn't exist.
  1478. func (s *Server) topicFromPath(path string) (*topic, error) {
  1479. parts := strings.Split(path, "/")
  1480. if len(parts) < 2 {
  1481. return nil, errHTTPBadRequestTopicInvalid
  1482. }
  1483. return s.topicFromID(parts[1])
  1484. }
  1485. // topicsFromPath returns the topic from a root path (e.g. /mytopic,mytopic2), creating it if it doesn't exist.
  1486. func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
  1487. parts := strings.Split(path, "/")
  1488. if len(parts) < 2 {
  1489. return nil, "", errHTTPBadRequestTopicInvalid
  1490. }
  1491. topicIDs := util.SplitNoEmpty(parts[1], ",")
  1492. topics, err := s.topicsFromIDs(topicIDs...)
  1493. if err != nil {
  1494. return nil, "", errHTTPBadRequestTopicInvalid
  1495. }
  1496. return topics, parts[1], nil
  1497. }
  1498. // topicsFromIDs returns the topics with the given IDs, creating them if they don't exist.
  1499. func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
  1500. s.mu.Lock()
  1501. defer s.mu.Unlock()
  1502. topics := make([]*topic, 0)
  1503. for _, id := range ids {
  1504. if util.Contains(s.config.DisallowedTopics, id) {
  1505. return nil, errHTTPBadRequestTopicDisallowed
  1506. }
  1507. if _, ok := s.topics[id]; !ok {
  1508. if len(s.topics) >= s.config.TotalTopicLimit {
  1509. return nil, errHTTPTooManyRequestsLimitTotalTopics
  1510. }
  1511. s.topics[id] = newTopic(id)
  1512. }
  1513. topics = append(topics, s.topics[id])
  1514. }
  1515. return topics, nil
  1516. }
  1517. // topicFromID returns the topic with the given ID, creating it if it doesn't exist.
  1518. func (s *Server) topicFromID(id string) (*topic, error) {
  1519. topics, err := s.topicsFromIDs(id)
  1520. if err != nil {
  1521. return nil, err
  1522. }
  1523. return topics[0], nil
  1524. }
  1525. // topicsFromPattern returns a list of topics matching the given pattern, but it does not create them.
  1526. func (s *Server) topicsFromPattern(pattern string) ([]*topic, error) {
  1527. s.mu.RLock()
  1528. defer s.mu.RUnlock()
  1529. patternRegexp, err := regexp.Compile("^" + strings.ReplaceAll(pattern, "*", ".*") + "$")
  1530. if err != nil {
  1531. return nil, err
  1532. }
  1533. topics := make([]*topic, 0)
  1534. for _, t := range s.topics {
  1535. if patternRegexp.MatchString(t.ID) {
  1536. topics = append(topics, t)
  1537. }
  1538. }
  1539. return topics, nil
  1540. }
  1541. func (s *Server) runSMTPServer() error {
  1542. s.smtpServerBackend = newMailBackend(s.config, s.handle)
  1543. s.smtpServer = smtp.NewServer(s.smtpServerBackend)
  1544. s.smtpServer.Addr = s.config.SMTPServerListen
  1545. s.smtpServer.Domain = s.config.SMTPServerDomain
  1546. s.smtpServer.ReadTimeout = 10 * time.Second
  1547. s.smtpServer.WriteTimeout = 10 * time.Second
  1548. s.smtpServer.MaxMessageBytes = 1024 * 1024 // Must be much larger than message size (headers, multipart, etc.)
  1549. s.smtpServer.MaxRecipients = 1
  1550. s.smtpServer.AllowInsecureAuth = true
  1551. return s.smtpServer.ListenAndServe()
  1552. }
  1553. func (s *Server) runManager() {
  1554. for {
  1555. select {
  1556. case <-time.After(s.config.ManagerInterval):
  1557. log.
  1558. Tag(tagManager).
  1559. Timing(s.execManager).
  1560. Debug("Manager finished")
  1561. case <-s.closeChan:
  1562. return
  1563. }
  1564. }
  1565. }
  1566. // runStatsResetter runs once a day (usually midnight UTC) to reset all the visitor's message and
  1567. // email counters. The stats are used to display the counters in the web app, as well as for rate limiting.
  1568. func (s *Server) runStatsResetter() {
  1569. for {
  1570. runAt := util.NextOccurrenceUTC(s.config.VisitorStatsResetTime, time.Now())
  1571. timer := time.NewTimer(time.Until(runAt))
  1572. log.Tag(tagResetter).Debug("Waiting until %v to reset visitor stats", runAt)
  1573. select {
  1574. case <-timer.C:
  1575. log.Tag(tagResetter).Debug("Running stats resetter")
  1576. s.resetStats()
  1577. case <-s.closeChan:
  1578. log.Tag(tagResetter).Debug("Stopping stats resetter")
  1579. timer.Stop()
  1580. return
  1581. }
  1582. }
  1583. }
  1584. func (s *Server) resetStats() {
  1585. log.Info("Resetting all visitor stats (daily task)")
  1586. s.mu.Lock()
  1587. defer s.mu.Unlock() // Includes the database query to avoid races with other processes
  1588. for _, v := range s.visitors {
  1589. v.ResetStats()
  1590. }
  1591. if s.userManager != nil {
  1592. if err := s.userManager.ResetStats(); err != nil {
  1593. log.Tag(tagResetter).Warn("Failed to write to database: %s", err.Error())
  1594. }
  1595. }
  1596. }
  1597. func (s *Server) runFirebaseKeepaliver() {
  1598. if s.firebaseClient == nil {
  1599. return
  1600. }
  1601. v := newVisitor(s.config, s.messageCache, s.userManager, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
  1602. for {
  1603. select {
  1604. case <-time.After(s.config.FirebaseKeepaliveInterval):
  1605. s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
  1606. /*
  1607. FIXME: Disable iOS polling entirely for now due to thundering herd problem (see #677)
  1608. To solve this, we'd have to shard the iOS poll topics to spread out the polling evenly.
  1609. Given that it's not really necessary to poll, turning it off for now should not have any impact.
  1610. case <-time.After(s.config.FirebasePollInterval):
  1611. s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
  1612. */
  1613. case <-s.closeChan:
  1614. return
  1615. }
  1616. }
  1617. }
  1618. func (s *Server) runDelayedSender() {
  1619. for {
  1620. select {
  1621. case <-time.After(s.config.DelayedSenderInterval):
  1622. if err := s.sendDelayedMessages(); err != nil {
  1623. log.Tag(tagPublish).Err(err).Warn("Error sending delayed messages")
  1624. }
  1625. case <-s.closeChan:
  1626. return
  1627. }
  1628. }
  1629. }
  1630. func (s *Server) sendDelayedMessages() error {
  1631. messages, err := s.messageCache.MessagesDue()
  1632. if err != nil {
  1633. return err
  1634. }
  1635. for _, m := range messages {
  1636. var u *user.User
  1637. if s.userManager != nil && m.User != "" {
  1638. u, err = s.userManager.UserByID(m.User)
  1639. if err != nil {
  1640. log.With(m).Err(err).Warn("Error sending delayed message")
  1641. continue
  1642. }
  1643. }
  1644. v := s.visitor(m.Sender, u)
  1645. if err := s.sendDelayedMessage(v, m); err != nil {
  1646. logvm(v, m).Err(err).Warn("Error sending delayed message")
  1647. }
  1648. }
  1649. return nil
  1650. }
  1651. func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
  1652. logvm(v, m).Debug("Sending delayed message")
  1653. s.mu.RLock()
  1654. t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
  1655. s.mu.RUnlock()
  1656. if ok {
  1657. go func() {
  1658. // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
  1659. if err := t.Publish(v, m); err != nil {
  1660. logvm(v, m).Err(err).Warn("Unable to publish message")
  1661. }
  1662. }()
  1663. }
  1664. if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
  1665. go s.sendToFirebase(v, m)
  1666. }
  1667. if s.config.UpstreamBaseURL != "" {
  1668. go s.forwardPollRequest(v, m)
  1669. }
  1670. if s.config.WebPushPublicKey != "" {
  1671. go s.publishToWebPushEndpoints(v, m)
  1672. }
  1673. if err := s.messageCache.MarkPublished(m); err != nil {
  1674. return err
  1675. }
  1676. return nil
  1677. }
  1678. // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
  1679. // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
  1680. func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
  1681. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1682. m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2, false) // 2x to account for JSON format overhead
  1683. if err != nil {
  1684. return err
  1685. }
  1686. if !topicRegex.MatchString(m.Topic) {
  1687. return errHTTPBadRequestTopicInvalid
  1688. }
  1689. if m.Message == "" {
  1690. m.Message = emptyMessageBody
  1691. }
  1692. r.URL.Path = "/" + m.Topic
  1693. r.Body = io.NopCloser(strings.NewReader(m.Message))
  1694. if m.Title != "" {
  1695. r.Header.Set("X-Title", m.Title)
  1696. }
  1697. if m.Priority != 0 {
  1698. r.Header.Set("X-Priority", fmt.Sprintf("%d", m.Priority))
  1699. }
  1700. if m.Tags != nil && len(m.Tags) > 0 {
  1701. r.Header.Set("X-Tags", strings.Join(m.Tags, ","))
  1702. }
  1703. if m.Attach != "" {
  1704. r.Header.Set("X-Attach", m.Attach)
  1705. }
  1706. if m.Filename != "" {
  1707. r.Header.Set("X-Filename", m.Filename)
  1708. }
  1709. if m.Click != "" {
  1710. r.Header.Set("X-Click", m.Click)
  1711. }
  1712. if m.Icon != "" {
  1713. r.Header.Set("X-Icon", m.Icon)
  1714. }
  1715. if m.Markdown {
  1716. r.Header.Set("X-Markdown", "yes")
  1717. }
  1718. if len(m.Actions) > 0 {
  1719. actionsStr, err := json.Marshal(m.Actions)
  1720. if err != nil {
  1721. return errHTTPBadRequestMessageJSONInvalid
  1722. }
  1723. r.Header.Set("X-Actions", string(actionsStr))
  1724. }
  1725. if m.Email != "" {
  1726. r.Header.Set("X-Email", m.Email)
  1727. }
  1728. if m.Delay != "" {
  1729. r.Header.Set("X-Delay", m.Delay)
  1730. }
  1731. if m.Call != "" {
  1732. r.Header.Set("X-Call", m.Call)
  1733. }
  1734. return next(w, r, v)
  1735. }
  1736. }
  1737. func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
  1738. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1739. newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
  1740. if err != nil {
  1741. logvr(v, r).Tag(tagMatrix).Err(err).Debug("Invalid Matrix request")
  1742. if e, ok := err.(*errMatrixPushkeyRejected); ok {
  1743. return writeMatrixResponse(w, e.rejectedPushKey)
  1744. }
  1745. return err
  1746. }
  1747. if err := next(w, newRequest, v); err != nil {
  1748. logvr(v, r).Tag(tagMatrix).Err(err).Debug("Error handling Matrix request")
  1749. return err
  1750. }
  1751. return nil
  1752. }
  1753. }
  1754. func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
  1755. return s.autorizeTopic(next, user.PermissionWrite)
  1756. }
  1757. func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
  1758. return s.autorizeTopic(next, user.PermissionRead)
  1759. }
  1760. func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
  1761. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1762. if s.userManager == nil {
  1763. return next(w, r, v)
  1764. }
  1765. topics, _, err := s.topicsFromPath(r.URL.Path)
  1766. if err != nil {
  1767. return err
  1768. }
  1769. u := v.User()
  1770. for _, t := range topics {
  1771. if err := s.userManager.Authorize(u, t.ID, perm); err != nil {
  1772. logvr(v, r).With(t).Err(err).Debug("Access to topic %s not authorized", t.ID)
  1773. return errHTTPForbidden.With(t)
  1774. }
  1775. }
  1776. return next(w, r, v)
  1777. }
  1778. }
  1779. // maybeAuthenticate delegates between auth based on the Authorization header (Bearer/Basic), and auth
  1780. // based on the user-defined header (as defined in the "auth-user-header" setting). The function prefers
  1781. // the user-defined header, if both are present.
  1782. //
  1783. // This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so
  1784. // that subsequent logging calls still have a visitor context.
  1785. func (s *Server) maybeAuthenticate(r *http.Request) (*visitor, error) {
  1786. ip := extractIPAddress(r, s.config.BehindProxy)
  1787. vip := s.visitor(ip, nil) // IP-based visitor
  1788. if s.userManager == nil {
  1789. return vip, nil
  1790. }
  1791. if s.config.AuthUserHeader != "" && s.config.BehindProxy {
  1792. username := r.Header.Get(s.config.AuthUserHeader) // Do not allow a query param, only a header!
  1793. if username != "" {
  1794. return s.authenticateViaUserDefinedHeader(r, vip, username)
  1795. }
  1796. }
  1797. return s.authenticateViaAuthHeader(r, vip)
  1798. }
  1799. // authenticateViaUserDefinedHeader tries to authenticate the user via the header defined in the "auth-user-header"
  1800. // configuration value if it is set. The value of the passed username is used to lookup the user in the database.
  1801. // If it exists, authentication is successful.
  1802. //
  1803. // This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so
  1804. // that subsequent logging calls still have a visitor context.
  1805. func (s *Server) authenticateViaUserDefinedHeader(r *http.Request, vip *visitor, username string) (*visitor, error) {
  1806. // Check the rate limiter first
  1807. if !vip.AuthAllowed() {
  1808. return vip, errHTTPTooManyRequestsLimitAuthFailure // Always return visitor, even when error occurs!
  1809. }
  1810. // Retrieve user from database; if found, we have a successful authentication
  1811. u, err := s.userManager.User(username)
  1812. if err != nil || u.Deleted {
  1813. vip.AuthFailed()
  1814. logr(r).Err(err).Debug("Authentication failed")
  1815. return vip, errHTTPUnauthorized
  1816. }
  1817. // User was found, meaning that auth was successful
  1818. return s.visitor(vip.ip, u), nil
  1819. }
  1820. // authenticateViaAuthHeader reads the "Authorization" header and will try to authenticate the user
  1821. // if it is set.
  1822. //
  1823. // - If the header is not set or not supported (anything non-Basic and non-Bearer),
  1824. // an IP-based visitor is returned
  1825. // - If the header is set, authenticate will be called to check the username/password (Basic auth),
  1826. // or the token (Bearer auth), and read the user from the database
  1827. //
  1828. // This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so
  1829. // that subsequent logging calls still have a visitor context.
  1830. func (s *Server) authenticateViaAuthHeader(r *http.Request, vip *visitor) (*visitor, error) {
  1831. // Read "Authorization" header value, and exit out early if it's not set
  1832. header, err := readAuthHeader(r)
  1833. if err != nil {
  1834. return vip, err
  1835. } else if !supportedAuthHeader(header) {
  1836. return vip, nil
  1837. }
  1838. // If we're trying to auth, check the rate limiter first
  1839. if !vip.AuthAllowed() {
  1840. return vip, errHTTPTooManyRequestsLimitAuthFailure // Always return visitor, even when error occurs!
  1841. }
  1842. u, err := s.authenticate(r, header)
  1843. if err != nil {
  1844. vip.AuthFailed()
  1845. logr(r).Err(err).Debug("Authentication failed")
  1846. return vip, errHTTPUnauthorized // Always return visitor, even when error occurs!
  1847. }
  1848. // Authentication with user was successful
  1849. return s.visitor(vip.ip, u), nil
  1850. }
  1851. // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
  1852. // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
  1853. // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
  1854. // query param is effectively doubly base64 encoded. Its format is base64(Basic base64(user:pass)).
  1855. func (s *Server) authenticate(r *http.Request, header string) (user *user.User, err error) {
  1856. if strings.HasPrefix(header, "Bearer") {
  1857. return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(header, "Bearer")))
  1858. }
  1859. return s.authenticateBasicAuth(r, header)
  1860. }
  1861. // readAuthHeader reads the raw value of the Authorization header, either from the actual HTTP header,
  1862. // or from the ?auth... query parameter
  1863. func readAuthHeader(r *http.Request) (string, error) {
  1864. value := strings.TrimSpace(r.Header.Get("Authorization"))
  1865. queryParam := readQueryParam(r, "authorization", "auth")
  1866. if queryParam != "" {
  1867. a, err := base64.RawURLEncoding.DecodeString(queryParam)
  1868. if err != nil {
  1869. return "", err
  1870. }
  1871. value = strings.TrimSpace(string(a))
  1872. }
  1873. return value, nil
  1874. }
  1875. // supportedAuthHeader returns true only if the Authorization header value starts
  1876. // with "Basic" or "Bearer". In particular, an empty value is not supported, and neither
  1877. // are things like "WebPush", or "vapid" (see #629).
  1878. func supportedAuthHeader(value string) bool {
  1879. value = strings.ToLower(value)
  1880. return strings.HasPrefix(value, "basic ") || strings.HasPrefix(value, "bearer ")
  1881. }
  1882. func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
  1883. r.Header.Set("Authorization", value)
  1884. username, password, ok := r.BasicAuth()
  1885. if !ok {
  1886. return nil, errors.New("invalid basic auth")
  1887. } else if username == "" {
  1888. return s.authenticateBearerAuth(r, password) // Treat password as token
  1889. }
  1890. return s.userManager.Authenticate(username, password)
  1891. }
  1892. func (s *Server) authenticateBearerAuth(r *http.Request, token string) (*user.User, error) {
  1893. u, err := s.userManager.AuthenticateToken(token)
  1894. if err != nil {
  1895. return nil, err
  1896. }
  1897. ip := extractIPAddress(r, s.config.BehindProxy)
  1898. go s.userManager.EnqueueTokenUpdate(token, &user.TokenUpdate{
  1899. LastAccess: time.Now(),
  1900. LastOrigin: ip,
  1901. })
  1902. return u, nil
  1903. }
  1904. func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
  1905. s.mu.Lock()
  1906. defer s.mu.Unlock()
  1907. id := visitorID(ip, user)
  1908. v, exists := s.visitors[id]
  1909. if !exists {
  1910. s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
  1911. return s.visitors[id]
  1912. }
  1913. v.Keepalive()
  1914. v.SetUser(user) // Always update with the latest user, may be nil!
  1915. return v
  1916. }
  1917. func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
  1918. return s.writeJSONWithContentType(w, v, "application/json")
  1919. }
  1920. func (s *Server) writeJSONWithContentType(w http.ResponseWriter, v any, contentType string) error {
  1921. w.Header().Set("Content-Type", contentType)
  1922. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1923. if err := json.NewEncoder(w).Encode(v); err != nil {
  1924. return err
  1925. }
  1926. return nil
  1927. }
  1928. func (s *Server) updateAndWriteStats(messagesCount int64) {
  1929. s.mu.Lock()
  1930. s.messagesHistory = append(s.messagesHistory, messagesCount)
  1931. if len(s.messagesHistory) > messagesHistoryMax {
  1932. s.messagesHistory = s.messagesHistory[1:]
  1933. }
  1934. s.mu.Unlock()
  1935. go func() {
  1936. if err := s.messageCache.UpdateStats(messagesCount); err != nil {
  1937. log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
  1938. }
  1939. }()
  1940. }