auth.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. package v1
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "regexp"
  7. "github.com/labstack/echo/v4"
  8. "github.com/pkg/errors"
  9. "github.com/usememos/memos/common/util"
  10. "github.com/usememos/memos/plugin/idp"
  11. "github.com/usememos/memos/plugin/idp/oauth2"
  12. "github.com/usememos/memos/store"
  13. "golang.org/x/crypto/bcrypt"
  14. )
  15. type SignIn struct {
  16. Username string `json:"username"`
  17. Password string `json:"password"`
  18. }
  19. type SSOSignIn struct {
  20. IdentityProviderID int32 `json:"identityProviderId"`
  21. Code string `json:"code"`
  22. RedirectURI string `json:"redirectUri"`
  23. }
  24. type SignUp struct {
  25. Username string `json:"username"`
  26. Password string `json:"password"`
  27. }
  28. func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
  29. // POST /auth/signin - Sign in.
  30. g.POST("/auth/signin", func(c echo.Context) error {
  31. ctx := c.Request().Context()
  32. signin := &SignIn{}
  33. disablePasswordLoginSystemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
  34. Name: SystemSettingDisablePasswordLoginName.String(),
  35. })
  36. if err != nil {
  37. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
  38. }
  39. if disablePasswordLoginSystemSetting != nil {
  40. disablePasswordLogin := false
  41. err = json.Unmarshal([]byte(disablePasswordLoginSystemSetting.Value), &disablePasswordLogin)
  42. if err != nil {
  43. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting").SetInternal(err)
  44. }
  45. if disablePasswordLogin {
  46. return echo.NewHTTPError(http.StatusUnauthorized, "Password login is deactivated")
  47. }
  48. }
  49. if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
  50. return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
  51. }
  52. user, err := s.Store.GetUser(ctx, &store.FindUser{
  53. Username: &signin.Username,
  54. })
  55. if err != nil {
  56. return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
  57. }
  58. if user == nil {
  59. return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again")
  60. } else if user.RowStatus == store.Archived {
  61. return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", signin.Username))
  62. }
  63. // Compare the stored hashed password, with the hashed version of the password that was received.
  64. if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(signin.Password)); err != nil {
  65. // If the two passwords don't match, return a 401 status.
  66. return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again")
  67. }
  68. if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
  69. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
  70. }
  71. if err := s.createAuthSignInActivity(c, user); err != nil {
  72. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
  73. }
  74. userMessage := convertUserFromStore(user)
  75. return c.JSON(http.StatusOK, userMessage)
  76. })
  77. // POST /auth/signin/sso - Sign in with SSO
  78. g.POST("/auth/signin/sso", func(c echo.Context) error {
  79. ctx := c.Request().Context()
  80. signin := &SSOSignIn{}
  81. if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
  82. return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
  83. }
  84. identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
  85. ID: &signin.IdentityProviderID,
  86. })
  87. if err != nil {
  88. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider").SetInternal(err)
  89. }
  90. if identityProvider == nil {
  91. return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
  92. }
  93. var userInfo *idp.IdentityProviderUserInfo
  94. if identityProvider.Type == store.IdentityProviderOAuth2Type {
  95. oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config)
  96. if err != nil {
  97. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
  98. }
  99. token, err := oauth2IdentityProvider.ExchangeToken(ctx, signin.RedirectURI, signin.Code)
  100. if err != nil {
  101. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to exchange token").SetInternal(err)
  102. }
  103. userInfo, err = oauth2IdentityProvider.UserInfo(token)
  104. if err != nil {
  105. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user info").SetInternal(err)
  106. }
  107. }
  108. identifierFilter := identityProvider.IdentifierFilter
  109. if identifierFilter != "" {
  110. identifierFilterRegex, err := regexp.Compile(identifierFilter)
  111. if err != nil {
  112. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compile identifier filter").SetInternal(err)
  113. }
  114. if !identifierFilterRegex.MatchString(userInfo.Identifier) {
  115. return echo.NewHTTPError(http.StatusUnauthorized, "Access denied, identifier does not match the filter.").SetInternal(err)
  116. }
  117. }
  118. user, err := s.Store.GetUser(ctx, &store.FindUser{
  119. Username: &userInfo.Identifier,
  120. })
  121. if err != nil {
  122. return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
  123. }
  124. if user == nil {
  125. userCreate := &store.User{
  126. Username: userInfo.Identifier,
  127. // The new signup user should be normal user by default.
  128. Role: store.RoleUser,
  129. Nickname: userInfo.DisplayName,
  130. Email: userInfo.Email,
  131. OpenID: util.GenUUID(),
  132. }
  133. password, err := util.RandomString(20)
  134. if err != nil {
  135. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate random password").SetInternal(err)
  136. }
  137. passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
  138. if err != nil {
  139. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
  140. }
  141. userCreate.PasswordHash = string(passwordHash)
  142. user, err = s.Store.CreateUser(ctx, userCreate)
  143. if err != nil {
  144. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
  145. }
  146. }
  147. if user.RowStatus == store.Archived {
  148. return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier))
  149. }
  150. if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
  151. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
  152. }
  153. if err := s.createAuthSignInActivity(c, user); err != nil {
  154. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
  155. }
  156. userMessage := convertUserFromStore(user)
  157. return c.JSON(http.StatusOK, userMessage)
  158. })
  159. // POST /auth/signup - Sign up a new user.
  160. g.POST("/auth/signup", func(c echo.Context) error {
  161. ctx := c.Request().Context()
  162. signup := &SignUp{}
  163. if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil {
  164. return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err)
  165. }
  166. hostUserType := store.RoleHost
  167. existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
  168. Role: &hostUserType,
  169. })
  170. if err != nil {
  171. return echo.NewHTTPError(http.StatusBadRequest, "Failed to find users").SetInternal(err)
  172. }
  173. userCreate := &store.User{
  174. Username: signup.Username,
  175. // The new signup user should be normal user by default.
  176. Role: store.RoleUser,
  177. Nickname: signup.Username,
  178. OpenID: util.GenUUID(),
  179. }
  180. if len(existedHostUsers) == 0 {
  181. // Change the default role to host if there is no host user.
  182. userCreate.Role = store.RoleHost
  183. } else {
  184. allowSignUpSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
  185. Name: SystemSettingAllowSignUpName.String(),
  186. })
  187. if err != nil {
  188. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
  189. }
  190. allowSignUpSettingValue := false
  191. if allowSignUpSetting != nil {
  192. err = json.Unmarshal([]byte(allowSignUpSetting.Value), &allowSignUpSettingValue)
  193. if err != nil {
  194. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting allow signup").SetInternal(err)
  195. }
  196. }
  197. if !allowSignUpSettingValue {
  198. return echo.NewHTTPError(http.StatusUnauthorized, "signup is disabled").SetInternal(err)
  199. }
  200. }
  201. passwordHash, err := bcrypt.GenerateFromPassword([]byte(signup.Password), bcrypt.DefaultCost)
  202. if err != nil {
  203. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
  204. }
  205. userCreate.PasswordHash = string(passwordHash)
  206. user, err := s.Store.CreateUser(ctx, userCreate)
  207. if err != nil {
  208. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
  209. }
  210. if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
  211. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
  212. }
  213. if err := s.createAuthSignUpActivity(c, user); err != nil {
  214. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
  215. }
  216. userMessage := convertUserFromStore(user)
  217. return c.JSON(http.StatusOK, userMessage)
  218. })
  219. // POST /auth/signout - Sign out.
  220. g.POST("/auth/signout", func(c echo.Context) error {
  221. RemoveTokensAndCookies(c)
  222. return c.JSON(http.StatusOK, true)
  223. })
  224. }
  225. func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User) error {
  226. ctx := c.Request().Context()
  227. payload := ActivityUserAuthSignInPayload{
  228. UserID: user.ID,
  229. IP: echo.ExtractIPFromRealIPHeader()(c.Request()),
  230. }
  231. payloadBytes, err := json.Marshal(payload)
  232. if err != nil {
  233. return errors.Wrap(err, "failed to marshal activity payload")
  234. }
  235. activity, err := s.Store.CreateActivity(ctx, &store.Activity{
  236. CreatorID: user.ID,
  237. Type: string(ActivityUserAuthSignIn),
  238. Level: string(ActivityInfo),
  239. Payload: string(payloadBytes),
  240. })
  241. if err != nil || activity == nil {
  242. return errors.Wrap(err, "failed to create activity")
  243. }
  244. return err
  245. }
  246. func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User) error {
  247. ctx := c.Request().Context()
  248. payload := ActivityUserAuthSignUpPayload{
  249. Username: user.Username,
  250. IP: echo.ExtractIPFromRealIPHeader()(c.Request()),
  251. }
  252. payloadBytes, err := json.Marshal(payload)
  253. if err != nil {
  254. return errors.Wrap(err, "failed to marshal activity payload")
  255. }
  256. activity, err := s.Store.CreateActivity(ctx, &store.Activity{
  257. CreatorID: user.ID,
  258. Type: string(ActivityUserAuthSignUp),
  259. Level: string(ActivityInfo),
  260. Payload: string(payloadBytes),
  261. })
  262. if err != nil || activity == nil {
  263. return errors.Wrap(err, "failed to create activity")
  264. }
  265. return err
  266. }