123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- package v2
- import (
- "context"
- "net/http"
- "strings"
- "github.com/golang-jwt/jwt/v5"
- "github.com/pkg/errors"
- "google.golang.org/grpc"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/metadata"
- "google.golang.org/grpc/status"
- "github.com/usememos/memos/api/auth"
- "github.com/usememos/memos/internal/util"
- storepb "github.com/usememos/memos/proto/gen/store"
- "github.com/usememos/memos/store"
- )
- // ContextKey is the key type of context value.
- type ContextKey int
- const (
- // The key name used to store username in the context
- // user id is extracted from the jwt token subject field.
- usernameContextKey ContextKey = iota
- )
- // GRPCAuthInterceptor is the auth interceptor for gRPC server.
- type GRPCAuthInterceptor struct {
- Store *store.Store
- secret string
- }
- // NewGRPCAuthInterceptor returns a new API auth interceptor.
- func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
- return &GRPCAuthInterceptor{
- Store: store,
- secret: secret,
- }
- }
- // AuthenticationInterceptor is the unary interceptor for gRPC API.
- func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
- md, ok := metadata.FromIncomingContext(ctx)
- if !ok {
- return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
- }
- accessToken, err := getTokenFromMetadata(md)
- if err != nil {
- return nil, status.Errorf(codes.Unauthenticated, err.Error())
- }
- username, err := in.authenticate(ctx, accessToken)
- if err != nil {
- if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
- return handler(ctx, request)
- }
- return nil, err
- }
- user, err := in.Store.GetUser(ctx, &store.FindUser{
- Username: &username,
- })
- if err != nil {
- return nil, errors.Wrap(err, "failed to get user")
- }
- if user == nil {
- return nil, errors.Errorf("user %q not exists", username)
- }
- if user.RowStatus == store.Archived {
- return nil, errors.Errorf("user %q is archived", username)
- }
- if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
- return nil, errors.Errorf("user %q is not admin", username)
- }
- // Stores userID into context.
- childCtx := context.WithValue(ctx, usernameContextKey, username)
- return handler(childCtx, request)
- }
- func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken string) (string, error) {
- if accessToken == "" {
- return "", status.Errorf(codes.Unauthenticated, "access token not found")
- }
- claims := &auth.ClaimsMessage{}
- _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
- if t.Method.Alg() != jwt.SigningMethodHS256.Name {
- return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
- }
- if kid, ok := t.Header["kid"].(string); ok {
- if kid == "v1" {
- return []byte(in.secret), nil
- }
- }
- return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
- })
- if err != nil {
- return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
- }
- // We either have a valid access token or we will attempt to generate new access token.
- userID, err := util.ConvertStringToInt32(claims.Subject)
- if err != nil {
- return "", errors.Wrap(err, "malformed ID in the token")
- }
- user, err := in.Store.GetUser(ctx, &store.FindUser{
- ID: &userID,
- })
- if err != nil {
- return "", errors.Wrap(err, "failed to get user")
- }
- if user == nil {
- return "", errors.Errorf("user %q not exists", userID)
- }
- if user.RowStatus == store.Archived {
- return "", errors.Errorf("user %q is archived", userID)
- }
- accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
- if err != nil {
- return "", errors.Wrapf(err, "failed to get user access tokens")
- }
- if !validateAccessToken(accessToken, accessTokens) {
- return "", status.Errorf(codes.Unauthenticated, "invalid access token")
- }
- return user.Username, nil
- }
- func getTokenFromMetadata(md metadata.MD) (string, error) {
- // Check the HTTP request header first.
- authorizationHeaders := md.Get("Authorization")
- if len(md.Get("Authorization")) > 0 {
- authHeaderParts := strings.Fields(authorizationHeaders[0])
- if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
- return "", errors.New("authorization header format must be Bearer {token}")
- }
- return authHeaderParts[1], nil
- }
- // Check the cookie header.
- var accessToken string
- for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
- header := http.Header{}
- header.Add("Cookie", t)
- request := http.Request{Header: header}
- if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil {
- accessToken = v.Value
- }
- }
- return accessToken, nil
- }
- func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
- for _, userAccessToken := range userAccessTokens {
- if accessTokenString == userAccessToken.AccessToken {
- return true
- }
- }
- return false
- }
|