265 lines
5.2 KiB
Go
265 lines
5.2 KiB
Go
package migrations
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"regexp"
|
|
"runtime"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
var allMigrations []Migration
|
|
|
|
type Migration struct {
|
|
Version int64
|
|
Up func(DB) error
|
|
Down func(DB) error
|
|
}
|
|
|
|
func (m *Migration) String() string {
|
|
return strconv.FormatInt(m.Version, 10)
|
|
}
|
|
|
|
// Register registers new database migration. Must be called
|
|
// from file with name like "1_initialize_db.go", where:
|
|
// - 1 - migration version;
|
|
// - initialize_db - comment.
|
|
func Register(up, down func(DB) error) error {
|
|
_, file, _, _ := runtime.Caller(1)
|
|
version, err := extractVersion(file)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
allMigrations = append(allMigrations, Migration{
|
|
Version: version,
|
|
Up: up,
|
|
Down: down,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// Run runs command on the db. Supported commands are:
|
|
// - init - creates gopg_migrations table.
|
|
// - up - runs all available migrations.
|
|
// - down - reverts last migration.
|
|
// - reset - reverts all migrations.
|
|
// - version - prints current db version.
|
|
// - set_version - sets db version without running migrations.
|
|
func Run(db DB, a ...string) (oldVersion, newVersion int64, err error) {
|
|
// Make a copy so there are no side effects of sorting.
|
|
migrations := make([]Migration, len(allMigrations))
|
|
copy(migrations, allMigrations)
|
|
return RunMigrations(db, migrations, a...)
|
|
}
|
|
|
|
// RunMigrations is like Run, but accepts list of migrations.
|
|
func RunMigrations(db DB, migrations []Migration, a ...string) (oldVersion, newVersion int64, err error) {
|
|
sortMigrations(migrations)
|
|
|
|
var cmd string
|
|
if len(a) > 0 {
|
|
cmd = a[0]
|
|
}
|
|
|
|
if cmd == "init" {
|
|
err = createTables(db)
|
|
if err != nil {
|
|
return
|
|
}
|
|
cmd = "version"
|
|
}
|
|
|
|
oldVersion, err = Version(db)
|
|
if err != nil {
|
|
return
|
|
}
|
|
newVersion = oldVersion
|
|
|
|
switch cmd {
|
|
case "create":
|
|
if len(a) < 2 {
|
|
fmt.Println("Please enter migration description")
|
|
return
|
|
}
|
|
|
|
var version int64
|
|
if len(migrations) > 0 {
|
|
version = migrations[len(migrations)-1].Version
|
|
}
|
|
|
|
filename := fmtMigrationFilename(version+1, strings.Join(a[1:], "_"))
|
|
err = createMigrationFile(filename)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
fmt.Println("created migration", filename)
|
|
return
|
|
case "version":
|
|
return
|
|
case "up", "":
|
|
for i := range migrations {
|
|
m := &migrations[i]
|
|
if m.Version <= oldVersion {
|
|
continue
|
|
}
|
|
err = m.Up(db)
|
|
if err != nil {
|
|
return
|
|
}
|
|
newVersion = m.Version
|
|
err = SetVersion(db, newVersion)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
case "down":
|
|
newVersion, err = down(db, migrations, oldVersion)
|
|
return
|
|
case "reset":
|
|
version := oldVersion
|
|
for {
|
|
newVersion, err = down(db, migrations, version)
|
|
if err != nil || newVersion == version {
|
|
return
|
|
}
|
|
version = newVersion
|
|
}
|
|
case "set_version":
|
|
if len(a) < 2 {
|
|
err = fmt.Errorf("set_version requires version as 2nd arg, e.g. set_version 42")
|
|
return
|
|
}
|
|
|
|
newVersion, err = strconv.ParseInt(a[1], 10, 64)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = SetVersion(db, newVersion)
|
|
return
|
|
default:
|
|
err = fmt.Errorf("unsupported command: %q", cmd)
|
|
return
|
|
}
|
|
}
|
|
|
|
func down(db DB, migrations []Migration, oldVersion int64) (newVersion int64, err error) {
|
|
if oldVersion == 0 {
|
|
return
|
|
}
|
|
|
|
var m *Migration
|
|
for i := len(migrations) - 1; i >= 0; i-- {
|
|
mm := &migrations[i]
|
|
if mm.Version <= oldVersion {
|
|
m = mm
|
|
break
|
|
}
|
|
}
|
|
if m == nil {
|
|
err = fmt.Errorf("migration %d not found\n", oldVersion)
|
|
return
|
|
}
|
|
|
|
if m.Down != nil {
|
|
err = m.Down(db)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
newVersion = m.Version - 1
|
|
err = SetVersion(db, newVersion)
|
|
return
|
|
}
|
|
|
|
func extractVersion(name string) (int64, error) {
|
|
base := filepath.Base(name)
|
|
|
|
if ext := filepath.Ext(base); ext != ".go" {
|
|
return 0, fmt.Errorf("can not extract version from %q", base)
|
|
}
|
|
|
|
idx := strings.IndexByte(base, '_')
|
|
if idx == -1 {
|
|
return 0, fmt.Errorf("can not extract version from %q", base)
|
|
}
|
|
|
|
n, err := strconv.ParseInt(base[:idx], 10, 64)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if n <= 0 {
|
|
return 0, errors.New("version must be greater than zero")
|
|
}
|
|
|
|
return n, nil
|
|
}
|
|
|
|
type migrationSorter []Migration
|
|
|
|
func (ms migrationSorter) Len() int {
|
|
return len(ms)
|
|
}
|
|
|
|
func (ms migrationSorter) Swap(i, j int) {
|
|
ms[i], ms[j] = ms[j], ms[i]
|
|
}
|
|
|
|
func (ms migrationSorter) Less(i, j int) bool {
|
|
return ms[i].Version < ms[j].Version
|
|
}
|
|
|
|
func sortMigrations(migrations []Migration) {
|
|
ms := migrationSorter(migrations)
|
|
sort.Sort(ms)
|
|
}
|
|
|
|
var migrationNameRE = regexp.MustCompile(`[^a-z0-9]+`)
|
|
|
|
func fmtMigrationFilename(version int64, descr string) string {
|
|
descr = strings.ToLower(descr)
|
|
descr = migrationNameRE.ReplaceAllString(descr, "_")
|
|
return fmt.Sprintf("%d_%s.go", version, descr)
|
|
}
|
|
|
|
func createMigrationFile(filename string) error {
|
|
basepath, err := os.Getwd()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
filename = path.Join(basepath, filename)
|
|
|
|
_, err = os.Stat(filename)
|
|
if !os.IsNotExist(err) {
|
|
return fmt.Errorf("file=%q already exists (%s)", filename, err)
|
|
}
|
|
|
|
return ioutil.WriteFile(filename, migrationTemplate, 0644)
|
|
}
|
|
|
|
var migrationTemplate = []byte(`package main
|
|
|
|
import (
|
|
"github.com/go-pg/migrations"
|
|
)
|
|
|
|
func init() {
|
|
migrations.Register(func(db migrations.DB) error {
|
|
_, err := db.Exec("")
|
|
return err
|
|
}, func(db migrations.DB) error {
|
|
_, err := db.Exec("")
|
|
return err
|
|
})
|
|
}
|
|
`)
|