jwt.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. package server
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/golang-jwt/jwt/v4"
  10. "github.com/labstack/echo/v4"
  11. pkgerrors "github.com/pkg/errors"
  12. "github.com/usememos/memos/api"
  13. "github.com/usememos/memos/common"
  14. "github.com/usememos/memos/server/auth"
  15. )
  16. const (
  17. // Context section
  18. // The key name used to store user id in the context
  19. // user id is extracted from the jwt token subject field.
  20. userIDContextKey = "user-id"
  21. )
  22. // Claims creates a struct that will be encoded to a JWT.
  23. // We add jwt.RegisteredClaims as an embedded type, to provide fields such as name.
  24. type Claims struct {
  25. Name string `json:"name"`
  26. jwt.RegisteredClaims
  27. }
  28. func getUserIDContextKey() string {
  29. return userIDContextKey
  30. }
  31. // GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
  32. func GenerateTokensAndSetCookies(c echo.Context, user *api.User, mode string, secret string) error {
  33. accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, mode, secret)
  34. if err != nil {
  35. return pkgerrors.Wrap(err, "failed to generate access token")
  36. }
  37. cookieExp := time.Now().Add(auth.CookieExpDuration)
  38. setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
  39. // We generate here a new refresh token and saving it to the cookie.
  40. refreshToken, err := auth.GenerateRefreshToken(user.Username, user.ID, mode, secret)
  41. if err != nil {
  42. return pkgerrors.Wrap(err, "failed to generate refresh token")
  43. }
  44. setTokenCookie(c, auth.RefreshTokenCookieName, refreshToken, cookieExp)
  45. return nil
  46. }
  47. // RemoveTokensAndCookies removes the jwt token and refresh token from the cookies.
  48. func RemoveTokensAndCookies(c echo.Context) {
  49. // We set the expiration time to the past, so that the cookie will be removed.
  50. cookieExp := time.Now().Add(-1 * time.Hour)
  51. setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
  52. setTokenCookie(c, auth.RefreshTokenCookieName, "", cookieExp)
  53. }
  54. // Here we are creating a new cookie, which will store the valid JWT token.
  55. func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
  56. cookie := new(http.Cookie)
  57. cookie.Name = name
  58. cookie.Value = token
  59. cookie.Expires = expiration
  60. cookie.Path = "/"
  61. // Http-only helps mitigate the risk of client side script accessing the protected cookie.
  62. cookie.HttpOnly = true
  63. cookie.SameSite = http.SameSiteStrictMode
  64. c.SetCookie(cookie)
  65. }
  66. func extractTokenFromHeader(c echo.Context) (string, error) {
  67. authHeader := c.Request().Header.Get("Authorization")
  68. if authHeader == "" {
  69. return "", nil
  70. }
  71. authHeaderParts := strings.Fields(authHeader)
  72. if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
  73. return "", errors.New("Authorization header format must be Bearer {token}")
  74. }
  75. return authHeaderParts[1], nil
  76. }
  77. func findAccessToken(c echo.Context) string {
  78. accessToken := ""
  79. cookie, _ := c.Cookie(auth.AccessTokenCookieName)
  80. if cookie != nil {
  81. accessToken = cookie.Value
  82. }
  83. if accessToken == "" {
  84. accessToken, _ = extractTokenFromHeader(c)
  85. }
  86. return accessToken
  87. }
  88. // JWTMiddleware validates the access token.
  89. // If the access token is about to expire or has expired and the request has a valid refresh token, it
  90. // will try to generate new access token and refresh token.
  91. func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc {
  92. return func(c echo.Context) error {
  93. path := c.Request().URL.Path
  94. method := c.Request().Method
  95. mode := server.Profile.Mode
  96. if server.defaultAuthSkipper(c) {
  97. return next(c)
  98. }
  99. // Skip validation for server status endpoints.
  100. if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/idp", "/api/user/:id") && method == http.MethodGet {
  101. return next(c)
  102. }
  103. token := findAccessToken(c)
  104. if token == "" {
  105. // Allow the user to access the public endpoints.
  106. if common.HasPrefixes(path, "/o") {
  107. return next(c)
  108. }
  109. // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
  110. if common.HasPrefixes(path, "/api/memo") && method == http.MethodGet {
  111. return next(c)
  112. }
  113. return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
  114. }
  115. claims := &Claims{}
  116. accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
  117. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  118. return nil, pkgerrors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
  119. }
  120. if kid, ok := t.Header["kid"].(string); ok {
  121. if kid == "v1" {
  122. return []byte(secret), nil
  123. }
  124. }
  125. return nil, pkgerrors.Errorf("unexpected access token kid=%v", t.Header["kid"])
  126. })
  127. if !audienceContains(claims.Audience, fmt.Sprintf(auth.AccessTokenAudienceFmt, mode)) {
  128. return echo.NewHTTPError(http.StatusUnauthorized,
  129. fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
  130. claims.Audience,
  131. fmt.Sprintf(auth.AccessTokenAudienceFmt, mode),
  132. ))
  133. }
  134. generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration
  135. if err != nil {
  136. var ve *jwt.ValidationError
  137. if errors.As(err, &ve) {
  138. // If expiration error is the only error, we will clear the err
  139. // and generate new access token and refresh token
  140. if ve.Errors == jwt.ValidationErrorExpired {
  141. generateToken = true
  142. }
  143. } else {
  144. return &echo.HTTPError{
  145. Code: http.StatusUnauthorized,
  146. Message: "Invalid or expired access token",
  147. Internal: err,
  148. }
  149. }
  150. }
  151. // We either have a valid access token or we will attempt to generate new access token and refresh token
  152. ctx := c.Request().Context()
  153. userID, err := strconv.Atoi(claims.Subject)
  154. if err != nil {
  155. return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.")
  156. }
  157. // Even if there is no error, we still need to make sure the user still exists.
  158. user, err := server.Store.FindUser(ctx, &api.UserFind{
  159. ID: &userID,
  160. })
  161. if err != nil {
  162. return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err)
  163. }
  164. if user == nil {
  165. return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID))
  166. }
  167. if generateToken {
  168. generateTokenFunc := func() error {
  169. rc, err := c.Cookie(auth.RefreshTokenCookieName)
  170. if err != nil {
  171. return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.")
  172. }
  173. // Parses token and checks if it's valid.
  174. refreshTokenClaims := &Claims{}
  175. refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) {
  176. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  177. return nil, pkgerrors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256)
  178. }
  179. if kid, ok := t.Header["kid"].(string); ok {
  180. if kid == "v1" {
  181. return []byte(secret), nil
  182. }
  183. }
  184. return nil, pkgerrors.Errorf("unexpected refresh token kid=%v", t.Header["kid"])
  185. })
  186. if err != nil {
  187. if err == jwt.ErrSignatureInvalid {
  188. return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.")
  189. }
  190. return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err)
  191. }
  192. if !audienceContains(refreshTokenClaims.Audience, fmt.Sprintf(auth.RefreshTokenAudienceFmt, mode)) {
  193. return echo.NewHTTPError(http.StatusUnauthorized,
  194. fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
  195. refreshTokenClaims.Audience,
  196. fmt.Sprintf(auth.RefreshTokenAudienceFmt, mode),
  197. ))
  198. }
  199. // If we have a valid refresh token, we will generate new access token and refresh token
  200. if refreshToken != nil && refreshToken.Valid {
  201. if err := GenerateTokensAndSetCookies(c, user, mode, secret); err != nil {
  202. return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err)
  203. }
  204. }
  205. return nil
  206. }
  207. // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token
  208. // In such case, we won't return the error.
  209. if err := generateTokenFunc(); err != nil && !accessToken.Valid {
  210. return err
  211. }
  212. }
  213. // Stores userID into context.
  214. c.Set(getUserIDContextKey(), userID)
  215. return next(c)
  216. }
  217. }
  218. func audienceContains(audience jwt.ClaimStrings, token string) bool {
  219. for _, v := range audience {
  220. if v == token {
  221. return true
  222. }
  223. }
  224. return false
  225. }