topic.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package server
  2. import (
  3. "heckel.io/ntfy/log"
  4. "math/rand"
  5. "sync"
  6. )
  7. // topic represents a channel to which subscribers can subscribe, and publishers
  8. // can publish a message
  9. type topic struct {
  10. ID string
  11. subscribers map[int]*topicSubscriber
  12. mu sync.Mutex
  13. }
  14. type topicSubscriber struct {
  15. userID string // User ID associated with this subscription, may be empty
  16. subscriber subscriber
  17. cancel func()
  18. }
  19. // subscriber is a function that is called for every new message on a topic
  20. type subscriber func(v *visitor, msg *message) error
  21. // newTopic creates a new topic
  22. func newTopic(id string) *topic {
  23. return &topic{
  24. ID: id,
  25. subscribers: make(map[int]*topicSubscriber),
  26. }
  27. }
  28. // Subscribe subscribes to this topic
  29. func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
  30. t.mu.Lock()
  31. defer t.mu.Unlock()
  32. subscriberID := rand.Int()
  33. t.subscribers[subscriberID] = &topicSubscriber{
  34. userID: userID, // May be empty
  35. subscriber: s,
  36. cancel: cancel,
  37. }
  38. return subscriberID
  39. }
  40. // Unsubscribe removes the subscription from the list of subscribers
  41. func (t *topic) Unsubscribe(id int) {
  42. t.mu.Lock()
  43. defer t.mu.Unlock()
  44. delete(t.subscribers, id)
  45. }
  46. // Publish asynchronously publishes to all subscribers
  47. func (t *topic) Publish(v *visitor, m *message) error {
  48. go func() {
  49. // We want to lock the topic as short as possible, so we make a shallow copy of the
  50. // subscribers map here. Actually sending out the messages then doesn't have to lock.
  51. subscribers := t.subscribersCopy()
  52. if len(subscribers) > 0 {
  53. logvm(v, m).Tag(tagPublish).Debug("Forwarding to %d subscriber(s)", len(subscribers))
  54. for _, s := range subscribers {
  55. // We call the subscriber functions in their own Go routines because they are blocking, and
  56. // we don't want individual slow subscribers to be able to block others.
  57. go func(s subscriber) {
  58. if err := s(v, m); err != nil {
  59. logvm(v, m).Tag(tagPublish).Err(err).Warn("Error forwarding to subscriber")
  60. }
  61. }(s.subscriber)
  62. }
  63. } else {
  64. logvm(v, m).Tag(tagPublish).Trace("No stream or WebSocket subscribers, not forwarding")
  65. }
  66. }()
  67. return nil
  68. }
  69. // SubscribersCount returns the number of subscribers to this topic
  70. func (t *topic) SubscribersCount() int {
  71. t.mu.Lock()
  72. defer t.mu.Unlock()
  73. return len(t.subscribers)
  74. }
  75. // CancelSubscribers calls the cancel function for all subscribers, forcing
  76. func (t *topic) CancelSubscribers(exceptUserID string) {
  77. t.mu.Lock()
  78. defer t.mu.Unlock()
  79. for _, s := range t.subscribers {
  80. if s.userID != exceptUserID {
  81. log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.userID)
  82. s.cancel()
  83. }
  84. }
  85. }
  86. // subscribersCopy returns a shallow copy of the subscribers map
  87. func (t *topic) subscribersCopy() map[int]*topicSubscriber {
  88. t.mu.Lock()
  89. defer t.mu.Unlock()
  90. subscribers := make(map[int]*topicSubscriber)
  91. for k, sub := range t.subscribers {
  92. subscribers[k] = &topicSubscriber{
  93. userID: sub.userID,
  94. subscriber: sub.subscriber,
  95. cancel: sub.cancel,
  96. }
  97. }
  98. return subscribers
  99. }