Просмотр исходного кода

Making RateLimiter and FixedLimiter, so they can both work with LimitWriter

Philipp Heckel 3 лет назад
Родитель
Сommit
c76e55a1c8
7 измененных файлов с 127 добавлено и 67 удалено
  1. 2 2
      server/file_cache.go
  2. 2 2
      server/file_cache_test.go
  3. 1 1
      server/server.go
  4. 0 7
      server/server_test.go
  5. 4 4
      server/visitor.go
  6. 42 29
      util/limit.go
  7. 76 22
      util/limit_test.go

+ 2 - 2
server/file_cache.go

@@ -40,7 +40,7 @@ func newFileCache(dir string, totalSizeLimit int64, fileSizeLimit int64) (*fileC
 	}, nil
 }
 
-func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (int64, error) {
+func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
 	if !fileIDRegex.MatchString(id) {
 		return 0, errInvalidFileID
 	}
@@ -53,7 +53,7 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i
 		return 0, err
 	}
 	defer f.Close()
-	limiters = append(limiters, util.NewLimiter(c.Remaining()), util.NewLimiter(c.fileSizeLimit))
+	limiters = append(limiters, util.NewFixedLimiter(c.Remaining()), util.NewFixedLimiter(c.fileSizeLimit))
 	limitWriter := util.NewLimitWriter(f, limiters...)
 	size, err := io.Copy(limitWriter, in)
 	if err != nil {

+ 2 - 2
server/file_cache_test.go

@@ -16,7 +16,7 @@ var (
 
 func TestFileCache_Write_Success(t *testing.T) {
 	dir, c := newTestFileCache(t)
-	size, err := c.Write("abc", strings.NewReader("normal file"), util.NewLimiter(999))
+	size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999))
 	require.Nil(t, err)
 	require.Equal(t, int64(11), size)
 	require.Equal(t, "normal file", readFile(t, dir+"/abc"))
@@ -64,7 +64,7 @@ func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) {
 
 func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
 	dir, c := newTestFileCache(t)
-	_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewLimiter(1000))
+	_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
 	require.Equal(t, util.ErrLimitReached, err)
 	require.NoFileExists(t, dir+"/abc")
 }

+ 1 - 1
server/server.go

@@ -648,7 +648,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
 	if m.Message == "" {
 		m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
 	}
