user_service.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. package v2
  2. import (
  3. "context"
  4. "net/http"
  5. "regexp"
  6. "strings"
  7. "time"
  8. "github.com/golang-jwt/jwt/v4"
  9. "github.com/labstack/echo/v4"
  10. "github.com/pkg/errors"
  11. "golang.org/x/crypto/bcrypt"
  12. "golang.org/x/exp/slices"
  13. "google.golang.org/grpc/codes"
  14. "google.golang.org/grpc/status"
  15. "google.golang.org/protobuf/types/known/timestamppb"
  16. "github.com/usememos/memos/api/auth"
  17. apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
  18. storepb "github.com/usememos/memos/proto/gen/store"
  19. "github.com/usememos/memos/store"
  20. )
  21. var (
  22. usernameMatcher = regexp.MustCompile("^[a-z]([a-z0-9-]{2,30}[a-z0-9])?$")
  23. )
  24. type UserService struct {
  25. apiv2pb.UnimplementedUserServiceServer
  26. Store *store.Store
  27. Secret string
  28. }
  29. // NewUserService creates a new UserService.
  30. func NewUserService(store *store.Store, secret string) *UserService {
  31. return &UserService{
  32. Store: store,
  33. Secret: secret,
  34. }
  35. }
  36. func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserRequest) (*apiv2pb.GetUserResponse, error) {
  37. user, err := s.Store.GetUser(ctx, &store.FindUser{
  38. Username: &request.Username,
  39. })
  40. if err != nil {
  41. return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
  42. }
  43. if user == nil {
  44. return nil, status.Errorf(codes.NotFound, "user not found")
  45. }
  46. userMessage := convertUserFromStore(user)
  47. response := &apiv2pb.GetUserResponse{
  48. User: userMessage,
  49. }
  50. return response, nil
  51. }
  52. func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUserRequest) (*apiv2pb.UpdateUserResponse, error) {
  53. currentUser, err := getCurrentUser(ctx, s.Store)
  54. if err != nil {
  55. return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
  56. }
  57. if currentUser.Username != request.Username && currentUser.Role != store.RoleAdmin {
  58. return nil, status.Errorf(codes.PermissionDenied, "permission denied")
  59. }
  60. if request.UpdateMask == nil || len(request.UpdateMask) == 0 {
  61. return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
  62. }
  63. currentTs := time.Now().Unix()
  64. update := &store.UpdateUser{
  65. ID: currentUser.ID,
  66. UpdatedTs: &currentTs,
  67. }
  68. for _, path := range request.UpdateMask {
  69. if path == "username" {
  70. if !usernameMatcher.MatchString(strings.ToLower(request.User.Username)) {
  71. return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
  72. }
  73. update.Username = &request.User.Username
  74. } else if path == "nickname" {
  75. update.Nickname = &request.User.Nickname
  76. } else if path == "email" {
  77. update.Email = &request.User.Email
  78. } else if path == "avatar_url" {
  79. update.AvatarURL = &request.User.AvatarUrl
  80. } else if path == "role" {
  81. role := convertUserRoleToStore(request.User.Role)
  82. update.Role = &role
  83. } else if path == "password" {
  84. passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
  85. if err != nil {
  86. return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
  87. }
  88. passwordHashStr := string(passwordHash)
  89. update.PasswordHash = &passwordHashStr
  90. } else if path == "row_status" {
  91. rowStatus := convertRowStatusToStore(request.User.RowStatus)
  92. update.RowStatus = &rowStatus
  93. } else {
  94. return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", path)
  95. }
  96. }
  97. user, err := s.Store.UpdateUser(ctx, update)
  98. if err != nil {
  99. return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
  100. }
  101. response := &apiv2pb.UpdateUserResponse{
  102. User: convertUserFromStore(user),
  103. }
  104. return response, nil
  105. }
  106. func (s *UserService) ListUserAccessTokens(ctx context.Context, request *apiv2pb.ListUserAccessTokensRequest) (*apiv2pb.ListUserAccessTokensResponse, error) {
  107. user, err := getCurrentUser(ctx, s.Store)
  108. if err != nil {
  109. return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
  110. }
  111. if user == nil || user.Username != request.Username {
  112. return nil, status.Errorf(codes.PermissionDenied, "permission denied")
  113. }
  114. userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
  115. if err != nil {
  116. return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
  117. }
  118. accessTokens := []*apiv2pb.UserAccessToken{}
  119. for _, userAccessToken := range userAccessTokens {
  120. claims := &auth.ClaimsMessage{}
  121. _, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
  122. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  123. return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
  124. }
  125. if kid, ok := t.Header["kid"].(string); ok {
  126. if kid == "v1" {
  127. return []byte(s.Secret), nil
  128. }
  129. }
  130. return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
  131. })
  132. if err != nil {
  133. // If the access token is invalid or expired, just ignore it.
  134. continue
  135. }
  136. userAccessToken := &apiv2pb.UserAccessToken{
  137. AccessToken: userAccessToken.AccessToken,
  138. Description: userAccessToken.Description,
  139. IssuedAt: timestamppb.New(claims.IssuedAt.Time),
  140. }
  141. if claims.ExpiresAt != nil {
  142. userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
  143. }
  144. accessTokens = append(accessTokens, userAccessToken)
  145. }
  146. // Sort by issued time in descending order.
  147. slices.SortFunc(accessTokens, func(i, j *apiv2pb.UserAccessToken) bool {
  148. return i.IssuedAt.Seconds > j.IssuedAt.Seconds
  149. })
  150. response := &apiv2pb.ListUserAccessTokensResponse{
  151. AccessTokens: accessTokens,
  152. }
  153. return response, nil
  154. }
  155. func (s *UserService) CreateUserAccessToken(ctx context.Context, request *apiv2pb.CreateUserAccessTokenRequest) (*apiv2pb.CreateUserAccessTokenResponse, error) {
  156. user, err := getCurrentUser(ctx, s.Store)
  157. if err != nil {
  158. return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
  159. }
  160. accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, request.UserAccessToken.ExpiresAt.AsTime(), s.Secret)
  161. if err != nil {
  162. return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
  163. }
  164. claims := &auth.ClaimsMessage{}
  165. _, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
  166. if t.Method.Alg() != jwt.SigningMethodHS256.Name {
  167. return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
  168. }
  169. if kid, ok := t.Header["kid"].(string); ok {
  170. if kid == "v1" {
  171. return []byte(s.Secret), nil
  172. }
  173. }
  174. return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
  175. })
  176. if err != nil {
  177. return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
  178. }
  179. // Upsert the access token to user setting store.
  180. if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.UserAccessToken.Description); err != nil {
  181. return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
  182. }
  183. userAccessToken := &apiv2pb.UserAccessToken{
  184. AccessToken: accessToken,
  185. Description: request.UserAccessToken.Description,
  186. IssuedAt: timestamppb.New(claims.IssuedAt.Time),
  187. }
  188. if claims.ExpiresAt != nil {
  189. userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
  190. }
  191. response := &apiv2pb.CreateUserAccessTokenResponse{
  192. AccessToken: userAccessToken,
  193. }
  194. return response, nil
  195. }
  196. func (s *UserService) DeleteUserAccessToken(ctx context.Context, request *apiv2pb.DeleteUserAccessTokenRequest) (*apiv2pb.DeleteUserAccessTokenResponse, error) {
  197. user, err := getCurrentUser(ctx, s.Store)
  198. if err != nil {
  199. return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
  200. }
  201. userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
  202. if err != nil {
  203. return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
  204. }
  205. updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
  206. for _, userAccessToken := range userAccessTokens {
  207. if userAccessToken.AccessToken == request.AccessToken {
  208. continue
  209. }
  210. updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
  211. }
  212. if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
  213. UserId: user.ID,
  214. Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
  215. Value: &storepb.UserSetting_AccessTokens{
  216. AccessTokens: &storepb.AccessTokensUserSetting{
  217. AccessTokens: updatedUserAccessTokens,
  218. },
  219. },
  220. }); err != nil {
  221. return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
  222. }
  223. return &apiv2pb.DeleteUserAccessTokenResponse{}, nil
  224. }
  225. func (s *UserService) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
  226. userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
  227. if err != nil {
  228. return errors.Wrap(err, "failed to get user access tokens")
  229. }
  230. userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
  231. AccessToken: accessToken,
  232. Description: description,
  233. }
  234. userAccessTokens = append(userAccessTokens, &userAccessToken)
  235. if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
  236. UserId: user.ID,
  237. Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
  238. Value: &storepb.UserSetting_AccessTokens{
  239. AccessTokens: &storepb.AccessTokensUserSetting{
  240. AccessTokens: userAccessTokens,
  241. },
  242. },
  243. }); err != nil {
  244. return errors.Wrap(err, "failed to upsert user setting")
  245. }
  246. return nil
  247. }
  248. func convertUserFromStore(user *store.User) *apiv2pb.User {
  249. return &apiv2pb.User{
  250. Id: int32(user.ID),
  251. RowStatus: convertRowStatusFromStore(user.RowStatus),
  252. CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
  253. UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
  254. Username: user.Username,
  255. Role: convertUserRoleFromStore(user.Role),
  256. Email: user.Email,
  257. Nickname: user.Nickname,
  258. AvatarUrl: user.AvatarURL,
  259. }
  260. }
  261. func convertUserRoleFromStore(role store.Role) apiv2pb.User_Role {
  262. switch role {
  263. case store.RoleHost:
  264. return apiv2pb.User_HOST
  265. case store.RoleAdmin:
  266. return apiv2pb.User_ADMIN
  267. case store.RoleUser:
  268. return apiv2pb.User_USER
  269. default:
  270. return apiv2pb.User_ROLE_UNSPECIFIED
  271. }
  272. }
  273. func convertUserRoleToStore(role apiv2pb.User_Role) store.Role {
  274. switch role {
  275. case apiv2pb.User_HOST:
  276. return store.RoleHost
  277. case apiv2pb.User_ADMIN:
  278. return store.RoleAdmin
  279. case apiv2pb.User_USER:
  280. return store.RoleUser
  281. default:
  282. return store.RoleUser
  283. }
  284. }