topic.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package server
  2. import (
  3. "heckel.io/ntfy/log"
  4. "math/rand"
  5. "sync"
  6. )
  7. // topic represents a channel to which subscribers can subscribe, and publishers
  8. // can publish a message
  9. type topic struct {
  10. ID string
  11. subscribers map[int]subscriber
  12. mu sync.Mutex
  13. }
  14. // subscriber is a function that is called for every new message on a topic
  15. type subscriber func(v *visitor, msg *message) error
  16. // newTopic creates a new topic
  17. func newTopic(id string) *topic {
  18. return &topic{
  19. ID: id,
  20. subscribers: make(map[int]subscriber),
  21. }
  22. }
  23. // Subscribe subscribes to this topic
  24. func (t *topic) Subscribe(s subscriber) int {
  25. t.mu.Lock()
  26. defer t.mu.Unlock()
  27. subscriberID := rand.Int()
  28. t.subscribers[subscriberID] = s
  29. return subscriberID
  30. }
  31. // Unsubscribe removes the subscription from the list of subscribers
  32. func (t *topic) Unsubscribe(id int) {
  33. t.mu.Lock()
  34. defer t.mu.Unlock()
  35. delete(t.subscribers, id)
  36. }
  37. // Publish asynchronously publishes to all subscribers
  38. func (t *topic) Publish(v *visitor, m *message) error {
  39. go func() {
  40. // We want to lock the topic as short as possible, so we make a shallow copy of the
  41. // subscribers map here. Actually sending out the messages then doesn't have to lock.
  42. subscribers := t.subscribersCopy()
  43. if len(subscribers) > 0 {
  44. log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(subscribers))
  45. for _, s := range subscribers {
  46. // We call the subscriber functions in their own Go routines because they are blocking, and
  47. // we don't want individual slow subscribers to be able to block others.
  48. go func(s subscriber) {
  49. if err := s(v, m); err != nil {
  50. log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m))
  51. }
  52. }(s)
  53. }
  54. } else {
  55. log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m))
  56. }
  57. }()
  58. return nil
  59. }
  60. // SubscribersCount returns the number of subscribers to this topic
  61. func (t *topic) SubscribersCount() int {
  62. t.mu.Lock()
  63. defer t.mu.Unlock()
  64. return len(t.subscribers)
  65. }
  66. // subscribersCopy returns a shallow copy of the subscribers map
  67. func (t *topic) subscribersCopy() map[int]subscriber {
  68. t.mu.Lock()
  69. defer t.mu.Unlock()
  70. subscribers := make(map[int]subscriber)
  71. for k, v := range t.subscribers {
  72. subscribers[k] = v
  73. }
  74. return subscribers
  75. }