client.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. // Package client provides a ntfy client to publish and subscribe to topics
  2. package client
  3. import (
  4. "bufio"
  5. "bytes"
  6. "context"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "github.com/stretchr/testify/require"
  11. "heckel.io/ntfy/crypto"
  12. "heckel.io/ntfy/log"
  13. "heckel.io/ntfy/util"
  14. "io"
  15. "mime/multipart"
  16. "net/http"
  17. "net/http/httptest"
  18. "strings"
  19. "sync"
  20. "time"
  21. )
  22. // Event type constants
  23. const (
  24. MessageEvent = "message"
  25. KeepaliveEvent = "keepalive"
  26. OpenEvent = "open"
  27. PollRequestEvent = "poll_request"
  28. )
  29. const (
  30. maxResponseBytes = 4096
  31. encryptedMessageBytesLimit = 100 * 1024 * 1024 // 100 MB
  32. )
  33. // Client is the ntfy client that can be used to publish and subscribe to ntfy topics
  34. type Client struct {
  35. Messages chan *Message
  36. config *Config
  37. subscriptions map[string]*subscription
  38. mu sync.Mutex
  39. }
  40. // Message is a struct that represents a ntfy message
  41. type Message struct { // TODO combine with server.message
  42. ID string
  43. Event string
  44. Time int64
  45. Topic string
  46. Message string
  47. Title string
  48. Priority int
  49. Tags []string
  50. Click string
  51. Icon string
  52. Attachment *Attachment
  53. // Additional fields
  54. TopicURL string
  55. SubscriptionID string
  56. Raw string
  57. }
  58. // Attachment represents a message attachment
  59. type Attachment struct {
  60. Name string `json:"name"`
  61. Type string `json:"type,omitempty"`
  62. Size int64 `json:"size,omitempty"`
  63. Expires int64 `json:"expires,omitempty"`
  64. URL string `json:"url"`
  65. Owner string `json:"-"` // IP address of uploader, used for rate limiting
  66. }
  67. type subscription struct {
  68. ID string
  69. topicURL string
  70. cancel context.CancelFunc
  71. }
  72. // New creates a new Client using a given Config
  73. func New(config *Config) *Client {
  74. return &Client{
  75. Messages: make(chan *Message, 50), // Allow reading a few messages
  76. config: config,
  77. subscriptions: make(map[string]*subscription),
  78. }
  79. }
  80. // Publish sends a message to a specific topic, optionally using options.
  81. // See PublishReader for details.
  82. func (c *Client) Publish(topic, message string, options ...PublishOption) (*Message, error) {
  83. return c.PublishReader(topic, strings.NewReader(message), options...)
  84. }
  85. // PublishReader sends a message to a specific topic, optionally using options.
  86. //
  87. // A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https://
  88. // (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the
  89. // config (e.g. mytopic -> https://ntfy.sh/mytopic).
  90. //
  91. // To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache,
  92. // WithNoFirebase, and the generic WithHeader.
  93. func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) {
  94. topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
  95. req, _ := http.NewRequest("POST", topicURL, body)
  96. for _, option := range options {
  97. if err := option(req); err != nil {
  98. return nil, err
  99. }
  100. }
  101. log.Debug("%s Publishing message with headers %s", util.ShortTopicURL(topicURL), req.Header)
  102. resp, err := http.DefaultClient.Do(req)
  103. if err != nil {
  104. return nil, err
  105. }
  106. defer resp.Body.Close()
  107. b, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
  108. if err != nil {
  109. return nil, err
  110. }
  111. if resp.StatusCode != http.StatusOK {
  112. return nil, errors.New(strings.TrimSpace(string(b)))
  113. }
  114. m, err := toMessage(string(b), topicURL, "")
  115. if err != nil {
  116. return nil, err
  117. }
  118. return m, nil
  119. }
  120. func (c *Client) PublishEncryptedReader(topic string, body io.Reader, password string, options ...PublishOption) (*Message, error) {
  121. topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
  122. key := crypto.DeriveKey(password, topicURL)
  123. peaked, err := util.PeekLimit(io.NopCloser(body), encryptedMessageBytesLimit)
  124. if err != nil {
  125. return nil, err
  126. }
  127. ciphertext, err := crypto.Encrypt(peaked.PeekedBytes, key)
  128. if err != nil {
  129. return nil, err
  130. }
  131. var b bytes.Buffer
  132. body = strings.NewReader(ciphertext)
  133. w := multipart.NewWriter(&b)
  134. for _, part := range parts {
  135. mw, _ := w.CreateFormField(part.key)
  136. _, err := io.Copy(mw, strings.NewReader(part.value))
  137. require.Nil(t, err)
  138. }
  139. require.Nil(t, w.Close())
  140. rr := httptest.NewRecorder()
  141. req, err := http.NewRequest(method, url, &b)
  142. if err != nil {
  143. t.Fatal(err)
  144. }
  145. req, _ := http.NewRequest("POST", topicURL, body)
  146. req.Header.Set("X-Encoding", "jwe")
  147. for _, option := range options {
  148. if err := option(req); err != nil {
  149. return nil, err
  150. }
  151. }
  152. log.Debug("%s Publishing message with headers %s", util.ShortTopicURL(topicURL), req.Header)
  153. resp, err := http.DefaultClient.Do(req)
  154. if err != nil {
  155. return nil, err
  156. }
  157. defer resp.Body.Close()
  158. b, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
  159. if err != nil {
  160. return nil, err
  161. }
  162. if resp.StatusCode != http.StatusOK {
  163. return nil, errors.New(strings.TrimSpace(string(b)))
  164. }
  165. m, err := toMessage(string(b), topicURL, "")
  166. if err != nil {
  167. return nil, err
  168. }
  169. return m, nil
  170. }
  171. // Poll queries a topic for all (or a limited set) of messages. Unlike Subscribe, this method only polls for
  172. // messages and does not subscribe to messages that arrive after this call.
  173. //
  174. // A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https://
  175. // (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the
  176. // config (e.g. mytopic -> https://ntfy.sh/mytopic).
  177. //
  178. // By default, all messages will be returned, but you can change this behavior using a SubscribeOption.
  179. // See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam.
  180. func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, error) {
  181. ctx := context.Background()
  182. messages := make([]*Message, 0)
  183. msgChan := make(chan *Message)
  184. errChan := make(chan error)
  185. topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
  186. log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL))
  187. options = append(options, WithPoll())
  188. go func() {
  189. err := performSubscribeRequest(ctx, msgChan, topicURL, "", options...)
  190. close(msgChan)
  191. errChan <- err
  192. }()
  193. for m := range msgChan {
  194. messages = append(messages, m)
  195. }
  196. return messages, <-errChan
  197. }
  198. // Subscribe subscribes to a topic to listen for newly incoming messages. The method starts a connection in the
  199. // background and returns new messages via the Messages channel.
  200. //
  201. // A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https://
  202. // (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the
  203. // config (e.g. mytopic -> https://ntfy.sh/mytopic).
  204. //
  205. // By default, only new messages will be returned, but you can change this behavior using a SubscribeOption.
  206. // See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam.
  207. //
  208. // The method returns a unique subscriptionID that can be used in Unsubscribe.
  209. //
  210. // Example:
  211. //
  212. // c := client.New(client.NewConfig())
  213. // subscriptionID := c.Subscribe("mytopic")
  214. // for m := range c.Messages {
  215. // fmt.Printf("New message: %s", m.Message)
  216. // }
  217. func (c *Client) Subscribe(topic string, options ...SubscribeOption) string {
  218. c.mu.Lock()
  219. defer c.mu.Unlock()
  220. subscriptionID := util.RandomString(10)
  221. topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
  222. log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL))
  223. ctx, cancel := context.WithCancel(context.Background())
  224. c.subscriptions[subscriptionID] = &subscription{
  225. ID: subscriptionID,
  226. topicURL: topicURL,
  227. cancel: cancel,
  228. }
  229. go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...)
  230. return subscriptionID
  231. }
  232. // Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique
  233. // subscriptionID returned in Subscribe.
  234. func (c *Client) Unsubscribe(subscriptionID string) {
  235. c.mu.Lock()
  236. defer c.mu.Unlock()
  237. sub, ok := c.subscriptions[subscriptionID]
  238. if !ok {
  239. return
  240. }
  241. delete(c.subscriptions, subscriptionID)
  242. sub.cancel()
  243. }
  244. // UnsubscribeAll unsubscribes from a topic that has been previously subscribed with Subscribe.
  245. // If there are multiple subscriptions matching the topic, all of them are unsubscribed from.
  246. //
  247. // A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https://
  248. // (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the
  249. // config (e.g. mytopic -> https://ntfy.sh/mytopic).
  250. func (c *Client) UnsubscribeAll(topic string) {
  251. c.mu.Lock()
  252. defer c.mu.Unlock()
  253. topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
  254. for _, sub := range c.subscriptions {
  255. if sub.topicURL == topicURL {
  256. delete(c.subscriptions, sub.ID)
  257. sub.cancel()
  258. }
  259. }
  260. }
  261. func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) {
  262. for {
  263. // TODO The retry logic is crude and may lose messages. It should record the last message like the
  264. // Android client, use since=, and do incremental backoff too
  265. if err := performSubscribeRequest(ctx, msgChan, topicURL, subcriptionID, options...); err != nil {
  266. log.Warn("%s Connection failed: %s", util.ShortTopicURL(topicURL), err.Error())
  267. }
  268. select {
  269. case <-ctx.Done():
  270. log.Info("%s Connection exited", util.ShortTopicURL(topicURL))
  271. return
  272. case <-time.After(10 * time.Second): // TODO Add incremental backoff
  273. }
  274. }
  275. }
  276. func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicURL string, subscriptionID string, options ...SubscribeOption) error {
  277. streamURL := fmt.Sprintf("%s/json", topicURL)
  278. log.Debug("%s Listening to %s", util.ShortTopicURL(topicURL), streamURL)
  279. req, err := http.NewRequestWithContext(ctx, http.MethodGet, streamURL, nil)
  280. if err != nil {
  281. return err
  282. }
  283. for _, option := range options {
  284. if err := option(req); err != nil {
  285. return err
  286. }
  287. }
  288. resp, err := http.DefaultClient.Do(req)
  289. if err != nil {
  290. return err
  291. }
  292. defer resp.Body.Close()
  293. if resp.StatusCode != http.StatusOK {
  294. b, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
  295. if err != nil {
  296. return err
  297. }
  298. return errors.New(strings.TrimSpace(string(b)))
  299. }
  300. scanner := bufio.NewScanner(resp.Body)
  301. for scanner.Scan() {
  302. messageJSON := scanner.Text()
  303. m, err := toMessage(messageJSON, topicURL, subscriptionID)
  304. if err != nil {
  305. return err
  306. }
  307. log.Trace("%s Message received: %s", util.ShortTopicURL(topicURL), messageJSON)
  308. if m.Event == MessageEvent {
  309. msgChan <- m
  310. }
  311. }
  312. return nil
  313. }
  314. func toMessage(s, topicURL, subscriptionID string) (*Message, error) {
  315. var m *Message
  316. if err := json.NewDecoder(strings.NewReader(s)).Decode(&m); err != nil {
  317. return nil, err
  318. }
  319. m.TopicURL = topicURL
  320. m.SubscriptionID = subscriptionID
  321. m.Raw = s
  322. return m, nil
  323. }