server.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package server
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/pkg/errors"
  9. "github.com/usememos/memos/api"
  10. "github.com/usememos/memos/server/profile"
  11. "github.com/usememos/memos/store"
  12. "github.com/usememos/memos/store/db"
  13. "github.com/gorilla/sessions"
  14. "github.com/labstack/echo-contrib/session"
  15. "github.com/labstack/echo/v4"
  16. "github.com/labstack/echo/v4/middleware"
  17. )
  18. type Server struct {
  19. e *echo.Echo
  20. db *sql.DB
  21. ID string
  22. Profile *profile.Profile
  23. Store *store.Store
  24. }
  25. func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
  26. e := echo.New()
  27. e.Debug = true
  28. e.HideBanner = true
  29. e.HidePort = true
  30. db := db.NewDB(profile)
  31. if err := db.Open(ctx); err != nil {
  32. return nil, errors.Wrap(err, "cannot open db")
  33. }
  34. s := &Server{
  35. e: e,
  36. db: db.DBInstance,
  37. Profile: profile,
  38. }
  39. storeInstance := store.New(db.DBInstance, profile)
  40. s.Store = storeInstance
  41. e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
  42. Format: `{"time":"${time_rfc3339}",` +
  43. `"method":"${method}","uri":"${uri}",` +
  44. `"status":${status},"error":"${error}"}` + "\n",
  45. }))
  46. e.Use(middleware.Gzip())
  47. e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
  48. Skipper: s.defaultAuthSkipper,
  49. TokenLookup: "cookie:_csrf",
  50. }))
  51. e.Use(middleware.CORS())
  52. e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
  53. Skipper: defaultGetRequestSkipper,
  54. XSSProtection: "1; mode=block",
  55. ContentTypeNosniff: "nosniff",
  56. XFrameOptions: "SAMEORIGIN",
  57. HSTSPreloadEnabled: false,
  58. }))
  59. e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
  60. ErrorMessage: "Request timeout",
  61. Timeout: 30 * time.Second,
  62. }))
  63. serverID, err := s.getSystemServerID(ctx)
  64. if err != nil {
  65. return nil, err
  66. }
  67. s.ID = serverID
  68. secretSessionName := "usememos"
  69. if profile.Mode == "prod" {
  70. secretSessionName, err = s.getSystemSecretSessionName(ctx)
  71. if err != nil {
  72. return nil, err
  73. }
  74. }
  75. e.Use(session.Middleware(sessions.NewCookieStore([]byte(secretSessionName))))
  76. embedFrontend(e)
  77. rootGroup := e.Group("")
  78. s.registerRSSRoutes(rootGroup)
  79. publicGroup := e.Group("/o")
  80. s.registerResourcePublicRoutes(publicGroup)
  81. registerGetterPublicRoutes(publicGroup)
  82. apiGroup := e.Group("/api")
  83. apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
  84. return aclMiddleware(s, next)
  85. })
  86. s.registerSystemRoutes(apiGroup)
  87. s.registerAuthRoutes(apiGroup)
  88. s.registerUserRoutes(apiGroup)
  89. s.registerMemoRoutes(apiGroup)
  90. s.registerShortcutRoutes(apiGroup)
  91. s.registerResourceRoutes(apiGroup)
  92. s.registerTagRoutes(apiGroup)
  93. s.registerStorageRoutes(apiGroup)
  94. s.registerIdentityProviderRoutes(apiGroup)
  95. s.registerOpenAIRoutes(apiGroup)
  96. return s, nil
  97. }
  98. func (s *Server) Start(ctx context.Context) error {
  99. if err := s.createServerStartActivity(ctx); err != nil {
  100. return errors.Wrap(err, "failed to create activity")
  101. }
  102. return s.e.Start(fmt.Sprintf(":%d", s.Profile.Port))
  103. }
  104. func (s *Server) Shutdown(ctx context.Context) {
  105. ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
  106. defer cancel()
  107. // Shutdown echo server
  108. if err := s.e.Shutdown(ctx); err != nil {
  109. fmt.Printf("failed to shutdown server, error: %v\n", err)
  110. }
  111. // Close database connection
  112. if err := s.db.Close(); err != nil {
  113. fmt.Printf("failed to close database, error: %v\n", err)
  114. }
  115. fmt.Printf("memos stopped properly\n")
  116. }
  117. func (s *Server) createServerStartActivity(ctx context.Context) error {
  118. payload := api.ActivityServerStartPayload{
  119. ServerID: s.ID,
  120. Profile: s.Profile,
  121. }
  122. payloadBytes, err := json.Marshal(payload)
  123. if err != nil {
  124. return errors.Wrap(err, "failed to marshal activity payload")
  125. }
  126. activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{
  127. CreatorID: api.UnknownID,
  128. Type: api.ActivityServerStart,
  129. Level: api.ActivityInfo,
  130. Payload: string(payloadBytes),
  131. })
  132. if err != nil || activity == nil {
  133. return errors.Wrap(err, "failed to create activity")
  134. }
  135. return err
  136. }