jwt.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package server
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "strings"
  7. "time"
  8. "github.com/golang-jwt/jwt/v4"
  9. "github.com/labstack/echo/v4"
  10. "github.com/pkg/errors"
  11. "github.com/usememos/memos/common"
  12. "github.com/usememos/memos/server/auth"
  13. "github.com/usememos/memos/store"
  14. )
  15. const (
  16. // Context section
  17. // The key name used to store user id in the context
  18. // user id is extracted from the jwt token subject field.
  19. userIDContextKey = "user-id"
  20. )
  21. func getUserIDContextKey() string {
  22. return userIDContextKey
  23. }
  24. // Claims creates a struct that will be encoded to a JWT.
  25. // We add jwt.RegisteredClaims as an embedded type, to provide fields such as name.
  26. type Claims struct {
  27. Name string `json:"name"`
  28. jwt.RegisteredClaims
  29. }
  30. func extractTokenFromHeader(c echo.Context) (string, error) {
  31. authHeader := c.Request().Header.Get("Authorization")
  32. if authHeader == "" {
  33. return "", nil
  34. }
  35. authHeaderParts := strings.Fields(authHeader)
  36. if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
  37. return "", errors.New("Authorization header format must be Bearer {token}")
  38. }
  39. return authHeaderParts[1], nil
  40. }
  41. func findAccessToken(c echo.Context) string {
  42. accessToken := ""
  43. cookie, _ := c.Cookie(auth.AccessTokenCookieName)
  44. if cookie != nil {
  45. accessToken = cookie.Value
  46. }
  47. if accessToken == "" {
  48. accessToken, _ = extractTokenFromHeader(c)
  49. }
  50. return accessToken
  51. }
  52. func audienceContains(audience jwt.ClaimStrings, token string) bool {
  53. for _, v := range audience {
  54. if v == token {
  55. return true
  56. }
  57. }
  58. return false
  59. }
  60. // JWTMiddleware validates the access token.
  61. // If the access token is about to expire or has expired and the request has a valid refresh token, it
  62. // will try to generate new access token and refresh token.
  63. func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc {
  64. return func(c echo.Context) error {
  65. path := c.Request().URL.Path
  66. method := c.Request().Method
  67. if server.defaultAuthSkipper(c) {
  68. return next(c)
  69. }
  70. // Skip validation for server status endpoints.
  71. if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
  72. return next(c)
  73. }
  74. token := findAccessToken(c)
  75. if token == "" {
  76. // Allow the user to access the public endpoints.
  77. if common.HasPrefixes(path, "/o") {
  78. return next(c)
  79. }
  80. // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
  81. if common.HasPrefixes(path, "/api/status", "/api/memo") && method == http.MethodGet {
  82. return next(c)
  83. }
  84. return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
  85. }
  86. claims := &Claims{}
  87. accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
  88. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  89. return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
  90. }
  91. if kid, ok := t.Header["kid"].(string); ok {
  92. if kid == "v1" {
  93. return []byte(secret), nil
  94. }
  95. }
  96. return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
  97. })
  98. if !accessToken.Valid {
  99. return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.")
  100. }
  101. if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
  102. return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName))
  103. }
  104. generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration
  105. if err != nil {
  106. var ve *jwt.ValidationError
  107. if errors.As(err, &ve) {
  108. // If expiration error is the only error, we will clear the err
  109. // and generate new access token and refresh token
  110. if ve.Errors == jwt.ValidationErrorExpired {
  111. generateToken = true
  112. }
  113. } else {
  114. return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token"))
  115. }
  116. }
  117. // We either have a valid access token or we will attempt to generate new access token and refresh token
  118. ctx := c.Request().Context()
  119. userID, err := strconv.Atoi(claims.Subject)
  120. if err != nil {
  121. return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.")
  122. }
  123. // Even if there is no error, we still need to make sure the user still exists.
  124. user, err := server.Store.GetUser(ctx, &store.FindUserMessage{
  125. ID: &userID,
  126. })
  127. if err != nil {
  128. return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err)
  129. }
  130. if user == nil {
  131. return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID))
  132. }
  133. if generateToken {
  134. generateTokenFunc := func() error {
  135. rc, err := c.Cookie(auth.RefreshTokenCookieName)
  136. if err != nil {
  137. return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.")
  138. }
  139. // Parses token and checks if it's valid.
  140. refreshTokenClaims := &Claims{}
  141. refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) {
  142. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  143. return nil, errors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256)
  144. }
  145. if kid, ok := t.Header["kid"].(string); ok {
  146. if kid == "v1" {
  147. return []byte(secret), nil
  148. }
  149. }
  150. return nil, errors.Errorf("unexpected refresh token kid=%v", t.Header["kid"])
  151. })
  152. if err != nil {
  153. if err == jwt.ErrSignatureInvalid {
  154. return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.")
  155. }
  156. return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err)
  157. }
  158. if !audienceContains(refreshTokenClaims.Audience, auth.RefreshTokenAudienceName) {
  159. return echo.NewHTTPError(http.StatusUnauthorized,
  160. fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
  161. refreshTokenClaims.Audience,
  162. auth.RefreshTokenAudienceName,
  163. ))
  164. }
  165. // If we have a valid refresh token, we will generate new access token and refresh token
  166. if refreshToken != nil && refreshToken.Valid {
  167. if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil {
  168. return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err)
  169. }
  170. }
  171. return nil
  172. }
  173. // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token
  174. // In such case, we won't return the error.
  175. if err := generateTokenFunc(); err != nil && !accessToken.Valid {
  176. return err
  177. }
  178. }
  179. // Stores userID into context.
  180. c.Set(getUserIDContextKey(), userID)
  181. return next(c)
  182. }
  183. }