get migrations working
This commit is contained in:
parent
d5b960ad8a
commit
969e17a169
|
@ -5,11 +5,8 @@ package db
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"embed"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -25,75 +22,6 @@ func init() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// embed the migrations into applications so they can update databases.
|
|
||||||
|
|
||||||
//go:embed migrations
|
|
||||||
var migrations embed.FS
|
|
||||||
|
|
||||||
var migrationRegex = regexp.MustCompile(`^([0-9]+)_(.*)_(down|up)\.sql$`)
|
|
||||||
|
|
||||||
type Migration struct {
|
|
||||||
Name string
|
|
||||||
Version uint
|
|
||||||
FileName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMigrations returns a list of migrations, which are correctly index. zero is nil.
|
|
||||||
|
|
||||||
// use len to get the highest number migration.
|
|
||||||
func RunMigrations(currentVer int) (finalVer int) {
|
|
||||||
|
|
||||||
res := make(map[int]map[string]Migration) // version number -> direction -> migration.
|
|
||||||
|
|
||||||
fs.WalkDir(migrations, ".", func(path string, d fs.DirEntry, err error) error {
|
|
||||||
|
|
||||||
if d.IsDir() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m := migrationRegex.FindStringSubmatch(d.Name())
|
|
||||||
if len(m) != 4 {
|
|
||||||
panic("error parsing migration name")
|
|
||||||
}
|
|
||||||
migrationVer, _ := strconv.ParseInt(m[1], 10, 64)
|
|
||||||
|
|
||||||
mig := Migration{
|
|
||||||
Name: m[2],
|
|
||||||
Version: uint(migrationVer),
|
|
||||||
FileName: d.Name(),
|
|
||||||
}
|
|
||||||
|
|
||||||
var mMap map[string]Migration
|
|
||||||
mMap, ok := res[int(migrationVer)]
|
|
||||||
if !ok {
|
|
||||||
mMap = make(map[string]Migration)
|
|
||||||
}
|
|
||||||
mMap[m[3]] = mig
|
|
||||||
|
|
||||||
res[int(migrationVer)] = mMap
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// now apply the mappings based on current ver.
|
|
||||||
|
|
||||||
for v := currentVer; v < finalVer; v++ {
|
|
||||||
// attempt to get the "up" migration.
|
|
||||||
mMap, ok := res[v]
|
|
||||||
if !ok {
|
|
||||||
panic("aa")
|
|
||||||
}
|
|
||||||
upMigration, ok := mMap["up"]
|
|
||||||
if !ok {
|
|
||||||
panic("aaa")
|
|
||||||
}
|
|
||||||
// open the file name
|
|
||||||
// execute the file.
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
type TelemDb struct {
|
type TelemDb struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
}
|
}
|
||||||
|
@ -126,8 +54,10 @@ func OpenTelemDb(path string, options ...TelemDbOption) (tdb *TelemDb, err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// get latest version of migrations - then run the SQL in order.
|
// get latest version of migrations - then run the SQL in order.
|
||||||
|
fmt.Printf("starting version %d\n", version)
|
||||||
|
|
||||||
_, err = tdb.db.Exec(sqlDbUp)
|
version, err = RunMigrations(version, tdb)
|
||||||
|
fmt.Printf("ending version %d\n", version)
|
||||||
|
|
||||||
return tdb, err
|
return tdb, err
|
||||||
}
|
}
|
||||||
|
|
132
internal/db/migration.go
Normal file
132
internal/db/migration.go
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"path"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// embed the migrations into applications so they can update databases.
|
||||||
|
|
||||||
|
//go:embed migrations/*
|
||||||
|
var migrationsFs embed.FS
|
||||||
|
|
||||||
|
var migrationRegex = regexp.MustCompile(`^([0-9]+)_(.*)_(down|up)\.sql$`)
|
||||||
|
|
||||||
|
type Migration struct {
|
||||||
|
Name string
|
||||||
|
Version uint
|
||||||
|
FileName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMigrations returns a list of migrations, which are correctly index. zero is nil.
|
||||||
|
func getMigrations(files fs.FS) map[int]map[string]Migration {
|
||||||
|
|
||||||
|
res := make(map[int]map[string]Migration) // version number -> direction -> migration.
|
||||||
|
|
||||||
|
fs.WalkDir(files, ".", func(path string, d fs.DirEntry, err error) error {
|
||||||
|
|
||||||
|
if d.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m := migrationRegex.FindStringSubmatch(d.Name())
|
||||||
|
if len(m) != 4 {
|
||||||
|
panic("error parsing migration name")
|
||||||
|
}
|
||||||
|
migrationVer, _ := strconv.ParseInt(m[1], 10, 64)
|
||||||
|
|
||||||
|
mig := Migration{
|
||||||
|
Name: m[2],
|
||||||
|
Version: uint(migrationVer),
|
||||||
|
FileName: d.Name(),
|
||||||
|
}
|
||||||
|
|
||||||
|
var mMap map[string]Migration
|
||||||
|
mMap, ok := res[int(migrationVer)]
|
||||||
|
if !ok {
|
||||||
|
mMap = make(map[string]Migration)
|
||||||
|
}
|
||||||
|
mMap[m[3]] = mig
|
||||||
|
|
||||||
|
res[int(migrationVer)] = mMap
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// use len to get the highest number migration.
|
||||||
|
func RunMigrations(currentVer int, tdb *TelemDb) (finalVer int, err error) {
|
||||||
|
|
||||||
|
migrations := getMigrations(migrationsFs)
|
||||||
|
|
||||||
|
// get a sorted list of versions.
|
||||||
|
vers := make([]int, len(migrations))
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for k := range migrations {
|
||||||
|
vers[i] = k
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
sort.Ints(vers)
|
||||||
|
expectedVer := 1
|
||||||
|
|
||||||
|
// check to make sure that there are no gaps (increasing by one each time)
|
||||||
|
for _, v := range vers {
|
||||||
|
if v != expectedVer {
|
||||||
|
err = errors.New("missing update between")
|
||||||
|
return
|
||||||
|
// invalid
|
||||||
|
}
|
||||||
|
expectedVer = v + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
finalVer = vers[len(vers)-1]
|
||||||
|
// now apply the mappings based on current ver.
|
||||||
|
|
||||||
|
tx, err := tdb.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for v := currentVer + 1; v < finalVer; v++ {
|
||||||
|
// attempt to get the "up" migration.
|
||||||
|
mMap, ok := migrations[v]
|
||||||
|
if !ok {
|
||||||
|
err = errors.New("could not find migration for version")
|
||||||
|
goto rollback
|
||||||
|
}
|
||||||
|
upMigration, ok := mMap["up"]
|
||||||
|
if !ok {
|
||||||
|
err = errors.New("could not get up migration")
|
||||||
|
goto rollback
|
||||||
|
}
|
||||||
|
upFile, err := migrationsFs.Open(path.Join("migrations", upMigration.FileName))
|
||||||
|
if err != nil {
|
||||||
|
goto rollback
|
||||||
|
}
|
||||||
|
|
||||||
|
upStmt, err := io.ReadAll(upFile)
|
||||||
|
if err != nil {
|
||||||
|
goto rollback
|
||||||
|
}
|
||||||
|
// open the file name
|
||||||
|
// execute the file.
|
||||||
|
_, err = tx.Exec(string(upStmt))
|
||||||
|
if err != nil {
|
||||||
|
goto rollback
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
rollback:
|
||||||
|
tx.Rollback()
|
||||||
|
return
|
||||||
|
}
|
83
internal/db/migration_test.go
Normal file
83
internal/db/migration_test.go
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed migrations/1_*.sql
|
||||||
|
//go:embed migrations/2_*.sql
|
||||||
|
var testFs embed.FS
|
||||||
|
|
||||||
|
func Test_getMigrations(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want map[int]map[string]Migration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "main test",
|
||||||
|
want: map[int]map[string]Migration{
|
||||||
|
1: {
|
||||||
|
"up": Migration{
|
||||||
|
Name: "initial",
|
||||||
|
Version: 1,
|
||||||
|
FileName: "1_initial_up.sql",
|
||||||
|
},
|
||||||
|
"down": Migration{
|
||||||
|
Name: "initial",
|
||||||
|
Version: 1,
|
||||||
|
FileName: "1_initial_down.sql",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
2: {
|
||||||
|
"up": Migration{
|
||||||
|
Name: "addl_tables",
|
||||||
|
Version: 2,
|
||||||
|
FileName: "2_addl_tables_up.sql",
|
||||||
|
},
|
||||||
|
"down": Migration{
|
||||||
|
Name: "addl_tables",
|
||||||
|
Version: 2,
|
||||||
|
FileName: "2_addl_tables_down.sql",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := getMigrations(testFs); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("getMigrations() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMigrations(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
currentVer int
|
||||||
|
tdb *TelemDb
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantFinalVer int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotFinalVer, err := RunMigrations(tt.args.currentVer, tt.args.tdb)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("RunMigrations() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gotFinalVer != tt.wantFinalVer {
|
||||||
|
t.Errorf("RunMigrations() = %v, want %v", gotFinalVer, tt.wantFinalVer)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue