Philipp Heckel 2 лет назад
Родитель
Сommit
cae06c5c61
7 измененных файлов с 78 добавлено и 15 удалено
  1. 4 0
      client/options.go
  2. 30 2
      cmd/publish.go
  3. 8 8
      crypto/crypto.go
  4. 11 4
      crypto/crypto_test.go
  5. 8 0
      server/server.go
  6. 7 1
      server/types.go
  7. 10 0
      util/peek.go

+ 4 - 0
client/options.go

@@ -92,6 +92,10 @@ func WithNoFirebase() PublishOption {
 	return WithHeader("X-Firebase", "no")
 }
 
+func WithEncrypted() PublishOption {
+	return WithHeader("X-Encryption", "jwe")
+}
+
 // WithSince limits the number of messages returned from the server. The parameter since can be a Unix
 // timestamp (see WithSinceUnixTime), a duration (WithSinceDuration) the word "all" (see WithSinceAll).
 func WithSince(since string) SubscribeOption {

+ 30 - 2
cmd/publish.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"github.com/urfave/cli/v2"
 	"heckel.io/ntfy/client"
+	"heckel.io/ntfy/crypto"
 	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/util"
 	"io"
@@ -15,6 +16,10 @@ import (
 	"time"
 )
 
+const (
+	encryptedMessageBytesLimit = 100 * 1024 * 1024 // 100 MB
+)
+
 func init() {
 	commands = append(commands, cmdPublish)
 }
@@ -100,7 +105,7 @@ func execPublish(c *cli.Context) error {
 	noFirebase := c.Bool("no-firebase")
 	quiet := c.Bool("quiet")
 	pid := c.Int("wait-pid")
-	//password := os.Getenv("NTFY_PASSWORD")
+	password := os.Getenv("NTFY_PASSWORD")
 	topic, message, command, err := parseTopicMessageCommand(c)
 	if err != nil {
 		return err
@@ -193,6 +198,20 @@ func execPublish(c *cli.Context) error {
 			}
 		}
 	}
+	if password != "" {
+		topicURL := expandTopicURL(topic, conf.DefaultHost)
+		key := crypto.DeriveKey(password, topicURL)
+		peaked, err := util.PeekLimit(io.NopCloser(body), encryptedMessageBytesLimit)
+		if err != nil {
+			return err
+		}
+		ciphertext, err := crypto.Encrypt(peaked.PeekedBytes, key)
+		if err != nil {
+			return err
+		}
+		body = strings.NewReader(ciphertext)
+		options = append(options, client.WithEncrypted())
+	}
 	cl := client.New(conf)
 	m, err := cl.PublishReader(topic, body, options...)
 	if err != nil {
@@ -204,8 +223,17 @@ func execPublish(c *cli.Context) error {
 	return nil
 }
 
-// parseTopicMessageCommand reads the topic and the remaining arguments from the context.
+func expandTopicURL(topic, defaultHost string) string {
+	if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") {
+		return topic
+	} else if strings.Contains(topic, "/") {
+		return fmt.Sprintf("https://%s", topic)
+	}
+	return fmt.Sprintf("%s/%s", defaultHost, topic)
+}
 
+// parseTopicMessageCommand reads the topic and the remaining arguments from the context.
+//
 // There are a few cases to consider:
 //   ntfy publish <topic> [<message>]
 //   ntfy publish --wait-cmd <topic> <command>

+ 8 - 8
crypto/crypto.go

@@ -13,31 +13,31 @@ const (
 	keyDerivIter  = 50000
 )
 
-func DeriveKey(password string, topicURL string) []byte {
+func DeriveKey(password, topicURL string) []byte {
 	salt := sha256.Sum256([]byte(topicURL))
 	return pbkdf2.Key([]byte(password), salt[:], keyDerivIter, keyLenBytes, sha256.New)
 }
 
-func Encrypt(plaintext string, key []byte) (string, error) {
+func Encrypt(plaintext []byte, key []byte) (string, error) {
 	enc, err := jose.NewEncrypter(jweEncryption, jose.Recipient{Algorithm: jweAlgorithm, Key: key}, nil)
 	if err != nil {
 		return "", err
 	}
-	jwe, err := enc.Encrypt([]byte(plaintext))
+	jwe, err := enc.Encrypt(plaintext)
 	if err != nil {
 		return "", err
 	}
 	return jwe.CompactSerialize()
 }
 
-func Decrypt(input string, key []byte) (string, error) {
-	jwe, err := jose.ParseEncrypted(input)
+func Decrypt(ciphertext string, key []byte) ([]byte, error) {
+	jwe, err := jose.ParseEncrypted(ciphertext)
 	if err != nil {
-		return "", err
+		return nil, err
 	}
 	out, err := jwe.Decrypt(key)
 	if err != nil {
-		return "", err
+		return nil, err
 	}
-	return string(out), nil
+	return out, nil
 }

+ 11 - 4
crypto/crypto_test.go

@@ -1,25 +1,32 @@
 package crypto
 
 import (
+	"fmt"
 	"github.com/stretchr/testify/require"
 	"testing"
 )
 
+func TestDeriveKey(t *testing.T) {
+	key := DeriveKey("secr3t password", "https://ntfy.sh/mysecret")
+	require.Equal(t, "30b7e72f6273da6e59d2dec535466e548da3eafc98650c9664c06edab707fa25", fmt.Sprintf("%x", key))
+}
+
 func TestEncryptDecrypt(t *testing.T) {
 	message := "this is a message or is it?"
-	ciphertext, err := Encrypt(message, []byte("AES256Key-32Characters1234567890"))
+	ciphertext, err := Encrypt([]byte(message), []byte("AES256Key-32Characters1234567890"))
 	require.Nil(t, err)
 	plaintext, err := Decrypt(ciphertext, []byte("AES256Key-32Characters1234567890"))
 	require.Nil(t, err)
-	require.Equal(t, message, plaintext)
+	require.Equal(t, message, string(plaintext))
 }
 
 func TestEncryptDecrypt_FromPHP(t *testing.T) {
 	ciphertext := "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..vbe1Qv_-mKYbUgce.EfmOUIUi7lxXZG_o4bqXZ9pmpr1Rzs4Y5QLE2XD2_aw_SQ.y2hadrN5b2LEw7_PJHhbcA"
 	key := DeriveKey("secr3t password", "https://ntfy.sh/mysecret")
+	fmt.Printf("%x", key)
 	plaintext, err := Decrypt(ciphertext, key)
 	require.Nil(t, err)
-	require.Equal(t, `{"message":"Secret!","priority":5}`, plaintext)
+	require.Equal(t, `{"message":"Secret!","priority":5}`, string(plaintext))
 }
 
 func TestEncryptDecrypt_FromPython(t *testing.T) {
@@ -27,5 +34,5 @@ func TestEncryptDecrypt_FromPython(t *testing.T) {
 	key := DeriveKey("secr3t password", "https://ntfy.sh/mysecret")
 	plaintext, err := Decrypt(ciphertext, key)
 	require.Nil(t, err)
-	require.Equal(t, `{"message":"Python says hi","tags":["secret"]}`, plaintext)
+	require.Equal(t, `{"message":"Python says hi","tags":["secret"]}`, string(plaintext))
 }

+ 8 - 0
server/server.go

@@ -95,6 +95,7 @@ const (
 	newMessageBody           = "New message"             // Used in poll requests as generic message
 	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
 	encodingBase64           = "base64"
+	encodingJWE              = "jwe"
 )
 
 // WebSocket constants
@@ -461,6 +462,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	if m.PollID != "" {
 		m = newPollRequestMessage(t.ID, m.PollID)
 	}
+	if m.Encoding == encodingJWE {
+		m = newEncryptedMessage(t.ID, m.Message)
+	}
 	if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
 		return nil, err
 	}
@@ -644,6 +648,10 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 			return false, false, "", false, wrapErrHTTP(errHTTPBadRequestActionsInvalid, err.Error())
 		}
 	}
+	encryption := readParam(r, "x-encryption", "encryption", "encrypted", "encrypt", "enc")
+	if encryption == "yes" || encryption == "true" || encryption == "1" || encryption == encodingJWE {
+		m.Encoding = encodingJWE
+	}
 	unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
 	if unifiedpush {
 		firebase = false

+ 7 - 1
server/types.go

@@ -33,7 +33,7 @@ type message struct {
 	Attachment *attachment `json:"attachment,omitempty"`
 	PollID     string      `json:"poll_id,omitempty"`
 	Sender     string      `json:"-"`                  // IP address of uploader, used for rate limiting
-	Encoding   string      `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
+	Encoding   string      `json:"encoding,omitempty"` // empty for UTF-8, "base64", or "jwe" (encrypted)
 }
 
 type attachment struct {
@@ -115,6 +115,12 @@ func newPollRequestMessage(topic, pollID string) *message {
 	return m
 }
 
+func newEncryptedMessage(topic, msg string) *message {
+	m := newMessage(messageEvent, topic, msg)
+	m.Encoding = encodingJWE
+	return m
+}
+
 func validMessageID(s string) bool {
 	return util.ValidRandomString(s, messageIDLength)
 }

+ 10 - 0
util/peek.go

@@ -38,6 +38,16 @@ func Peek(underlying io.ReadCloser, limit int) (*PeekedReadCloser, error) {
 	}, nil
 }
 
+func PeekLimit(underlying io.ReadCloser, limit int) (*PeekedReadCloser, error) {
+	rc, err := Peek(underlying, limit)
+	if err != nil {
+		return nil, err
+	} else if rc.LimitReached {
+		return nil, ErrLimitReached
+	}
+	return rc, nil
+}
+
 // Read reads from the peeked bytes and then from the underlying stream
 func (r *PeekedReadCloser) Read(p []byte) (n int, err error) {
 	if r.closed {