migrator.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. package store
  2. import (
  3. "context"
  4. "database/sql"
  5. "embed"
  6. "fmt"
  7. "io/fs"
  8. "log/slog"
  9. "path/filepath"
  10. "sort"
  11. "strconv"
  12. "strings"
  13. "github.com/pkg/errors"
  14. storepb "github.com/usememos/memos/proto/gen/store"
  15. "github.com/usememos/memos/server/version"
  16. )
  17. //go:embed migration
  18. var migrationFS embed.FS
  19. //go:embed seed
  20. var seedFS embed.FS
  21. const (
  22. // MigrateFileNameSplit is the split character between the patch version and the description in the migration file name.
  23. // For example, "1__create_table.sql".
  24. MigrateFileNameSplit = "__"
  25. // LatestSchemaFileName is the name of the latest schema file.
  26. // This file is used to apply the latest schema when no migration history is found.
  27. LatestSchemaFileName = "LATEST.sql"
  28. )
  29. // Migrate applies the latest schema to the database.
  30. func (s *Store) Migrate(ctx context.Context) error {
  31. if err := s.preMigrate(ctx); err != nil {
  32. return errors.Wrap(err, "failed to pre-migrate")
  33. }
  34. if s.Profile.Mode == "prod" {
  35. migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
  36. if err != nil {
  37. return errors.Wrap(err, "failed to find migration history")
  38. }
  39. if len(migrationHistoryList) == 0 {
  40. return errors.Errorf("no migration history found")
  41. }
  42. migrationHistoryVersions := []string{}
  43. for _, migrationHistory := range migrationHistoryList {
  44. migrationHistoryVersions = append(migrationHistoryVersions, migrationHistory.Version)
  45. }
  46. sort.Sort(version.SortVersion(migrationHistoryVersions))
  47. latestMigrationHistoryVersion := migrationHistoryVersions[len(migrationHistoryVersions)-1]
  48. schemaVersion, err := s.GetCurrentSchemaVersion()
  49. if err != nil {
  50. return errors.Wrap(err, "failed to get current schema version")
  51. }
  52. if version.IsVersionGreaterThan(schemaVersion, latestMigrationHistoryVersion) {
  53. filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s*/*.sql", s.getMigrationBasePath()))
  54. if err != nil {
  55. return errors.Wrap(err, "failed to read migration files")
  56. }
  57. sort.Strings(filePaths)
  58. // Start a transaction to apply the latest schema.
  59. tx, err := s.driver.GetDB().Begin()
  60. if err != nil {
  61. return errors.Wrap(err, "failed to start transaction")
  62. }
  63. defer tx.Rollback()
  64. slog.Info("start migration", slog.String("currentSchemaVersion", latestMigrationHistoryVersion), slog.String("targetSchemaVersion", schemaVersion))
  65. for _, filePath := range filePaths {
  66. fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath)
  67. if err != nil {
  68. return errors.Wrap(err, "failed to get schema version of migrate script")
  69. }
  70. if version.IsVersionGreaterThan(fileSchemaVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(schemaVersion, fileSchemaVersion) {
  71. bytes, err := migrationFS.ReadFile(filePath)
  72. if err != nil {
  73. return errors.Wrapf(err, "failed to read minor version migration file: %s", filePath)
  74. }
  75. stmt := string(bytes)
  76. if err := s.execute(ctx, tx, stmt); err != nil {
  77. return errors.Wrapf(err, "migrate error: %s", stmt)
  78. }
  79. }
  80. }
  81. if err := tx.Commit(); err != nil {
  82. return errors.Wrap(err, "failed to commit transaction")
  83. }
  84. slog.Info("end migrate")
  85. // Upsert the current schema version to migration_history.
  86. // TODO: retire using migration history later.
  87. if _, err = s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
  88. Version: schemaVersion,
  89. }); err != nil {
  90. return errors.Wrapf(err, "failed to upsert migration history with version: %s", schemaVersion)
  91. }
  92. if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil {
  93. return errors.Wrap(err, "failed to update current schema version")
  94. }
  95. }
  96. } else if s.Profile.Mode == "demo" {
  97. // In demo mode, we should seed the database.
  98. if err := s.seed(ctx); err != nil {
  99. return errors.Wrap(err, "failed to seed")
  100. }
  101. }
  102. return nil
  103. }
  104. func (s *Store) preMigrate(ctx context.Context) error {
  105. // TODO: using schema version in basic setting instead of migration history.
  106. migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
  107. // If any error occurs or no migration history found, apply the latest schema.
  108. if err != nil || len(migrationHistoryList) == 0 {
  109. if err != nil {
  110. slog.Warn("failed to find migration history in pre-migrate", slog.String("error", err.Error()))
  111. }
  112. filePath := s.getMigrationBasePath() + LatestSchemaFileName
  113. bytes, err := migrationFS.ReadFile(filePath)
  114. if err != nil {
  115. return errors.Errorf("failed to read latest schema file: %s", err)
  116. }
  117. schemaVersion, err := s.GetCurrentSchemaVersion()
  118. if err != nil {
  119. return errors.Wrap(err, "failed to get current schema version")
  120. }
  121. // Start a transaction to apply the latest schema.
  122. tx, err := s.driver.GetDB().Begin()
  123. if err != nil {
  124. return errors.Wrap(err, "failed to start transaction")
  125. }
  126. defer tx.Rollback()
  127. if err := s.execute(ctx, tx, string(bytes)); err != nil {
  128. return errors.Errorf("failed to execute SQL file %s, err %s", filePath, err)
  129. }
  130. if err := tx.Commit(); err != nil {
  131. return errors.Wrap(err, "failed to commit transaction")
  132. }
  133. // TODO: using schema version in basic setting instead of migration history.
  134. if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
  135. Version: schemaVersion,
  136. }); err != nil {
  137. return errors.Wrap(err, "failed to upsert migration history")
  138. }
  139. if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil {
  140. return errors.Wrap(err, "failed to update current schema version")
  141. }
  142. }
  143. if s.Profile.Mode == "prod" {
  144. if err := s.normalizedMigrationHistoryList(ctx); err != nil {
  145. return errors.Wrap(err, "failed to normalize migration history list")
  146. }
  147. }
  148. return nil
  149. }
  150. func (s *Store) getMigrationBasePath() string {
  151. mode := "dev"
  152. if s.Profile.Mode == "prod" {
  153. mode = "prod"
  154. }
  155. return fmt.Sprintf("migration/%s/%s/", s.Profile.Driver, mode)
  156. }
  157. func (s *Store) getSeedBasePath() string {
  158. return fmt.Sprintf("seed/%s/", s.Profile.Driver)
  159. }
  160. func (s *Store) seed(ctx context.Context) error {
  161. // Only seed for SQLite.
  162. if s.Profile.Driver != "sqlite" {
  163. slog.Warn("seed is only supported for SQLite")
  164. return nil
  165. }
  166. filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s*.sql", s.getSeedBasePath()))
  167. if err != nil {
  168. return errors.Wrap(err, "failed to read seed files")
  169. }
  170. // Sort seed files by name. This is important to ensure that seed files are applied in order.
  171. sort.Strings(filenames)
  172. // Start a transaction to apply the seed files.
  173. tx, err := s.driver.GetDB().Begin()
  174. if err != nil {
  175. return errors.Wrap(err, "failed to start transaction")
  176. }
  177. defer tx.Rollback()
  178. // Loop over all seed files and execute them in order.
  179. for _, filename := range filenames {
  180. bytes, err := seedFS.ReadFile(filename)
  181. if err != nil {
  182. return errors.Wrapf(err, "failed to read seed file, filename=%s", filename)
  183. }
  184. if err := s.execute(ctx, tx, string(bytes)); err != nil {
  185. return errors.Wrapf(err, "seed error: %s", filename)
  186. }
  187. }
  188. return tx.Commit()
  189. }
  190. func (s *Store) GetCurrentSchemaVersion() (string, error) {
  191. currentVersion := version.GetCurrentVersion(s.Profile.Mode)
  192. minorVersion := version.GetMinorVersion(currentVersion)
  193. filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion))
  194. if err != nil {
  195. return "", errors.Wrap(err, "failed to read migration files")
  196. }
  197. sort.Strings(filePaths)
  198. if len(filePaths) == 0 {
  199. return fmt.Sprintf("%s.0", minorVersion), nil
  200. }
  201. return s.getSchemaVersionOfMigrateScript(filePaths[len(filePaths)-1])
  202. }
  203. func (s *Store) getSchemaVersionOfMigrateScript(filePath string) (string, error) {
  204. // If the file is the latest schema file, return the current schema version.
  205. if strings.HasSuffix(filePath, LatestSchemaFileName) {
  206. return s.GetCurrentSchemaVersion()
  207. }
  208. normalizedPath := filepath.ToSlash(filePath)
  209. elements := strings.Split(normalizedPath, "/")
  210. if len(elements) < 2 {
  211. return "", errors.Errorf("invalid file path: %s", filePath)
  212. }
  213. minorVersion := elements[len(elements)-2]
  214. rawPatchVersion := strings.Split(elements[len(elements)-1], MigrateFileNameSplit)[0]
  215. patchVersion, err := strconv.Atoi(rawPatchVersion)
  216. if err != nil {
  217. return "", errors.Wrapf(err, "failed to convert patch version to int: %s", rawPatchVersion)
  218. }
  219. return fmt.Sprintf("%s.%d", minorVersion, patchVersion+1), nil
  220. }
  221. // execute runs a single SQL statement within a transaction.
  222. func (*Store) execute(ctx context.Context, tx *sql.Tx, stmt string) error {
  223. if _, err := tx.ExecContext(ctx, stmt); err != nil {
  224. return errors.Wrap(err, "failed to execute statement")
  225. }
  226. return nil
  227. }
  228. func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
  229. migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
  230. if err != nil {
  231. return errors.Wrap(err, "failed to find migration history")
  232. }
  233. versions := []string{}
  234. for _, migrationHistory := range migrationHistoryList {
  235. versions = append(versions, migrationHistory.Version)
  236. }
  237. sort.Sort(version.SortVersion(versions))
  238. latestVersion := versions[len(versions)-1]
  239. latestMinorVersion := version.GetMinorVersion(latestVersion)
  240. // If the latest version is greater than 0.22, return.
  241. // As of 0.22, the migration history is already normalized.
  242. if version.IsVersionGreaterThan(latestMinorVersion, "0.22") {
  243. return nil
  244. }
  245. schemaVersionMap := map[string]string{}
  246. filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s*/*.sql", s.getMigrationBasePath()))
  247. if err != nil {
  248. return errors.Wrap(err, "failed to read migration files")
  249. }
  250. sort.Strings(filePaths)
  251. for _, filePath := range filePaths {
  252. fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath)
  253. if err != nil {
  254. return errors.Wrap(err, "failed to get schema version of migrate script")
  255. }
  256. schemaVersionMap[version.GetMinorVersion(fileSchemaVersion)] = fileSchemaVersion
  257. }
  258. latestSchemaVersion := schemaVersionMap[latestMinorVersion]
  259. if latestSchemaVersion == "" {
  260. return errors.Errorf("latest schema version not found")
  261. }
  262. if version.IsVersionGreaterOrEqualThan(latestVersion, latestSchemaVersion) {
  263. return nil
  264. }
  265. // Start a transaction to insert the latest schema version to migration_history.
  266. tx, err := s.driver.GetDB().Begin()
  267. if err != nil {
  268. return errors.Wrap(err, "failed to start transaction")
  269. }
  270. defer tx.Rollback()
  271. if err := s.execute(ctx, tx, fmt.Sprintf("INSERT INTO migration_history (version) VALUES ('%s')", latestSchemaVersion)); err != nil {
  272. return errors.Wrap(err, "failed to insert migration history")
  273. }
  274. return tx.Commit()
  275. }
  276. func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error {
  277. workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
  278. if err != nil {
  279. return errors.Wrap(err, "failed to get workspace basic setting")
  280. }
  281. workspaceBasicSetting.SchemaVersion = schemaVersion
  282. if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{
  283. Key: storepb.WorkspaceSettingKey_BASIC,
  284. Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
  285. }); err != nil {
  286. return errors.Wrap(err, "failed to upsert workspace setting")
  287. }
  288. return nil
  289. }