Просмотр исходного кода

Attachment behavior fix for Firefox

Philipp Heckel 2 лет назад
Родитель
Сommit
aba7e86cbc
10 измененных файлов с 163 добавлено и 97 удалено
  1. 1 1
      server/message_cache.go
  2. 2 2
      server/message_cache_test.go
  3. 35 21
      server/server.go
  4. 3 3
      server/server_test.go
  5. 27 1
      server/visitor.go
  6. 16 0
      util/limit.go
  7. 0 61
      util/peak.go
  8. 61 0
      util/peek.go
  9. 7 7
      util/peek_test.go
  10. 11 1
      web/src/app/Api.js

+ 1 - 1
server/message_cache.go

@@ -355,7 +355,7 @@ func (c *messageCache) Prune(olderThan time.Time) error {
 	return err
 }
 
-func (c *messageCache) AttachmentsSize(owner string) (int64, error) {
+func (c *messageCache) AttachmentBytesUsed(owner string) (int64, error) {
 	rows, err := c.db.Query(selectAttachmentsSizeQuery, owner, time.Now().Unix())
 	if err != nil {
 		return 0, err

+ 2 - 2
server/message_cache_test.go

@@ -337,11 +337,11 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
 	require.Equal(t, "1.2.3.4", messages[1].Attachment.Owner)
 
-	size, err := c.AttachmentsSize("1.2.3.4")
+	size, err := c.AttachmentBytesUsed("1.2.3.4")
 	require.Nil(t, err)
 	require.Equal(t, int64(30000), size)
 
-	size, err = c.AttachmentsSize("5.6.7.8")
+	size, err = c.AttachmentBytesUsed("5.6.7.8")
 	require.Nil(t, err)
 	require.Equal(t, int64(0), size)
 

+ 35 - 21
server/server.go

@@ -66,6 +66,7 @@ var (
 	publishPathRegex       = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
 
 	webConfigPath    = "/config.js"
+	userStatsPath    = "/user/stats"
 	staticRegex      = regexp.MustCompile(`^/static/.+`)
 	docsRegex        = regexp.MustCompile(`^/docs(|/.*)$`)
 	fileRegex        = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
@@ -269,6 +270,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 		return s.handleEmpty(w, r, v)
 	} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
 		return s.handleWebConfig(w, r)
+	} else if r.Method == http.MethodGet && r.URL.Path == userStatsPath {
+		return s.handleUserStats(w, r, v)
 	} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
 		return s.handleStatic(w, r)
 	} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
@@ -351,6 +354,19 @@ var config = {
 	return err
 }
 
+func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	stats, err := v.Stats()
+	if err != nil {
+		return err
+	}
+	w.Header().Set("Content-Type", "text/json")
+	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
+	if err := json.NewEncoder(w).Encode(stats); err != nil {
+		return err
+	}
+	return nil
+}
+
 func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
 	r.URL.Path = webSiteDir + r.URL.Path
 	util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
@@ -395,8 +411,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 	if err != nil {
 		return err
 	}
-	return errHTTPEntityTooLargeAttachmentTooLarge
-	body, err := util.Peak(r.Body, s.config.MessageLimit)
+	body, err := util.Peek(r.Body, s.config.MessageLimit)
 	if err != nil {
 		return err
 	}
@@ -540,35 +555,35 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 //    If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
 // 5. curl -T file.txt ntfy.sh/mytopic
 //    If file.txt is > message limit, treat it as an attachment
-func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser, unifiedpush bool) error {
+func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
 	if unifiedpush {
 		return s.handleBodyAsMessageAutoDetect(m, body) // Case 1
 	} else if m.Attachment != nil && m.Attachment.URL != "" {
 		return s.handleBodyAsTextMessage(m, body) // Case 2
 	} else if m.Attachment != nil && m.Attachment.Name != "" {
 		return s.handleBodyAsAttachment(r, v, m, body) // Case 3
-	} else if !body.LimitReached && utf8.Valid(body.PeakedBytes) {
+	} else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
 		return s.handleBodyAsTextMessage(m, body) // Case 4
 	}
 	return s.handleBodyAsAttachment(r, v, m, body) // Case 5
 }
 
-func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeakedReadCloser) error {
-	if utf8.Valid(body.PeakedBytes) {
-		m.Message = string(body.PeakedBytes) // Do not trim
+func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
+	if utf8.Valid(body.PeekedBytes) {
+		m.Message = string(body.PeekedBytes) // Do not trim
 	} else {
-		m.Message = base64.StdEncoding.EncodeToString(body.PeakedBytes)
+		m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes)
 		m.Encoding = encodingBase64
 	}
 	return nil
 }
 
