server.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package server
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/pkg/errors"
  9. "github.com/usememos/memos/api"
  10. metric "github.com/usememos/memos/plugin/metrics"
  11. "github.com/usememos/memos/server/profile"
  12. "github.com/usememos/memos/store"
  13. "github.com/usememos/memos/store/db"
  14. "github.com/gorilla/sessions"
  15. "github.com/labstack/echo-contrib/session"
  16. "github.com/labstack/echo/v4"
  17. "github.com/labstack/echo/v4/middleware"
  18. )
  19. type Server struct {
  20. e *echo.Echo
  21. db *sql.DB
  22. ID string
  23. Profile *profile.Profile
  24. Store *store.Store
  25. Collector *MetricCollector
  26. }
  27. func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
  28. e := echo.New()
  29. e.Debug = true
  30. e.HideBanner = true
  31. e.HidePort = true
  32. db := db.NewDB(profile)
  33. if err := db.Open(ctx); err != nil {
  34. return nil, errors.Wrap(err, "cannot open db")
  35. }
  36. s := &Server{
  37. e: e,
  38. db: db.DBInstance,
  39. Profile: profile,
  40. }
  41. storeInstance := store.New(db.DBInstance, profile)
  42. s.Store = storeInstance
  43. e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
  44. Format: `{"time":"${time_rfc3339}",` +
  45. `"method":"${method}","uri":"${uri}",` +
  46. `"status":${status},"error":"${error}"}` + "\n",
  47. }))
  48. e.Use(middleware.Gzip())
  49. e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
  50. Skipper: s.defaultAuthSkipper,
  51. TokenLookup: "cookie:_csrf",
  52. }))
  53. e.Use(middleware.CORS())
  54. e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
  55. Skipper: defaultGetRequestSkipper,
  56. XSSProtection: "1; mode=block",
  57. ContentTypeNosniff: "nosniff",
  58. XFrameOptions: "SAMEORIGIN",
  59. HSTSPreloadEnabled: false,
  60. }))
  61. e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
  62. ErrorMessage: "Request timeout",
  63. Timeout: 30 * time.Second,
  64. }))
  65. serverID, err := s.getSystemServerID(ctx)
  66. if err != nil {
  67. return nil, err
  68. }
  69. s.ID = serverID
  70. secretSessionName := "usememos"
  71. if profile.Mode == "prod" {
  72. secretSessionName, err = s.getSystemSecretSessionName(ctx)
  73. if err != nil {
  74. return nil, err
  75. }
  76. }
  77. e.Use(session.Middleware(sessions.NewCookieStore([]byte(secretSessionName))))
  78. embedFrontend(e)
  79. // Register MetricCollector to server.
  80. s.registerMetricCollector()
  81. rootGroup := e.Group("")
  82. s.registerRSSRoutes(rootGroup)
  83. publicGroup := e.Group("/o")
  84. s.registerResourcePublicRoutes(publicGroup)
  85. registerGetterPublicRoutes(publicGroup)
  86. apiGroup := e.Group("/api")
  87. apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
  88. return aclMiddleware(s, next)
  89. })
  90. s.registerSystemRoutes(apiGroup)
  91. s.registerAuthRoutes(apiGroup)
  92. s.registerUserRoutes(apiGroup)
  93. s.registerMemoRoutes(apiGroup)
  94. s.registerShortcutRoutes(apiGroup)
  95. s.registerResourceRoutes(apiGroup)
  96. s.registerTagRoutes(apiGroup)
  97. s.registerStorageRoutes(apiGroup)
  98. s.registerIdentityProviderRoutes(apiGroup)
  99. s.registerOpenAIRoutes(apiGroup)
  100. return s, nil
  101. }
  102. func (s *Server) Start(ctx context.Context) error {
  103. if err := s.createServerStartActivity(ctx); err != nil {
  104. return errors.Wrap(err, "failed to create activity")
  105. }
  106. s.Collector.Identify(ctx)
  107. return s.e.Start(fmt.Sprintf(":%d", s.Profile.Port))
  108. }
  109. func (s *Server) Shutdown(ctx context.Context) {
  110. ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
  111. defer cancel()
  112. // Shutdown echo server
  113. if err := s.e.Shutdown(ctx); err != nil {
  114. fmt.Printf("failed to shutdown server, error: %v\n", err)
  115. }
  116. // Close database connection
  117. if err := s.db.Close(); err != nil {
  118. fmt.Printf("failed to close database, error: %v\n", err)
  119. }
  120. fmt.Printf("memos stopped properly\n")
  121. }
  122. func (s *Server) createServerStartActivity(ctx context.Context) error {
  123. payload := api.ActivityServerStartPayload{
  124. ServerID: s.ID,
  125. Profile: s.Profile,
  126. }
  127. payloadBytes, err := json.Marshal(payload)
  128. if err != nil {
  129. return errors.Wrap(err, "failed to marshal activity payload")
  130. }
  131. activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{
  132. CreatorID: api.UnknownID,
  133. Type: api.ActivityServerStart,
  134. Level: api.ActivityInfo,
  135. Payload: string(payloadBytes),
  136. })
  137. if err != nil || activity == nil {
  138. return errors.Wrap(err, "failed to create activity")
  139. }
  140. s.Collector.Collect(ctx, &metric.Metric{
  141. Name: string(activity.Type),
  142. })
  143. return err
  144. }