server.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package server
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "time"
  8. "github.com/google/uuid"
  9. "github.com/pkg/errors"
  10. apiv1 "github.com/usememos/memos/api/v1"
  11. "github.com/usememos/memos/common/util"
  12. "github.com/usememos/memos/plugin/telegram"
  13. "github.com/usememos/memos/server/profile"
  14. "github.com/usememos/memos/store"
  15. "github.com/labstack/echo/v4"
  16. "github.com/labstack/echo/v4/middleware"
  17. )
  18. type Server struct {
  19. e *echo.Echo
  20. ID string
  21. Secret string
  22. Profile *profile.Profile
  23. Store *store.Store
  24. telegramBot *telegram.Bot
  25. }
  26. func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) {
  27. e := echo.New()
  28. e.Debug = true
  29. e.HideBanner = true
  30. e.HidePort = true
  31. s := &Server{
  32. e: e,
  33. Store: store,
  34. Profile: profile,
  35. }
  36. telegramBotHandler := newTelegramHandler(store)
  37. s.telegramBot = telegram.NewBotWithHandler(telegramBotHandler)
  38. e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
  39. Format: `{"time":"${time_rfc3339}",` +
  40. `"method":"${method}","uri":"${uri}",` +
  41. `"status":${status},"error":"${error}"}` + "\n",
  42. }))
  43. e.Use(middleware.Gzip())
  44. e.Use(middleware.CORS())
  45. e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
  46. Skipper: defaultGetRequestSkipper,
  47. XSSProtection: "1; mode=block",
  48. ContentTypeNosniff: "nosniff",
  49. XFrameOptions: "SAMEORIGIN",
  50. HSTSPreloadEnabled: false,
  51. }))
  52. e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
  53. Skipper: func(c echo.Context) bool {
  54. // this is a hack to skip timeout for openai chat streaming
  55. // because streaming require to flush response. But the timeout middleware will break it.
  56. return c.Request().URL.Path == "/api/v1/openai/chat-streaming"
  57. },
  58. ErrorMessage: "Request timeout",
  59. Timeout: 30 * time.Second,
  60. }))
  61. serverID, err := s.getSystemServerID(ctx)
  62. if err != nil {
  63. return nil, fmt.Errorf("failed to retrieve system server ID: %w", err)
  64. }
  65. s.ID = serverID
  66. embedFrontend(e)
  67. secret := "usememos"
  68. if profile.Mode == "prod" {
  69. secret, err = s.getSystemSecretSessionName(ctx)
  70. if err != nil {
  71. return nil, fmt.Errorf("failed to retrieve system secret session name: %w", err)
  72. }
  73. }
  74. s.Secret = secret
  75. rootGroup := e.Group("")
  76. apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store)
  77. apiV1Service.Register(rootGroup)
  78. return s, nil
  79. }
  80. func (s *Server) Start(ctx context.Context) error {
  81. if err := s.createServerStartActivity(ctx); err != nil {
  82. return errors.Wrap(err, "failed to create activity")
  83. }
  84. go s.telegramBot.Start(ctx)
  85. go autoBackup(ctx, s.Store)
  86. return s.e.Start(fmt.Sprintf(":%d", s.Profile.Port))
  87. }
  88. func (s *Server) Shutdown(ctx context.Context) {
  89. ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
  90. defer cancel()
  91. // Shutdown echo server
  92. if err := s.e.Shutdown(ctx); err != nil {
  93. fmt.Printf("failed to shutdown server, error: %v\n", err)
  94. }
  95. // Close database connection
  96. if err := s.Store.GetDB().Close(); err != nil {
  97. fmt.Printf("failed to close database, error: %v\n", err)
  98. }
  99. fmt.Printf("memos stopped properly\n")
  100. }
  101. func (s *Server) GetEcho() *echo.Echo {
  102. return s.e
  103. }
  104. func (s *Server) getSystemServerID(ctx context.Context) (string, error) {
  105. serverIDSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
  106. Name: apiv1.SystemSettingServerIDName.String(),
  107. })
  108. if err != nil {
  109. return "", err
  110. }
  111. if serverIDSetting == nil || serverIDSetting.Value == "" {
  112. serverIDSetting, err = s.Store.UpsertSystemSetting(ctx, &store.SystemSetting{
  113. Name: apiv1.SystemSettingServerIDName.String(),
  114. Value: uuid.NewString(),
  115. })
  116. if err != nil {
  117. return "", err
  118. }
  119. }
  120. return serverIDSetting.Value, nil
  121. }
  122. func (s *Server) getSystemSecretSessionName(ctx context.Context) (string, error) {
  123. secretSessionNameValue, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
  124. Name: apiv1.SystemSettingSecretSessionName.String(),
  125. })
  126. if err != nil {
  127. return "", err
  128. }
  129. if secretSessionNameValue == nil || secretSessionNameValue.Value == "" {
  130. secretSessionNameValue, err = s.Store.UpsertSystemSetting(ctx, &store.SystemSetting{
  131. Name: apiv1.SystemSettingSecretSessionName.String(),
  132. Value: uuid.NewString(),
  133. })
  134. if err != nil {
  135. return "", err
  136. }
  137. }
  138. return secretSessionNameValue.Value, nil
  139. }
  140. func (s *Server) createServerStartActivity(ctx context.Context) error {
  141. payload := apiv1.ActivityServerStartPayload{
  142. ServerID: s.ID,
  143. Profile: s.Profile,
  144. }
  145. payloadBytes, err := json.Marshal(payload)
  146. if err != nil {
  147. return errors.Wrap(err, "failed to marshal activity payload")
  148. }
  149. activity, err := s.Store.CreateActivity(ctx, &store.Activity{
  150. CreatorID: apiv1.UnknownID,
  151. Type: apiv1.ActivityServerStart.String(),
  152. Level: apiv1.ActivityInfo.String(),
  153. Payload: string(payloadBytes),
  154. })
  155. if err != nil || activity == nil {
  156. return errors.Wrap(err, "failed to create activity")
  157. }
  158. return err
  159. }
  160. func defaultGetRequestSkipper(c echo.Context) bool {
  161. return c.Request().Method == http.MethodGet
  162. }
  163. func defaultAPIRequestSkipper(c echo.Context) bool {
  164. path := c.Path()
  165. return util.HasPrefixes(path, "/api", "/api/v1")
  166. }