batching_queue.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package util
  2. import (
  3. "sync"
  4. "time"
  5. )
  6. // BatchingQueue is a queue that creates batches of the enqueued elements based on a
  7. // max batch size and a batch timeout.
  8. //
  9. // Example:
  10. //
  11. // q := NewBatchingQueue[int](2, 500 * time.Millisecond)
  12. // go func() {
  13. // for batch := range q.Dequeue() {
  14. // fmt.Println(batch)
  15. // }
  16. // }()
  17. // q.Enqueue(1)
  18. // q.Enqueue(2)
  19. // q.Enqueue(3)
  20. // time.Sleep(time.Second)
  21. //
  22. // This example will emit batch [1, 2] immediately (because the batch size is 2), and
  23. // a batch [3] after 500ms.
  24. type BatchingQueue[T any] struct {
  25. batchSize int
  26. timeout time.Duration
  27. in []T
  28. out chan []T
  29. mu sync.Mutex
  30. }
  31. // NewBatchingQueue creates a new BatchingQueue
  32. func NewBatchingQueue[T any](batchSize int, timeout time.Duration) *BatchingQueue[T] {
  33. q := &BatchingQueue[T]{
  34. batchSize: batchSize,
  35. timeout: timeout,
  36. in: make([]T, 0),
  37. out: make(chan []T),
  38. }
  39. go q.timeoutTicker()
  40. return q
  41. }
  42. // Enqueue enqueues an element to the queue. If the configured batch size is reached,
  43. // the batch will be emitted immediately.
  44. func (q *BatchingQueue[T]) Enqueue(element T) {
  45. q.mu.Lock()
  46. q.in = append(q.in, element)
  47. var elements []T
  48. if len(q.in) == q.batchSize {
  49. elements = q.dequeueAll()
  50. }
  51. q.mu.Unlock()
  52. if len(elements) > 0 {
  53. q.out <- elements
  54. }
  55. }
  56. // Dequeue returns a channel emitting batches of elements
  57. func (q *BatchingQueue[T]) Dequeue() <-chan []T {
  58. return q.out
  59. }
  60. func (q *BatchingQueue[T]) dequeueAll() []T {
  61. elements := make([]T, len(q.in))
  62. copy(elements, q.in)
  63. q.in = q.in[:0]
  64. return elements
  65. }
  66. func (q *BatchingQueue[T]) timeoutTicker() {
  67. if q.timeout == 0 {
  68. return
  69. }
  70. ticker := time.NewTicker(q.timeout)
  71. for range ticker.C {
  72. q.mu.Lock()
  73. elements := q.dequeueAll()
  74. q.mu.Unlock()
  75. if len(elements) > 0 {
  76. q.out <- elements
  77. }
  78. }
  79. }