util.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. package util
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/rand"
  9. "net/netip"
  10. "os"
  11. "regexp"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/gabriel-vasile/mimetype"
  17. "golang.org/x/term"
  18. )
  19. const (
  20. randomStringCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
  21. )
  22. var (
  23. random = rand.New(rand.NewSource(time.Now().UnixNano()))
  24. randomMutex = sync.Mutex{}
  25. sizeStrRegex = regexp.MustCompile(`(?i)^(\d+)([gmkb])?$`)
  26. errInvalidPriority = errors.New("invalid priority")
  27. noQuotesRegex = regexp.MustCompile(`^[-_./:@a-zA-Z0-9]+$`)
  28. )
  29. // Errors for UnmarshalJSON and UnmarshalJSONWithLimit functions
  30. var (
  31. ErrUnmarshalJSON = errors.New("unmarshalling JSON failed")
  32. ErrTooLargeJSON = errors.New("too large JSON")
  33. )
  34. // FileExists checks if a file exists, and returns true if it does
  35. func FileExists(filename string) bool {
  36. stat, _ := os.Stat(filename)
  37. return stat != nil
  38. }
  39. // Contains returns true if needle is contained in haystack
  40. func Contains[T comparable](haystack []T, needle T) bool {
  41. for _, s := range haystack {
  42. if s == needle {
  43. return true
  44. }
  45. }
  46. return false
  47. }
  48. // ContainsIP returns true if any one of the of prefixes contains the ip.
  49. func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool {
  50. for _, s := range haystack {
  51. if s.Contains(needle) {
  52. return true
  53. }
  54. }
  55. return false
  56. }
  57. // ContainsAll returns true if all needles are contained in haystack
  58. func ContainsAll[T comparable](haystack []T, needles []T) bool {
  59. matches := 0
  60. for _, s := range haystack {
  61. for _, needle := range needles {
  62. if s == needle {
  63. matches++
  64. }
  65. }
  66. }
  67. return matches == len(needles)
  68. }
  69. // SplitNoEmpty splits a string using strings.Split, but filters out empty strings
  70. func SplitNoEmpty(s string, sep string) []string {
  71. res := make([]string, 0)
  72. for _, r := range strings.Split(s, sep) {
  73. if r != "" {
  74. res = append(res, r)
  75. }
  76. }
  77. return res
  78. }
  79. // SplitKV splits a string into a key/value pair using a separator, and trimming space. If the separator
  80. // is not found, key is empty.
  81. func SplitKV(s string, sep string) (key string, value string) {
  82. kv := strings.SplitN(strings.TrimSpace(s), sep, 2)
  83. if len(kv) == 2 {
  84. return strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1])
  85. }
  86. return "", strings.TrimSpace(kv[0])
  87. }
  88. // LastString returns the last string in a slice, or def if s is empty
  89. func LastString(s []string, def string) string {
  90. if len(s) == 0 {
  91. return def
  92. }
  93. return s[len(s)-1]
  94. }
  95. // RandomString returns a random string with a given length
  96. func RandomString(length int) string {
  97. randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?!
  98. defer randomMutex.Unlock()
  99. b := make([]byte, length)
  100. for i := range b {
  101. b[i] = randomStringCharset[random.Intn(len(randomStringCharset))]
  102. }
  103. return string(b)
  104. }
  105. // ValidRandomString returns true if the given string matches the format created by RandomString
  106. func ValidRandomString(s string, length int) bool {
  107. if len(s) != length {
  108. return false
  109. }
  110. for _, c := range strings.Split(s, "") {
  111. if !strings.Contains(randomStringCharset, c) {
  112. return false
  113. }
  114. }
  115. return true
  116. }
  117. // ParsePriority parses a priority string into its equivalent integer value
  118. func ParsePriority(priority string) (int, error) {
  119. p := strings.TrimSpace(strings.ToLower(priority))
  120. switch p {
  121. case "":
  122. return 0, nil
  123. case "1", "min":
  124. return 1, nil
  125. case "2", "low":
  126. return 2, nil
  127. case "3", "default":
  128. return 3, nil
  129. case "4", "high":
  130. return 4, nil
  131. case "5", "max", "urgent":
  132. return 5, nil
  133. default:
  134. // Ignore new HTTP Priority header (see https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-priority)
  135. // Cloudflare adds this to requests when forwarding to the backend (ntfy), so we just ignore it.
  136. if strings.HasPrefix(p, "u=") {
  137. return 3, nil
  138. }
  139. return 0, errInvalidPriority
  140. }
  141. }
  142. // PriorityString converts a priority number to a string
  143. func PriorityString(priority int) (string, error) {
  144. switch priority {
  145. case 0:
  146. return "default", nil
  147. case 1:
  148. return "min", nil
  149. case 2:
  150. return "low", nil
  151. case 3:
  152. return "default", nil
  153. case 4:
  154. return "high", nil
  155. case 5:
  156. return "max", nil
  157. default:
  158. return "", errInvalidPriority
  159. }
  160. }
  161. // ShortTopicURL shortens the topic URL to be human-friendly, removing the http:// or https://
  162. func ShortTopicURL(s string) string {
  163. return strings.TrimPrefix(strings.TrimPrefix(s, "https://"), "http://")
  164. }
  165. // DetectContentType probes the byte array b and returns mime type and file extension.
  166. // The filename is only used to override certain special cases.
  167. func DetectContentType(b []byte, filename string) (mimeType string, ext string) {
  168. if strings.HasSuffix(strings.ToLower(filename), ".apk") {
  169. return "application/vnd.android.package-archive", ".apk"
  170. }
  171. m := mimetype.Detect(b)
  172. mimeType, ext = m.String(), m.Extension()
  173. if ext == "" {
  174. ext = ".bin"
  175. }
  176. return
  177. }
  178. // ParseSize parses a size string like 2K or 2M into bytes. If no unit is found, e.g. 123, bytes is assumed.
  179. func ParseSize(s string) (int64, error) {
  180. matches := sizeStrRegex.FindStringSubmatch(s)
  181. if matches == nil {
  182. return -1, fmt.Errorf("invalid size %s", s)
  183. }
  184. value, err := strconv.Atoi(matches[1])
  185. if err != nil {
  186. return -1, fmt.Errorf("cannot convert number %s", matches[1])
  187. }
  188. switch strings.ToUpper(matches[2]) {
  189. case "G":
  190. return int64(value) * 1024 * 1024 * 1024, nil
  191. case "M":
  192. return int64(value) * 1024 * 1024, nil
  193. case "K":
  194. return int64(value) * 1024, nil
  195. default:
  196. return int64(value), nil
  197. }
  198. }
  199. // ReadPassword will read a password from STDIN. If the terminal supports it, it will not print the
  200. // input characters to the screen. If not, it'll just read using normal readline semantics (useful for testing).
  201. func ReadPassword(in io.Reader) ([]byte, error) {
  202. // If in is a file and a character device (a TTY), use term.ReadPassword
  203. if f, ok := in.(*os.File); ok {
  204. stat, err := f.Stat()
  205. if err != nil {
  206. return nil, err
  207. }
  208. if (stat.Mode() & os.ModeCharDevice) == os.ModeCharDevice {
  209. password, err := term.ReadPassword(int(f.Fd())) // This is always going to be 0
  210. if err != nil {
  211. return nil, err
  212. }
  213. return password, nil
  214. }
  215. }
  216. // Fallback: Manually read util \n if found, see #69 for details why this is so manual
  217. password := make([]byte, 0)
  218. buf := make([]byte, 1)
  219. for {
  220. _, err := in.Read(buf)
  221. if err == io.EOF || buf[0] == '\n' {
  222. break
  223. } else if err != nil {
  224. return nil, err
  225. } else if len(password) > 10240 {
  226. return nil, errors.New("passwords this long are not supported")
  227. }
  228. password = append(password, buf[0])
  229. }
  230. return password, nil
  231. }
  232. // BasicAuth encodes the Authorization header value for basic auth
  233. func BasicAuth(user, pass string) string {
  234. return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", user, pass))))
  235. }
  236. // BearerAuth encodes the Authorization header value for a bearer/token auth
  237. func BearerAuth(token string) string {
  238. return fmt.Sprintf("Bearer %s", token)
  239. }
  240. // MaybeMarshalJSON returns a JSON string of the given object, or "<cannot serialize>" if serialization failed.
  241. // This is useful for logging purposes where a failure doesn't matter that much.
  242. func MaybeMarshalJSON(v any) string {
  243. jsonBytes, err := json.MarshalIndent(v, "", " ")
  244. if err != nil {
  245. return "<cannot serialize>"
  246. }
  247. if len(jsonBytes) > 5000 {
  248. return string(jsonBytes)[:5000]
  249. }
  250. return string(jsonBytes)
  251. }
  252. // QuoteCommand combines a command array to a string, quoting arguments that need quoting.
  253. // This function is naive, and sometimes wrong. It is only meant for lo pretty-printing a command.
  254. //
  255. // Warning: Never use this function with the intent to run the resulting command.
  256. //
  257. // Example:
  258. //
  259. // []string{"ls", "-al", "Document Folder"} -> ls -al "Document Folder"
  260. func QuoteCommand(command []string) string {
  261. var quoted []string
  262. for _, c := range command {
  263. if noQuotesRegex.MatchString(c) {
  264. quoted = append(quoted, c)
  265. } else {
  266. quoted = append(quoted, fmt.Sprintf(`"%s"`, c))
  267. }
  268. }
  269. return strings.Join(quoted, " ")
  270. }
  271. // UnmarshalJSON reads the given io.ReadCloser into a struct
  272. func UnmarshalJSON[T any](body io.ReadCloser) (*T, error) {
  273. var obj T
  274. if err := json.NewDecoder(body).Decode(&obj); err != nil {
  275. return nil, ErrUnmarshalJSON
  276. }
  277. return &obj, nil
  278. }
  279. // UnmarshalJSONWithLimit reads the given io.ReadCloser into a struct, but only until limit is reached
  280. func UnmarshalJSONWithLimit[T any](r io.ReadCloser, limit int) (*T, error) {
  281. defer r.Close()
  282. p, err := Peek(r, limit)
  283. if err != nil {
  284. return nil, err
  285. } else if p.LimitReached {
  286. return nil, ErrTooLargeJSON
  287. }
  288. var obj T
  289. if err := json.NewDecoder(p).Decode(&obj); err != nil {
  290. return nil, ErrUnmarshalJSON
  291. }
  292. return &obj, nil
  293. }