get migrations working

This commit is contained in:
saji 2023-07-06 11:26:00 -05:00
parent d5b960ad8a
commit 969e17a169
3 changed files with 218 additions and 73 deletions

View file

@ -5,11 +5,8 @@ package db
import (
"context"
"database/sql"
"embed"
"encoding/json"
"fmt"
"io/fs"
"regexp"
"strconv"
"strings"
"time"
@ -25,75 +22,6 @@ func init() {
})
}
// embed the migrations into applications so they can update databases.
//go:embed migrations
var migrations embed.FS
var migrationRegex = regexp.MustCompile(`^([0-9]+)_(.*)_(down|up)\.sql$`)
type Migration struct {
Name string
Version uint
FileName string
}
// GetMigrations returns a list of migrations, which are correctly index. zero is nil.
// use len to get the highest number migration.
func RunMigrations(currentVer int) (finalVer int) {
res := make(map[int]map[string]Migration) // version number -> direction -> migration.
fs.WalkDir(migrations, ".", func(path string, d fs.DirEntry, err error) error {
if d.IsDir() {
return nil
}
m := migrationRegex.FindStringSubmatch(d.Name())
if len(m) != 4 {
panic("error parsing migration name")
}
migrationVer, _ := strconv.ParseInt(m[1], 10, 64)
mig := Migration{
Name: m[2],
Version: uint(migrationVer),
FileName: d.Name(),
}
var mMap map[string]Migration
mMap, ok := res[int(migrationVer)]
if !ok {
mMap = make(map[string]Migration)
}
mMap[m[3]] = mig
res[int(migrationVer)] = mMap
return nil
})
// now apply the mappings based on current ver.
for v := currentVer; v < finalVer; v++ {
// attempt to get the "up" migration.
mMap, ok := res[v]
if !ok {
panic("aa")
}
upMigration, ok := mMap["up"]
if !ok {
panic("aaa")
}
// open the file name
// execute the file.
}
return res
}
type TelemDb struct {
db *sqlx.DB
}
@ -126,8 +54,10 @@ func OpenTelemDb(path string, options ...TelemDbOption) (tdb *TelemDb, err error
}
// get latest version of migrations - then run the SQL in order.
fmt.Printf("starting version %d\n", version)
_, err = tdb.db.Exec(sqlDbUp)
version, err = RunMigrations(version, tdb)
fmt.Printf("ending version %d\n", version)
return tdb, err
}

132
internal/db/migration.go Normal file
View file

@ -0,0 +1,132 @@
package db
import (
"embed"
"errors"
"io"
"io/fs"
"path"
"regexp"
"sort"
"strconv"
)
// embed the migrations into applications so they can update databases.
//go:embed migrations/*
var migrationsFs embed.FS
var migrationRegex = regexp.MustCompile(`^([0-9]+)_(.*)_(down|up)\.sql$`)
type Migration struct {
Name string
Version uint
FileName string
}
// getMigrations returns a list of migrations, which are correctly index. zero is nil.
func getMigrations(files fs.FS) map[int]map[string]Migration {
res := make(map[int]map[string]Migration) // version number -> direction -> migration.
fs.WalkDir(files, ".", func(path string, d fs.DirEntry, err error) error {
if d.IsDir() {
return nil
}
m := migrationRegex.FindStringSubmatch(d.Name())
if len(m) != 4 {
panic("error parsing migration name")
}
migrationVer, _ := strconv.ParseInt(m[1], 10, 64)
mig := Migration{
Name: m[2],
Version: uint(migrationVer),
FileName: d.Name(),
}
var mMap map[string]Migration
mMap, ok := res[int(migrationVer)]
if !ok {
mMap = make(map[string]Migration)
}
mMap[m[3]] = mig
res[int(migrationVer)] = mMap
return nil
})
return res
}
// use len to get the highest number migration.
func RunMigrations(currentVer int, tdb *TelemDb) (finalVer int, err error) {
migrations := getMigrations(migrationsFs)
// get a sorted list of versions.
vers := make([]int, len(migrations))
i := 0
for k := range migrations {
vers[i] = k
i++
}
sort.Ints(vers)
expectedVer := 1
// check to make sure that there are no gaps (increasing by one each time)
for _, v := range vers {
if v != expectedVer {
err = errors.New("missing update between")
return
// invalid
}
expectedVer = v + 1
}
finalVer = vers[len(vers)-1]
// now apply the mappings based on current ver.
tx, err := tdb.db.Begin()
if err != nil {
return
}
for v := currentVer + 1; v < finalVer; v++ {
// attempt to get the "up" migration.
mMap, ok := migrations[v]
if !ok {
err = errors.New("could not find migration for version")
goto rollback
}
upMigration, ok := mMap["up"]
if !ok {
err = errors.New("could not get up migration")
goto rollback
}
upFile, err := migrationsFs.Open(path.Join("migrations", upMigration.FileName))
if err != nil {
goto rollback
}
upStmt, err := io.ReadAll(upFile)
if err != nil {
goto rollback
}
// open the file name
// execute the file.
_, err = tx.Exec(string(upStmt))
if err != nil {
goto rollback
}
}
tx.Commit()
return
rollback:
tx.Rollback()
return
}

View file

@ -0,0 +1,83 @@
package db
import (
"embed"
"reflect"
"testing"
)
//go:embed migrations/1_*.sql
//go:embed migrations/2_*.sql
var testFs embed.FS
func Test_getMigrations(t *testing.T) {
tests := []struct {
name string
want map[int]map[string]Migration
}{
{
name: "main test",
want: map[int]map[string]Migration{
1: {
"up": Migration{
Name: "initial",
Version: 1,
FileName: "1_initial_up.sql",
},
"down": Migration{
Name: "initial",
Version: 1,
FileName: "1_initial_down.sql",
},
},
2: {
"up": Migration{
Name: "addl_tables",
Version: 2,
FileName: "2_addl_tables_up.sql",
},
"down": Migration{
Name: "addl_tables",
Version: 2,
FileName: "2_addl_tables_down.sql",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getMigrations(testFs); !reflect.DeepEqual(got, tt.want) {
t.Errorf("getMigrations() = %v, want %v", got, tt.want)
}
})
}
}
func TestRunMigrations(t *testing.T) {
type args struct {
currentVer int
tdb *TelemDb
}
tests := []struct {
name string
args args
wantFinalVer int
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotFinalVer, err := RunMigrations(tt.args.currentVer, tt.args.tdb)
if (err != nil) != tt.wantErr {
t.Errorf("RunMigrations() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotFinalVer != tt.wantFinalVer {
t.Errorf("RunMigrations() = %v, want %v", gotFinalVer, tt.wantFinalVer)
}
})
}
}