diff --git a/internal/db/db.go b/internal/db/db.go index 1d97788..4a37159 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -5,11 +5,8 @@ package db import ( "context" "database/sql" - "embed" "encoding/json" "fmt" - "io/fs" - "regexp" "strconv" "strings" "time" @@ -25,75 +22,6 @@ func init() { }) } -// embed the migrations into applications so they can update databases. - -//go:embed migrations -var migrations embed.FS - -var migrationRegex = regexp.MustCompile(`^([0-9]+)_(.*)_(down|up)\.sql$`) - -type Migration struct { - Name string - Version uint - FileName string -} - -// GetMigrations returns a list of migrations, which are correctly index. zero is nil. - -// use len to get the highest number migration. -func RunMigrations(currentVer int) (finalVer int) { - - res := make(map[int]map[string]Migration) // version number -> direction -> migration. - - fs.WalkDir(migrations, ".", func(path string, d fs.DirEntry, err error) error { - - if d.IsDir() { - return nil - } - m := migrationRegex.FindStringSubmatch(d.Name()) - if len(m) != 4 { - panic("error parsing migration name") - } - migrationVer, _ := strconv.ParseInt(m[1], 10, 64) - - mig := Migration{ - Name: m[2], - Version: uint(migrationVer), - FileName: d.Name(), - } - - var mMap map[string]Migration - mMap, ok := res[int(migrationVer)] - if !ok { - mMap = make(map[string]Migration) - } - mMap[m[3]] = mig - - res[int(migrationVer)] = mMap - - return nil - }) - - // now apply the mappings based on current ver. - - for v := currentVer; v < finalVer; v++ { - // attempt to get the "up" migration. - mMap, ok := res[v] - if !ok { - panic("aa") - } - upMigration, ok := mMap["up"] - if !ok { - panic("aaa") - } - // open the file name - // execute the file. - - } - - return res -} - type TelemDb struct { db *sqlx.DB } @@ -126,8 +54,10 @@ func OpenTelemDb(path string, options ...TelemDbOption) (tdb *TelemDb, err error } // get latest version of migrations - then run the SQL in order. + fmt.Printf("starting version %d\n", version) - _, err = tdb.db.Exec(sqlDbUp) + version, err = RunMigrations(version, tdb) + fmt.Printf("ending version %d\n", version) return tdb, err } diff --git a/internal/db/migration.go b/internal/db/migration.go new file mode 100644 index 0000000..1f1a31a --- /dev/null +++ b/internal/db/migration.go @@ -0,0 +1,132 @@ +package db + +import ( + "embed" + "errors" + "io" + "io/fs" + "path" + "regexp" + "sort" + "strconv" +) + +// embed the migrations into applications so they can update databases. + +//go:embed migrations/* +var migrationsFs embed.FS + +var migrationRegex = regexp.MustCompile(`^([0-9]+)_(.*)_(down|up)\.sql$`) + +type Migration struct { + Name string + Version uint + FileName string +} + +// getMigrations returns a list of migrations, which are correctly index. zero is nil. +func getMigrations(files fs.FS) map[int]map[string]Migration { + + res := make(map[int]map[string]Migration) // version number -> direction -> migration. + + fs.WalkDir(files, ".", func(path string, d fs.DirEntry, err error) error { + + if d.IsDir() { + return nil + } + m := migrationRegex.FindStringSubmatch(d.Name()) + if len(m) != 4 { + panic("error parsing migration name") + } + migrationVer, _ := strconv.ParseInt(m[1], 10, 64) + + mig := Migration{ + Name: m[2], + Version: uint(migrationVer), + FileName: d.Name(), + } + + var mMap map[string]Migration + mMap, ok := res[int(migrationVer)] + if !ok { + mMap = make(map[string]Migration) + } + mMap[m[3]] = mig + + res[int(migrationVer)] = mMap + + return nil + }) + return res +} + +// use len to get the highest number migration. +func RunMigrations(currentVer int, tdb *TelemDb) (finalVer int, err error) { + + migrations := getMigrations(migrationsFs) + + // get a sorted list of versions. + vers := make([]int, len(migrations)) + + i := 0 + for k := range migrations { + vers[i] = k + i++ + } + sort.Ints(vers) + expectedVer := 1 + + // check to make sure that there are no gaps (increasing by one each time) + for _, v := range vers { + if v != expectedVer { + err = errors.New("missing update between") + return + // invalid + } + expectedVer = v + 1 + } + + finalVer = vers[len(vers)-1] + // now apply the mappings based on current ver. + + tx, err := tdb.db.Begin() + if err != nil { + return + } + for v := currentVer + 1; v < finalVer; v++ { + // attempt to get the "up" migration. + mMap, ok := migrations[v] + if !ok { + err = errors.New("could not find migration for version") + goto rollback + } + upMigration, ok := mMap["up"] + if !ok { + err = errors.New("could not get up migration") + goto rollback + } + upFile, err := migrationsFs.Open(path.Join("migrations", upMigration.FileName)) + if err != nil { + goto rollback + } + + upStmt, err := io.ReadAll(upFile) + if err != nil { + goto rollback + } + // open the file name + // execute the file. + _, err = tx.Exec(string(upStmt)) + if err != nil { + goto rollback + } + + } + tx.Commit() + + return + +rollback: + tx.Rollback() + return +} diff --git a/internal/db/migration_test.go b/internal/db/migration_test.go new file mode 100644 index 0000000..1909c02 --- /dev/null +++ b/internal/db/migration_test.go @@ -0,0 +1,83 @@ +package db + +import ( + "embed" + "reflect" + "testing" +) + +//go:embed migrations/1_*.sql +//go:embed migrations/2_*.sql +var testFs embed.FS + +func Test_getMigrations(t *testing.T) { + tests := []struct { + name string + want map[int]map[string]Migration + }{ + { + name: "main test", + want: map[int]map[string]Migration{ + 1: { + "up": Migration{ + Name: "initial", + Version: 1, + FileName: "1_initial_up.sql", + }, + "down": Migration{ + Name: "initial", + Version: 1, + FileName: "1_initial_down.sql", + }, + }, + + 2: { + "up": Migration{ + Name: "addl_tables", + Version: 2, + FileName: "2_addl_tables_up.sql", + }, + "down": Migration{ + Name: "addl_tables", + Version: 2, + FileName: "2_addl_tables_down.sql", + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getMigrations(testFs); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getMigrations() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRunMigrations(t *testing.T) { + type args struct { + currentVer int + tdb *TelemDb + } + tests := []struct { + name string + args args + wantFinalVer int + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotFinalVer, err := RunMigrations(tt.args.currentVer, tt.args.tdb) + if (err != nil) != tt.wantErr { + t.Errorf("RunMigrations() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotFinalVer != tt.wantFinalVer { + t.Errorf("RunMigrations() = %v, want %v", gotFinalVer, tt.wantFinalVer) + } + }) + } +}