-func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeakedReadCloser) error {
-	if !utf8.Valid(body.PeakedBytes) {
+func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
+	if !utf8.Valid(body.PeekedBytes) {
 		return errHTTPBadRequestMessageNotUTF8
 	}
-	if len(body.PeakedBytes) > 0 { // Empty body should not override message (publish via GET!)
-		m.Message = strings.TrimSpace(string(body.PeakedBytes)) // Truncates the message to the peak limit if required
+	if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
+		m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
 	}
 	if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
 		m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
@@ -576,21 +591,20 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeakedReadCloser
 	return nil
 }
 
-func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser) error {
+func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
 	if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
 		return errHTTPBadRequestAttachmentsDisallowed
 	} else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() {
 		return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
 	}
-	visitorAttachmentsSize, err := s.messageCache.AttachmentsSize(v.ip)
+	visitorStats, err := v.Stats()
 	if err != nil {
 		return err
 	}
-	remainingVisitorAttachmentSize := s.config.VisitorAttachmentTotalSizeLimit - visitorAttachmentsSize
 	contentLengthStr := r.Header.Get("Content-Length")
 	if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
 		contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
-		if err == nil && (contentLength > remainingVisitorAttachmentSize || contentLength > s.config.AttachmentFileSizeLimit) {
+		if err == nil && (contentLength > visitorStats.VisitorAttachmentBytesRemaining || contentLength > s.config.AttachmentFileSizeLimit) {
 			return errHTTPEntityTooLargeAttachmentTooLarge
 		}
 	}
@@ -600,7 +614,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
 	var ext string
 	m.Attachment.Owner = v.ip // Important for attachment rate limiting
 	m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix()
-	m.Attachment.Type, ext = util.DetectContentType(body.PeakedBytes, m.Attachment.Name)
+	m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
 	m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
 	if m.Attachment.Name == "" {
 		m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
@@ -608,7 +622,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
 	if m.Message == "" {
 		m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
 	}
-	m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(remainingVisitorAttachmentSize))
+	m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(visitorStats.VisitorAttachmentBytesRemaining))
 	if err == util.ErrLimitReached {
 		return errHTTPEntityTooLargeAttachmentTooLarge
 	} else if err != nil {
@@ -1097,11 +1111,11 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
 	}
 }
 
-// transformBodyJSON peaks the request body, reads the JSON, and converts it to headers
+// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
 // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
 func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		body, err := util.Peak(r.Body, s.config.MessageLimit)
+		body, err := util.Peek(r.Body, s.config.MessageLimit)
 		if err != nil {
 			return err
 		}
@@ -1217,7 +1231,7 @@ func (s *Server) visitor(r *http.Request) *visitor {
 	}
 	v, exists := s.visitors[ip]
 	if !exists {
-		s.visitors[ip] = newVisitor(s.config, ip)
+		s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
 		return s.visitors[ip]
 	}
 	v.Keepalive()

+ 3 - 3
server/server_test.go

@@ -938,7 +938,7 @@ func TestServer_PublishAttachment(t *testing.T) {
 	require.Equal(t, content, response.Body.String())
 
 	// Slightly unrelated cross-test: make sure we add an owner for internal attachments
-	size, err := s.messageCache.AttachmentsSize("9.9.9.9") // See request()
+	size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request()
 	require.Nil(t, err)
 	require.Equal(t, int64(5000), size)
 }
@@ -967,7 +967,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
 	require.Equal(t, content, response.Body.String())
 
 	// Slightly unrelated cross-test: make sure we add an owner for internal attachments
-	size, err := s.messageCache.AttachmentsSize("1.2.3.4")
+	size, err := s.messageCache.AttachmentBytesUsed("1.2.3.4")
 	require.Nil(t, err)
 	require.Equal(t, int64(21), size)
 }
@@ -987,7 +987,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
 	require.Equal(t, "", msg.Attachment.Owner)
 
 	// Slightly unrelated cross-test: make sure we don't add an owner for external attachments
-	size, err := s.messageCache.AttachmentsSize("127.0.0.1")
+	size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
 	require.Nil(t, err)
 	require.Equal(t, int64(0), size)
 }

+ 27 - 1
server/visitor.go

