acl.go 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package server
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "github.com/usememos/memos/api"
  7. "github.com/usememos/memos/common"
  8. "github.com/gorilla/sessions"
  9. "github.com/labstack/echo-contrib/session"
  10. "github.com/labstack/echo/v4"
  11. )
  12. var (
  13. userIDContextKey = "user-id"
  14. sessionName = "memos_session"
  15. )
  16. func getUserIDContextKey() string {
  17. return userIDContextKey
  18. }
  19. func setUserSession(ctx echo.Context, user *api.User) error {
  20. sess, _ := session.Get(sessionName, ctx)
  21. sess.Options = &sessions.Options{
  22. Path: "/",
  23. MaxAge: 3600 * 24 * 30,
  24. HttpOnly: true,
  25. SameSite: http.SameSiteStrictMode,
  26. }
  27. sess.Values[userIDContextKey] = user.ID
  28. err := sess.Save(ctx.Request(), ctx.Response())
  29. if err != nil {
  30. return fmt.Errorf("failed to set session, err: %w", err)
  31. }
  32. return nil
  33. }
  34. func removeUserSession(ctx echo.Context) error {
  35. sess, _ := session.Get(sessionName, ctx)
  36. sess.Options = &sessions.Options{
  37. Path: "/",
  38. MaxAge: 0,
  39. HttpOnly: true,
  40. }
  41. sess.Values[userIDContextKey] = nil
  42. err := sess.Save(ctx.Request(), ctx.Response())
  43. if err != nil {
  44. return fmt.Errorf("failed to set session, err: %w", err)
  45. }
  46. return nil
  47. }
  48. func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
  49. return func(c echo.Context) error {
  50. ctx := c.Request().Context()
  51. path := c.Path()
  52. if s.defaultAuthSkipper(c) {
  53. return next(c)
  54. }
  55. sess, _ := session.Get(sessionName, c)
  56. userIDValue := sess.Values[userIDContextKey]
  57. if userIDValue != nil {
  58. userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
  59. userFind := &api.UserFind{
  60. ID: &userID,
  61. }
  62. user, err := s.Store.FindUser(ctx, userFind)
  63. if err != nil && common.ErrorCode(err) != common.NotFound {
  64. return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
  65. }
  66. if user != nil {
  67. if user.RowStatus == api.Archived {
  68. return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username))
  69. }
  70. c.Set(getUserIDContextKey(), userID)
  71. }
  72. }
  73. if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/idp", "/api/user/:id", "/api/memo") && c.Request().Method == http.MethodGet {
  74. return next(c)
  75. }
  76. userID := c.Get(getUserIDContextKey())
  77. if userID == nil {
  78. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  79. }
  80. return next(c)
  81. }
  82. }