util.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. package util
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "golang.org/x/time/rate"
  9. "io"
  10. "math/rand"
  11. "net/netip"
  12. "os"
  13. "regexp"
  14. "strconv"
  15. "strings"
  16. "sync"
  17. "time"
  18. "github.com/gabriel-vasile/mimetype"
  19. "golang.org/x/term"
  20. )
  21. const (
  22. randomStringCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
  23. )
  24. var (
  25. random = rand.New(rand.NewSource(time.Now().UnixNano()))
  26. randomMutex = sync.Mutex{}
  27. sizeStrRegex = regexp.MustCompile(`(?i)^(\d+)([gmkb])?$`)
  28. errInvalidPriority = errors.New("invalid priority")
  29. noQuotesRegex = regexp.MustCompile(`^[-_./:@a-zA-Z0-9]+$`)
  30. )
  31. // Errors for UnmarshalJSON and UnmarshalJSONWithLimit functions
  32. var (
  33. ErrUnmarshalJSON = errors.New("unmarshalling JSON failed")
  34. ErrTooLargeJSON = errors.New("too large JSON")
  35. )
  36. // FileExists checks if a file exists, and returns true if it does
  37. func FileExists(filename string) bool {
  38. stat, _ := os.Stat(filename)
  39. return stat != nil
  40. }
  41. // Contains returns true if needle is contained in haystack
  42. func Contains[T comparable](haystack []T, needle T) bool {
  43. for _, s := range haystack {
  44. if s == needle {
  45. return true
  46. }
  47. }
  48. return false
  49. }
  50. // ContainsIP returns true if any one of the of prefixes contains the ip.
  51. func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool {
  52. for _, s := range haystack {
  53. if s.Contains(needle) {
  54. return true
  55. }
  56. }
  57. return false
  58. }
  59. // ContainsAll returns true if all needles are contained in haystack
  60. func ContainsAll[T comparable](haystack []T, needles []T) bool {
  61. matches := 0
  62. for _, s := range haystack {
  63. for _, needle := range needles {
  64. if s == needle {
  65. matches++
  66. }
  67. }
  68. }
  69. return matches == len(needles)
  70. }
  71. // SplitNoEmpty splits a string using strings.Split, but filters out empty strings
  72. func SplitNoEmpty(s string, sep string) []string {
  73. res := make([]string, 0)
  74. for _, r := range strings.Split(s, sep) {
  75. if r != "" {
  76. res = append(res, r)
  77. }
  78. }
  79. return res
  80. }
  81. // SplitKV splits a string into a key/value pair using a separator, and trimming space. If the separator
  82. // is not found, key is empty.
  83. func SplitKV(s string, sep string) (key string, value string) {
  84. kv := strings.SplitN(strings.TrimSpace(s), sep, 2)
  85. if len(kv) == 2 {
  86. return strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1])
  87. }
  88. return "", strings.TrimSpace(kv[0])
  89. }
  90. // LastString returns the last string in a slice, or def if s is empty
  91. func LastString(s []string, def string) string {
  92. if len(s) == 0 {
  93. return def
  94. }
  95. return s[len(s)-1]
  96. }
  97. // RandomString returns a random string with a given length
  98. func RandomString(length int) string {
  99. return RandomStringPrefix("", length)
  100. }
  101. // RandomStringPrefix returns a random string with a given length, with a prefix
  102. func RandomStringPrefix(prefix string, length int) string {
  103. randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?!
  104. defer randomMutex.Unlock()
  105. b := make([]byte, length-len(prefix))
  106. for i := range b {
  107. b[i] = randomStringCharset[random.Intn(len(randomStringCharset))]
  108. }
  109. return prefix + string(b)
  110. }
  111. // ValidRandomString returns true if the given string matches the format created by RandomString
  112. func ValidRandomString(s string, length int) bool {
  113. if len(s) != length {
  114. return false
  115. }
  116. for _, c := range strings.Split(s, "") {
  117. if !strings.Contains(randomStringCharset, c) {
  118. return false
  119. }
  120. }
  121. return true
  122. }
  123. // ParsePriority parses a priority string into its equivalent integer value
  124. func ParsePriority(priority string) (int, error) {
  125. p := strings.TrimSpace(strings.ToLower(priority))
  126. switch p {
  127. case "":
  128. return 0, nil
  129. case "1", "min":
  130. return 1, nil
  131. case "2", "low":
  132. return 2, nil
  133. case "3", "default":
  134. return 3, nil
  135. case "4", "high":
  136. return 4, nil
  137. case "5", "max", "urgent":
  138. return 5, nil
  139. default:
  140. // Ignore new HTTP Priority header (see https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-priority)
  141. // Cloudflare adds this to requests when forwarding to the backend (ntfy), so we just ignore it.
  142. if strings.HasPrefix(p, "u=") {
  143. return 3, nil
  144. }
  145. return 0, errInvalidPriority
  146. }
  147. }
  148. // PriorityString converts a priority number to a string
  149. func PriorityString(priority int) (string, error) {
  150. switch priority {
  151. case 0:
  152. return "default", nil
  153. case 1:
  154. return "min", nil
  155. case 2:
  156. return "low", nil
  157. case 3:
  158. return "default", nil
  159. case 4:
  160. return "high", nil
  161. case 5:
  162. return "max", nil
  163. default:
  164. return "", errInvalidPriority
  165. }
  166. }
  167. // ShortTopicURL shortens the topic URL to be human-friendly, removing the http:// or https://
  168. func ShortTopicURL(s string) string {
  169. return strings.TrimPrefix(strings.TrimPrefix(s, "https://"), "http://")
  170. }
  171. // DetectContentType probes the byte array b and returns mime type and file extension.
  172. // The filename is only used to override certain special cases.
  173. func DetectContentType(b []byte, filename string) (mimeType string, ext string) {
  174. if strings.HasSuffix(strings.ToLower(filename), ".apk") {
  175. return "application/vnd.android.package-archive", ".apk"
  176. }
  177. m := mimetype.Detect(b)
  178. mimeType, ext = m.String(), m.Extension()
  179. if ext == "" {
  180. ext = ".bin"
  181. }
  182. return
  183. }
  184. // ParseSize parses a size string like 2K or 2M into bytes. If no unit is found, e.g. 123, bytes is assumed.
  185. func ParseSize(s string) (int64, error) {
  186. matches := sizeStrRegex.FindStringSubmatch(s)
  187. if matches == nil {
  188. return -1, fmt.Errorf("invalid size %s", s)
  189. }
  190. value, err := strconv.Atoi(matches[1])
  191. if err != nil {
  192. return -1, fmt.Errorf("cannot convert number %s", matches[1])
  193. }
  194. switch strings.ToUpper(matches[2]) {
  195. case "G":
  196. return int64(value) * 1024 * 1024 * 1024, nil
  197. case "M":
  198. return int64(value) * 1024 * 1024, nil
  199. case "K":
  200. return int64(value) * 1024, nil
  201. default:
  202. return int64(value), nil
  203. }
  204. }
  205. // FormatSize formats bytes into a human-readable notation, e.g. 2.1 MB
  206. func FormatSize(b int64) string {
  207. const unit = 1024
  208. if b < unit {
  209. return fmt.Sprintf("%d bytes", b)
  210. }
  211. div, exp := int64(unit), 0
  212. for n := b / unit; n >= unit; n /= unit {
  213. div *= unit
  214. exp++
  215. }
  216. return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
  217. }
  218. // ReadPassword will read a password from STDIN. If the terminal supports it, it will not print the
  219. // input characters to the screen. If not, it'll just read using normal readline semantics (useful for testing).
  220. func ReadPassword(in io.Reader) ([]byte, error) {
  221. // If in is a file and a character device (a TTY), use term.ReadPassword
  222. if f, ok := in.(*os.File); ok {
  223. stat, err := f.Stat()
  224. if err != nil {
  225. return nil, err
  226. }
  227. if (stat.Mode() & os.ModeCharDevice) == os.ModeCharDevice {
  228. password, err := term.ReadPassword(int(f.Fd())) // This is always going to be 0
  229. if err != nil {
  230. return nil, err
  231. }
  232. return password, nil
  233. }
  234. }
  235. // Fallback: Manually read util \n if found, see #69 for details why this is so manual
  236. password := make([]byte, 0)
  237. buf := make([]byte, 1)
  238. for {
  239. _, err := in.Read(buf)
  240. if err == io.EOF || buf[0] == '\n' {
  241. break
  242. } else if err != nil {
  243. return nil, err
  244. } else if len(password) > 10240 {
  245. return nil, errors.New("passwords this long are not supported")
  246. }
  247. password = append(password, buf[0])
  248. }
  249. return password, nil
  250. }
  251. // BasicAuth encodes the Authorization header value for basic auth
  252. func BasicAuth(user, pass string) string {
  253. return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", user, pass))))
  254. }
  255. // BearerAuth encodes the Authorization header value for a bearer/token auth
  256. func BearerAuth(token string) string {
  257. return fmt.Sprintf("Bearer %s", token)
  258. }
  259. // MaybeMarshalJSON returns a JSON string of the given object, or "<cannot serialize>" if serialization failed.
  260. // This is useful for logging purposes where a failure doesn't matter that much.
  261. func MaybeMarshalJSON(v any) string {
  262. jsonBytes, err := json.MarshalIndent(v, "", " ")
  263. if err != nil {
  264. return "<cannot serialize>"
  265. }
  266. if len(jsonBytes) > 5000 {
  267. return string(jsonBytes)[:5000]
  268. }
  269. return string(jsonBytes)
  270. }
  271. // QuoteCommand combines a command array to a string, quoting arguments that need quoting.
  272. // This function is naive, and sometimes wrong. It is only meant for lo pretty-printing a command.
  273. //
  274. // Warning: Never use this function with the intent to run the resulting command.
  275. //
  276. // Example:
  277. //
  278. // []string{"ls", "-al", "Document Folder"} -> ls -al "Document Folder"
  279. func QuoteCommand(command []string) string {
  280. var quoted []string
  281. for _, c := range command {
  282. if noQuotesRegex.MatchString(c) {
  283. quoted = append(quoted, c)
  284. } else {
  285. quoted = append(quoted, fmt.Sprintf(`"%s"`, c))
  286. }
  287. }
  288. return strings.Join(quoted, " ")
  289. }
  290. // UnmarshalJSON reads the given io.ReadCloser into a struct
  291. func UnmarshalJSON[T any](body io.ReadCloser) (*T, error) {
  292. var obj T
  293. if err := json.NewDecoder(body).Decode(&obj); err != nil {
  294. return nil, ErrUnmarshalJSON
  295. }
  296. return &obj, nil
  297. }
  298. // UnmarshalJSONWithLimit reads the given io.ReadCloser into a struct, but only until limit is reached
  299. func UnmarshalJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T, error) {
  300. defer r.Close()
  301. p, err := Peek(r, limit)
  302. if err != nil {
  303. return nil, err
  304. } else if p.LimitReached {
  305. return nil, ErrTooLargeJSON
  306. }
  307. var obj T
  308. if len(bytes.TrimSpace(p.PeekedBytes)) == 0 && allowEmpty {
  309. return &obj, nil
  310. } else if err := json.NewDecoder(p).Decode(&obj); err != nil {
  311. return nil, ErrUnmarshalJSON
  312. }
  313. return &obj, nil
  314. }
  315. // Retry executes function f until if succeeds, and then returns t. If f fails, it sleeps
  316. // and tries again. The sleep durations are passed as the after params.
  317. func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error) {
  318. for _, delay := range after {
  319. if t, err = f(); err == nil {
  320. return t, nil
  321. }
  322. time.Sleep(delay)
  323. }
  324. return nil, err
  325. }
  326. // MinMax returns value if it is between min and max, or either
  327. // min or max if it is out of range
  328. func MinMax[T int | int64](value, min, max T) T {
  329. if value < min {
  330. return min
  331. } else if value > max {
  332. return max
  333. }
  334. return value
  335. }
  336. // Max returns the maximum value of the two given values
  337. func Max[T int | int64 | rate.Limit](a, b T) T {
  338. if a > b {
  339. return a
  340. }
  341. return b
  342. }
  343. // String turns a string into a pointer of a string
  344. func String(v string) *string {
  345. return &v
  346. }
  347. // Int turns an int into a pointer of an int
  348. func Int(v int) *int {
  349. return &v
  350. }
  351. // Time turns a time.Time into a pointer
  352. func Time(v time.Time) *time.Time {
  353. return &v
  354. }