gotelem/migration.go
2024-03-02 21:49:18 -06:00

139 lines
2.7 KiB
Go

package gotelem
import (
"embed"
"errors"
"io"
"io/fs"
"path"
"regexp"
"sort"
"strconv"
)
// embed the migrations into applications so they can update databases.
//go:embed migrations/*
var migrationsFs embed.FS
var migrationRegex = regexp.MustCompile(`^([0-9]+)_(.*)_(down|up)\.sql$`)
type Migration struct {
Name string
Version uint
FileName string
}
type MigrationError struct {
}
// getMigrations returns a list of migrations, which are correctly index. zero is nil.
func getMigrations(files fs.FS) map[int]map[string]Migration {
res := make(map[int]map[string]Migration) // version number -> direction -> migration.
fs.WalkDir(files, ".", func(path string, d fs.DirEntry, err error) error {
if d.IsDir() {
return nil
}
m := migrationRegex.FindStringSubmatch(d.Name())
if len(m) != 4 {
panic("error parsing migration name")
}
migrationVer, _ := strconv.ParseInt(m[1], 10, 64)
mig := Migration{
Name: m[2],
Version: uint(migrationVer),
FileName: d.Name(),
}
var mMap map[string]Migration
mMap, ok := res[int(migrationVer)]
if !ok {
mMap = make(map[string]Migration)
}
mMap[m[3]] = mig
res[int(migrationVer)] = mMap
return nil
})
return res
}
func RunMigrations(tdb *TelemDb) (finalVer int, err error) {
currentVer, err := tdb.GetVersion()
if err != nil {
return
}
migrations := getMigrations(migrationsFs)
// get a sorted list of versions.
vers := make([]int, len(migrations))
i := 0
for k := range migrations {
vers[i] = k
i++
}
sort.Ints(vers)
expectedVer := 1
// check to make sure that there are no gaps (increasing by one each time)
for _, v := range vers {
if v != expectedVer {
err = errors.New("missing update between")
return 0, err
// invalid
}
expectedVer = v + 1
}
finalVer = vers[len(vers)-1]
// now apply the mappings based on current ver.
tx, err := tdb.db.Begin()
defer tx.Rollback()
if err != nil {
return 0, err
}
for v := currentVer + 1; v <= finalVer; v++ {
// attempt to get the "up" migration.
mMap, ok := migrations[v]
if !ok {
err = errors.New("could not find migration for version")
return 0, err
}
upMigration, ok := mMap["up"]
if !ok {
err = errors.New("could not get up migration")
return 0, err
}
upFile, err := migrationsFs.Open(path.Join("migrations", upMigration.FileName))
if err != nil {
return 0, err
}
upStmt, err := io.ReadAll(upFile)
if err != nil {
return 0, err
}
// open the file name
// execute the file.
_, err = tx.Exec(string(upStmt))
if err != nil {
return 0, err
}
}
// if all the versions applied correctly, update the PRAGMA user_version in the database.
tx.Commit()
err = tdb.SetVersion(finalVer)
return
}