idp.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. package server
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "github.com/labstack/echo/v4"
  8. "github.com/usememos/memos/api"
  9. "github.com/usememos/memos/common"
  10. "github.com/usememos/memos/store"
  11. )
  12. func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
  13. g.POST("/idp", func(c echo.Context) error {
  14. ctx := c.Request().Context()
  15. userID, ok := c.Get(getUserIDContextKey()).(int)
  16. if !ok {
  17. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  18. }
  19. user, err := s.Store.FindUser(ctx, &api.UserFind{
  20. ID: &userID,
  21. })
  22. if err != nil {
  23. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  24. }
  25. if user == nil || user.Role != api.Host {
  26. return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
  27. }
  28. identityProviderCreate := &api.IdentityProviderCreate{}
  29. if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil {
  30. return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
  31. }
  32. identityProviderMessage, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
  33. Name: identityProviderCreate.Name,
  34. Type: store.IdentityProviderType(identityProviderCreate.Type),
  35. IdentifierFilter: identityProviderCreate.IdentifierFilter,
  36. Config: convertIdentityProviderConfigToStore(identityProviderCreate.Config),
  37. })
  38. if err != nil {
  39. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
  40. }
  41. return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
  42. })
  43. g.PATCH("/idp/:idpId", func(c echo.Context) error {
  44. ctx := c.Request().Context()
  45. userID, ok := c.Get(getUserIDContextKey()).(int)
  46. if !ok {
  47. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  48. }
  49. user, err := s.Store.FindUser(ctx, &api.UserFind{
  50. ID: &userID,
  51. })
  52. if err != nil {
  53. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  54. }
  55. if user == nil || user.Role != api.Host {
  56. return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
  57. }
  58. identityProviderID, err := strconv.Atoi(c.Param("idpId"))
  59. if err != nil {
  60. return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
  61. }
  62. identityProviderPatch := &api.IdentityProviderPatch{
  63. ID: identityProviderID,
  64. }
  65. if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil {
  66. return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err)
  67. }
  68. identityProviderMessage, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderMessage{
  69. ID: identityProviderPatch.ID,
  70. Type: store.IdentityProviderType(identityProviderPatch.Type),
  71. Name: identityProviderPatch.Name,
  72. IdentifierFilter: identityProviderPatch.IdentifierFilter,
  73. Config: convertIdentityProviderConfigToStore(identityProviderPatch.Config),
  74. })
  75. if err != nil {
  76. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
  77. }
  78. return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
  79. })
  80. g.GET("/idp", func(c echo.Context) error {
  81. ctx := c.Request().Context()
  82. identityProviderMessageList, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
  83. if err != nil {
  84. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err)
  85. }
  86. userID, ok := c.Get(getUserIDContextKey()).(int)
  87. isHostUser := false
  88. if ok {
  89. user, err := s.Store.FindUser(ctx, &api.UserFind{
  90. ID: &userID,
  91. })
  92. if err != nil && common.ErrorCode(err) != common.NotFound {
  93. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  94. }
  95. if user != nil && user.Role == api.Host {
  96. isHostUser = true
  97. }
  98. }
  99. identityProviderList := []*api.IdentityProvider{}
  100. for _, identityProviderMessage := range identityProviderMessageList {
  101. identityProvider := convertIdentityProviderFromStore(identityProviderMessage)
  102. // data desensitize
  103. if !isHostUser {
  104. identityProvider.Config.OAuth2Config.ClientSecret = ""
  105. }
  106. identityProviderList = append(identityProviderList, identityProvider)
  107. }
  108. return c.JSON(http.StatusOK, composeResponse(identityProviderList))
  109. })
  110. g.GET("/idp/:idpId", func(c echo.Context) error {
  111. ctx := c.Request().Context()
  112. userID, ok := c.Get(getUserIDContextKey()).(int)
  113. if !ok {
  114. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  115. }
  116. user, err := s.Store.FindUser(ctx, &api.UserFind{
  117. ID: &userID,
  118. })
  119. if err != nil {
  120. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  121. }
  122. // We should only show identity provider list to host user.
  123. if user == nil || user.Role != api.Host {
  124. return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
  125. }
  126. identityProviderID, err := strconv.Atoi(c.Param("idpId"))
  127. if err != nil {
  128. return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
  129. }
  130. identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
  131. ID: &identityProviderID,
  132. })
  133. if err != nil {
  134. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
  135. }
  136. return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
  137. })
  138. g.DELETE("/idp/:idpId", func(c echo.Context) error {
  139. ctx := c.Request().Context()
  140. userID, ok := c.Get(getUserIDContextKey()).(int)
  141. if !ok {
  142. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  143. }
  144. user, err := s.Store.FindUser(ctx, &api.UserFind{
  145. ID: &userID,
  146. })
  147. if err != nil {
  148. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  149. }
  150. if user == nil || user.Role != api.Host {
  151. return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
  152. }
  153. identityProviderID, err := strconv.Atoi(c.Param("idpId"))
  154. if err != nil {
  155. return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
  156. }
  157. if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{ID: identityProviderID}); err != nil {
  158. if common.ErrorCode(err) == common.NotFound {
  159. return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Identity provider ID not found: %d", identityProviderID))
  160. }
  161. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete identity provider").SetInternal(err)
  162. }
  163. return c.JSON(http.StatusOK, true)
  164. })
  165. }
  166. func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *api.IdentityProvider {
  167. return &api.IdentityProvider{
  168. ID: identityProviderMessage.ID,
  169. Name: identityProviderMessage.Name,
  170. Type: api.IdentityProviderType(identityProviderMessage.Type),
  171. IdentifierFilter: identityProviderMessage.IdentifierFilter,
  172. Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config),
  173. }
  174. }
  175. func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *api.IdentityProviderConfig {
  176. return &api.IdentityProviderConfig{
  177. OAuth2Config: &api.IdentityProviderOAuth2Config{
  178. ClientID: config.OAuth2Config.ClientID,
  179. ClientSecret: config.OAuth2Config.ClientSecret,
  180. AuthURL: config.OAuth2Config.AuthURL,
  181. TokenURL: config.OAuth2Config.TokenURL,
  182. UserInfoURL: config.OAuth2Config.UserInfoURL,
  183. Scopes: config.OAuth2Config.Scopes,
  184. FieldMapping: &api.FieldMapping{
  185. Identifier: config.OAuth2Config.FieldMapping.Identifier,
  186. DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
  187. Email: config.OAuth2Config.FieldMapping.Email,
  188. },
  189. },
  190. }
  191. }
  192. func convertIdentityProviderConfigToStore(config *api.IdentityProviderConfig) *store.IdentityProviderConfig {
  193. return &store.IdentityProviderConfig{
  194. OAuth2Config: &store.IdentityProviderOAuth2Config{
  195. ClientID: config.OAuth2Config.ClientID,
  196. ClientSecret: config.OAuth2Config.ClientSecret,
  197. AuthURL: config.OAuth2Config.AuthURL,
  198. TokenURL: config.OAuth2Config.TokenURL,
  199. UserInfoURL: config.OAuth2Config.UserInfoURL,
  200. Scopes: config.OAuth2Config.Scopes,
  201. FieldMapping: &store.FieldMapping{
  202. Identifier: config.OAuth2Config.FieldMapping.Identifier,
  203. DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
  204. Email: config.OAuth2Config.FieldMapping.Email,
  205. },
  206. },
  207. }
  208. }