auth_service.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. package v1
  2. import (
  3. "context"
  4. "fmt"
  5. "log/slog"
  6. "regexp"
  7. "strings"
  8. "time"
  9. "github.com/pkg/errors"
  10. "golang.org/x/crypto/bcrypt"
  11. "google.golang.org/grpc"
  12. "google.golang.org/grpc/codes"
  13. "google.golang.org/grpc/metadata"
  14. "google.golang.org/grpc/status"
  15. "google.golang.org/protobuf/types/known/emptypb"
  16. "github.com/usememos/memos/internal/util"
  17. "github.com/usememos/memos/plugin/idp"
  18. "github.com/usememos/memos/plugin/idp/oauth2"
  19. v1pb "github.com/usememos/memos/proto/gen/api/v1"
  20. storepb "github.com/usememos/memos/proto/gen/store"
  21. "github.com/usememos/memos/store"
  22. )
  23. const (
  24. unmatchedUsernameAndPasswordError = "unmatched username and password"
  25. )
  26. func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusRequest) (*v1pb.User, error) {
  27. user, err := s.GetCurrentUser(ctx)
  28. if err != nil {
  29. return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
  30. }
  31. if user == nil {
  32. // Set the cookie header to expire access token.
  33. if err := s.clearAccessTokenCookie(ctx); err != nil {
  34. return nil, status.Errorf(codes.Internal, "failed to set grpc header: %v", err)
  35. }
  36. return nil, status.Errorf(codes.Unauthenticated, "user not found")
  37. }
  38. return convertUserFromStore(user), nil
  39. }
  40. func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) {
  41. user, err := s.Store.GetUser(ctx, &store.FindUser{
  42. Username: &request.Username,
  43. })
  44. if err != nil {
  45. return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
  46. }
  47. if user == nil {
  48. return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
  49. }
  50. // Compare the stored hashed password, with the hashed version of the password that was received.
  51. if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil {
  52. return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
  53. }
  54. workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
  55. if err != nil {
  56. return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
  57. }
  58. // Check if the password auth in is allowed.
  59. if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
  60. return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
  61. }
  62. if user.RowStatus == store.Archived {
  63. return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", request.Username)
  64. }
  65. expireTime := time.Now().Add(AccessTokenDuration)
  66. if request.NeverExpire {
  67. // Set the expire time to 100 years.
  68. expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
  69. }
  70. if err := s.doSignIn(ctx, user, expireTime); err != nil {
  71. return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
  72. }
  73. return convertUserFromStore(user), nil
  74. }
  75. func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWithSSORequest) (*v1pb.User, error) {
  76. identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
  77. ID: &request.IdpId,
  78. })
  79. if err != nil {
  80. return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
  81. }
  82. if identityProvider == nil {
  83. return nil, status.Errorf(codes.InvalidArgument, "identity provider not found")
  84. }
  85. var userInfo *idp.IdentityProviderUserInfo
  86. if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
  87. oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
  88. if err != nil {
  89. return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
  90. }
  91. token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code)
  92. if err != nil {
  93. return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
  94. }
  95. userInfo, err = oauth2IdentityProvider.UserInfo(token)
  96. if err != nil {
  97. return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err)
  98. }
  99. }
  100. identifierFilter := identityProvider.IdentifierFilter
  101. if identifierFilter != "" {
  102. identifierFilterRegex, err := regexp.Compile(identifierFilter)
  103. if err != nil {
  104. return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err)
  105. }
  106. if !identifierFilterRegex.MatchString(userInfo.Identifier) {
  107. return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier)
  108. }
  109. }
  110. user, err := s.Store.GetUser(ctx, &store.FindUser{
  111. Username: &userInfo.Identifier,
  112. })
  113. if err != nil {
  114. return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
  115. }
  116. if user == nil {
  117. userCreate := &store.User{
  118. Username: userInfo.Identifier,
  119. // The new signup user should be normal user by default.
  120. Role: store.RoleUser,
  121. Nickname: userInfo.DisplayName,
  122. Email: userInfo.Email,
  123. }
  124. password, err := util.RandomString(20)
  125. if err != nil {
  126. return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err)
  127. }
  128. passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
  129. if err != nil {
  130. return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
  131. }
  132. userCreate.PasswordHash = string(passwordHash)
  133. user, err = s.Store.CreateUser(ctx, userCreate)
  134. if err != nil {
  135. return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
  136. }
  137. }
  138. if user.RowStatus == store.Archived {
  139. return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", userInfo.Identifier)
  140. }
  141. if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
  142. return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
  143. }
  144. return convertUserFromStore(user), nil
  145. }
  146. func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
  147. accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
  148. if err != nil {
  149. return status.Errorf(codes.Internal, "failed to generate access token, error: %v", err)
  150. }
  151. if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, "user login"); err != nil {
  152. return status.Errorf(codes.Internal, "failed to upsert access token to store, error: %v", err)
  153. }
  154. cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
  155. if err != nil {
  156. return status.Errorf(codes.Internal, "failed to build access token cookie, error: %v", err)
  157. }
  158. if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
  159. "Set-Cookie": cookie,
  160. })); err != nil {
  161. return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
  162. }
  163. return nil
  164. }
  165. func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) (*v1pb.User, error) {
  166. workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
  167. if err != nil {
  168. return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
  169. }
  170. if workspaceGeneralSetting.DisallowUserRegistration {
  171. return nil, status.Errorf(codes.PermissionDenied, "sign up is not allowed")
  172. }
  173. passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
  174. if err != nil {
  175. return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
  176. }
  177. create := &store.User{
  178. Username: request.Username,
  179. Nickname: request.Username,
  180. PasswordHash: string(passwordHash),
  181. }
  182. if !util.UIDMatcher.MatchString(strings.ToLower(create.Username)) {
  183. return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", create.Username)
  184. }
  185. hostUserType := store.RoleHost
  186. existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
  187. Role: &hostUserType,
  188. })
  189. if err != nil {
  190. return nil, status.Errorf(codes.Internal, "failed to list host users, error: %v", err)
  191. }
  192. if len(existedHostUsers) == 0 {
  193. // Change the default role to host if there is no host user.
  194. create.Role = store.RoleHost
  195. } else {
  196. create.Role = store.RoleUser
  197. }
  198. user, err := s.Store.CreateUser(ctx, create)
  199. if err != nil {
  200. return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
  201. }
  202. if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
  203. return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
  204. }
  205. return convertUserFromStore(user), nil
  206. }
  207. func (s *APIV1Service) SignOut(ctx context.Context, _ *v1pb.SignOutRequest) (*emptypb.Empty, error) {
  208. accessToken, ok := ctx.Value(accessTokenContextKey).(string)
  209. // Try to delete the access token from the store.
  210. if ok {
  211. user, _ := s.GetCurrentUser(ctx)
  212. if user != nil {
  213. if _, err := s.DeleteUserAccessToken(ctx, &v1pb.DeleteUserAccessTokenRequest{
  214. Name: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
  215. AccessToken: accessToken,
  216. }); err != nil {
  217. slog.Error("failed to delete access token", "error", err)
  218. }
  219. }
  220. }
  221. if err := s.clearAccessTokenCookie(ctx); err != nil {
  222. return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
  223. }
  224. return &emptypb.Empty{}, nil
  225. }
  226. func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error {
  227. cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
  228. if err != nil {
  229. return errors.Wrap(err, "failed to build access token cookie")
  230. }
  231. if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
  232. "Set-Cookie": cookie,
  233. })); err != nil {
  234. return errors.Wrap(err, "failed to set grpc header")
  235. }
  236. return nil
  237. }
  238. func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
  239. attrs := []string{
  240. fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken),
  241. "Path=/",
  242. "HttpOnly",
  243. }
  244. if expireTime.IsZero() {
  245. attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
  246. } else {
  247. attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
  248. }
  249. md, ok := metadata.FromIncomingContext(ctx)
  250. if !ok {
  251. return "", errors.New("failed to get metadata from context")
  252. }
  253. var origin string
  254. for _, v := range md.Get("origin") {
  255. origin = v
  256. }
  257. isHTTPS := strings.HasPrefix(origin, "https://")
  258. if isHTTPS {
  259. attrs = append(attrs, "SameSite=None")
  260. attrs = append(attrs, "Secure")
  261. } else {
  262. attrs = append(attrs, "SameSite=Strict")
  263. }
  264. return strings.Join(attrs, "; "), nil
  265. }
  266. func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) {
  267. username, ok := ctx.Value(usernameContextKey).(string)
  268. if !ok {
  269. return nil, nil
  270. }
  271. user, err := s.Store.GetUser(ctx, &store.FindUser{
  272. Username: &username,
  273. })
  274. if err != nil {
  275. return nil, err
  276. }
  277. return user, nil
  278. }