acl.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package server
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "github.com/usememos/memos/api"
  7. "github.com/usememos/memos/common"
  8. "github.com/gorilla/sessions"
  9. "github.com/labstack/echo-contrib/session"
  10. "github.com/labstack/echo/v4"
  11. )
  12. var (
  13. userIDContextKey = "user-id"
  14. )
  15. func getUserIDContextKey() string {
  16. return userIDContextKey
  17. }
  18. func setUserSession(ctx echo.Context, user *api.User) error {
  19. sess, _ := session.Get("memos_session", ctx)
  20. sess.Options = &sessions.Options{
  21. Path: "/",
  22. MaxAge: 3600 * 24 * 30,
  23. HttpOnly: true,
  24. SameSite: http.SameSiteStrictMode,
  25. }
  26. sess.Values[userIDContextKey] = user.ID
  27. err := sess.Save(ctx.Request(), ctx.Response())
  28. if err != nil {
  29. return fmt.Errorf("failed to set session, err: %w", err)
  30. }
  31. return nil
  32. }
  33. func removeUserSession(ctx echo.Context) error {
  34. sess, _ := session.Get("memos_session", ctx)
  35. sess.Options = &sessions.Options{
  36. Path: "/",
  37. MaxAge: 0,
  38. HttpOnly: true,
  39. }
  40. sess.Values[userIDContextKey] = nil
  41. err := sess.Save(ctx.Request(), ctx.Response())
  42. if err != nil {
  43. return fmt.Errorf("failed to set session, err: %w", err)
  44. }
  45. return nil
  46. }
  47. func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
  48. return func(c echo.Context) error {
  49. ctx := c.Request().Context()
  50. path := c.Path()
  51. // Skip auth.
  52. if common.HasPrefixes(path, "/api/auth") {
  53. return next(c)
  54. }
  55. {
  56. // If there is openId in query string and related user is found, then skip auth.
  57. openID := c.QueryParam("openId")
  58. if openID != "" {
  59. userFind := &api.UserFind{
  60. OpenID: &openID,
  61. }
  62. user, err := s.Store.FindUser(ctx, userFind)
  63. if err != nil && common.ErrorCode(err) != common.NotFound {
  64. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by open_id").SetInternal(err)
  65. }
  66. if user != nil {
  67. // Stores userID into context.
  68. c.Set(getUserIDContextKey(), user.ID)
  69. return next(c)
  70. }
  71. }
  72. }
  73. {
  74. sess, _ := session.Get("memos_session", c)
  75. userIDValue := sess.Values[userIDContextKey]
  76. if userIDValue != nil {
  77. userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
  78. userFind := &api.UserFind{
  79. ID: &userID,
  80. }
  81. user, err := s.Store.FindUser(ctx, userFind)
  82. if err != nil && common.ErrorCode(err) != common.NotFound {
  83. return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
  84. }
  85. if user != nil {
  86. if user.RowStatus == api.Archived {
  87. return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username))
  88. }
  89. c.Set(getUserIDContextKey(), userID)
  90. }
  91. }
  92. }
  93. if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id", "/api/memo/all", "/api/memo/:memoId", "/api/memo/amount") && c.Request().Method == http.MethodGet {
  94. return next(c)
  95. }
  96. if common.HasPrefixes(path, "/api/memo", "/api/tag", "/api/shortcut") && c.Request().Method == http.MethodGet {
  97. if _, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
  98. return next(c)
  99. }
  100. }
  101. userID := c.Get(getUserIDContextKey())
  102. if userID == nil {
  103. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  104. }
  105. return next(c)
  106. }
  107. }