idp.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "strings"
  8. "github.com/usememos/memos/common"
  9. )
  10. type IdentityProviderType string
  11. const (
  12. IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
  13. )
  14. type IdentityProviderConfig struct {
  15. OAuth2Config *IdentityProviderOAuth2Config
  16. }
  17. type IdentityProviderOAuth2Config struct {
  18. ClientID string `json:"clientId"`
  19. ClientSecret string `json:"clientSecret"`
  20. AuthURL string `json:"authUrl"`
  21. TokenURL string `json:"tokenUrl"`
  22. UserInfoURL string `json:"userInfoUrl"`
  23. Scopes []string `json:"scopes"`
  24. FieldMapping *FieldMapping `json:"fieldMapping"`
  25. }
  26. type FieldMapping struct {
  27. Identifier string `json:"identifier"`
  28. DisplayName string `json:"displayName"`
  29. Email string `json:"email"`
  30. }
  31. type IdentityProviderMessage struct {
  32. ID int
  33. Name string
  34. Type IdentityProviderType
  35. IdentifierFilter string
  36. Config *IdentityProviderConfig
  37. }
  38. type FindIdentityProviderMessage struct {
  39. ID *int
  40. }
  41. type UpdateIdentityProviderMessage struct {
  42. ID int
  43. Type IdentityProviderType
  44. Name *string
  45. IdentifierFilter *string
  46. Config *IdentityProviderConfig
  47. }
  48. type DeleteIdentityProviderMessage struct {
  49. ID int
  50. }
  51. func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProviderMessage) (*IdentityProviderMessage, error) {
  52. tx, err := s.db.BeginTx(ctx, nil)
  53. if err != nil {
  54. return nil, FormatError(err)
  55. }
  56. defer tx.Rollback()
  57. var configBytes []byte
  58. if create.Type == IdentityProviderOAuth2 {
  59. configBytes, err = json.Marshal(create.Config.OAuth2Config)
  60. if err != nil {
  61. return nil, err
  62. }
  63. } else {
  64. return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
  65. }
  66. query := `
  67. INSERT INTO idp (
  68. name,
  69. type,
  70. identifier_filter,
  71. config
  72. )
  73. VALUES (?, ?, ?, ?)
  74. RETURNING id
  75. `
  76. if err := tx.QueryRowContext(
  77. ctx,
  78. query,
  79. create.Name,
  80. create.Type,
  81. create.IdentifierFilter,
  82. string(configBytes),
  83. ).Scan(
  84. &create.ID,
  85. ); err != nil {
  86. return nil, FormatError(err)
  87. }
  88. if err := tx.Commit(); err != nil {
  89. return nil, FormatError(err)
  90. }
  91. identityProviderMessage := create
  92. s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
  93. return identityProviderMessage, nil
  94. }
  95. func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) {
  96. tx, err := s.db.BeginTx(ctx, nil)
  97. if err != nil {
  98. return nil, FormatError(err)
  99. }
  100. defer tx.Rollback()
  101. list, err := listIdentityProviders(ctx, tx, find)
  102. if err != nil {
  103. return nil, err
  104. }
  105. for _, item := range list {
  106. s.idpCache.Store(item.ID, item)
  107. }
  108. return list, nil
  109. }
  110. func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProviderMessage) (*IdentityProviderMessage, error) {
  111. if find.ID != nil {
  112. if cache, ok := s.idpCache.Load(*find.ID); ok {
  113. return cache.(*IdentityProviderMessage), nil
  114. }
  115. }
  116. tx, err := s.db.BeginTx(ctx, nil)
  117. if err != nil {
  118. return nil, FormatError(err)
  119. }
  120. defer tx.Rollback()
  121. list, err := listIdentityProviders(ctx, tx, find)
  122. if err != nil {
  123. return nil, err
  124. }
  125. if len(list) == 0 {
  126. return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
  127. }
  128. identityProviderMessage := list[0]
  129. s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
  130. return identityProviderMessage, nil
  131. }
  132. func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderMessage) (*IdentityProviderMessage, error) {
  133. tx, err := s.db.BeginTx(ctx, nil)
  134. if err != nil {
  135. return nil, FormatError(err)
  136. }
  137. defer tx.Rollback()
  138. set, args := []string{}, []interface{}{}
  139. if v := update.Name; v != nil {
  140. set, args = append(set, "name = ?"), append(args, *v)
  141. }
  142. if v := update.IdentifierFilter; v != nil {
  143. set, args = append(set, "identifier_filter = ?"), append(args, *v)
  144. }
  145. if v := update.Config; v != nil {
  146. var configBytes []byte
  147. if update.Type == IdentityProviderOAuth2 {
  148. configBytes, err = json.Marshal(update.Config.OAuth2Config)
  149. if err != nil {
  150. return nil, err
  151. }
  152. } else {
  153. return nil, fmt.Errorf("unsupported idp type %s", string(update.Type))
  154. }
  155. set, args = append(set, "config = ?"), append(args, string(configBytes))
  156. }
  157. args = append(args, update.ID)
  158. query := `
  159. UPDATE idp
  160. SET ` + strings.Join(set, ", ") + `
  161. WHERE id = ?
  162. RETURNING id, name, type, identifier_filter, config
  163. `
  164. var identityProviderMessage IdentityProviderMessage
  165. var identityProviderConfig string
  166. if err := tx.QueryRowContext(ctx, query, args...).Scan(
  167. &identityProviderMessage.ID,
  168. &identityProviderMessage.Name,
  169. &identityProviderMessage.Type,
  170. &identityProviderMessage.IdentifierFilter,
  171. &identityProviderConfig,
  172. ); err != nil {
  173. return nil, FormatError(err)
  174. }
  175. if identityProviderMessage.Type == IdentityProviderOAuth2 {
  176. oauth2Config := &IdentityProviderOAuth2Config{}
  177. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  178. return nil, err
  179. }
  180. identityProviderMessage.Config = &IdentityProviderConfig{
  181. OAuth2Config: oauth2Config,
  182. }
  183. } else {
  184. return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type))
  185. }
  186. if err := tx.Commit(); err != nil {
  187. return nil, FormatError(err)
  188. }
  189. s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
  190. return &identityProviderMessage, nil
  191. }
  192. func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProviderMessage) error {
  193. tx, err := s.db.BeginTx(ctx, nil)
  194. if err != nil {
  195. return FormatError(err)
  196. }
  197. defer tx.Rollback()
  198. where, args := []string{"id = ?"}, []interface{}{delete.ID}
  199. stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
  200. result, err := tx.ExecContext(ctx, stmt, args...)
  201. if err != nil {
  202. return FormatError(err)
  203. }
  204. rows, err := result.RowsAffected()
  205. if err != nil {
  206. return err
  207. }
  208. if rows == 0 {
  209. return &common.Error{Code: common.NotFound, Err: fmt.Errorf("idp not found")}
  210. }
  211. if err := tx.Commit(); err != nil {
  212. return err
  213. }
  214. s.idpCache.Delete(delete.ID)
  215. return nil
  216. }
  217. func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) {
  218. where, args := []string{"TRUE"}, []interface{}{}
  219. if v := find.ID; v != nil {
  220. where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
  221. }
  222. rows, err := tx.QueryContext(ctx, `
  223. SELECT
  224. id,
  225. name,
  226. type,
  227. identifier_filter,
  228. config
  229. FROM idp
  230. WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
  231. args...,
  232. )
  233. if err != nil {
  234. return nil, FormatError(err)
  235. }
  236. defer rows.Close()
  237. var identityProviderMessages []*IdentityProviderMessage
  238. for rows.Next() {
  239. var identityProviderMessage IdentityProviderMessage
  240. var identityProviderConfig string
  241. if err := rows.Scan(
  242. &identityProviderMessage.ID,
  243. &identityProviderMessage.Name,
  244. &identityProviderMessage.Type,
  245. &identityProviderMessage.IdentifierFilter,
  246. &identityProviderConfig,
  247. ); err != nil {
  248. return nil, FormatError(err)
  249. }
  250. if identityProviderMessage.Type == IdentityProviderOAuth2 {
  251. oauth2Config := &IdentityProviderOAuth2Config{}
  252. if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
  253. return nil, err
  254. }
  255. identityProviderMessage.Config = &IdentityProviderConfig{
  256. OAuth2Config: oauth2Config,
  257. }
  258. } else {
  259. return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type))
  260. }
  261. identityProviderMessages = append(identityProviderMessages, &identityProviderMessage)
  262. }
  263. return identityProviderMessages, nil
  264. }