mysql.go 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. package mysql
  2. import (
  3. "database/sql"
  4. "github.com/go-sql-driver/mysql"
  5. "github.com/pkg/errors"
  6. "github.com/usememos/memos/server/profile"
  7. "github.com/usememos/memos/store"
  8. )
  9. type DB struct {
  10. db *sql.DB
  11. profile *profile.Profile
  12. config *mysql.Config
  13. }
  14. func NewDB(profile *profile.Profile) (store.Driver, error) {
  15. // Open MySQL connection with parameter.
  16. // multiStatements=true is required for migration.
  17. // See more in: https://github.com/go-sql-driver/mysql#multistatements
  18. dsn, err := mergeDSN(profile.DSN)
  19. if err != nil {
  20. return nil, err
  21. }
  22. driver := DB{profile: profile}
  23. driver.config, err = mysql.ParseDSN(dsn)
  24. if err != nil {
  25. return nil, errors.New("Parse DSN eroor")
  26. }
  27. driver.db, err = sql.Open("mysql", dsn)
  28. if err != nil {
  29. return nil, errors.Wrapf(err, "failed to open db: %s", profile.DSN)
  30. }
  31. return &driver, nil
  32. }
  33. func (d *DB) GetDB() *sql.DB {
  34. return d.db
  35. }
  36. func (d *DB) Close() error {
  37. return d.db.Close()
  38. }
  39. func mergeDSN(baseDSN string) (string, error) {
  40. config, err := mysql.ParseDSN(baseDSN)
  41. if err != nil {
  42. return "", errors.Wrapf(err, "failed to parse DSN: %s", baseDSN)
  43. }
  44. config.MultiStatements = true
  45. return config.FormatDSN(), nil
  46. }