@@ -22,6 +22,7 @@ var (
 // visitor represents an API user, and its associated rate.Limiter used for rate limiting
 type visitor struct {
 	config        *Config
+	messageCache  *messageCache
 	ip            string
 	requests      *rate.Limiter
 	emails        *rate.Limiter
@@ -31,9 +32,17 @@ type visitor struct {
 	mu            sync.Mutex
 }
 
-func newVisitor(conf *Config, ip string) *visitor {
+type visitorStats struct {
+	AttachmentFileSizeLimit         int64 `json:"attachmentFileSizeLimit"`
+	VisitorAttachmentBytesTotal     int64 `json:"visitorAttachmentBytesTotal"`
+	VisitorAttachmentBytesUsed      int64 `json:"visitorAttachmentBytesUsed"`
+	VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
+}
+
+func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
 	return &visitor{
 		config:        conf,
+		messageCache:  messageCache,
 		ip:            ip,
 		requests:      rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
 		emails:        rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
@@ -91,3 +100,20 @@ func (v *visitor) Stale() bool {
 	defer v.mu.Unlock()
 	return time.Since(v.seen) > visitorExpungeAfter
 }
+
+func (v *visitor) Stats() (*visitorStats, error) {
+	attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip)
+	if err != nil {
+		return nil, err
+	}
+	attachmentsBytesRemaining := v.config.VisitorAttachmentTotalSizeLimit - attachmentsBytesUsed
+	if attachmentsBytesRemaining < 0 {
+		attachmentsBytesRemaining = 0
+	}
+	return &visitorStats{
+		AttachmentFileSizeLimit:         v.config.AttachmentFileSizeLimit,
+		VisitorAttachmentBytesTotal:     v.config.VisitorAttachmentTotalSizeLimit,
+		VisitorAttachmentBytesUsed:      attachmentsBytesUsed,
+		VisitorAttachmentBytesRemaining: attachmentsBytesRemaining,
+	}, nil
+}

+ 16 - 0
util/limit.go

@@ -15,6 +15,10 @@ var ErrLimitReached = errors.New("limit reached")
 type Limiter interface {
 	// Allow adds n to the limiters internal value, or returns ErrLimitReached if the limit has been reached
 	Allow(n int64) error
+
+	// Remaining returns the remaining count until the limit is reached; may return -1 if the implementation
+	// does not support this operation.
+	Remaining() int64
 }
 
 // FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
@@ -44,6 +48,13 @@ func (l *FixedLimiter) Allow(n int64) error {
 	return nil
 }
 
+// Remaining  returns the remaining count until the limit is reached
+func (l *FixedLimiter) Remaining() int64 {
+	l.mu.Lock()
+	defer l.mu.Unlock()
+	return l.limit - l.value
+}
+
 // RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit.
 type RateLimiter struct {
 	limiter *rate.Limiter
@@ -74,6 +85,11 @@ func (l *RateLimiter) Allow(n int64) error {
 	return nil
 }
 
+// Remaining is not implemented for RateLimiter. It always returns -1.
+func (l *RateLimiter) Remaining() int64 {
+	return -1
+}
+
 // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
 // writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
 // Each limiter's value is increased with every write.

+ 0 - 61
util/peak.go

@@ -1,61 +0,0 @@
-package util
-
-import (
-	"bytes"
-	"io"
-	"strings"
-)
-
-// PeakedReadCloser is a ReadCloser that allows peaking into a stream and buffering it in memory.
-// It can be instantiated using the Peak function. After a stream has been peaked, it can still be fully
-// read by reading the PeakedReadCloser. It first drained from the memory buffer, and then from the remaining
-// underlying reader.
-type PeakedReadCloser struct {
-	PeakedBytes  []byte
-	LimitReached bool
-	peaked       io.Reader
-	underlying   io.ReadCloser
-	closed       bool
-}
-
-// Peak reads the underlying ReadCloser into memory up until the limit and returns a PeakedReadCloser
-func Peak(underlying io.ReadCloser, limit int) (*PeakedReadCloser, error) {
-	if underlying == nil {
-		underlying = io.NopCloser(strings.NewReader(""))
-	}
-	peaked := make([]byte, limit)
-	read, err := io.ReadFull(underlying, peaked)
-	if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
-		return nil, err
-	}
-	return &PeakedReadCloser{
-		PeakedBytes:  peaked[:read],
-		LimitReached: read == limit,
-		underlying:   underlying,
-		peaked:       bytes.NewReader(peaked[:read]),
-		closed:       false,
-	}, nil
-}
-
-// Read reads from the peaked bytes and then from the underlying stream
-func (r *PeakedReadCloser) Read(p []byte) (n int, err error) {
-	if r.closed {
-		return 0, io.EOF
-	}
-	n, err = r.peaked.Read(p)
-	if err == io.EOF {
-		return r.underlying.Read(p)
-	} else if err != nil {
-		return 0, err
-	}
-	return
-}
-
-// Close closes the underlying stream
-func (r *PeakedReadCloser) Close() error {
-	if r.closed {
-		return io.EOF
-	}
-	r.closed = true
-	return r.underlying.Close()
-}

+ 61 - 0
util/peek.go

@@ -0,0 +1,61 @@
+package util
+
+import (
+	"bytes"
+	"io"
+	"strings"
+)
+
+// PeekedReadCloser is a ReadCloser that allows peeking into a stream and buffering it in memory.
+// It can be instantiated using the Peek function. After a stream has been peeked, it can still be fully
+// read by reading the PeekedReadCloser. It first drained from the memory buffer, and then from the remaining
+// underlying reader.
+type PeekedReadCloser struct {
+	PeekedBytes  []byte
+	LimitReached bool
+	peeked       io.Reader
+	underlying   io.ReadCloser
+	closed       bool
+}
+
+// Peek reads the underlying ReadCloser into memory up until the limit and returns a PeekedReadCloser
+func Peek(underlying io.ReadCloser, limit int) (*PeekedReadCloser, error) {
+	if underlying == nil {
+		underlying = io.NopCloser(strings.NewReader(""))
+	}
+	peeked := make([]byte, limit)
+	read, err := io.ReadFull(underlying, peeked)
+	if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
+		return nil, err
+	}
+	return &PeekedReadCloser{
+		PeekedBytes:  peeked[:read],
+		LimitReached: read == limit,
+		underlying:   underlying,
+		peeked:       bytes.NewReader(peeked[:read]),
+		closed:       false,
+	}, nil
+}
+
+// Read reads from the peeked bytes and then from the underlying stream
+func (r *PeekedReadCloser) Read(p []byte) (n int, err error) {
+	if r.closed {
+		return 0, io.EOF
+	}
+	n, err = r.peeked.Read(p)
+	if err == io.EOF {
+		return r.underlying.Read(p)
+	} else if err != nil {
+		return 0, err
+	}
+	return
+}
+
+// Close closes the underlying stream
+func (r *PeekedReadCloser) Close() error {
+	if r.closed {
+		return io.EOF
+	}
+	r.closed = true
+	return r.underlying.Close()
+}

+ 7 - 7
util/peak_test.go → util/peek_test.go

@@ -9,11 +9,11 @@ import (
 
 func TestPeak_LimitReached(t *testing.T) {
 	underlying := io.NopCloser(strings.NewReader("1234567890"))
-	peaked, err := Peak(underlying, 5)
+	peaked, err := Peek(underlying, 5)
 	if err != nil {
 		t.Fatal(err)
 	}
-	require.Equal(t, []byte("12345"), peaked.PeakedBytes)
+	require.Equal(t, []byte("12345"), peaked.PeekedBytes)
 	require.Equal(t, true, peaked.LimitReached)
 
 	all, err := io.ReadAll(peaked)
@@ -21,13 +21,13 @@ func TestPeak_LimitReached(t *testing.T) {
 		t.Fatal(err)
 	}
 	require.Equal(t, []byte("1234567890"), all)
-	require.Equal(t, []byte("12345"), peaked.PeakedBytes)
+	require.Equal(t, []byte("12345"), peaked.PeekedBytes)
 	require.Equal(t, true, peaked.LimitReached)
 }
 
 func TestPeak_LimitNotReached(t *testing.T) {
 	underlying := io.NopCloser(strings.NewReader("1234567890"))
-	peaked, err := Peak(underlying, 15)
+	peaked, err := Peek(underlying, 15)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -36,12 +36,12 @@ func TestPeak_LimitNotReached(t *testing.T) {
 		t.Fatal(err)
 	}
 	require.Equal(t, []byte("1234567890"), all)
-	require.Equal(t, []byte("1234567890"), peaked.PeakedBytes)
+	require.Equal(t, []byte("1234567890"), peaked.PeekedBytes)
 	require.Equal(t, false, peaked.LimitReached)
 }
 
 func TestPeak_Nil(t *testing.T) {
-	peaked, err := Peak(nil, 15)
+	peaked, err := Peek(nil, 15)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -50,6 +50,6 @@ func TestPeak_Nil(t *testing.T) {
 		t.Fatal(err)
 	}
 	require.Equal(t, []byte(""), all)
-	require.Equal(t, []byte(""), peaked.PeakedBytes)
+	require.Equal(t, []byte(""), peaked.PeekedBytes)
 	require.Equal(t, false, peaked.LimitReached)
 }

+ 11 - 1
web/src/app/Api.js

@@ -7,7 +7,7 @@ import {
     topicUrl,
     topicUrlAuth,
     topicUrlJsonPoll,
-    topicUrlJsonPollWithSince
+    topicUrlJsonPollWithSince, userStatsUrl
 } from "./utils";
 import userManager from "./UserManager";
 
@@ -93,6 +93,16 @@ class Api {
         }
         throw new Error(`Unexpected server response ${response.status}`);
     }
+
+    async userStats(baseUrl) {
+        const url = userStatsUrl(baseUrl);
+        console.log(`[Api] Fetching user stats ${url}`);
+        const response = await fetch(url);
+        if (response.status !== 200) {
+            throw new Error(`Unexpected server response ${response.status}`);
+        }
+        return response.json();
+    }
 }
 
 const api = new Api();

Некоторые файлы не были показаны из-за большого количества измененных файлов