memos.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package cmd
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "os"
  7. "os/signal"
  8. "syscall"
  9. "github.com/spf13/cobra"
  10. "github.com/spf13/viper"
  11. "go.uber.org/zap"
  12. "github.com/usememos/memos/internal/log"
  13. "github.com/usememos/memos/server"
  14. _profile "github.com/usememos/memos/server/profile"
  15. "github.com/usememos/memos/server/service/metric"
  16. "github.com/usememos/memos/store"
  17. "github.com/usememos/memos/store/db"
  18. )
  19. const (
  20. greetingBanner = `
  21. ███╗ ███╗███████╗███╗ ███╗ ██████╗ ███████╗
  22. ████╗ ████║██╔════╝████╗ ████║██╔═══██╗██╔════╝
  23. ██╔████╔██║█████╗ ██╔████╔██║██║ ██║███████╗
  24. ██║╚██╔╝██║██╔══╝ ██║╚██╔╝██║██║ ██║╚════██║
  25. ██║ ╚═╝ ██║███████╗██║ ╚═╝ ██║╚██████╔╝███████║
  26. ╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚══════╝
  27. `
  28. )
  29. var (
  30. profile *_profile.Profile
  31. mode string
  32. addr string
  33. port int
  34. data string
  35. driver string
  36. dsn string
  37. enableMetric bool
  38. rootCmd = &cobra.Command{
  39. Use: "memos",
  40. Short: `An open-source, self-hosted memo hub with knowledge management and social networking.`,
  41. Run: func(_cmd *cobra.Command, _args []string) {
  42. ctx, cancel := context.WithCancel(context.Background())
  43. dbDriver, err := db.NewDBDriver(profile)
  44. if err != nil {
  45. cancel()
  46. log.Error("failed to create db driver", zap.Error(err))
  47. return
  48. }
  49. if err := dbDriver.Migrate(ctx); err != nil {
  50. cancel()
  51. log.Error("failed to migrate db", zap.Error(err))
  52. return
  53. }
  54. store := store.New(dbDriver, profile)
  55. s, err := server.NewServer(ctx, profile, store)
  56. if err != nil {
  57. cancel()
  58. log.Error("failed to create server", zap.Error(err))
  59. return
  60. }
  61. if profile.Metric {
  62. // nolint
  63. metric.NewMetricClient(s.ID, *profile)
  64. }
  65. c := make(chan os.Signal, 1)
  66. // Trigger graceful shutdown on SIGINT or SIGTERM.
  67. // The default signal sent by the `kill` command is SIGTERM,
  68. // which is taken as the graceful shutdown signal for many systems, eg., Kubernetes, Gunicorn.
  69. signal.Notify(c, os.Interrupt, syscall.SIGTERM)
  70. go func() {
  71. sig := <-c
  72. log.Info(fmt.Sprintf("%s received.\n", sig.String()))
  73. s.Shutdown(ctx)
  74. cancel()
  75. }()
  76. printGreetings()
  77. if err := s.Start(ctx); err != nil {
  78. if err != http.ErrServerClosed {
  79. log.Error("failed to start server", zap.Error(err))
  80. cancel()
  81. }
  82. }
  83. // Wait for CTRL-C.
  84. <-ctx.Done()
  85. },
  86. }
  87. )
  88. func Execute() error {
  89. defer log.Sync()
  90. return rootCmd.Execute()
  91. }
  92. func init() {
  93. cobra.OnInitialize(initConfig)
  94. rootCmd.PersistentFlags().StringVarP(&mode, "mode", "m", "demo", `mode of server, can be "prod" or "dev" or "demo"`)
  95. rootCmd.PersistentFlags().StringVarP(&addr, "addr", "a", "", "address of server")
  96. rootCmd.PersistentFlags().IntVarP(&port, "port", "p", 8081, "port of server")
  97. rootCmd.PersistentFlags().StringVarP(&data, "data", "d", "", "data directory")
  98. rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
  99. rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
  100. rootCmd.PersistentFlags().BoolVarP(&enableMetric, "metric", "", true, "allow metric collection")
  101. err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
  102. if err != nil {
  103. panic(err)
  104. }
  105. err = viper.BindPFlag("addr", rootCmd.PersistentFlags().Lookup("addr"))
  106. if err != nil {
  107. panic(err)
  108. }
  109. err = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port"))
  110. if err != nil {
  111. panic(err)
  112. }
  113. err = viper.BindPFlag("data", rootCmd.PersistentFlags().Lookup("data"))
  114. if err != nil {
  115. panic(err)
  116. }
  117. err = viper.BindPFlag("driver", rootCmd.PersistentFlags().Lookup("driver"))
  118. if err != nil {
  119. panic(err)
  120. }
  121. err = viper.BindPFlag("dsn", rootCmd.PersistentFlags().Lookup("dsn"))
  122. if err != nil {
  123. panic(err)
  124. }
  125. err = viper.BindPFlag("metric", rootCmd.PersistentFlags().Lookup("metric"))
  126. if err != nil {
  127. panic(err)
  128. }
  129. viper.SetDefault("mode", "demo")
  130. viper.SetDefault("driver", "sqlite")
  131. viper.SetDefault("addr", "")
  132. viper.SetDefault("port", 8081)
  133. viper.SetDefault("metric", true)
  134. viper.SetEnvPrefix("memos")
  135. }
  136. func initConfig() {
  137. viper.AutomaticEnv()
  138. var err error
  139. profile, err = _profile.GetProfile()
  140. if err != nil {
  141. fmt.Printf("failed to get profile, error: %+v\n", err)
  142. return
  143. }
  144. println("---")
  145. println("Server profile")
  146. println("data:", profile.Data)
  147. println("dsn:", profile.DSN)
  148. println("addr:", profile.Addr)
  149. println("port:", profile.Port)
  150. println("mode:", profile.Mode)
  151. println("driver:", profile.Driver)
  152. println("version:", profile.Version)
  153. println("metric:", profile.Metric)
  154. println("---")
  155. }
  156. func printGreetings() {
  157. print(greetingBanner)
  158. if len(profile.Addr) == 0 {
  159. fmt.Printf("Version %s has been started on port %d\n", profile.Version, profile.Port)
  160. } else {
  161. fmt.Printf("Version %s has been started on address '%s' and port %d\n", profile.Version, profile.Addr, profile.Port)
  162. }
  163. println("---")
  164. println("See more in:")
  165. fmt.Printf("👉Website: %s\n", "https://usememos.com")
  166. fmt.Printf("👉GitHub: %s\n", "https://github.com/usememos/memos")
  167. println("---")
  168. }