auth_service.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. package v1
  2. import (
  3. "context"
  4. "fmt"
  5. "log/slog"
  6. "regexp"
  7. "strings"
  8. "time"
  9. "github.com/pkg/errors"
  10. "golang.org/x/crypto/bcrypt"
  11. "google.golang.org/grpc"
  12. "google.golang.org/grpc/codes"
  13. "google.golang.org/grpc/metadata"
  14. "google.golang.org/grpc/status"
  15. "google.golang.org/protobuf/types/known/emptypb"
  16. "github.com/usememos/memos/internal/util"
  17. "github.com/usememos/memos/plugin/idp"
  18. "github.com/usememos/memos/plugin/idp/oauth2"
  19. v1pb "github.com/usememos/memos/proto/gen/api/v1"
  20. storepb "github.com/usememos/memos/proto/gen/store"
  21. "github.com/usememos/memos/store"
  22. )
  23. func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusRequest) (*v1pb.User, error) {
  24. user, err := s.GetCurrentUser(ctx)
  25. if err != nil {
  26. return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
  27. }
  28. if user == nil {
  29. // Set the cookie header to expire access token.
  30. if err := s.clearAccessTokenCookie(ctx); err != nil {
  31. return nil, status.Errorf(codes.Internal, "failed to set grpc header: %v", err)
  32. }
  33. return nil, status.Errorf(codes.Unauthenticated, "user not found")
  34. }
  35. return convertUserFromStore(user), nil
  36. }
  37. func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) {
  38. user, err := s.Store.GetUser(ctx, &store.FindUser{
  39. Username: &request.Username,
  40. })
  41. if err != nil {
  42. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", request.Username))
  43. }
  44. if user == nil {
  45. return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("user not found with username %s", request.Username))
  46. } else if user.RowStatus == store.Archived {
  47. return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", request.Username))
  48. }
  49. // Compare the stored hashed password, with the hashed version of the password that was received.
  50. if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil {
  51. return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password")
  52. }
  53. expireTime := time.Now().Add(AccessTokenDuration)
  54. if request.NeverExpire {
  55. // Set the expire time to 100 years.
  56. expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
  57. }
  58. if err := s.doSignIn(ctx, user, expireTime); err != nil {
  59. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
  60. }
  61. return convertUserFromStore(user), nil
  62. }
  63. func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWithSSORequest) (*v1pb.User, error) {
  64. identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
  65. ID: &request.IdpId,
  66. })
  67. if err != nil {
  68. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get identity provider, err: %s", err))
  69. }
  70. if identityProvider == nil {
  71. return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("identity provider not found with id %d", request.IdpId))
  72. }
  73. var userInfo *idp.IdentityProviderUserInfo
  74. if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
  75. oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
  76. if err != nil {
  77. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create oauth2 identity provider, err: %s", err))
  78. }
  79. token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code)
  80. if err != nil {
  81. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to exchange token, err: %s", err))
  82. }
  83. userInfo, err = oauth2IdentityProvider.UserInfo(token)
  84. if err != nil {
  85. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get user info, err: %s", err))
  86. }
  87. }
  88. identifierFilter := identityProvider.IdentifierFilter
  89. if identifierFilter != "" {
  90. identifierFilterRegex, err := regexp.Compile(identifierFilter)
  91. if err != nil {
  92. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to compile identifier filter regex, err: %s", err))
  93. }
  94. if !identifierFilterRegex.MatchString(userInfo.Identifier) {
  95. return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("identifier %s is not allowed", userInfo.Identifier))
  96. }
  97. }
  98. user, err := s.Store.GetUser(ctx, &store.FindUser{
  99. Username: &userInfo.Identifier,
  100. })
  101. if err != nil {
  102. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", userInfo.Identifier))
  103. }
  104. if user == nil {
  105. userCreate := &store.User{
  106. Username: userInfo.Identifier,
  107. // The new signup user should be normal user by default.
  108. Role: store.RoleUser,
  109. Nickname: userInfo.DisplayName,
  110. Email: userInfo.Email,
  111. }
  112. password, err := util.RandomString(20)
  113. if err != nil {
  114. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate random password, err: %s", err))
  115. }
  116. passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
  117. if err != nil {
  118. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
  119. }
  120. userCreate.PasswordHash = string(passwordHash)
  121. user, err = s.Store.CreateUser(ctx, userCreate)
  122. if err != nil {
  123. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
  124. }
  125. }
  126. if user.RowStatus == store.Archived {
  127. return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier))
  128. }
  129. if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
  130. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
  131. }
  132. return convertUserFromStore(user), nil
  133. }
  134. func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
  135. accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
  136. if err != nil {
  137. return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err))
  138. }
  139. if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, "user login"); err != nil {
  140. return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err))
  141. }
  142. cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
  143. if err != nil {
  144. return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err))
  145. }
  146. if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
  147. "Set-Cookie": cookie,
  148. })); err != nil {
  149. return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
  150. }
  151. return nil
  152. }
  153. func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) (*v1pb.User, error) {
  154. if !s.Profile.Public {
  155. return nil, status.Errorf(codes.PermissionDenied, "sign up is not allowed")
  156. }
  157. passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
  158. if err != nil {
  159. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
  160. }
  161. create := &store.User{
  162. Username: request.Username,
  163. Nickname: request.Username,
  164. PasswordHash: string(passwordHash),
  165. }
  166. if !util.UIDMatcher.MatchString(strings.ToLower(create.Username)) {
  167. return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", create.Username)
  168. }
  169. hostUserType := store.RoleHost
  170. existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
  171. Role: &hostUserType,
  172. })
  173. if err != nil {
  174. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to list users, err: %s", err))
  175. }
  176. if len(existedHostUsers) == 0 {
  177. // Change the default role to host if there is no host user.
  178. create.Role = store.RoleHost
  179. } else {
  180. create.Role = store.RoleUser
  181. }
  182. user, err := s.Store.CreateUser(ctx, create)
  183. if err != nil {
  184. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
  185. }
  186. if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
  187. return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
  188. }
  189. return convertUserFromStore(user), nil
  190. }
  191. func (s *APIV1Service) SignOut(ctx context.Context, _ *v1pb.SignOutRequest) (*emptypb.Empty, error) {
  192. accessToken, ok := ctx.Value(accessTokenContextKey).(string)
  193. // Try to delete the access token from the store.
  194. if ok {
  195. user, _ := s.GetCurrentUser(ctx)
  196. if user != nil {
  197. if _, err := s.DeleteUserAccessToken(ctx, &v1pb.DeleteUserAccessTokenRequest{
  198. Name: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
  199. AccessToken: accessToken,
  200. }); err != nil {
  201. slog.Error("failed to delete access token", slog.Any("err", err))
  202. }
  203. }
  204. }
  205. if err := s.clearAccessTokenCookie(ctx); err != nil {
  206. return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
  207. }
  208. return &emptypb.Empty{}, nil
  209. }
  210. func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error {
  211. cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
  212. if err != nil {
  213. return errors.Wrap(err, "failed to build access token cookie")
  214. }
  215. if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
  216. "Set-Cookie": cookie,
  217. })); err != nil {
  218. return errors.Wrap(err, "failed to set grpc header")
  219. }
  220. return nil
  221. }
  222. func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
  223. attrs := []string{
  224. fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken),
  225. "Path=/",
  226. "HttpOnly",
  227. }
  228. if expireTime.IsZero() {
  229. attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
  230. } else {
  231. attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
  232. }
  233. md, ok := metadata.FromIncomingContext(ctx)
  234. if !ok {
  235. return "", errors.New("failed to get metadata from context")
  236. }
  237. var origin string
  238. for _, v := range md.Get("origin") {
  239. origin = v
  240. }
  241. isHTTPS := strings.HasPrefix(origin, "https://")
  242. if isHTTPS {
  243. attrs = append(attrs, "SameSite=None")
  244. attrs = append(attrs, "Secure")
  245. } else {
  246. attrs = append(attrs, "SameSite=Strict")
  247. }
  248. return strings.Join(attrs, "; "), nil
  249. }
  250. func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) {
  251. username, ok := ctx.Value(usernameContextKey).(string)
  252. if !ok {
  253. return nil, nil
  254. }
  255. user, err := s.Store.GetUser(ctx, &store.FindUser{
  256. Username: &username,
  257. })
  258. if err != nil {
  259. return nil, err
  260. }
  261. return user, nil
  262. }