migrator.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package mysql
  2. import (
  3. "context"
  4. "embed"
  5. "fmt"
  6. "io/fs"
  7. "regexp"
  8. "sort"
  9. "strings"
  10. "github.com/pkg/errors"
  11. "github.com/usememos/memos/server/version"
  12. "github.com/usememos/memos/store"
  13. )
  14. //go:embed migration
  15. var migrationFS embed.FS
  16. const (
  17. latestSchemaFileName = "LATEST__SCHEMA.sql"
  18. )
  19. func (d *DB) Migrate(ctx context.Context) error {
  20. if d.profile.IsDev() {
  21. return d.nonProdMigrate(ctx)
  22. }
  23. return d.prodMigrate(ctx)
  24. }
  25. func (d *DB) nonProdMigrate(ctx context.Context) error {
  26. rows, err := d.db.QueryContext(ctx, "SHOW TABLES")
  27. if err != nil {
  28. return errors.Errorf("failed to query database tables: %s", err)
  29. }
  30. if rows.Err() != nil {
  31. return errors.Errorf("failed to query database tables: %s", err)
  32. }
  33. defer rows.Close()
  34. var tables []string
  35. for rows.Next() {
  36. var table string
  37. err := rows.Scan(&table)
  38. if err != nil {
  39. return errors.Errorf("failed to scan table name: %s", err)
  40. }
  41. tables = append(tables, table)
  42. }
  43. if len(tables) != 0 {
  44. return nil
  45. }
  46. buf, err := migrationFS.ReadFile("migration/dev/" + latestSchemaFileName)
  47. if err != nil {
  48. return errors.Errorf("failed to read latest schema file: %s", err)
  49. }
  50. stmt := string(buf)
  51. if _, err := d.db.ExecContext(ctx, stmt); err != nil {
  52. return errors.Errorf("failed to exec SQL %s: %s", stmt, err)
  53. }
  54. return nil
  55. }
  56. func (d *DB) prodMigrate(ctx context.Context) error {
  57. currentVersion := version.GetCurrentVersion(d.profile.Mode)
  58. migrationHistoryList, err := d.FindMigrationHistoryList(ctx, &store.FindMigrationHistory{})
  59. // If there is no migration history, we should apply the latest schema.
  60. if err != nil || len(migrationHistoryList) == 0 {
  61. buf, err := migrationFS.ReadFile("migration/prod/" + latestSchemaFileName)
  62. if err != nil {
  63. return errors.Errorf("failed to read latest schema file: %s", err)
  64. }
  65. if _, err := d.db.ExecContext(ctx, string(buf)); err != nil {
  66. return errors.Errorf("failed to exec latest schema: %s", err)
  67. }
  68. if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
  69. Version: currentVersion,
  70. }); err != nil {
  71. return errors.Wrap(err, "failed to upsert migration history")
  72. }
  73. return nil
  74. }
  75. migrationHistoryVersionList := []string{}
  76. for _, migrationHistory := range migrationHistoryList {
  77. migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
  78. }
  79. sort.Sort(version.SortVersion(migrationHistoryVersionList))
  80. latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
  81. if !version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
  82. return nil
  83. }
  84. fmt.Println("start to migrate database schema")
  85. for _, minorVersion := range getMinorVersionList() {
  86. normalizedVersion := minorVersion + ".0"
  87. if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
  88. fmt.Println("applying migration of", normalizedVersion)
  89. if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
  90. return errors.Wrap(err, "failed to apply minor version migration")
  91. }
  92. }
  93. }
  94. fmt.Println("end migrate")
  95. return nil
  96. }
  97. func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
  98. filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion))
  99. if err != nil {
  100. return errors.Wrap(err, "failed to read ddl files")
  101. }
  102. sort.Strings(filenames)
  103. // Loop over all migration files and execute them in order.
  104. for _, filename := range filenames {
  105. buf, err := migrationFS.ReadFile(filename)
  106. if err != nil {
  107. return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename)
  108. }
  109. for _, stmt := range strings.Split(string(buf), ";") {
  110. if strings.TrimSpace(stmt) == "" {
  111. continue
  112. }
  113. if _, err := d.db.ExecContext(ctx, stmt); err != nil {
  114. return errors.Wrapf(err, "migrate error: %s", stmt)
  115. }
  116. }
  117. }
  118. // Upsert the newest version to migration_history.
  119. version := minorVersion + ".0"
  120. if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{Version: version}); err != nil {
  121. return errors.Wrapf(err, "failed to upsert migration history with version: %s", version)
  122. }
  123. return nil
  124. }
  125. // minorDirRegexp is a regular expression for minor version directory.
  126. var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
  127. func getMinorVersionList() []string {
  128. minorVersionList := []string{}
  129. if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
  130. if err != nil {
  131. return err
  132. }
  133. if file.IsDir() && minorDirRegexp.MatchString(path) {
  134. minorVersionList = append(minorVersionList, file.Name())
  135. }
  136. return nil
  137. }); err != nil {
  138. panic(err)
  139. }
  140. sort.Sort(version.SortVersion(minorVersionList))
  141. return minorVersionList
  142. }