user_service.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. package v2
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "regexp"
  7. "strings"
  8. "time"
  9. "github.com/golang-jwt/jwt/v4"
  10. "github.com/labstack/echo/v4"
  11. "github.com/pkg/errors"
  12. "golang.org/x/crypto/bcrypt"
  13. "golang.org/x/exp/slices"
  14. "google.golang.org/grpc/codes"
  15. "google.golang.org/grpc/status"
  16. "google.golang.org/protobuf/types/known/timestamppb"
  17. "github.com/usememos/memos/api/auth"
  18. apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
  19. storepb "github.com/usememos/memos/proto/gen/store"
  20. "github.com/usememos/memos/store"
  21. )
  22. var (
  23. usernameMatcher = regexp.MustCompile("^[a-z0-9]([a-z0-9-]{1,30}[a-z0-9])$")
  24. )
  25. func (s *APIV2Service) GetUser(ctx context.Context, request *apiv2pb.GetUserRequest) (*apiv2pb.GetUserResponse, error) {
  26. username, err := ExtractUsernameFromName(request.Name)
  27. if err != nil {
  28. return nil, status.Errorf(codes.InvalidArgument, "name is required")
  29. }
  30. user, err := s.Store.GetUser(ctx, &store.FindUser{
  31. Username: &username,
  32. })
  33. if err != nil {
  34. return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
  35. }
  36. if user == nil {
  37. return nil, status.Errorf(codes.NotFound, "user not found")
  38. }
  39. userMessage := convertUserFromStore(user)
  40. response := &apiv2pb.GetUserResponse{
  41. User: userMessage,
  42. }
  43. return response, nil
  44. }
  45. func (s *APIV2Service) CreateUser(ctx context.Context, request *apiv2pb.CreateUserRequest) (*apiv2pb.CreateUserResponse, error) {
  46. currentUser, err := getCurrentUser(ctx, s.Store)
  47. if err != nil {
  48. return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
  49. }
  50. if currentUser.Role != store.RoleHost {
  51. return nil, status.Errorf(codes.PermissionDenied, "permission denied")
  52. }
  53. username, err := ExtractUsernameFromName(request.User.Name)
  54. if err != nil {
  55. return nil, status.Errorf(codes.InvalidArgument, "name is required")
  56. }
  57. if !usernameMatcher.MatchString(strings.ToLower(username)) {
  58. return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", username)
  59. }
  60. passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
  61. if err != nil {
  62. return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
  63. }
  64. user, err := s.Store.CreateUser(ctx, &store.User{
  65. Username: username,
  66. Role: convertUserRoleToStore(request.User.Role),
  67. Email: request.User.Email,
  68. Nickname: request.User.Nickname,
  69. PasswordHash: string(passwordHash),
  70. })
  71. if err != nil {
  72. return nil, status.Errorf(codes.Internal, "failed to create user: %v", err)
  73. }
  74. response := &apiv2pb.CreateUserResponse{
  75. User: convertUserFromStore(user),
  76. }
  77. return response, nil
  78. }
  79. func (s *APIV2Service) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUserRequest) (*apiv2pb.UpdateUserResponse, error) {
  80. username, err := ExtractUsernameFromName(request.User.Name)
  81. if err != nil {
  82. return nil, status.Errorf(codes.InvalidArgument, "name is required")
  83. }
  84. currentUser, err := getCurrentUser(ctx, s.Store)
  85. if err != nil {
  86. return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
  87. }
  88. if currentUser.Username != username && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
  89. return nil, status.Errorf(codes.PermissionDenied, "permission denied")
  90. }
  91. if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
  92. return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
  93. }
  94. user, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
  95. if err != nil {
  96. return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
  97. }
  98. if user == nil {
  99. return nil, status.Errorf(codes.NotFound, "user not found")
  100. }
  101. currentTs := time.Now().Unix()
  102. update := &store.UpdateUser{
  103. ID: user.ID,
  104. UpdatedTs: &currentTs,
  105. }
  106. for _, field := range request.UpdateMask.Paths {
  107. if field == "username" {
  108. if !usernameMatcher.MatchString(strings.ToLower(username)) {
  109. return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", username)
  110. }
  111. update.Username = &username
  112. } else if field == "nickname" {
  113. update.Nickname = &request.User.Nickname
  114. } else if field == "email" {
  115. update.Email = &request.User.Email
  116. } else if field == "avatar_url" {
  117. update.AvatarURL = &request.User.AvatarUrl
  118. } else if field == "role" {
  119. role := convertUserRoleToStore(request.User.Role)
  120. update.Role = &role
  121. } else if field == "password" {
  122. passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
  123. if err != nil {
  124. return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
  125. }
  126. passwordHashStr := string(passwordHash)
  127. update.PasswordHash = &passwordHashStr
  128. } else if field == "row_status" {
  129. rowStatus := convertRowStatusToStore(request.User.RowStatus)
  130. update.RowStatus = &rowStatus
  131. } else {
  132. return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
  133. }
  134. }
  135. updatedUser, err := s.Store.UpdateUser(ctx, update)
  136. if err != nil {
  137. return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
  138. }
  139. response := &apiv2pb.UpdateUserResponse{
  140. User: convertUserFromStore(updatedUser),
  141. }
  142. return response, nil
  143. }
  144. func (s *APIV2Service) ListUserAccessTokens(ctx context.Context, request *apiv2pb.ListUserAccessTokensRequest) (*apiv2pb.ListUserAccessTokensResponse, error) {
  145. user, err := getCurrentUser(ctx, s.Store)
  146. if err != nil {
  147. return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
  148. }
  149. if user == nil {
  150. return nil, status.Errorf(codes.PermissionDenied, "permission denied")
  151. }
  152. userID := user.ID
  153. username, err := ExtractUsernameFromName(request.Name)
  154. if err != nil {
  155. return nil, status.Errorf(codes.InvalidArgument, "name is required")
  156. }
  157. // List access token for other users need to be verified.
  158. if user.Username != username {
  159. // Normal users can only list their access tokens.
  160. if user.Role == store.RoleUser {
  161. return nil, status.Errorf(codes.PermissionDenied, "permission denied")
  162. }
  163. // The request user must be exist.
  164. requestUser, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
  165. if requestUser == nil || err != nil {
  166. return nil, status.Errorf(codes.NotFound, "fail to find user %s", username)
  167. }
  168. userID = requestUser.ID
  169. }
  170. userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, userID)
  171. if err != nil {
  172. return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
  173. }
  174. accessTokens := []*apiv2pb.UserAccessToken{}
  175. for _, userAccessToken := range userAccessTokens {
  176. claims := &auth.ClaimsMessage{}
  177. _, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
  178. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  179. return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
  180. }
  181. if kid, ok := t.Header["kid"].(string); ok {
  182. if kid == "v1" {
  183. return []byte(s.Secret), nil
  184. }
  185. }
  186. return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
  187. })
  188. if err != nil {
  189. // If the access token is invalid or expired, just ignore it.
  190. continue
  191. }
  192. userAccessToken := &apiv2pb.UserAccessToken{
  193. AccessToken: userAccessToken.AccessToken,
  194. Description: userAccessToken.Description,
  195. IssuedAt: timestamppb.New(claims.IssuedAt.Time),
  196. }
  197. if claims.ExpiresAt != nil {
  198. userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
  199. }
  200. accessTokens = append(accessTokens, userAccessToken)
  201. }
  202. // Sort by issued time in descending order.
  203. slices.SortFunc(accessTokens, func(i, j *apiv2pb.UserAccessToken) int {
  204. return int(i.IssuedAt.Seconds - j.IssuedAt.Seconds)
  205. })
  206. response := &apiv2pb.ListUserAccessTokensResponse{
  207. AccessTokens: accessTokens,
  208. }
  209. return response, nil
  210. }
  211. func (s *APIV2Service) CreateUserAccessToken(ctx context.Context, request *apiv2pb.CreateUserAccessTokenRequest) (*apiv2pb.CreateUserAccessTokenResponse, error) {
  212. user, err := getCurrentUser(ctx, s.Store)
  213. if err != nil {
  214. return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
  215. }
  216. expiresAt := time.Time{}
  217. if request.ExpiresAt != nil {
  218. expiresAt = request.ExpiresAt.AsTime()
  219. }
  220. accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret))
  221. if err != nil {
  222. return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
  223. }
  224. claims := &auth.ClaimsMessage{}
  225. _, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
  226. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  227. return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
  228. }
  229. if kid, ok := t.Header["kid"].(string); ok {
  230. if kid == "v1" {
  231. return []byte(s.Secret), nil
  232. }
  233. }
  234. return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
  235. })
  236. if err != nil {
  237. return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
  238. }
  239. // Upsert the access token to user setting store.
  240. if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.Description); err != nil {
  241. return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
  242. }
  243. userAccessToken := &apiv2pb.UserAccessToken{
  244. AccessToken: accessToken,
  245. Description: request.Description,
  246. IssuedAt: timestamppb.New(claims.IssuedAt.Time),
  247. }
  248. if claims.ExpiresAt != nil {
  249. userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
  250. }
  251. response := &apiv2pb.CreateUserAccessTokenResponse{
  252. AccessToken: userAccessToken,
  253. }
  254. return response, nil
  255. }
  256. func (s *APIV2Service) DeleteUserAccessToken(ctx context.Context, request *apiv2pb.DeleteUserAccessTokenRequest) (*apiv2pb.DeleteUserAccessTokenResponse, error) {
  257. user, err := getCurrentUser(ctx, s.Store)
  258. if err != nil {
  259. return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
  260. }
  261. userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
  262. if err != nil {
  263. return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
  264. }
  265. updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
  266. for _, userAccessToken := range userAccessTokens {
  267. if userAccessToken.AccessToken == request.AccessToken {
  268. continue
  269. }
  270. updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
  271. }
  272. if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
  273. UserId: user.ID,
  274. Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
  275. Value: &storepb.UserSetting_AccessTokens{
  276. AccessTokens: &storepb.AccessTokensUserSetting{
  277. AccessTokens: updatedUserAccessTokens,
  278. },
  279. },
  280. }); err != nil {
  281. return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
  282. }
  283. return &apiv2pb.DeleteUserAccessTokenResponse{}, nil
  284. }
  285. func (s *APIV2Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
  286. userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
  287. if err != nil {
  288. return errors.Wrap(err, "failed to get user access tokens")
  289. }
  290. userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
  291. AccessToken: accessToken,
  292. Description: description,
  293. }
  294. userAccessTokens = append(userAccessTokens, &userAccessToken)
  295. if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
  296. UserId: user.ID,
  297. Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
  298. Value: &storepb.UserSetting_AccessTokens{
  299. AccessTokens: &storepb.AccessTokensUserSetting{
  300. AccessTokens: userAccessTokens,
  301. },
  302. },
  303. }); err != nil {
  304. return errors.Wrap(err, "failed to upsert user setting")
  305. }
  306. return nil
  307. }
  308. func convertUserFromStore(user *store.User) *apiv2pb.User {
  309. return &apiv2pb.User{
  310. Name: fmt.Sprintf("%s%s", UserNamePrefix, user.Username),
  311. Id: user.ID,
  312. RowStatus: convertRowStatusFromStore(user.RowStatus),
  313. CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
  314. UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
  315. Role: convertUserRoleFromStore(user.Role),
  316. Email: user.Email,
  317. Nickname: user.Nickname,
  318. AvatarUrl: user.AvatarURL,
  319. }
  320. }
  321. func convertUserRoleFromStore(role store.Role) apiv2pb.User_Role {
  322. switch role {
  323. case store.RoleHost:
  324. return apiv2pb.User_HOST
  325. case store.RoleAdmin:
  326. return apiv2pb.User_ADMIN
  327. case store.RoleUser:
  328. return apiv2pb.User_USER
  329. default:
  330. return apiv2pb.User_ROLE_UNSPECIFIED
  331. }
  332. }
  333. func convertUserRoleToStore(role apiv2pb.User_Role) store.Role {
  334. switch role {
  335. case apiv2pb.User_HOST:
  336. return store.RoleHost
  337. case apiv2pb.User_ADMIN:
  338. return store.RoleAdmin
  339. case apiv2pb.User_USER:
  340. return store.RoleUser
  341. default:
  342. return store.RoleUser
  343. }
  344. }