server.go 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "embed"
  7. "encoding/base64"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "heckel.io/ntfy/user"
  12. "io"
  13. "net"
  14. "net/http"
  15. "net/netip"
  16. "net/url"
  17. "os"
  18. "path"
  19. "path/filepath"
  20. "regexp"
  21. "sort"
  22. "strconv"
  23. "strings"
  24. "sync"
  25. "time"
  26. "unicode/utf8"
  27. "heckel.io/ntfy/log"
  28. "github.com/emersion/go-smtp"
  29. "github.com/gorilla/websocket"
  30. "golang.org/x/sync/errgroup"
  31. "heckel.io/ntfy/util"
  32. )
  33. /*
  34. TODO
  35. races:
  36. - v.user --> see publishSyncEventAsync() test
  37. payments:
  38. - delete messages + reserved topics on ResetTier delete attachments in access.go
  39. - reconciliation
  40. Limits & rate limiting:
  41. users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
  42. login/account endpoints
  43. when ResetStats() is run, reset messagesLimiter (and others)?
  44. Delete visitor when tier is changed to refresh rate limiters
  45. Make sure account endpoints make sense for admins
  46. UI:
  47. - revert home page change
  48. - flicker of upgrade banner
  49. - JS constants
  50. Sync:
  51. - sync problems with "deleteAfter=0" and "displayName="
  52. Tests:
  53. - Payment endpoints (make mocks)
  54. - Message rate limiting and reset tests
  55. - test that the visitor is based on the IP address when a user has no tier
  56. */
  57. // Server is the main server, providing the UI and API for ntfy
  58. type Server struct {
  59. config *Config
  60. httpServer *http.Server
  61. httpsServer *http.Server
  62. unixListener net.Listener
  63. smtpServer *smtp.Server
  64. smtpServerBackend *smtpBackend
  65. smtpSender mailer
  66. topics map[string]*topic
  67. visitors map[string]*visitor // ip:<ip> or user:<user>
  68. firebaseClient *firebaseClient
  69. messages int64
  70. userManager *user.Manager // Might be nil!
  71. messageCache *messageCache // Database that stores the messages
  72. fileCache *fileCache // File system based cache that stores attachments
  73. stripe stripeAPI // Stripe API, can be replaced with a mock
  74. priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
  75. closeChan chan bool
  76. mu sync.Mutex
  77. }
  78. // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
  79. type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
  80. var (
  81. // If changed, don't forget to update Android App and auth_sqlite.go
  82. topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
  83. topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
  84. externalTopicPathRegex = regexp.MustCompile(`^/[^/]+\.[^/]+/[-_A-Za-z0-9]{1,64}$`) // Extended topic path, for web-app, e.g. /example.com/mytopic
  85. jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
  86. ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
  87. rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
  88. wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
  89. authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
  90. publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
  91. webConfigPath = "/config.js"
  92. accountPath = "/account"
  93. matrixPushPath = "/_matrix/push/v1/notify"
  94. apiHealthPath = "/v1/health"
  95. apiTiers = "/v1/tiers"
  96. apiAccountPath = "/v1/account"
  97. apiAccountTokenPath = "/v1/account/token"
  98. apiAccountPasswordPath = "/v1/account/password"
  99. apiAccountSettingsPath = "/v1/account/settings"
  100. apiAccountSubscriptionPath = "/v1/account/subscription"
  101. apiAccountReservationPath = "/v1/account/reservation"
  102. apiAccountBillingPortalPath = "/v1/account/billing/portal"
  103. apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
  104. apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
  105. apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
  106. apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
  107. apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
  108. apiAccountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
  109. staticRegex = regexp.MustCompile(`^/static/.+`)
  110. docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
  111. fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
  112. disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app
  113. urlRegex = regexp.MustCompile(`^https?://`)
  114. //go:embed site
  115. webFs embed.FS
  116. webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
  117. webSiteDir = "/site"
  118. webHomeIndex = "/home.html" // Landing page, only if "web-root: home"
  119. webAppIndex = "/app.html" // React app
  120. //go:embed docs
  121. docsStaticFs embed.FS
  122. docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
  123. )
  124. const (
  125. firebaseControlTopic = "~control" // See Android if changed
  126. firebasePollTopic = "~poll" // See iOS if changed
  127. emptyMessageBody = "triggered" // Used if message body is empty
  128. newMessageBody = "New message" // Used in poll requests as generic message
  129. defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
  130. encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
  131. jsonBodyBytesLimit = 16384
  132. )
  133. // WebSocket constants
  134. const (
  135. wsWriteWait = 2 * time.Second
  136. wsBufferSize = 1024
  137. wsReadLimit = 64 // We only ever receive PINGs
  138. wsPongWait = 15 * time.Second
  139. )
  140. // New instantiates a new Server. It creates the cache and adds a Firebase
  141. // subscriber (if configured).
  142. func New(conf *Config) (*Server, error) {
  143. var mailer mailer
  144. if conf.SMTPSenderAddr != "" {
  145. mailer = &smtpSender{config: conf}
  146. }
  147. var stripe stripeAPI
  148. if conf.StripeSecretKey != "" {
  149. stripe = newStripeAPI()
  150. }
  151. messageCache, err := createMessageCache(conf)
  152. if err != nil {
  153. return nil, err
  154. }
  155. topics, err := messageCache.Topics()
  156. if err != nil {
  157. return nil, err
  158. }
  159. var fileCache *fileCache
  160. if conf.AttachmentCacheDir != "" {
  161. fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
  162. if err != nil {
  163. return nil, err
  164. }
  165. }
  166. var userManager *user.Manager
  167. if conf.AuthFile != "" {
  168. userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault)
  169. if err != nil {
  170. return nil, err
  171. }
  172. }
  173. var firebaseClient *firebaseClient
  174. if conf.FirebaseKeyFile != "" {
  175. sender, err := newFirebaseSender(conf.FirebaseKeyFile)
  176. if err != nil {
  177. return nil, err
  178. }
  179. firebaseClient = newFirebaseClient(sender, userManager)
  180. }
  181. s := &Server{
  182. config: conf,
  183. messageCache: messageCache,
  184. fileCache: fileCache,
  185. firebaseClient: firebaseClient,
  186. smtpSender: mailer,
  187. topics: topics,
  188. userManager: userManager,
  189. visitors: make(map[string]*visitor),
  190. stripe: stripe,
  191. }
  192. s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
  193. return s, nil
  194. }
  195. func createMessageCache(conf *Config) (*messageCache, error) {
  196. if conf.CacheDuration == 0 {
  197. return newNopCache()
  198. } else if conf.CacheFile != "" {
  199. return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
  200. }
  201. return newMemCache()
  202. }
  203. // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
  204. // a manager go routine to print stats and prune messages.
  205. func (s *Server) Run() error {
  206. var listenStr string
  207. if s.config.ListenHTTP != "" {
  208. listenStr += fmt.Sprintf(" %s[http]", s.config.ListenHTTP)
  209. }
  210. if s.config.ListenHTTPS != "" {
  211. listenStr += fmt.Sprintf(" %s[https]", s.config.ListenHTTPS)
  212. }
  213. if s.config.ListenUnix != "" {
  214. listenStr += fmt.Sprintf(" %s[unix]", s.config.ListenUnix)
  215. }
  216. if s.config.SMTPServerListen != "" {
  217. listenStr += fmt.Sprintf(" %s[smtp]", s.config.SMTPServerListen)
  218. }
  219. log.Info("Listening on%s, ntfy %s, log level is %s", listenStr, s.config.Version, log.CurrentLevel().String())
  220. mux := http.NewServeMux()
  221. mux.HandleFunc("/", s.handle)
  222. errChan := make(chan error)
  223. s.mu.Lock()
  224. s.closeChan = make(chan bool)
  225. if s.config.ListenHTTP != "" {
  226. s.httpServer = &http.Server{Addr: s.config.ListenHTTP, Handler: mux}
  227. go func() {
  228. errChan <- s.httpServer.ListenAndServe()
  229. }()
  230. }
  231. if s.config.ListenHTTPS != "" {
  232. s.httpsServer = &http.Server{Addr: s.config.ListenHTTPS, Handler: mux}
  233. go func() {
  234. errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
  235. }()
  236. }
  237. if s.config.ListenUnix != "" {
  238. go func() {
  239. var err error
  240. s.mu.Lock()
  241. os.Remove(s.config.ListenUnix)
  242. s.unixListener, err = net.Listen("unix", s.config.ListenUnix)
  243. if err != nil {
  244. s.mu.Unlock()
  245. errChan <- err
  246. return
  247. }
  248. defer s.unixListener.Close()
  249. if s.config.ListenUnixMode > 0 {
  250. if err := os.Chmod(s.config.ListenUnix, s.config.ListenUnixMode); err != nil {
  251. s.mu.Unlock()
  252. errChan <- err
  253. return
  254. }
  255. }
  256. s.mu.Unlock()
  257. httpServer := &http.Server{Handler: mux}
  258. errChan <- httpServer.Serve(s.unixListener)
  259. }()
  260. }
  261. if s.config.SMTPServerListen != "" {
  262. go func() {
  263. errChan <- s.runSMTPServer()
  264. }()
  265. }
  266. s.mu.Unlock()
  267. go s.runManager()
  268. go s.runStatsResetter()
  269. go s.runDelayedSender()
  270. go s.runFirebaseKeepaliver()
  271. return <-errChan
  272. }
  273. // Stop stops HTTP (+HTTPS) server and all managers
  274. func (s *Server) Stop() {
  275. s.mu.Lock()
  276. defer s.mu.Unlock()
  277. if s.httpServer != nil {
  278. s.httpServer.Close()
  279. }
  280. if s.httpsServer != nil {
  281. s.httpsServer.Close()
  282. }
  283. if s.unixListener != nil {
  284. s.unixListener.Close()
  285. }
  286. if s.smtpServer != nil {
  287. s.smtpServer.Close()
  288. }
  289. close(s.closeChan)
  290. }
  291. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  292. v, err := s.visitor(r) // Note: Always returns v, even when error is returned
  293. if err == nil {
  294. log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
  295. if log.IsTrace() {
  296. log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
  297. }
  298. err = s.handleInternal(w, r, v)
  299. }
  300. if err != nil {
  301. if websocket.IsWebSocketUpgrade(r) {
  302. isNormalError := strings.Contains(err.Error(), "i/o timeout")
  303. if isNormalError {
  304. log.Debug("%s WebSocket error (this error is okay, it happens a lot): %s", logHTTPPrefix(v, r), err.Error())
  305. } else {
  306. log.Info("%s WebSocket error: %s", logHTTPPrefix(v, r), err.Error())
  307. }
  308. return // Do not attempt to write to upgraded connection
  309. }
  310. if matrixErr, ok := err.(*errMatrix); ok {
  311. writeMatrixError(w, r, v, matrixErr)
  312. return
  313. }
  314. httpErr, ok := err.(*errHTTP)
  315. if !ok {
  316. httpErr = errHTTPInternalError
  317. }
  318. isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest
  319. if isNormalError {
  320. log.Debug("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  321. } else {
  322. log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  323. }
  324. w.Header().Set("Content-Type", "application/json")
  325. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  326. w.WriteHeader(httpErr.HTTPCode)
  327. io.WriteString(w, httpErr.JSON()+"\n")
  328. }
  329. }
  330. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
  331. if r.Method == http.MethodGet && r.URL.Path == "/" {
  332. return s.ensureWebEnabled(s.handleHome)(w, r, v)
  333. } else if r.Method == http.MethodHead && r.URL.Path == "/" {
  334. return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
  335. } else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
  336. return s.handleHealth(w, r, v)
  337. } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
  338. return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
  339. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
  340. return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
  341. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath {
  342. return s.ensureUser(s.handleAccountTokenIssue)(w, r, v)
  343. } else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
  344. return s.handleAccountGet(w, r, v) // Allowed by anonymous
  345. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath {
  346. return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
  347. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath {
  348. return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
  349. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath {
  350. return s.ensureUser(s.handleAccountTokenExtend)(w, r, v)
  351. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath {
  352. return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
  353. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath {
  354. return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
  355. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath {
  356. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
  357. } else if r.Method == http.MethodPatch && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  358. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
  359. } else if r.Method == http.MethodDelete && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  360. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
  361. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath {
  362. return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
  363. } else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
  364. return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
  365. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
  366. return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
  367. } else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
  368. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
  369. } else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
  370. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
  371. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
  372. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
  373. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
  374. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
  375. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
  376. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
  377. } else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
  378. return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
  379. } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
  380. return s.handleMatrixDiscovery(w)
  381. } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
  382. return s.ensureWebEnabled(s.handleStatic)(w, r, v)
  383. } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
  384. return s.ensureWebEnabled(s.handleDocs)(w, r, v)
  385. } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
  386. return s.limitRequests(s.handleFile)(w, r, v)
  387. } else if r.Method == http.MethodOptions {
  388. return s.ensureWebEnabled(s.handleOptions)(w, r, v)
  389. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
  390. return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
  391. } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
  392. return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
  393. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
  394. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  395. } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
  396. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  397. } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
  398. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
  399. } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
  400. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
  401. } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
  402. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
  403. } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
  404. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
  405. } else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
  406. return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
  407. } else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
  408. return s.ensureWebEnabled(s.handleTopic)(w, r, v)
  409. }
  410. return errHTTPNotFound
  411. }
  412. func (s *Server) handleHome(w http.ResponseWriter, r *http.Request, v *visitor) error {
  413. if s.config.WebRootIsApp {
  414. r.URL.Path = webAppIndex
  415. } else {
  416. r.URL.Path = webHomeIndex
  417. }
  418. return s.handleStatic(w, r, v)
  419. }
  420. func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor) error {
  421. unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
  422. if unifiedpush {
  423. w.Header().Set("Content-Type", "application/json")
  424. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  425. _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
  426. return err
  427. }
  428. r.URL.Path = webAppIndex
  429. return s.handleStatic(w, r, v)
  430. }
  431. func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error {
  432. return nil
  433. }
  434. func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  435. return s.writeJSON(w, newSuccessResponse())
  436. }
  437. func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  438. response := &apiHealthResponse{
  439. Healthy: true,
  440. }
  441. return s.writeJSON(w, response)
  442. }
  443. func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  444. appRoot := "/"
  445. if !s.config.WebRootIsApp {
  446. appRoot = "/app"
  447. }
  448. response := &apiConfigResponse{
  449. BaseURL: "", // Will translate to window.location.origin
  450. AppRoot: appRoot,
  451. EnableLogin: s.config.EnableLogin,
  452. EnableSignup: s.config.EnableSignup,
  453. EnablePayments: s.config.StripeSecretKey != "",
  454. EnableReservations: s.config.EnableReservations,
  455. DisallowedTopics: disallowedTopics,
  456. }
  457. b, err := json.MarshalIndent(response, "", " ")
  458. if err != nil {
  459. return err
  460. }
  461. w.Header().Set("Content-Type", "text/javascript")
  462. _, err = io.WriteString(w, fmt.Sprintf("// Generated server configuration\nvar config = %s;\n", string(b)))
  463. return err
  464. }
  465. func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  466. r.URL.Path = webSiteDir + r.URL.Path
  467. util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
  468. return nil
  469. }
  470. func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  471. util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
  472. return nil
  473. }
  474. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
  475. if s.config.AttachmentCacheDir == "" {
  476. return errHTTPInternalError
  477. }
  478. matches := fileRegex.FindStringSubmatch(r.URL.Path)
  479. if len(matches) != 2 {
  480. return errHTTPInternalErrorInvalidPath
  481. }
  482. messageID := matches[1]
  483. file := filepath.Join(s.config.AttachmentCacheDir, messageID)
  484. stat, err := os.Stat(file)
  485. if err != nil {
  486. return errHTTPNotFound
  487. }
  488. if r.Method == http.MethodGet {
  489. if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
  490. return errHTTPTooManyRequestsLimitAttachmentBandwidth
  491. }
  492. }
  493. w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
  494. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  495. if r.Method == http.MethodGet {
  496. f, err := os.Open(file)
  497. if err != nil {
  498. return err
  499. }
  500. defer f.Close()
  501. _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
  502. return err
  503. }
  504. return nil
  505. }
  506. func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
  507. if s.config.BaseURL == "" {
  508. return errHTTPInternalErrorMissingBaseURL
  509. }
  510. return writeMatrixDiscoveryResponse(w)
  511. }
  512. func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
  513. t, err := s.topicFromPath(r.URL.Path)
  514. if err != nil {
  515. return nil, err
  516. }
  517. if err := v.MessageAllowed(); err != nil {
  518. return nil, errHTTPTooManyRequestsLimitMessages
  519. }
  520. body, err := util.Peek(r.Body, s.config.MessageLimit)
  521. if err != nil {
  522. return nil, err
  523. }
  524. m := newDefaultMessage(t.ID, "")
  525. cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
  526. if err != nil {
  527. return nil, err
  528. }
  529. if m.PollID != "" {
  530. m = newPollRequestMessage(t.ID, m.PollID)
  531. }
  532. if v.user != nil {
  533. m.User = v.user.Name
  534. }
  535. m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
  536. if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
  537. return nil, err
  538. }
  539. if m.Message == "" {
  540. m.Message = emptyMessageBody
  541. }
  542. delayed := m.Time > time.Now().Unix()
  543. log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s",
  544. logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email)
  545. if log.IsTrace() {
  546. log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m))
  547. }
  548. if !delayed {
  549. if err := t.Publish(v, m); err != nil {
  550. return nil, err
  551. }
  552. if s.firebaseClient != nil && firebase {
  553. go s.sendToFirebase(v, m)
  554. }
  555. if s.smtpSender != nil && email != "" {
  556. v.IncrementEmails()
  557. go s.sendEmail(v, m, email)
  558. }
  559. if s.config.UpstreamBaseURL != "" {
  560. go s.forwardPollRequest(v, m)
  561. }
  562. } else {
  563. log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m))
  564. }
  565. if cache {
  566. log.Debug("%s Adding message to cache", logMessagePrefix(v, m))
  567. if err := s.messageCache.AddMessage(m); err != nil {
  568. return nil, err
  569. }
  570. }
  571. v.IncrementMessages()
  572. if s.userManager != nil && v.user != nil {
  573. s.userManager.EnqueueStats(v.user)
  574. }
  575. s.mu.Lock()
  576. s.messages++
  577. s.mu.Unlock()
  578. return m, nil
  579. }
  580. func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
  581. m, err := s.handlePublishWithoutResponse(r, v)
  582. if err != nil {
  583. return err
  584. }
  585. return s.writeJSON(w, m)
  586. }
  587. func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
  588. _, err := s.handlePublishWithoutResponse(r, v)
  589. if err != nil {
  590. return &errMatrix{pushKey: r.Header.Get(matrixPushKeyHeader), err: err}
  591. }
  592. return writeMatrixSuccess(w)
  593. }
  594. func (s *Server) sendToFirebase(v *visitor, m *message) {
  595. log.Debug("%s Publishing to Firebase", logMessagePrefix(v, m))
  596. if err := s.firebaseClient.Send(v, m); err != nil {
  597. if err == errFirebaseTemporarilyBanned {
  598. log.Debug("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  599. } else {
  600. log.Warn("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  601. }
  602. }
  603. }
  604. func (s *Server) sendEmail(v *visitor, m *message, email string) {
  605. log.Debug("%s Sending email to %s", logMessagePrefix(v, m), email)
  606. if err := s.smtpSender.Send(v, m, email); err != nil {
  607. log.Warn("%s Unable to send email to %s: %v", logMessagePrefix(v, m), email, err.Error())
  608. }
  609. }
  610. func (s *Server) forwardPollRequest(v *visitor, m *message) {
  611. topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
  612. topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
  613. forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
  614. log.Debug("%s Publishing poll request to %s", logMessagePrefix(v, m), forwardURL)
  615. req, err := http.NewRequest("POST", forwardURL, strings.NewReader(""))
  616. if err != nil {
  617. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  618. return
  619. }
  620. req.Header.Set("X-Poll-ID", m.ID)
  621. var httpClient = &http.Client{
  622. Timeout: time.Second * 10,
  623. }
  624. response, err := httpClient.Do(req)
  625. if err != nil {
  626. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  627. return
  628. } else if response.StatusCode != http.StatusOK {
  629. log.Warn("%s Unable to publish poll request, unexpected HTTP status: %d", logMessagePrefix(v, m), response.StatusCode)
  630. return
  631. }
  632. }
  633. func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
  634. cache = readBoolParam(r, true, "x-cache", "cache")
  635. firebase = readBoolParam(r, true, "x-firebase", "firebase")
  636. m.Title = readParam(r, "x-title", "title", "t")
  637. m.Click = readParam(r, "x-click", "click")
  638. icon := readParam(r, "x-icon", "icon")
  639. filename := readParam(r, "x-filename", "filename", "file", "f")
  640. attach := readParam(r, "x-attach", "attach", "a")
  641. if attach != "" || filename != "" {
  642. m.Attachment = &attachment{}
  643. }
  644. if filename != "" {
  645. m.Attachment.Name = filename
  646. }
  647. if attach != "" {
  648. if !urlRegex.MatchString(attach) {
  649. return false, false, "", false, errHTTPBadRequestAttachmentURLInvalid
  650. }
  651. m.Attachment.URL = attach
  652. if m.Attachment.Name == "" {
  653. u, err := url.Parse(m.Attachment.URL)
  654. if err == nil {
  655. m.Attachment.Name = path.Base(u.Path)
  656. if m.Attachment.Name == "." || m.Attachment.Name == "/" {
  657. m.Attachment.Name = ""
  658. }
  659. }
  660. }
  661. if m.Attachment.Name == "" {
  662. m.Attachment.Name = "attachment"
  663. }
  664. }
  665. if icon != "" {
  666. if !urlRegex.MatchString(icon) {
  667. return false, false, "", false, errHTTPBadRequestIconURLInvalid
  668. }
  669. m.Icon = icon
  670. }
  671. email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
  672. if email != "" {
  673. if err := v.EmailAllowed(); err != nil {
  674. return false, false, "", false, errHTTPTooManyRequestsLimitEmails
  675. }
  676. }
  677. if s.smtpSender == nil && email != "" {
  678. return false, false, "", false, errHTTPBadRequestEmailDisabled
  679. }
  680. messageStr := strings.ReplaceAll(readParam(r, "x-message", "message", "m"), "\\n", "\n")
  681. if messageStr != "" {
  682. m.Message = messageStr
  683. }
  684. m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
  685. if err != nil {
  686. return false, false, "", false, errHTTPBadRequestPriorityInvalid
  687. }
  688. tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
  689. if tagsStr != "" {
  690. m.Tags = make([]string, 0)
  691. for _, s := range util.SplitNoEmpty(tagsStr, ",") {
  692. m.Tags = append(m.Tags, strings.TrimSpace(s))
  693. }
  694. }
  695. delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
  696. if delayStr != "" {
  697. if !cache {
  698. return false, false, "", false, errHTTPBadRequestDelayNoCache
  699. }
  700. if email != "" {
  701. return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
  702. }
  703. delay, err := util.ParseFutureTime(delayStr, time.Now())
  704. if err != nil {
  705. return false, false, "", false, errHTTPBadRequestDelayCannotParse
  706. } else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
  707. return false, false, "", false, errHTTPBadRequestDelayTooSmall
  708. } else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
  709. return false, false, "", false, errHTTPBadRequestDelayTooLarge
  710. }
  711. m.Time = delay.Unix()
  712. m.Sender = v.ip // Important for rate limiting
  713. }
  714. actionsStr := readParam(r, "x-actions", "actions", "action")
  715. if actionsStr != "" {
  716. m.Actions, err = parseActions(actionsStr)
  717. if err != nil {
  718. return false, false, "", false, wrapErrHTTP(errHTTPBadRequestActionsInvalid, err.Error())
  719. }
  720. }
  721. unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
  722. if unifiedpush {
  723. firebase = false
  724. unifiedpush = true
  725. }
  726. m.PollID = readParam(r, "x-poll-id", "poll-id")
  727. if m.PollID != "" {
  728. unifiedpush = false
  729. cache = false
  730. email = ""
  731. }
  732. return cache, firebase, email, unifiedpush, nil
  733. }
  734. // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
  735. //
  736. // 1. curl -X POST -H "Poll: 1234" ntfy.sh/...
  737. // If a message is flagged as poll request, the body does not matter and is discarded
  738. // 2. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1"
  739. // If body is binary, encode as base64, if not do not encode
  740. // 3. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
  741. // Body must be a message, because we attached an external URL
  742. // 4. curl -T short.txt -H "Filename: short.txt" ntfy.sh/mytopic
  743. // Body must be attachment, because we passed a filename
  744. // 5. curl -T file.txt ntfy.sh/mytopic
  745. // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
  746. // 6. curl -T file.txt ntfy.sh/mytopic
  747. // If file.txt is > message limit, treat it as an attachment
  748. func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
  749. if m.Event == pollRequestEvent { // Case 1
  750. return s.handleBodyDiscard(body)
  751. } else if unifiedpush {
  752. return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
  753. } else if m.Attachment != nil && m.Attachment.URL != "" {
  754. return s.handleBodyAsTextMessage(m, body) // Case 3
  755. } else if m.Attachment != nil && m.Attachment.Name != "" {
  756. return s.handleBodyAsAttachment(r, v, m, body) // Case 4
  757. } else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
  758. return s.handleBodyAsTextMessage(m, body) // Case 5
  759. }
  760. return s.handleBodyAsAttachment(r, v, m, body) // Case 6
  761. }
  762. func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
  763. _, err := io.Copy(io.Discard, body)
  764. _ = body.Close()
  765. return err
  766. }
  767. func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
  768. if utf8.Valid(body.PeekedBytes) {
  769. m.Message = string(body.PeekedBytes) // Do not trim
  770. } else {
  771. m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes)
  772. m.Encoding = encodingBase64
  773. }
  774. return nil
  775. }
  776. func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
  777. if !utf8.Valid(body.PeekedBytes) {
  778. return errHTTPBadRequestMessageNotUTF8
  779. }
  780. if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
  781. m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
  782. }
  783. if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
  784. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  785. }
  786. return nil
  787. }
  788. func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
  789. if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
  790. return errHTTPBadRequestAttachmentsDisallowed
  791. }
  792. vinfo, err := v.Info()
  793. if err != nil {
  794. return err
  795. }
  796. attachmentExpiry := time.Now().Add(vinfo.Limits.AttachmentExpiryDuration).Unix()
  797. if m.Time > attachmentExpiry {
  798. return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
  799. }
  800. contentLengthStr := r.Header.Get("Content-Length")
  801. if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
  802. contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
  803. if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) {
  804. return errHTTPEntityTooLargeAttachment
  805. }
  806. }
  807. if m.Attachment == nil {
  808. m.Attachment = &attachment{}
  809. }
  810. var ext string
  811. m.Sender = v.ip // Important for attachment rate limiting
  812. m.Attachment.Expires = attachmentExpiry
  813. m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
  814. m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
  815. if m.Attachment.Name == "" {
  816. m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
  817. }
  818. if m.Message == "" {
  819. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  820. }
  821. limiters := []util.Limiter{
  822. v.BandwidthLimiter(),
  823. util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
  824. util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
  825. }
  826. m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
  827. if err == util.ErrLimitReached {
  828. return errHTTPEntityTooLargeAttachment
  829. } else if err != nil {
  830. return err
  831. }
  832. return nil
  833. }
  834. func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
  835. encoder := func(msg *message) (string, error) {
  836. var buf bytes.Buffer
  837. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  838. return "", err
  839. }
  840. return buf.String(), nil
  841. }
  842. return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder)
  843. }
  844. func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
  845. encoder := func(msg *message) (string, error) {
  846. var buf bytes.Buffer
  847. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  848. return "", err
  849. }
  850. if msg.Event != messageEvent {
  851. return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
  852. }
  853. return fmt.Sprintf("data: %s\n", buf.String()), nil
  854. }
  855. return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder)
  856. }
  857. func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
  858. encoder := func(msg *message) (string, error) {
  859. if msg.Event == messageEvent { // only handle default events
  860. return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
  861. }
  862. return "\n", nil // "keepalive" and "open" events just send an empty line
  863. }
  864. return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder)
  865. }
  866. func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
  867. log.Debug("%s HTTP stream connection opened", logHTTPPrefix(v, r))
  868. defer log.Debug("%s HTTP stream connection closed", logHTTPPrefix(v, r))
  869. if err := v.SubscriptionAllowed(); err != nil {
  870. return errHTTPTooManyRequestsLimitSubscriptions
  871. }
  872. defer v.RemoveSubscription()
  873. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  874. if err != nil {
  875. return err
  876. }
  877. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  878. if err != nil {
  879. return err
  880. }
  881. var wlock sync.Mutex
  882. defer func() {
  883. // Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time.
  884. // It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER
  885. // this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the
  886. // data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889.
  887. wlock.TryLock()
  888. }()
  889. sub := func(v *visitor, msg *message) error {
  890. if !filters.Pass(msg) {
  891. return nil
  892. }
  893. m, err := encoder(msg)
  894. if err != nil {
  895. return err
  896. }
  897. wlock.Lock()
  898. defer wlock.Unlock()
  899. if _, err := w.Write([]byte(m)); err != nil {
  900. return err
  901. }
  902. if fl, ok := w.(http.Flusher); ok {
  903. fl.Flush()
  904. }
  905. return nil
  906. }
  907. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  908. w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
  909. if poll {
  910. return s.sendOldMessages(topics, since, scheduled, v, sub)
  911. }
  912. subscriberIDs := make([]int, 0)
  913. for _, t := range topics {
  914. subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
  915. }
  916. defer func() {
  917. for i, subscriberID := range subscriberIDs {
  918. topics[i].Unsubscribe(subscriberID) // Order!
  919. }
  920. }()
  921. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  922. return err
  923. }
  924. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  925. return err
  926. }
  927. for {
  928. select {
  929. case <-r.Context().Done():
  930. return nil
  931. case <-time.After(s.config.KeepaliveInterval):
  932. log.Trace("%s Sending keepalive message", logHTTPPrefix(v, r))
  933. v.Keepalive()
  934. if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
  935. return err
  936. }
  937. }
  938. }
  939. }
  940. func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
  941. if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
  942. return errHTTPBadRequestWebSocketsUpgradeHeaderMissing
  943. }
  944. if err := v.SubscriptionAllowed(); err != nil {
  945. return errHTTPTooManyRequestsLimitSubscriptions
  946. }
  947. defer v.RemoveSubscription()
  948. log.Debug("%s WebSocket connection opened", logHTTPPrefix(v, r))
  949. defer log.Debug("%s WebSocket connection closed", logHTTPPrefix(v, r))
  950. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  951. if err != nil {
  952. return err
  953. }
  954. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  955. if err != nil {
  956. return err
  957. }
  958. upgrader := &websocket.Upgrader{
  959. ReadBufferSize: wsBufferSize,
  960. WriteBufferSize: wsBufferSize,
  961. CheckOrigin: func(r *http.Request) bool {
  962. return true // We're open for business!
  963. },
  964. }
  965. conn, err := upgrader.Upgrade(w, r, nil)
  966. if err != nil {
  967. return err
  968. }
  969. defer conn.Close()
  970. var wlock sync.Mutex
  971. g, ctx := errgroup.WithContext(context.Background())
  972. g.Go(func() error {
  973. pongWait := s.config.KeepaliveInterval + wsPongWait
  974. conn.SetReadLimit(wsReadLimit)
  975. if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
  976. return err
  977. }
  978. conn.SetPongHandler(func(appData string) error {
  979. log.Trace("%s Received WebSocket pong", logHTTPPrefix(v, r))
  980. return conn.SetReadDeadline(time.Now().Add(pongWait))
  981. })
  982. for {
  983. _, _, err := conn.NextReader()
  984. if err != nil {
  985. return err
  986. }
  987. }
  988. })
  989. g.Go(func() error {
  990. ping := func() error {
  991. wlock.Lock()
  992. defer wlock.Unlock()
  993. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  994. return err
  995. }
  996. log.Trace("%s Sending WebSocket ping", logHTTPPrefix(v, r))
  997. return conn.WriteMessage(websocket.PingMessage, nil)
  998. }
  999. for {
  1000. select {
  1001. case <-ctx.Done():
  1002. return nil
  1003. case <-time.After(s.config.KeepaliveInterval):
  1004. v.Keepalive()
  1005. if err := ping(); err != nil {
  1006. return err
  1007. }
  1008. }
  1009. }
  1010. })
  1011. sub := func(v *visitor, msg *message) error {
  1012. if !filters.Pass(msg) {
  1013. return nil
  1014. }
  1015. wlock.Lock()
  1016. defer wlock.Unlock()
  1017. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1018. return err
  1019. }
  1020. return conn.WriteJSON(msg)
  1021. }
  1022. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  1023. if poll {
  1024. return s.sendOldMessages(topics, since, scheduled, v, sub)
  1025. }
  1026. subscriberIDs := make([]int, 0)
  1027. for _, t := range topics {
  1028. subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
  1029. }
  1030. defer func() {
  1031. for i, subscriberID := range subscriberIDs {
  1032. topics[i].Unsubscribe(subscriberID) // Order!
  1033. }
  1034. }()
  1035. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  1036. return err
  1037. }
  1038. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  1039. return err
  1040. }
  1041. err = g.Wait()
  1042. if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
  1043. log.Trace("%s WebSocket connection closed: %s", logHTTPPrefix(v, r), err.Error())
  1044. return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
  1045. }
  1046. return err
  1047. }
  1048. func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
  1049. poll = readBoolParam(r, false, "x-poll", "poll", "po")
  1050. scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
  1051. since, err = parseSince(r, poll)
  1052. if err != nil {
  1053. return
  1054. }
  1055. filters, err = parseQueryFilters(r)
  1056. if err != nil {
  1057. return
  1058. }
  1059. return
  1060. }
  1061. // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
  1062. // marker, returning only messages that are newer than the marker.
  1063. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
  1064. if since.IsNone() {
  1065. return nil
  1066. }
  1067. messages := make([]*message, 0)
  1068. for _, t := range topics {
  1069. topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
  1070. if err != nil {
  1071. return err
  1072. }
  1073. messages = append(messages, topicMessages...)
  1074. }
  1075. sort.Slice(messages, func(i, j int) bool {
  1076. return messages[i].Time < messages[j].Time
  1077. })
  1078. for _, m := range messages {
  1079. if err := sub(v, m); err != nil {
  1080. return err
  1081. }
  1082. }
  1083. return nil
  1084. }
  1085. // parseSince returns a timestamp identifying the time span from which cached messages should be received.
  1086. //
  1087. // Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h), or
  1088. // "all" for all messages.
  1089. func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
  1090. since := readParam(r, "x-since", "since", "si")
  1091. // Easy cases (empty, all, none)
  1092. if since == "" {
  1093. if poll {
  1094. return sinceAllMessages, nil
  1095. }
  1096. return sinceNoMessages, nil
  1097. } else if since == "all" {
  1098. return sinceAllMessages, nil
  1099. } else if since == "none" {
  1100. return sinceNoMessages, nil
  1101. }
  1102. // ID, timestamp, duration
  1103. if validMessageID(since) {
  1104. return newSinceID(since), nil
  1105. } else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
  1106. return newSinceTime(s), nil
  1107. } else if d, err := time.ParseDuration(since); err == nil {
  1108. return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
  1109. }
  1110. return sinceNoMessages, errHTTPBadRequestSinceInvalid
  1111. }
  1112. func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  1113. w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
  1114. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1115. w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
  1116. return nil
  1117. }
  1118. func (s *Server) topicFromPath(path string) (*topic, error) {
  1119. parts := strings.Split(path, "/")
  1120. if len(parts) < 2 {
  1121. return nil, errHTTPBadRequestTopicInvalid
  1122. }
  1123. topics, err := s.topicsFromIDs(parts[1])
  1124. if err != nil {
  1125. return nil, err
  1126. }
  1127. return topics[0], nil
  1128. }
  1129. func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
  1130. parts := strings.Split(path, "/")
  1131. if len(parts) < 2 {
  1132. return nil, "", errHTTPBadRequestTopicInvalid
  1133. }
  1134. topicIDs := util.SplitNoEmpty(parts[1], ",")
  1135. topics, err := s.topicsFromIDs(topicIDs...)
  1136. if err != nil {
  1137. return nil, "", errHTTPBadRequestTopicInvalid
  1138. }
  1139. return topics, parts[1], nil
  1140. }
  1141. func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
  1142. s.mu.Lock()
  1143. defer s.mu.Unlock()
  1144. topics := make([]*topic, 0)
  1145. for _, id := range ids {
  1146. if util.Contains(disallowedTopics, id) {
  1147. return nil, errHTTPBadRequestTopicDisallowed
  1148. }
  1149. if _, ok := s.topics[id]; !ok {
  1150. if len(s.topics) >= s.config.TotalTopicLimit {
  1151. return nil, errHTTPTooManyRequestsLimitTotalTopics
  1152. }
  1153. s.topics[id] = newTopic(id)
  1154. }
  1155. topics = append(topics, s.topics[id])
  1156. }
  1157. return topics, nil
  1158. }
  1159. func (s *Server) execManager() {
  1160. log.Debug("Manager: Starting")
  1161. defer log.Debug("Manager: Finished")
  1162. // WARNING: Make sure to only selectively lock with the mutex, and be aware that this
  1163. // there is no mutex for the entire function.
  1164. // Expire visitors from rate visitors map
  1165. s.mu.Lock()
  1166. staleVisitors := 0
  1167. for ip, v := range s.visitors {
  1168. if v.Stale() {
  1169. log.Trace("Deleting stale visitor %s", v.ip)
  1170. delete(s.visitors, ip)
  1171. staleVisitors++
  1172. }
  1173. }
  1174. s.mu.Unlock()
  1175. log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
  1176. // Delete expired user tokens
  1177. if s.userManager != nil {
  1178. if err := s.userManager.RemoveExpiredTokens(); err != nil {
  1179. log.Warn("Error expiring user tokens: %s", err.Error())
  1180. }
  1181. }
  1182. // Delete expired attachments
  1183. if s.fileCache != nil {
  1184. ids, err := s.messageCache.AttachmentsExpired()
  1185. if err != nil {
  1186. log.Warn("Manager: Error retrieving expired attachments: %s", err.Error())
  1187. } else if len(ids) > 0 {
  1188. if log.IsDebug() {
  1189. log.Debug("Manager: Deleting attachments %s", strings.Join(ids, ", "))
  1190. }
  1191. if err := s.fileCache.Remove(ids...); err != nil {
  1192. log.Warn("Manager: Error deleting attachments: %s", err.Error())
  1193. }
  1194. if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
  1195. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1196. }
  1197. } else {
  1198. log.Debug("Manager: No expired attachments to delete")
  1199. }
  1200. }
  1201. // DeleteMessages message cache
  1202. log.Debug("Manager: Pruning messages")
  1203. expiredMessageIDs, err := s.messageCache.MessagesExpired()
  1204. if err != nil {
  1205. log.Warn("Manager: Error retrieving expired messages: %s", err.Error())
  1206. } else if len(expiredMessageIDs) > 0 {
  1207. if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
  1208. log.Warn("Manager: Error deleting attachments for expired messages: %s", err.Error())
  1209. }
  1210. if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
  1211. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1212. }
  1213. } else {
  1214. log.Debug("Manager: No expired messages to delete")
  1215. }
  1216. // Message count per topic
  1217. var messages int
  1218. messageCounts, err := s.messageCache.MessageCounts()
  1219. if err != nil {
  1220. log.Warn("Manager: Cannot get message counts: %s", err.Error())
  1221. messageCounts = make(map[string]int) // Empty, so we can continue
  1222. }
  1223. for _, count := range messageCounts {
  1224. messages += count
  1225. }
  1226. // Remove subscriptions without subscribers
  1227. s.mu.Lock()
  1228. var subscribers int
  1229. for _, t := range s.topics {
  1230. subs := t.SubscribersCount()
  1231. msgs, exists := messageCounts[t.ID]
  1232. if subs == 0 && (!exists || msgs == 0) {
  1233. log.Trace("Deleting empty topic %s", t.ID)
  1234. delete(s.topics, t.ID)
  1235. continue
  1236. }
  1237. subscribers += subs
  1238. }
  1239. s.mu.Unlock()
  1240. // Mail stats
  1241. var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
  1242. if s.smtpServerBackend != nil {
  1243. receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
  1244. }
  1245. var sentMailTotal, sentMailSuccess, sentMailFailure int64
  1246. if s.smtpSender != nil {
  1247. sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
  1248. }
  1249. // Print stats
  1250. s.mu.Lock()
  1251. messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
  1252. s.mu.Unlock()
  1253. log.Info("Stats: %d messages published, %d in cache, %d topic(s) active, %d subscriber(s), %d visitor(s), %d mails received (%d successful, %d failed), %d mails sent (%d successful, %d failed)",
  1254. messagesCount, messages, topicsCount, subscribers, visitorsCount,
  1255. receivedMailTotal, receivedMailSuccess, receivedMailFailure,
  1256. sentMailTotal, sentMailSuccess, sentMailFailure)
  1257. }
  1258. func (s *Server) runSMTPServer() error {
  1259. s.smtpServerBackend = newMailBackend(s.config, s.handle)
  1260. s.smtpServer = smtp.NewServer(s.smtpServerBackend)
  1261. s.smtpServer.Addr = s.config.SMTPServerListen
  1262. s.smtpServer.Domain = s.config.SMTPServerDomain
  1263. s.smtpServer.ReadTimeout = 10 * time.Second
  1264. s.smtpServer.WriteTimeout = 10 * time.Second
  1265. s.smtpServer.MaxMessageBytes = 1024 * 1024 // Must be much larger than message size (headers, multipart, etc.)
  1266. s.smtpServer.MaxRecipients = 1
  1267. s.smtpServer.AllowInsecureAuth = true
  1268. return s.smtpServer.ListenAndServe()
  1269. }
  1270. func (s *Server) runManager() {
  1271. for {
  1272. select {
  1273. case <-time.After(s.config.ManagerInterval):
  1274. s.execManager()
  1275. case <-s.closeChan:
  1276. return
  1277. }
  1278. }
  1279. }
  1280. // runStatsResetter runs once a day (usually midnight UTC) to reset all the visitor's message and
  1281. // email counters. The stats are used to display the counters in the web app, as well as for rate limiting.
  1282. func (s *Server) runStatsResetter() {
  1283. for {
  1284. runAt := util.NextOccurrenceUTC(s.config.VisitorStatsResetTime, time.Now())
  1285. timer := time.NewTimer(time.Until(runAt))
  1286. log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
  1287. select {
  1288. case <-timer.C:
  1289. s.resetStats()
  1290. case <-s.closeChan:
  1291. timer.Stop()
  1292. return
  1293. }
  1294. }
  1295. }
  1296. func (s *Server) resetStats() {
  1297. log.Info("Resetting all visitor stats (daily task)")
  1298. s.mu.Lock()
  1299. defer s.mu.Unlock() // Includes the database query to avoid races with other processes
  1300. for _, v := range s.visitors {
  1301. v.ResetStats()
  1302. }
  1303. if s.userManager != nil {
  1304. if err := s.userManager.ResetStats(); err != nil {
  1305. log.Warn("Failed to write to database: %s", err.Error())
  1306. }
  1307. }
  1308. }
  1309. func (s *Server) runFirebaseKeepaliver() {
  1310. if s.firebaseClient == nil {
  1311. return
  1312. }
  1313. v := newVisitor(s.config, s.messageCache, s.userManager, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
  1314. for {
  1315. select {
  1316. case <-time.After(s.config.FirebaseKeepaliveInterval):
  1317. s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
  1318. case <-time.After(s.config.FirebasePollInterval):
  1319. s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
  1320. case <-s.closeChan:
  1321. return
  1322. }
  1323. }
  1324. }
  1325. func (s *Server) runDelayedSender() {
  1326. for {
  1327. select {
  1328. case <-time.After(s.config.DelayedSenderInterval):
  1329. if err := s.sendDelayedMessages(); err != nil {
  1330. log.Warn("Error sending delayed messages: %s", err.Error())
  1331. }
  1332. case <-s.closeChan:
  1333. return
  1334. }
  1335. }
  1336. }
  1337. func (s *Server) sendDelayedMessages() error {
  1338. messages, err := s.messageCache.MessagesDue()
  1339. if err != nil {
  1340. return err
  1341. }
  1342. for _, m := range messages {
  1343. var v *visitor
  1344. if s.userManager != nil && m.User != "" {
  1345. u, err := s.userManager.User(m.User)
  1346. if err != nil {
  1347. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1348. continue
  1349. }
  1350. v = s.visitorFromUser(u, m.Sender)
  1351. } else {
  1352. v = s.visitorFromIP(m.Sender)
  1353. }
  1354. if err := s.sendDelayedMessage(v, m); err != nil {
  1355. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1356. }
  1357. }
  1358. return nil
  1359. }
  1360. func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
  1361. log.Debug("%s Sending delayed message", logMessagePrefix(v, m))
  1362. s.mu.Lock()
  1363. t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
  1364. s.mu.Unlock()
  1365. if ok {
  1366. go func() {
  1367. // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
  1368. if err := t.Publish(v, m); err != nil {
  1369. log.Warn("%s Unable to publish message: %v", logMessagePrefix(v, m), err.Error())
  1370. }
  1371. }()
  1372. }
  1373. if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
  1374. go s.sendToFirebase(v, m)
  1375. }
  1376. if s.config.UpstreamBaseURL != "" {
  1377. go s.forwardPollRequest(v, m)
  1378. }
  1379. if err := s.messageCache.MarkPublished(m); err != nil {
  1380. return err
  1381. }
  1382. return nil
  1383. }
  1384. func (s *Server) limitRequests(next handleFunc) handleFunc {
  1385. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1386. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  1387. return next(w, r, v)
  1388. } else if err := v.RequestAllowed(); err != nil {
  1389. return errHTTPTooManyRequestsLimitRequests
  1390. }
  1391. return next(w, r, v)
  1392. }
  1393. }
  1394. // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
  1395. // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
  1396. func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
  1397. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1398. m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2) // 2x to account for JSON format overhead
  1399. if err != nil {
  1400. return err
  1401. }
  1402. if !topicRegex.MatchString(m.Topic) {
  1403. return errHTTPBadRequestTopicInvalid
  1404. }
  1405. if m.Message == "" {
  1406. m.Message = emptyMessageBody
  1407. }
  1408. r.URL.Path = "/" + m.Topic
  1409. r.Body = io.NopCloser(strings.NewReader(m.Message))
  1410. if m.Title != "" {
  1411. r.Header.Set("X-Title", m.Title)
  1412. }
  1413. if m.Priority != 0 {
  1414. r.Header.Set("X-Priority", fmt.Sprintf("%d", m.Priority))
  1415. }
  1416. if m.Tags != nil && len(m.Tags) > 0 {
  1417. r.Header.Set("X-Tags", strings.Join(m.Tags, ","))
  1418. }
  1419. if m.Attach != "" {
  1420. r.Header.Set("X-Attach", m.Attach)
  1421. }
  1422. if m.Filename != "" {
  1423. r.Header.Set("X-Filename", m.Filename)
  1424. }
  1425. if m.Click != "" {
  1426. r.Header.Set("X-Click", m.Click)
  1427. }
  1428. if m.Icon != "" {
  1429. r.Header.Set("X-Icon", m.Icon)
  1430. }
  1431. if len(m.Actions) > 0 {
  1432. actionsStr, err := json.Marshal(m.Actions)
  1433. if err != nil {
  1434. return errHTTPBadRequestMessageJSONInvalid
  1435. }
  1436. r.Header.Set("X-Actions", string(actionsStr))
  1437. }
  1438. if m.Email != "" {
  1439. r.Header.Set("X-Email", m.Email)
  1440. }
  1441. if m.Delay != "" {
  1442. r.Header.Set("X-Delay", m.Delay)
  1443. }
  1444. return next(w, r, v)
  1445. }
  1446. }
  1447. func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
  1448. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1449. newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
  1450. if err != nil {
  1451. return err
  1452. }
  1453. if err := next(w, newRequest, v); err != nil {
  1454. return &errMatrix{pushKey: newRequest.Header.Get(matrixPushKeyHeader), err: err}
  1455. }
  1456. return nil
  1457. }
  1458. }
  1459. func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
  1460. return s.autorizeTopic(next, user.PermissionWrite)
  1461. }
  1462. func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
  1463. return s.autorizeTopic(next, user.PermissionRead)
  1464. }
  1465. func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
  1466. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1467. if s.userManager == nil {
  1468. return next(w, r, v)
  1469. }
  1470. topics, _, err := s.topicsFromPath(r.URL.Path)
  1471. if err != nil {
  1472. return err
  1473. }
  1474. for _, t := range topics {
  1475. if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil {
  1476. log.Info("unauthorized: %s", err.Error())
  1477. return errHTTPForbidden
  1478. }
  1479. }
  1480. return next(w, r, v)
  1481. }
  1482. }
  1483. // visitor creates or retrieves a rate.Limiter for the given visitor.
  1484. // Note that this function will always return a visitor, even if an error occurs.
  1485. func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
  1486. ip := extractIPAddress(r, s.config.BehindProxy)
  1487. var u *user.User // may stay nil if no auth header!
  1488. if u, err = s.authenticate(r); err != nil {
  1489. log.Debug("authentication failed: %s", err.Error())
  1490. err = errHTTPUnauthorized // Always return visitor, even when error occurs!
  1491. }
  1492. if u != nil {
  1493. v = s.visitorFromUser(u, ip)
  1494. } else {
  1495. v = s.visitorFromIP(ip)
  1496. }
  1497. v.mu.Lock()
  1498. v.user = u
  1499. v.mu.Unlock()
  1500. return v, err // Always return visitor, even when error occurs!
  1501. }
  1502. // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
  1503. // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
  1504. // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
  1505. // query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
  1506. func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
  1507. value := strings.TrimSpace(r.Header.Get("Authorization"))
  1508. queryParam := readQueryParam(r, "authorization", "auth")
  1509. if queryParam != "" {
  1510. a, err := base64.RawURLEncoding.DecodeString(queryParam)
  1511. if err != nil {
  1512. return nil, err
  1513. }
  1514. value = strings.TrimSpace(string(a))
  1515. }
  1516. if value == "" {
  1517. return nil, nil
  1518. } else if s.userManager == nil {
  1519. return nil, errHTTPUnauthorized
  1520. }
  1521. if strings.HasPrefix(value, "Bearer") {
  1522. return s.authenticateBearerAuth(value)
  1523. }
  1524. return s.authenticateBasicAuth(r, value)
  1525. }
  1526. func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
  1527. r.Header.Set("Authorization", value)
  1528. username, password, ok := r.BasicAuth()
  1529. if !ok {
  1530. return nil, errors.New("invalid basic auth")
  1531. }
  1532. return s.userManager.Authenticate(username, password)
  1533. }
  1534. func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
  1535. token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
  1536. return s.userManager.AuthenticateToken(token)
  1537. }
  1538. func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
  1539. s.mu.Lock()
  1540. defer s.mu.Unlock()
  1541. v, exists := s.visitors[visitorID]
  1542. if !exists {
  1543. s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
  1544. return s.visitors[visitorID]
  1545. }
  1546. v.Keepalive()
  1547. return v
  1548. }
  1549. func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
  1550. return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
  1551. }
  1552. func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
  1553. return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user)
  1554. }
  1555. func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
  1556. w.Header().Set("Content-Type", "application/json")
  1557. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1558. if err := json.NewEncoder(w).Encode(v); err != nil {
  1559. return err
  1560. }
  1561. return nil
  1562. }