limit.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package util
  2. import (
  3. "errors"
  4. "golang.org/x/time/rate"
  5. "io"
  6. "sync"
  7. "time"
  8. )
  9. // ErrLimitReached is the error returned by the Limiter and LimitWriter when the predefined limit has been reached
  10. var ErrLimitReached = errors.New("limit reached")
  11. // Limiter is an interface that implements a rate limiting mechanism, e.g. based on time or a fixed value
  12. type Limiter interface {
  13. // Allow adds one to the limiters value, or returns false if the limit has been reached
  14. Allow() bool
  15. // AllowN adds n to the limiters value, or returns false if the limit has been reached
  16. AllowN(n int64) bool
  17. // Value returns the current internal limiter value
  18. Value() int64
  19. // Reset resets the state of the limiter
  20. Reset()
  21. }
  22. // FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
  23. // ErrLimitReached will be returned. FixedLimiter may be used by multiple goroutines.
  24. type FixedLimiter struct {
  25. value int64
  26. limit int64
  27. mu sync.Mutex
  28. }
  29. var _ Limiter = (*FixedLimiter)(nil)
  30. // NewFixedLimiter creates a new Limiter
  31. func NewFixedLimiter(limit int64) *FixedLimiter {
  32. return NewFixedLimiterWithValue(limit, 0)
  33. }
  34. // NewFixedLimiterWithValue creates a new Limiter and sets the initial value
  35. func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
  36. return &FixedLimiter{
  37. limit: limit,
  38. value: value,
  39. }
  40. }
  41. // Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was
  42. // exceeded, false is returned.
  43. func (l *FixedLimiter) Allow() bool {
  44. return l.AllowN(1)
  45. }
  46. // AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
  47. // exceeded after adding n, false is returned.
  48. func (l *FixedLimiter) AllowN(n int64) bool {
  49. l.mu.Lock()
  50. defer l.mu.Unlock()
  51. if l.value+n > l.limit {
  52. return false
  53. }
  54. l.value += n
  55. return true
  56. }
  57. // Value returns the current limiter value
  58. func (l *FixedLimiter) Value() int64 {
  59. l.mu.Lock()
  60. defer l.mu.Unlock()
  61. return l.value
  62. }
  63. // Reset sets the limiter's value back to zero
  64. func (l *FixedLimiter) Reset() {
  65. l.mu.Lock()
  66. defer l.mu.Unlock()
  67. l.value = 0
  68. }
  69. // RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit.
  70. type RateLimiter struct {
  71. r rate.Limit
  72. b int
  73. value int64
  74. limiter *rate.Limiter
  75. mu sync.Mutex
  76. }
  77. var _ Limiter = (*RateLimiter)(nil)
  78. // NewRateLimiter creates a new RateLimiter
  79. func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
  80. return NewRateLimiterWithValue(r, b, 0)
  81. }
  82. // NewRateLimiterWithValue creates a new RateLimiter with the given starting value.
  83. //
  84. // Note that the starting value only has informational value. It does not impact the underlying
  85. // value of the rate.Limiter.
  86. func NewRateLimiterWithValue(r rate.Limit, b int, value int64) *RateLimiter {
  87. return &RateLimiter{
  88. r: r,
  89. b: b,
  90. value: value,
  91. limiter: rate.NewLimiter(r, b),
  92. }
  93. }
  94. // NewBytesLimiter creates a RateLimiter that is meant to be used for a bytes-per-interval limit,
  95. // e.g. 250 MB per day. And example of the underlying idea can be found here: https://go.dev/play/p/0ljgzIZQ6dJ
  96. func NewBytesLimiter(bytes int, interval time.Duration) *RateLimiter {
  97. return NewRateLimiter(rate.Limit(bytes)*rate.Every(interval), bytes)
  98. }
  99. // Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was
  100. // exceeded, false is returned.
  101. func (l *RateLimiter) Allow() bool {
  102. return l.AllowN(1)
  103. }
  104. // AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
  105. // exceeded after adding n, false is returned.
  106. func (l *RateLimiter) AllowN(n int64) bool {
  107. if n <= 0 {
  108. return false // No-op. Can't take back bytes you're written!
  109. }
  110. l.mu.Lock()
  111. defer l.mu.Unlock()
  112. if !l.limiter.AllowN(time.Now(), int(n)) {
  113. return false
  114. }
  115. l.value += n
  116. return true
  117. }
  118. // Value returns the current limiter value
  119. func (l *RateLimiter) Value() int64 {
  120. l.mu.Lock()
  121. defer l.mu.Unlock()
  122. return l.value
  123. }
  124. // Reset sets the limiter's value back to zero, and resets the underlying rate.Limiter
  125. func (l *RateLimiter) Reset() {
  126. l.mu.Lock()
  127. defer l.mu.Unlock()
  128. l.limiter = rate.NewLimiter(l.r, l.b)
  129. l.value = 0
  130. }
  131. // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
  132. // writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
  133. // Each limiter's value is increased with every write.
  134. type LimitWriter struct {
  135. w io.Writer
  136. written int64
  137. limiters []Limiter
  138. mu sync.Mutex
  139. }
  140. // NewLimitWriter creates a new LimitWriter
  141. func NewLimitWriter(w io.Writer, limiters ...Limiter) *LimitWriter {
  142. return &LimitWriter{
  143. w: w,
  144. limiters: limiters,
  145. }
  146. }
  147. // Write passes through all writes to the underlying writer until any of the given limiter's limit is reached
  148. func (w *LimitWriter) Write(p []byte) (n int, err error) {
  149. w.mu.Lock()
  150. defer w.mu.Unlock()
  151. for i := 0; i < len(w.limiters); i++ {
  152. if !w.limiters[i].AllowN(int64(len(p))) {
  153. for j := i - 1; j >= 0; j-- {
  154. w.limiters[j].AllowN(-int64(len(p))) // Revert limiters limits if not allowed
  155. }
  156. return 0, ErrLimitReached
  157. }
  158. }
  159. n, err = w.w.Write(p)
  160. w.written += int64(n)
  161. return
  162. }