Implemented the Rollback feature.

In most dev/testing environments it is nice to have a way to reset/rollback
the entire database. This commit adds this feature in the most basic form;
when Rollback() is called, it will rollback ALL migrations that have run
leaving the database is a mostly pristine state.

There are a few potential issues here. For starters, devs may not always
want all migrations to be rolled back. The only current fix is to course
in this case is to create a migrator that only contains migrations they
want rolled back.

Another potential issue is that any migration that doesn't provide a
Rollback function will not be rolled back, but there also won't be any
errors. A message is printed out when this happens to help avoid some
confusion, but I could see this still causing issues. I'm not 100% sure
what the best long term solution is, but feel the current version is
good enough to move forward since it satisfies my needs.
This commit is contained in:
Jon Calhoun 2020-05-29 13:03:59 -04:00
parent 006e7c7cc3
commit eee25e8d27
3 changed files with 159 additions and 29 deletions

96
sqlx.go
View file

@ -49,6 +49,41 @@ func (s *Sqlx) Migrate(sqlDB *sql.DB, dialect string) error {
return nil return nil
} }
// Rollback will run all rollbacks using the provided db connection.
func (s *Sqlx) Rollback(sqlDB *sql.DB, dialect string) error {
db := sqlx.NewDb(sqlDB, dialect)
s.printf("Creating/checking migrations table...\n")
err := s.createMigrationTable(db)
if err != nil {
return err
}
for i := len(s.Migrations) - 1; i >= 0; i-- {
m := s.Migrations[i]
if m.Rollback == nil {
s.printf("Rollback not provided: %v\n", m.ID)
continue
}
var found string
err := db.Get(&found, "SELECT id FROM migrations WHERE id=$1", m.ID)
switch err {
case sql.ErrNoRows:
s.printf("Skipping rollback: %v\n", m.ID)
continue
case nil:
s.printf("Running rollback: %v\n", m.ID)
// we need to run the rollback so we continue to code below
default:
return fmt.Errorf("looking up rollback by id: %w", err)
}
err = s.runRollback(db, m)
if err != nil {
return err
}
}
return nil
}
func (s *Sqlx) printf(format string, a ...interface{}) (n int, err error) { func (s *Sqlx) printf(format string, a ...interface{}) (n int, err error) {
printf := s.Printf printf := s.Printf
if printf == nil { if printf == nil {
@ -89,6 +124,30 @@ func (s *Sqlx) runMigration(db *sqlx.DB, m SqlxMigration) error {
return nil return nil
} }
func (s *Sqlx) runRollback(db *sqlx.DB, m SqlxMigration) error {
errorf := func(err error) error { return fmt.Errorf("running rollback: %w", err) }
tx, err := db.Beginx()
if err != nil {
return errorf(err)
}
_, err = db.Exec("DELETE FROM migrations WHERE id=$1", m.ID)
if err != nil {
tx.Rollback()
return errorf(err)
}
err = m.Rollback(tx)
if err != nil {
tx.Rollback()
return errorf(err)
}
err = tx.Commit()
if err != nil {
return errorf(err)
}
return nil
}
// SqlxMigration is a unique ID plus a function that uses a sqlx transaction // SqlxMigration is a unique ID plus a function that uses a sqlx transaction
// to perform a database migration step. // to perform a database migration step.
// //
@ -97,24 +156,37 @@ func (s *Sqlx) runMigration(db *sqlx.DB, m SqlxMigration) error {
type SqlxMigration struct { type SqlxMigration struct {
ID string ID string
Migrate func(tx *sqlx.Tx) error Migrate func(tx *sqlx.Tx) error
Rollback func(tx *sqlx.Tx) error
} }
// SqlxQueryMigration will create a SqlxMigration using the provided id and // SqlxQueryMigration will create a SqlxMigration using the provided id and
// query string. It is a helper function designed to simplify the process of // query string. It is a helper function designed to simplify the process of
// creating migrations that only depending on a SQL query string. // creating migrations that only depending on a SQL query string.
func SqlxQueryMigration(id, query string) SqlxMigration { func SqlxQueryMigration(id, upQuery, downQuery string) SqlxMigration {
m := SqlxMigration{ queryFn := func(query string) func(tx *sqlx.Tx) error {
ID: id, if query == "" {
Migrate: func(tx *sqlx.Tx) error { return nil
}
return func(tx *sqlx.Tx) error {
_, err := tx.Exec(query) _, err := tx.Exec(query)
return err return err
}, }
}
m := SqlxMigration{
ID: id,
Migrate: queryFn(upQuery),
Rollback: queryFn(downQuery),
} }
return m return m
} }
// SqlxFileMigration will create a SqlxMigration using the provided file. // SqlxFileMigration will create a SqlxMigration using the provided file.
func SqlxFileMigration(id, filename string) SqlxMigration { func SqlxFileMigration(id, upFile, downFile string) SqlxMigration {
fileFn := func(filename string) func(tx *sqlx.Tx) error {
if filename == "" {
return nil
}
f, err := os.Open(filename) f, err := os.Open(filename)
if err != nil { if err != nil {
// We could return a migration that errors when the migration is run, but I // We could return a migration that errors when the migration is run, but I
@ -125,12 +197,16 @@ func SqlxFileMigration(id, filename string) SqlxMigration {
if err != nil { if err != nil {
panic(err) panic(err)
} }
m := SqlxMigration{ return func(tx *sqlx.Tx) error {
ID: id,
Migrate: func(tx *sqlx.Tx) error {
_, err := tx.Exec(string(fileBytes)) _, err := tx.Exec(string(fileBytes))
return err return err
}, }
}
m := SqlxMigration{
ID: id,
Migrate: fileFn(upFile),
Rollback: fileFn(downFile),
} }
return m return m
} }

View file

@ -33,7 +33,7 @@ func TestSqlx(t *testing.T) {
return 0, nil return 0, nil
}, },
Migrations: []migrate.SqlxMigration{ Migrations: []migrate.SqlxMigration{
migrate.SqlxQueryMigration("001_create_courses", createCoursesSql), migrate.SqlxQueryMigration("001_create_courses", createCoursesSql, ""),
}, },
} }
err := migrator.Migrate(db, "sqlite3") err := migrator.Migrate(db, "sqlite3")
@ -54,7 +54,7 @@ func TestSqlx(t *testing.T) {
return 0, nil return 0, nil
}, },
Migrations: []migrate.SqlxMigration{ Migrations: []migrate.SqlxMigration{
migrate.SqlxQueryMigration("001_create_courses", createCoursesSql), migrate.SqlxQueryMigration("001_create_courses", createCoursesSql, ""),
}, },
} }
err := migrator.Migrate(db, "sqlite3") err := migrator.Migrate(db, "sqlite3")
@ -73,8 +73,8 @@ func TestSqlx(t *testing.T) {
return 0, nil return 0, nil
}, },
Migrations: []migrate.SqlxMigration{ Migrations: []migrate.SqlxMigration{
migrate.SqlxQueryMigration("001_create_courses", createCoursesSql), migrate.SqlxQueryMigration("001_create_courses", createCoursesSql, ""),
migrate.SqlxQueryMigration("002_create_users", createUsersSql), migrate.SqlxQueryMigration("002_create_users", createUsersSql, ""),
}, },
} }
err = migrator.Migrate(db, "sqlite3") err = migrator.Migrate(db, "sqlite3")
@ -95,7 +95,7 @@ func TestSqlx(t *testing.T) {
return 0, nil return 0, nil
}, },
Migrations: []migrate.SqlxMigration{ Migrations: []migrate.SqlxMigration{
migrate.SqlxFileMigration("001_create_widgets", "testdata/widgets.sql"), migrate.SqlxFileMigration("001_create_widgets", "testdata/widgets.sql", ""),
}, },
} }
err := migrator.Migrate(db, "sqlite3") err := migrator.Migrate(db, "sqlite3")
@ -107,16 +107,68 @@ func TestSqlx(t *testing.T) {
t.Fatalf("db.Exec() err = %v; want nil", err) t.Fatalf("db.Exec() err = %v; want nil", err)
} }
}) })
t.Run("rollback", func(t *testing.T) {
db := sqliteInMem(t)
migrator := migrate.Sqlx{
Printf: func(format string, args ...interface{}) (int, error) {
t.Logf(format, args...)
return 0, nil
},
Migrations: []migrate.SqlxMigration{
migrate.SqlxQueryMigration("001_create_courses", createCoursesSql, dropCoursesSql),
},
}
err := migrator.Migrate(db, "sqlite3")
if err != nil {
t.Fatalf("Migrate() err = %v; want nil", err)
}
_, err = db.Exec("INSERT INTO courses (name) VALUES ($1) ", "cor_test")
if err != nil {
t.Fatalf("db.Exec() err = %v; want nil", err)
}
err = migrator.Rollback(db, "sqlite3")
if err != nil {
t.Fatalf("Rollback() err = %v; want nil", err)
}
var count int
err = db.QueryRow("SELECT COUNT(id) FROM courses;").Scan(&count)
if err == nil {
// Want an error here
t.Fatalf("db.QueryRow() err = nil; want table missing error")
}
// Don't want to test inner workings of lib, so let's just migrate again and verify we have a table now
err = migrator.Migrate(db, "sqlite3")
if err != nil {
t.Fatalf("Migrate() err = %v; want nil", err)
}
_, err = db.Exec("INSERT INTO courses (name) VALUES ($1) ", "cor_test")
if err != nil {
t.Fatalf("db.Exec() err = %v; want nil", err)
}
err = db.QueryRow("SELECT COUNT(*) FROM courses;").Scan(&count)
if err != nil {
// Want an error here
t.Fatalf("db.QueryRow() err = %v; want nil", err)
}
if count != 1 {
t.Fatalf("count = %d; want %d", count, 1)
}
})
} }
var createCoursesSql = ` var (
createCoursesSql = `
CREATE TABLE courses ( CREATE TABLE courses (
id serial PRIMARY KEY, id serial PRIMARY KEY,
name text name text
);` );`
dropCoursesSql = `DROP TABLE courses;`
var createUsersSql = ` createUsersSql = `
CREATE TABLE users ( CREATE TABLE users (
id serial PRIMARY KEY, id serial PRIMARY KEY,
email text UNIQUE NOT NULL email text UNIQUE NOT NULL
);` );`
dropUsersSql = `DROP TABLE users;`
)

2
testdata/widgets.down.sql vendored Normal file
View file

@ -0,0 +1,2 @@
DROP TABLE widgets;