acl.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package v2
  2. import (
  3. "context"
  4. "net/http"
  5. "strings"
  6. "github.com/golang-jwt/jwt/v5"
  7. "github.com/pkg/errors"
  8. "google.golang.org/grpc"
  9. "google.golang.org/grpc/codes"
  10. "google.golang.org/grpc/metadata"
  11. "google.golang.org/grpc/status"
  12. "github.com/usememos/memos/api/auth"
  13. "github.com/usememos/memos/internal/util"
  14. storepb "github.com/usememos/memos/proto/gen/store"
  15. "github.com/usememos/memos/store"
  16. )
  17. // ContextKey is the key type of context value.
  18. type ContextKey int
  19. const (
  20. // The key name used to store username in the context
  21. // user id is extracted from the jwt token subject field.
  22. usernameContextKey ContextKey = iota
  23. )
  24. // GRPCAuthInterceptor is the auth interceptor for gRPC server.
  25. type GRPCAuthInterceptor struct {
  26. Store *store.Store
  27. secret string
  28. }
  29. // NewGRPCAuthInterceptor returns a new API auth interceptor.
  30. func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
  31. return &GRPCAuthInterceptor{
  32. Store: store,
  33. secret: secret,
  34. }
  35. }
  36. // AuthenticationInterceptor is the unary interceptor for gRPC API.
  37. func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
  38. md, ok := metadata.FromIncomingContext(ctx)
  39. if !ok {
  40. return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
  41. }
  42. accessToken, err := getTokenFromMetadata(md)
  43. if err != nil {
  44. return nil, status.Errorf(codes.Unauthenticated, err.Error())
  45. }
  46. username, err := in.authenticate(ctx, accessToken)
  47. if err != nil {
  48. if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
  49. return handler(ctx, request)
  50. }
  51. return nil, err
  52. }
  53. user, err := in.Store.GetUser(ctx, &store.FindUser{
  54. Username: &username,
  55. })
  56. if err != nil {
  57. return nil, errors.Wrap(err, "failed to get user")
  58. }
  59. if user == nil {
  60. return nil, errors.Errorf("user %q not exists", username)
  61. }
  62. if user.RowStatus == store.Archived {
  63. return nil, errors.Errorf("user %q is archived", username)
  64. }
  65. if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
  66. return nil, errors.Errorf("user %q is not admin", username)
  67. }
  68. // Stores userID into context.
  69. childCtx := context.WithValue(ctx, usernameContextKey, username)
  70. return handler(childCtx, request)
  71. }
  72. func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken string) (string, error) {
  73. if accessToken == "" {
  74. return "", status.Errorf(codes.Unauthenticated, "access token not found")
  75. }
  76. claims := &auth.ClaimsMessage{}
  77. _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
  78. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  79. return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
  80. }
  81. if kid, ok := t.Header["kid"].(string); ok {
  82. if kid == "v1" {
  83. return []byte(in.secret), nil
  84. }
  85. }
  86. return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
  87. })
  88. if err != nil {
  89. return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
  90. }
  91. // We either have a valid access token or we will attempt to generate new access token.
  92. userID, err := util.ConvertStringToInt32(claims.Subject)
  93. if err != nil {
  94. return "", errors.Wrap(err, "malformed ID in the token")
  95. }
  96. user, err := in.Store.GetUser(ctx, &store.FindUser{
  97. ID: &userID,
  98. })
  99. if err != nil {
  100. return "", errors.Wrap(err, "failed to get user")
  101. }
  102. if user == nil {
  103. return "", errors.Errorf("user %q not exists", userID)
  104. }
  105. if user.RowStatus == store.Archived {
  106. return "", errors.Errorf("user %q is archived", userID)
  107. }
  108. accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
  109. if err != nil {
  110. return "", errors.Wrapf(err, "failed to get user access tokens")
  111. }
  112. if !validateAccessToken(accessToken, accessTokens) {
  113. return "", status.Errorf(codes.Unauthenticated, "invalid access token")
  114. }
  115. return user.Username, nil
  116. }
  117. func getTokenFromMetadata(md metadata.MD) (string, error) {
  118. // Check the HTTP request header first.
  119. authorizationHeaders := md.Get("Authorization")
  120. if len(md.Get("Authorization")) > 0 {
  121. authHeaderParts := strings.Fields(authorizationHeaders[0])
  122. if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
  123. return "", errors.New("authorization header format must be Bearer {token}")
  124. }
  125. return authHeaderParts[1], nil
  126. }
  127. // Check the cookie header.
  128. var accessToken string
  129. for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
  130. header := http.Header{}
  131. header.Add("Cookie", t)
  132. request := http.Request{Header: header}
  133. if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil {
  134. accessToken = v.Value
  135. }
  136. }
  137. return accessToken, nil
  138. }
  139. func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
  140. for _, userAccessToken := range userAccessTokens {
  141. if accessTokenString == userAccessToken.AccessToken {
  142. return true
  143. }
  144. }
  145. return false
  146. }