acl.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package v2
  2. import (
  3. "context"
  4. "net/http"
  5. "strings"
  6. "github.com/golang-jwt/jwt/v4"
  7. "github.com/pkg/errors"
  8. "google.golang.org/grpc"
  9. "google.golang.org/grpc/codes"
  10. "google.golang.org/grpc/metadata"
  11. "google.golang.org/grpc/status"
  12. "github.com/usememos/memos/api/auth"
  13. "github.com/usememos/memos/common/util"
  14. storepb "github.com/usememos/memos/proto/gen/store"
  15. "github.com/usememos/memos/store"
  16. )
  17. // ContextKey is the key type of context value.
  18. type ContextKey int
  19. const (
  20. // The key name used to store username in the context
  21. // user id is extracted from the jwt token subject field.
  22. usernameContextKey ContextKey = iota
  23. )
  24. // GRPCAuthInterceptor is the auth interceptor for gRPC server.
  25. type GRPCAuthInterceptor struct {
  26. Store *store.Store
  27. secret string
  28. }
  29. // NewGRPCAuthInterceptor returns a new API auth interceptor.
  30. func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
  31. return &GRPCAuthInterceptor{
  32. Store: store,
  33. secret: secret,
  34. }
  35. }
  36. // AuthenticationInterceptor is the unary interceptor for gRPC API.
  37. func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
  38. md, ok := metadata.FromIncomingContext(ctx)
  39. if !ok {
  40. return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
  41. }
  42. accessToken, err := getTokenFromMetadata(md)
  43. if err != nil {
  44. return nil, status.Errorf(codes.Unauthenticated, err.Error())
  45. }
  46. username, err := in.authenticate(ctx, accessToken)
  47. if err != nil {
  48. if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
  49. return handler(ctx, request)
  50. }
  51. return nil, err
  52. }
  53. user, err := in.Store.GetUser(ctx, &store.FindUser{
  54. Username: &username,
  55. })
  56. if err != nil {
  57. return nil, errors.Wrap(err, "failed to get user")
  58. }
  59. if user == nil {
  60. return nil, errors.Errorf("user %q not exists", username)
  61. }
  62. if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
  63. return nil, errors.Errorf("user %q is not admin", username)
  64. }
  65. // Stores userID into context.
  66. childCtx := context.WithValue(ctx, usernameContextKey, username)
  67. return handler(childCtx, request)
  68. }
  69. func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken string) (string, error) {
  70. if accessToken == "" {
  71. return "", status.Errorf(codes.Unauthenticated, "access token not found")
  72. }
  73. claims := &auth.ClaimsMessage{}
  74. _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
  75. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  76. return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
  77. }
  78. if kid, ok := t.Header["kid"].(string); ok {
  79. if kid == "v1" {
  80. return []byte(in.secret), nil
  81. }
  82. }
  83. return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
  84. })
  85. if err != nil {
  86. return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
  87. }
  88. if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
  89. return "", status.Errorf(codes.Unauthenticated,
  90. "invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
  91. claims.Audience,
  92. auth.AccessTokenAudienceName,
  93. )
  94. }
  95. // We either have a valid access token or we will attempt to generate new access token.
  96. userID, err := util.ConvertStringToInt32(claims.Subject)
  97. if err != nil {
  98. return "", errors.Wrap(err, "malformed ID in the token")
  99. }
  100. user, err := in.Store.GetUser(ctx, &store.FindUser{
  101. ID: &userID,
  102. })
  103. if err != nil {
  104. return "", errors.Wrap(err, "failed to get user")
  105. }
  106. if user == nil {
  107. return "", errors.Errorf("user %q not exists", userID)
  108. }
  109. if user.RowStatus == store.Archived {
  110. return "", errors.Errorf("user %q is archived", userID)
  111. }
  112. accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
  113. if err != nil {
  114. return "", errors.Wrapf(err, "failed to get user access tokens")
  115. }
  116. if !validateAccessToken(accessToken, accessTokens) {
  117. return "", status.Errorf(codes.Unauthenticated, "invalid access token")
  118. }
  119. return user.Username, nil
  120. }
  121. func getTokenFromMetadata(md metadata.MD) (string, error) {
  122. authorizationHeaders := md.Get("Authorization")
  123. if len(md.Get("Authorization")) > 0 {
  124. authHeaderParts := strings.Fields(authorizationHeaders[0])
  125. if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
  126. return "", errors.Errorf("authorization header format must be Bearer {token}")
  127. }
  128. return authHeaderParts[1], nil
  129. }
  130. // check the HTTP cookie
  131. var accessToken string
  132. for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
  133. header := http.Header{}
  134. header.Add("Cookie", t)
  135. request := http.Request{Header: header}
  136. if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil {
  137. accessToken = v.Value
  138. }
  139. }
  140. return accessToken, nil
  141. }
  142. func audienceContains(audience jwt.ClaimStrings, token string) bool {
  143. for _, v := range audience {
  144. if v == token {
  145. return true
  146. }
  147. }
  148. return false
  149. }
  150. func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
  151. for _, userAccessToken := range userAccessTokens {
  152. if accessTokenString == userAccessToken.AccessToken {
  153. return true
  154. }
  155. }
  156. return false
  157. }