server.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package server
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/pkg/errors"
  9. "github.com/usememos/memos/api"
  10. metric "github.com/usememos/memos/plugin/metrics"
  11. "github.com/usememos/memos/server/profile"
  12. "github.com/usememos/memos/store"
  13. "github.com/usememos/memos/store/db"
  14. "github.com/gorilla/sessions"
  15. "github.com/labstack/echo-contrib/session"
  16. "github.com/labstack/echo/v4"
  17. "github.com/labstack/echo/v4/middleware"
  18. )
  19. type Server struct {
  20. e *echo.Echo
  21. db *sql.DB
  22. ID string
  23. Profile *profile.Profile
  24. Store *store.Store
  25. Collector *MetricCollector
  26. }
  27. func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
  28. e := echo.New()
  29. e.Debug = true
  30. e.HideBanner = true
  31. e.HidePort = true
  32. db := db.NewDB(profile)
  33. if err := db.Open(ctx); err != nil {
  34. return nil, errors.Wrap(err, "cannot open db")
  35. }
  36. s := &Server{
  37. e: e,
  38. db: db.DBInstance,
  39. Profile: profile,
  40. }
  41. storeInstance := store.New(db.DBInstance, profile)
  42. s.Store = storeInstance
  43. e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
  44. Format: `{"time":"${time_rfc3339}",` +
  45. `"method":"${method}","uri":"${uri}",` +
  46. `"status":${status},"error":"${error}"}` + "\n",
  47. }))
  48. e.Use(middleware.Gzip())
  49. e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
  50. Skipper: s.defaultAuthSkipper,
  51. TokenLookup: "cookie:_csrf",
  52. }))
  53. e.Use(middleware.CORS())
  54. e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
  55. Skipper: defaultGetRequestSkipper,
  56. XSSProtection: "1; mode=block",
  57. ContentTypeNosniff: "nosniff",
  58. XFrameOptions: "SAMEORIGIN",
  59. HSTSPreloadEnabled: false,
  60. }))
  61. e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
  62. ErrorMessage: "Request timeout",
  63. Timeout: 30 * time.Second,
  64. }))
  65. serverID, err := s.getSystemServerID(ctx)
  66. if err != nil {
  67. return nil, err
  68. }
  69. s.ID = serverID
  70. secretSessionName := "usememos"
  71. if profile.Mode == "prod" {
  72. secretSessionName, err = s.getSystemSecretSessionName(ctx)
  73. if err != nil {
  74. return nil, err
  75. }
  76. }
  77. e.Use(session.Middleware(sessions.NewCookieStore([]byte(secretSessionName))))
  78. embedFrontend(e)
  79. // Register MetricCollector to server.
  80. s.registerMetricCollector()
  81. rootGroup := e.Group("")
  82. s.registerRSSRoutes(rootGroup)
  83. publicGroup := e.Group("/o")
  84. s.registerResourcePublicRoutes(publicGroup)
  85. registerGetterPublicRoutes(publicGroup)
  86. apiGroup := e.Group("/api")
  87. apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
  88. return aclMiddleware(s, next)
  89. })
  90. s.registerSystemRoutes(apiGroup)
  91. s.registerAuthRoutes(apiGroup)
  92. s.registerUserRoutes(apiGroup)
  93. s.registerMemoRoutes(apiGroup)
  94. s.registerShortcutRoutes(apiGroup)
  95. s.registerResourceRoutes(apiGroup)
  96. s.registerTagRoutes(apiGroup)
  97. s.registerStorageRoutes(apiGroup)
  98. s.registerIdentityProviderRoutes(apiGroup)
  99. return s, nil
  100. }
  101. func (s *Server) Start(ctx context.Context) error {
  102. if err := s.createServerStartActivity(ctx); err != nil {
  103. return errors.Wrap(err, "failed to create activity")
  104. }
  105. s.Collector.Identify(ctx)
  106. return s.e.Start(fmt.Sprintf(":%d", s.Profile.Port))
  107. }
  108. func (s *Server) Shutdown(ctx context.Context) {
  109. ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
  110. defer cancel()
  111. // Shutdown echo server
  112. if err := s.e.Shutdown(ctx); err != nil {
  113. fmt.Printf("failed to shutdown server, error: %v\n", err)
  114. }
  115. // Close database connection
  116. if err := s.db.Close(); err != nil {
  117. fmt.Printf("failed to close database, error: %v\n", err)
  118. }
  119. fmt.Printf("memos stopped properly\n")
  120. }
  121. func (s *Server) createServerStartActivity(ctx context.Context) error {
  122. payload := api.ActivityServerStartPayload{
  123. ServerID: s.ID,
  124. Profile: s.Profile,
  125. }
  126. payloadBytes, err := json.Marshal(payload)
  127. if err != nil {
  128. return errors.Wrap(err, "failed to marshal activity payload")
  129. }
  130. activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{
  131. CreatorID: api.UnknownID,
  132. Type: api.ActivityServerStart,
  133. Level: api.ActivityInfo,
  134. Payload: string(payloadBytes),
  135. })
  136. if err != nil || activity == nil {
  137. return errors.Wrap(err, "failed to create activity")
  138. }
  139. s.Collector.Collect(ctx, &metric.Metric{
  140. Name: string(activity.Type),
  141. })
  142. return err
  143. }