idp.go 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. package v1
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "github.com/labstack/echo/v4"
  8. "github.com/usememos/memos/store"
  9. )
  10. type IdentityProviderType string
  11. const (
  12. IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
  13. )
  14. func (t IdentityProviderType) String() string {
  15. return string(t)
  16. }
  17. type IdentityProviderConfig struct {
  18. OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
  19. }
  20. type IdentityProviderOAuth2Config struct {
  21. ClientID string `json:"clientId"`
  22. ClientSecret string `json:"clientSecret"`
  23. AuthURL string `json:"authUrl"`
  24. TokenURL string `json:"tokenUrl"`
  25. UserInfoURL string `json:"userInfoUrl"`
  26. Scopes []string `json:"scopes"`
  27. FieldMapping *FieldMapping `json:"fieldMapping"`
  28. }
  29. type FieldMapping struct {
  30. Identifier string `json:"identifier"`
  31. DisplayName string `json:"displayName"`
  32. Email string `json:"email"`
  33. }
  34. type IdentityProvider struct {
  35. ID int `json:"id"`
  36. Name string `json:"name"`
  37. Type IdentityProviderType `json:"type"`
  38. IdentifierFilter string `json:"identifierFilter"`
  39. Config *IdentityProviderConfig `json:"config"`
  40. }
  41. type CreateIdentityProviderRequest struct {
  42. Name string `json:"name"`
  43. Type IdentityProviderType `json:"type"`
  44. IdentifierFilter string `json:"identifierFilter"`
  45. Config *IdentityProviderConfig `json:"config"`
  46. }
  47. type UpdateIdentityProviderRequest struct {
  48. ID int `json:"-"`
  49. Type IdentityProviderType `json:"type"`
  50. Name *string `json:"name"`
  51. IdentifierFilter *string `json:"identifierFilter"`
  52. Config *IdentityProviderConfig `json:"config"`
  53. }
  54. func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
  55. g.POST("/idp", func(c echo.Context) error {
  56. ctx := c.Request().Context()
  57. userID, ok := c.Get(getUserIDContextKey()).(int)
  58. if !ok {
  59. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  60. }
  61. user, err := s.Store.GetUser(ctx, &store.FindUser{
  62. ID: &userID,
  63. })
  64. if err != nil {
  65. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  66. }
  67. if user == nil || user.Role != store.RoleHost {
  68. return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
  69. }
  70. identityProviderCreate := &CreateIdentityProviderRequest{}
  71. if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil {
  72. return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
  73. }
  74. identityProvider, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProvider{
  75. Name: identityProviderCreate.Name,
  76. Type: store.IdentityProviderType(identityProviderCreate.Type),
  77. IdentifierFilter: identityProviderCreate.IdentifierFilter,
  78. Config: convertIdentityProviderConfigToStore(identityProviderCreate.Config),
  79. })
  80. if err != nil {
  81. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
  82. }
  83. return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
  84. })
  85. g.PATCH("/idp/:idpId", func(c echo.Context) error {
  86. ctx := c.Request().Context()
  87. userID, ok := c.Get(getUserIDContextKey()).(int)
  88. if !ok {
  89. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  90. }
  91. user, err := s.Store.GetUser(ctx, &store.FindUser{
  92. ID: &userID,
  93. })
  94. if err != nil {
  95. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  96. }
  97. if user == nil || user.Role != store.RoleHost {
  98. return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
  99. }
  100. identityProviderID, err := strconv.Atoi(c.Param("idpId"))
  101. if err != nil {
  102. return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
  103. }
  104. identityProviderPatch := &UpdateIdentityProviderRequest{
  105. ID: identityProviderID,
  106. }
  107. if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil {
  108. return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err)
  109. }
  110. identityProvider, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{
  111. ID: identityProviderPatch.ID,
  112. Type: store.IdentityProviderType(identityProviderPatch.Type),
  113. Name: identityProviderPatch.Name,
  114. IdentifierFilter: identityProviderPatch.IdentifierFilter,
  115. Config: convertIdentityProviderConfigToStore(identityProviderPatch.Config),
  116. })
  117. if err != nil {
  118. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
  119. }
  120. return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
  121. })
  122. g.GET("/idp", func(c echo.Context) error {
  123. ctx := c.Request().Context()
  124. list, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
  125. if err != nil {
  126. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err)
  127. }
  128. userID, ok := c.Get(getUserIDContextKey()).(int)
  129. isHostUser := false
  130. if ok {
  131. user, err := s.Store.GetUser(ctx, &store.FindUser{
  132. ID: &userID,
  133. })
  134. if err != nil {
  135. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  136. }
  137. if user == nil || user.Role == store.RoleHost {
  138. isHostUser = true
  139. }
  140. }
  141. identityProviderList := []*IdentityProvider{}
  142. for _, item := range list {
  143. identityProvider := convertIdentityProviderFromStore(item)
  144. // data desensitize
  145. if !isHostUser {
  146. identityProvider.Config.OAuth2Config.ClientSecret = ""
  147. }
  148. identityProviderList = append(identityProviderList, identityProvider)
  149. }
  150. return c.JSON(http.StatusOK, identityProviderList)
  151. })
  152. g.GET("/idp/:idpId", func(c echo.Context) error {
  153. ctx := c.Request().Context()
  154. userID, ok := c.Get(getUserIDContextKey()).(int)
  155. if !ok {
  156. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  157. }
  158. user, err := s.Store.GetUser(ctx, &store.FindUser{
  159. ID: &userID,
  160. })
  161. if err != nil {
  162. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  163. }
  164. if user == nil || user.Role != store.RoleHost {
  165. return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
  166. }
  167. identityProviderID, err := strconv.Atoi(c.Param("idpId"))
  168. if err != nil {
  169. return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
  170. }
  171. identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
  172. ID: &identityProviderID,
  173. })
  174. if err != nil {
  175. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
  176. }
  177. if identityProvider == nil {
  178. return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
  179. }
  180. return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
  181. })
  182. g.DELETE("/idp/:idpId", func(c echo.Context) error {
  183. ctx := c.Request().Context()
  184. userID, ok := c.Get(getUserIDContextKey()).(int)
  185. if !ok {
  186. return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
  187. }
  188. user, err := s.Store.GetUser(ctx, &store.FindUser{
  189. ID: &userID,
  190. })
  191. if err != nil {
  192. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
  193. }
  194. if user == nil || user.Role != store.RoleHost {
  195. return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
  196. }
  197. identityProviderID, err := strconv.Atoi(c.Param("idpId"))
  198. if err != nil {
  199. return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
  200. }
  201. if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProviderID}); err != nil {
  202. return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete identity provider").SetInternal(err)
  203. }
  204. return c.JSON(http.StatusOK, true)
  205. })
  206. }
  207. func convertIdentityProviderFromStore(identityProvider *store.IdentityProvider) *IdentityProvider {
  208. return &IdentityProvider{
  209. ID: identityProvider.ID,
  210. Name: identityProvider.Name,
  211. Type: IdentityProviderType(identityProvider.Type),
  212. IdentifierFilter: identityProvider.IdentifierFilter,
  213. Config: convertIdentityProviderConfigFromStore(identityProvider.Config),
  214. }
  215. }
  216. func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *IdentityProviderConfig {
  217. return &IdentityProviderConfig{
  218. OAuth2Config: &IdentityProviderOAuth2Config{
  219. ClientID: config.OAuth2Config.ClientID,
  220. ClientSecret: config.OAuth2Config.ClientSecret,
  221. AuthURL: config.OAuth2Config.AuthURL,
  222. TokenURL: config.OAuth2Config.TokenURL,
  223. UserInfoURL: config.OAuth2Config.UserInfoURL,
  224. Scopes: config.OAuth2Config.Scopes,
  225. FieldMapping: &FieldMapping{
  226. Identifier: config.OAuth2Config.FieldMapping.Identifier,
  227. DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
  228. Email: config.OAuth2Config.FieldMapping.Email,
  229. },
  230. },
  231. }
  232. }
  233. func convertIdentityProviderConfigToStore(config *IdentityProviderConfig) *store.IdentityProviderConfig {
  234. return &store.IdentityProviderConfig{
  235. OAuth2Config: &store.IdentityProviderOAuth2Config{
  236. ClientID: config.OAuth2Config.ClientID,
  237. ClientSecret: config.OAuth2Config.ClientSecret,
  238. AuthURL: config.OAuth2Config.AuthURL,
  239. TokenURL: config.OAuth2Config.TokenURL,
  240. UserInfoURL: config.OAuth2Config.UserInfoURL,
  241. Scopes: config.OAuth2Config.Scopes,
  242. FieldMapping: &store.FieldMapping{
  243. Identifier: config.OAuth2Config.FieldMapping.Identifier,
  244. DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
  245. Email: config.OAuth2Config.FieldMapping.Email,
  246. },
  247. },
  248. }
  249. }