-	m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewLimiter(remainingVisitorAttachmentSize))
+	m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewFixedLimiter(remainingVisitorAttachmentSize))
 	if err == util.ErrLimitReached {
 		return errHTTPBadRequestAttachmentTooLarge
 	} else if err != nil {

+ 0 - 7
server/server_test.go

@@ -909,13 +909,6 @@ func toMessage(t *testing.T, s string) *message {
 	return &m
 }
 
-func tempFile(t *testing.T, length int) (filename string, content string) {
-	filename = filepath.Join(t.TempDir(), util.RandomString(10))
-	content = util.RandomString(length)
-	require.Nil(t, os.WriteFile(filename, []byte(content), 0600))
-	return
-}
-
 func toHTTPError(t *testing.T, s string) *errHTTP {
 	var e errHTTP
 	require.Nil(t, json.NewDecoder(strings.NewReader(s)).Decode(&e))

+ 4 - 4
server/visitor.go

@@ -24,7 +24,7 @@ type visitor struct {
 	config        *Config
 	ip            string
 	requests      *rate.Limiter
-	subscriptions *util.Limiter
+	subscriptions util.Limiter
 	emails        *rate.Limiter
 	seen          time.Time
 	mu            sync.Mutex
@@ -35,7 +35,7 @@ func newVisitor(conf *Config, ip string) *visitor {
 		config:        conf,
 		ip:            ip,
 		requests:      rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
-		subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
+		subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
 		emails:        rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
 		seen:          time.Now(),
 	}
@@ -62,7 +62,7 @@ func (v *visitor) EmailAllowed() error {
 func (v *visitor) SubscriptionAllowed() error {
 	v.mu.Lock()
 	defer v.mu.Unlock()
-	if err := v.subscriptions.Add(1); err != nil {
+	if err := v.subscriptions.Allow(1); err != nil {
 		return errVisitorLimitReached
 	}
 	return nil
@@ -71,7 +71,7 @@ func (v *visitor) SubscriptionAllowed() error {
 func (v *visitor) RemoveSubscription() {
 	v.mu.Lock()
 	defer v.mu.Unlock()
-	v.subscriptions.Sub(1)
+	v.subscriptions.Allow(-1)
 }
 
 func (v *visitor) Keepalive() {

+ 42 - 29
util/limit.go

@@ -2,31 +2,39 @@ package util
 
 import (
 	"errors"
+	"golang.org/x/time/rate"
 	"io"
 	"sync"
+	"time"
 )
 
 // ErrLimitReached is the error returned by the Limiter and LimitWriter when the predefined limit has been reached
 var ErrLimitReached = errors.New("limit reached")
 
-// Limiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
-// ErrLimitReached will be returned. Limiter may be used by multiple goroutines.
-type Limiter struct {
+// Limiter is an interface that implements a rate limiting mechanism, e.g. based on time or a fixed value
+type Limiter interface {
+	// Allow adds n to the limiters internal value, or returns ErrLimitReached if the limit has been reached
+	Allow(n int64) error
+}
+
+// FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
+// ErrLimitReached will be returned. FixedLimiter may be used by multiple goroutines.
+type FixedLimiter struct {
 	value int64
 	limit int64
 	mu    sync.Mutex
 }
 
-// NewLimiter creates a new Limiter
-func NewLimiter(limit int64) *Limiter {
-	return &Limiter{
+// NewFixedLimiter creates a new Limiter
+func NewFixedLimiter(limit int64) *FixedLimiter {
+	return &FixedLimiter{
 		limit: limit,
 	}
 }
 
-// Add adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
+// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
 // exceeded after adding n, ErrLimitReached is returned.
-func (l *Limiter) Add(n int64) error {
+func (l *FixedLimiter) Allow(n int64) error {
 	l.mu.Lock()
 	defer l.mu.Unlock()
 	if l.value+n > l.limit {
@@ -36,29 +44,34 @@ func (l *Limiter) Add(n int64) error {
 	return nil
 }
 
-// Sub subtracts a value from the limiters internal value
-func (l *Limiter) Sub(n int64) {
-	l.Add(-n)
+// RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit.
+type RateLimiter struct {
+	limiter *rate.Limiter
 }
 
-// Set sets the value of the limiter to n. This function ignores the limit. It is meant to set the value
-// based on reality.
-func (l *Limiter) Set(n int64) {
-	l.mu.Lock()
-	l.value = n
-	l.mu.Unlock()
+// NewRateLimiter creates a new RateLimiter
+func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
+	return &RateLimiter{
+		limiter: rate.NewLimiter(r, b),
+	}
 }
 
-// Value returns the internal value of the limiter
-func (l *Limiter) Value() int64 {
-	l.mu.Lock()
-	defer l.mu.Unlock()
-	return l.value
+// NewBytesLimiter creates a RateLimiter that is meant to be used for a bytes-per-interval limit,
+// e.g. 250 MB per day. And example of the underlying idea can be found here: https://go.dev/play/p/0ljgzIZQ6dJ
+func NewBytesLimiter(bytes int, interval time.Duration) *RateLimiter {
+	return NewRateLimiter(rate.Limit(bytes)*rate.Every(interval), bytes)
 }
 
-// Limit returns the defined limit
-func (l *Limiter) Limit() int64 {
-	return l.limit
+// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
+// exceeded after adding n, ErrLimitReached is returned.
+func (l *RateLimiter) Allow(n int64) error {
+	if n <= 0 {
+		return nil // No-op. Can't take back bytes you're written!
+	}
+	if !l.limiter.AllowN(time.Now(), int(n)) {
+		return ErrLimitReached
+	}
+	return nil
 }
 
 // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
@@ -67,12 +80,12 @@ func (l *Limiter) Limit() int64 {
 type LimitWriter struct {
 	w        io.Writer
 	written  int64
-	limiters []*Limiter
+	limiters []Limiter
 	mu       sync.Mutex
 }
 
 // NewLimitWriter creates a new LimitWriter
-func NewLimitWriter(w io.Writer, limiters ...*Limiter) *LimitWriter {
+func NewLimitWriter(w io.Writer, limiters ...Limiter) *LimitWriter {
 	return &LimitWriter{
 		w:        w,
 		limiters: limiters,
@@ -84,9 +97,9 @@ func (w *LimitWriter) Write(p []byte) (n int, err error) {
 	w.mu.Lock()
 	defer w.mu.Unlock()
 	for i := 0; i < len(w.limiters); i++ {
-		if err := w.limiters[i].Add(int64(len(p))); err != nil {
+		if err := w.limiters[i].Allow(int64(len(p))); err != nil {
 			for j := i - 1; j >= 0; j-- {
-				w.limiters[j].Sub(int64(len(p)))
+				w.limiters[j].Allow(-int64(len(p))) // Revert limiters limits if allowed
 			}
 			return 0, ErrLimitReached
 		}

+ 76 - 22
util/limit_test.go

@@ -2,34 +2,51 @@ package util
 
 import (
 	"bytes"
+	"github.com/stretchr/testify/require"
 	"testing"
+	"time"
 )
 
-func TestLimiter_Add(t *testing.T) {
-	l := NewLimiter(10)
-	if err := l.Add(5); err != nil {
+func TestFixedLimiter_Add(t *testing.T) {
+	l := NewFixedLimiter(10)
+	if err := l.Allow(5); err != nil {
 		t.Fatal(err)
 	}
-	if err := l.Add(5); err != nil {
+	if err := l.Allow(5); err != nil {
 		t.Fatal(err)
 	}
-	if err := l.Add(5); err != ErrLimitReached {
+	if err := l.Allow(5); err != ErrLimitReached {
 		t.Fatalf("expected ErrLimitReached, got %#v", err)
 	}
 }
 
-func TestLimiter_AddSet(t *testing.T) {
-	l := NewLimiter(10)
-	l.Add(5)
-	if l.Value() != 5 {
-		t.Fatalf("expected value to be %d, got %d", 5, l.Value())
+func TestFixedLimiter_AddSub(t *testing.T) {
+	l := NewFixedLimiter(10)
+	l.Allow(5)
+	if l.value != 5 {
+		t.Fatalf("expected value to be %d, got %d", 5, l.value)
 	}
-	l.Set(7)
-	if l.Value() != 7 {
-		t.Fatalf("expected value to be %d, got %d", 7, l.Value())
+	l.Allow(-2)
+	if l.value != 3 {
+		t.Fatalf("expected value to be %d, got %d", 7, l.value)
 	}
 }
 
+func TestBytesLimiter_Add_Simple(t *testing.T) {
+	l := NewBytesLimiter(250*1024*1024, 24*time.Hour) // 250 MB per 24h
+	require.Nil(t, l.Allow(100*1024*1024))
+	require.Nil(t, l.Allow(100*1024*1024))
+	require.Equal(t, ErrLimitReached, l.Allow(300*1024*1024))
+}
+
+func TestBytesLimiter_Add_Wait(t *testing.T) {
+	l := NewBytesLimiter(250*1024*1024, 24*time.Hour) // 250 MB per 24h (~ 303 bytes per 100ms)
+	require.Nil(t, l.Allow(250*1024*1024))
+	require.Equal(t, ErrLimitReached, l.Allow(400))
+	time.Sleep(200 * time.Millisecond)
+	require.Nil(t, l.Allow(400))
+}
+
 func TestLimitWriter_WriteNoLimiter(t *testing.T) {
 	var buf bytes.Buffer
 	lw := NewLimitWriter(&buf)
@@ -46,7 +63,7 @@ func TestLimitWriter_WriteNoLimiter(t *testing.T) {
 
 func TestLimitWriter_WriteOneLimiter(t *testing.T) {
 	var buf bytes.Buffer
-	l := NewLimiter(10)
+	l := NewFixedLimiter(10)
 	lw := NewLimitWriter(&buf, l)
 	if _, err := lw.Write(make([]byte, 10)); err != nil {
 		t.Fatal(err)
@@ -57,15 +74,15 @@ func TestLimitWriter_WriteOneLimiter(t *testing.T) {
 	if buf.Len() != 10 {
 		t.Fatalf("expected buffer length to be %d, got %d", 10, buf.Len())
 	}
-	if l.Value() != 10 {
-		t.Fatalf("expected limiter value to be %d, got %d", 10, l.Value())
+	if l.value != 10 {
+		t.Fatalf("expected limiter value to be %d, got %d", 10, l.value)
 	}
 }
 
 func TestLimitWriter_WriteTwoLimiters(t *testing.T) {
 	var buf bytes.Buffer
-	l1 := NewLimiter(11)
-	l2 := NewLimiter(9)
+	l1 := NewFixedLimiter(11)
+	l2 := NewFixedLimiter(9)
 	lw := NewLimitWriter(&buf, l1, l2)
 	if _, err := lw.Write(make([]byte, 8)); err != nil {
 		t.Fatal(err)
@@ -76,10 +93,47 @@ func TestLimitWriter_WriteTwoLimiters(t *testing.T) {
 	if buf.Len() != 8 {
 		t.Fatalf("expected buffer length to be %d, got %d", 8, buf.Len())
 	}
-	if l1.Value() != 8 {
-		t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.Value())
+	if l1.value != 8 {
+		t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.value)
 	}
-	if l2.Value() != 8 {
-		t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.Value())
+	if l2.value != 8 {
+		t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.value)
 	}
 }
+
+func TestLimitWriter_WriteTwoDifferentLimiters(t *testing.T) {
+	var buf bytes.Buffer
+	l1 := NewFixedLimiter(32)
+	l2 := NewBytesLimiter(8, 200*time.Millisecond)
+	lw := NewLimitWriter(&buf, l1, l2)
+	_, err := lw.Write(make([]byte, 8))
+	require.Nil(t, err)
+	_, err = lw.Write(make([]byte, 4))
+	require.Equal(t, ErrLimitReached, err)
+}
+
+func TestLimitWriter_WriteTwoDifferentLimiters_Wait(t *testing.T) {
+	var buf bytes.Buffer
+	l1 := NewFixedLimiter(32)
+	l2 := NewBytesLimiter(8, 200*time.Millisecond)
+	lw := NewLimitWriter(&buf, l1, l2)
+	_, err := lw.Write(make([]byte, 8))
+	require.Nil(t, err)
+	time.Sleep(250 * time.Millisecond)
+	_, err = lw.Write(make([]byte, 8))
+	require.Nil(t, err)
+	_, err = lw.Write(make([]byte, 4))
+	require.Equal(t, ErrLimitReached, err)
+}
+
+func TestLimitWriter_WriteTwoDifferentLimiters_Wait_FixedLimiterFail(t *testing.T) {
+	var buf bytes.Buffer
+	l1 := NewFixedLimiter(11) // <<< This fails below
+	l2 := NewBytesLimiter(8, 200*time.Millisecond)
+	lw := NewLimitWriter(&buf, l1, l2)
+	_, err := lw.Write(make([]byte, 8))
+	require.Nil(t, err)
+	time.Sleep(250 * time.Millisecond)
+	_, err = lw.Write(make([]byte, 8)) // <<< FixedLimiter fails
+	require.Equal(t, ErrLimitReached, err)
+}