2020-05-28 17:06:53 +00:00
|
|
|
package migrate
|
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
"os"
|
|
|
|
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
)
|
|
|
|
|
|
|
|
// Sqlx is a migrator that uses github.com/jmoiron/sqlx
|
|
|
|
type Sqlx struct {
|
|
|
|
Migrations []SqlxMigration
|
|
|
|
// Printf is used to print out additional information during a migration, such
|
|
|
|
// as which step the migration is currently on. It can be replaced with any
|
|
|
|
// custom printf function, including one that just ignores inputs. If nil it
|
|
|
|
// will default to fmt.Printf.
|
|
|
|
Printf func(format string, a ...interface{}) (n int, err error)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Migrate will run the migrations using the provided db connection.
|
|
|
|
func (s *Sqlx) Migrate(sqlDB *sql.DB, dialect string) error {
|
|
|
|
db := sqlx.NewDb(sqlDB, dialect)
|
|
|
|
|
|
|
|
s.printf("Creating/checking migrations table...\n")
|
|
|
|
err := s.createMigrationTable(db)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
for _, m := range s.Migrations {
|
|
|
|
var found string
|
|
|
|
err := db.Get(&found, "SELECT id FROM migrations WHERE id=$1", m.ID)
|
|
|
|
switch err {
|
|
|
|
case sql.ErrNoRows:
|
|
|
|
s.printf("Running migration: %v\n", m.ID)
|
|
|
|
// we need to run the migration so we continue to code below
|
|
|
|
case nil:
|
|
|
|
s.printf("Skipping migration: %v\n", m.ID)
|
|
|
|
continue
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("looking up migration by id: %w", err)
|
|
|
|
}
|
|
|
|
err = s.runMigration(db, m)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2020-05-29 17:03:59 +00:00
|
|
|
// Rollback will run all rollbacks using the provided db connection.
|
|
|
|
func (s *Sqlx) Rollback(sqlDB *sql.DB, dialect string) error {
|
|
|
|
db := sqlx.NewDb(sqlDB, dialect)
|
|
|
|
|
|
|
|
s.printf("Creating/checking migrations table...\n")
|
|
|
|
err := s.createMigrationTable(db)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
for i := len(s.Migrations) - 1; i >= 0; i-- {
|
|
|
|
m := s.Migrations[i]
|
|
|
|
if m.Rollback == nil {
|
|
|
|
s.printf("Rollback not provided: %v\n", m.ID)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
var found string
|
|
|
|
err := db.Get(&found, "SELECT id FROM migrations WHERE id=$1", m.ID)
|
|
|
|
switch err {
|
|
|
|
case sql.ErrNoRows:
|
|
|
|
s.printf("Skipping rollback: %v\n", m.ID)
|
|
|
|
continue
|
|
|
|
case nil:
|
|
|
|
s.printf("Running rollback: %v\n", m.ID)
|
|
|
|
// we need to run the rollback so we continue to code below
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("looking up rollback by id: %w", err)
|
|
|
|
}
|
|
|
|
err = s.runRollback(db, m)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2020-05-28 17:06:53 +00:00
|
|
|
func (s *Sqlx) printf(format string, a ...interface{}) (n int, err error) {
|
|
|
|
printf := s.Printf
|
|
|
|
if printf == nil {
|
|
|
|
printf = fmt.Printf
|
|
|
|
}
|
|
|
|
return printf(format, a...)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Sqlx) createMigrationTable(db *sqlx.DB) error {
|
|
|
|
_, err := db.Exec("CREATE TABLE IF NOT EXISTS migrations (id TEXT PRIMARY KEY )")
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("creating migrations table: %w", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Sqlx) runMigration(db *sqlx.DB, m SqlxMigration) error {
|
|
|
|
errorf := func(err error) error { return fmt.Errorf("running migration: %w", err) }
|
|
|
|
|
|
|
|
tx, err := db.Beginx()
|
|
|
|
if err != nil {
|
|
|
|
return errorf(err)
|
|
|
|
}
|
2021-08-18 21:30:04 +00:00
|
|
|
_, err = tx.Exec("INSERT INTO migrations (id) VALUES ($1)", m.ID)
|
2020-05-28 17:06:53 +00:00
|
|
|
if err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return errorf(err)
|
|
|
|
}
|
|
|
|
err = m.Migrate(tx)
|
|
|
|
if err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return errorf(err)
|
|
|
|
}
|
|
|
|
err = tx.Commit()
|
|
|
|
if err != nil {
|
|
|
|
return errorf(err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2020-05-29 17:03:59 +00:00
|
|
|
func (s *Sqlx) runRollback(db *sqlx.DB, m SqlxMigration) error {
|
|
|
|
errorf := func(err error) error { return fmt.Errorf("running rollback: %w", err) }
|
|
|
|
|
|
|
|
tx, err := db.Beginx()
|
|
|
|
if err != nil {
|
|
|
|
return errorf(err)
|
|
|
|
}
|
2021-08-18 21:30:04 +00:00
|
|
|
_, err = tx.Exec("DELETE FROM migrations WHERE id=$1", m.ID)
|
2020-05-29 17:03:59 +00:00
|
|
|
if err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return errorf(err)
|
|
|
|
}
|
|
|
|
err = m.Rollback(tx)
|
|
|
|
if err != nil {
|
|
|
|
tx.Rollback()
|
|
|
|
return errorf(err)
|
|
|
|
}
|
|
|
|
err = tx.Commit()
|
|
|
|
if err != nil {
|
|
|
|
return errorf(err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2020-05-28 17:06:53 +00:00
|
|
|
// SqlxMigration is a unique ID plus a function that uses a sqlx transaction
|
|
|
|
// to perform a database migration step.
|
|
|
|
//
|
|
|
|
// Note: Long term this could have a Rollback field if we wanted to support
|
|
|
|
// that.
|
|
|
|
type SqlxMigration struct {
|
2020-05-29 17:03:59 +00:00
|
|
|
ID string
|
|
|
|
Migrate func(tx *sqlx.Tx) error
|
|
|
|
Rollback func(tx *sqlx.Tx) error
|
2020-05-28 17:06:53 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// SqlxQueryMigration will create a SqlxMigration using the provided id and
|
|
|
|
// query string. It is a helper function designed to simplify the process of
|
|
|
|
// creating migrations that only depending on a SQL query string.
|
2020-05-29 17:03:59 +00:00
|
|
|
func SqlxQueryMigration(id, upQuery, downQuery string) SqlxMigration {
|
|
|
|
queryFn := func(query string) func(tx *sqlx.Tx) error {
|
|
|
|
if query == "" {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return func(tx *sqlx.Tx) error {
|
2020-05-28 17:06:53 +00:00
|
|
|
_, err := tx.Exec(query)
|
|
|
|
return err
|
2020-05-29 17:03:59 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
m := SqlxMigration{
|
|
|
|
ID: id,
|
|
|
|
Migrate: queryFn(upQuery),
|
|
|
|
Rollback: queryFn(downQuery),
|
2020-05-28 17:06:53 +00:00
|
|
|
}
|
|
|
|
return m
|
|
|
|
}
|
|
|
|
|
|
|
|
// SqlxFileMigration will create a SqlxMigration using the provided file.
|
2020-05-29 17:03:59 +00:00
|
|
|
func SqlxFileMigration(id, upFile, downFile string) SqlxMigration {
|
|
|
|
fileFn := func(filename string) func(tx *sqlx.Tx) error {
|
|
|
|
if filename == "" {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
f, err := os.Open(filename)
|
|
|
|
if err != nil {
|
|
|
|
// We could return a migration that errors when the migration is run, but I
|
|
|
|
// think it makes more sense to panic here.
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
fileBytes, err := ioutil.ReadAll(f)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
return func(tx *sqlx.Tx) error {
|
2020-05-28 17:06:53 +00:00
|
|
|
_, err := tx.Exec(string(fileBytes))
|
|
|
|
return err
|
2020-05-29 17:03:59 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
m := SqlxMigration{
|
|
|
|
ID: id,
|
|
|
|
Migrate: fileFn(upFile),
|
|
|
|
Rollback: fileFn(downFile),
|
2020-05-28 17:06:53 +00:00
|
|
|
}
|
|
|
|
return m
|
|
|
|
}
|