From eee25e8d2782258fe62b27c4fad643c91818f82d Mon Sep 17 00:00:00 2001 From: Jon Calhoun Date: Fri, 29 May 2020 13:03:59 -0400 Subject: [PATCH] Implemented the Rollback feature. In most dev/testing environments it is nice to have a way to reset/rollback the entire database. This commit adds this feature in the most basic form; when Rollback() is called, it will rollback ALL migrations that have run leaving the database is a mostly pristine state. There are a few potential issues here. For starters, devs may not always want all migrations to be rolled back. The only current fix is to course in this case is to create a migrator that only contains migrations they want rolled back. Another potential issue is that any migration that doesn't provide a Rollback function will not be rolled back, but there also won't be any errors. A message is printed out when this happens to help avoid some confusion, but I could see this still causing issues. I'm not 100% sure what the best long term solution is, but feel the current version is good enough to move forward since it satisfies my needs. --- sqlx.go | 120 +++++++++++++++++++++++++++++++------- sqlx_test.go | 66 ++++++++++++++++++--- testdata/widgets.down.sql | 2 + 3 files changed, 159 insertions(+), 29 deletions(-) create mode 100644 testdata/widgets.down.sql diff --git a/sqlx.go b/sqlx.go index e74b687..a624f60 100644 --- a/sqlx.go +++ b/sqlx.go @@ -49,6 +49,41 @@ func (s *Sqlx) Migrate(sqlDB *sql.DB, dialect string) error { return nil } +// Rollback will run all rollbacks using the provided db connection. +func (s *Sqlx) Rollback(sqlDB *sql.DB, dialect string) error { + db := sqlx.NewDb(sqlDB, dialect) + + s.printf("Creating/checking migrations table...\n") + err := s.createMigrationTable(db) + if err != nil { + return err + } + for i := len(s.Migrations) - 1; i >= 0; i-- { + m := s.Migrations[i] + if m.Rollback == nil { + s.printf("Rollback not provided: %v\n", m.ID) + continue + } + var found string + err := db.Get(&found, "SELECT id FROM migrations WHERE id=$1", m.ID) + switch err { + case sql.ErrNoRows: + s.printf("Skipping rollback: %v\n", m.ID) + continue + case nil: + s.printf("Running rollback: %v\n", m.ID) + // we need to run the rollback so we continue to code below + default: + return fmt.Errorf("looking up rollback by id: %w", err) + } + err = s.runRollback(db, m) + if err != nil { + return err + } + } + return nil +} + func (s *Sqlx) printf(format string, a ...interface{}) (n int, err error) { printf := s.Printf if printf == nil { @@ -89,48 +124,89 @@ func (s *Sqlx) runMigration(db *sqlx.DB, m SqlxMigration) error { return nil } +func (s *Sqlx) runRollback(db *sqlx.DB, m SqlxMigration) error { + errorf := func(err error) error { return fmt.Errorf("running rollback: %w", err) } + + tx, err := db.Beginx() + if err != nil { + return errorf(err) + } + _, err = db.Exec("DELETE FROM migrations WHERE id=$1", m.ID) + if err != nil { + tx.Rollback() + return errorf(err) + } + err = m.Rollback(tx) + if err != nil { + tx.Rollback() + return errorf(err) + } + err = tx.Commit() + if err != nil { + return errorf(err) + } + return nil +} + // SqlxMigration is a unique ID plus a function that uses a sqlx transaction // to perform a database migration step. // // Note: Long term this could have a Rollback field if we wanted to support // that. type SqlxMigration struct { - ID string - Migrate func(tx *sqlx.Tx) error + ID string + Migrate func(tx *sqlx.Tx) error + Rollback func(tx *sqlx.Tx) error } // SqlxQueryMigration will create a SqlxMigration using the provided id and // query string. It is a helper function designed to simplify the process of // creating migrations that only depending on a SQL query string. -func SqlxQueryMigration(id, query string) SqlxMigration { - m := SqlxMigration{ - ID: id, - Migrate: func(tx *sqlx.Tx) error { +func SqlxQueryMigration(id, upQuery, downQuery string) SqlxMigration { + queryFn := func(query string) func(tx *sqlx.Tx) error { + if query == "" { + return nil + } + return func(tx *sqlx.Tx) error { _, err := tx.Exec(query) return err - }, + } + } + + m := SqlxMigration{ + ID: id, + Migrate: queryFn(upQuery), + Rollback: queryFn(downQuery), } return m } // SqlxFileMigration will create a SqlxMigration using the provided file. -func SqlxFileMigration(id, filename string) SqlxMigration { - f, err := os.Open(filename) - if err != nil { - // We could return a migration that errors when the migration is run, but I - // think it makes more sense to panic here. - panic(err) - } - fileBytes, err := ioutil.ReadAll(f) - if err != nil { - panic(err) - } - m := SqlxMigration{ - ID: id, - Migrate: func(tx *sqlx.Tx) error { +func SqlxFileMigration(id, upFile, downFile string) SqlxMigration { + fileFn := func(filename string) func(tx *sqlx.Tx) error { + if filename == "" { + return nil + } + f, err := os.Open(filename) + if err != nil { + // We could return a migration that errors when the migration is run, but I + // think it makes more sense to panic here. + panic(err) + } + fileBytes, err := ioutil.ReadAll(f) + if err != nil { + panic(err) + } + return func(tx *sqlx.Tx) error { _, err := tx.Exec(string(fileBytes)) return err - }, + } + } + + m := SqlxMigration{ + ID: id, + Migrate: fileFn(upFile), + Rollback: fileFn(downFile), } return m } diff --git a/sqlx_test.go b/sqlx_test.go index 2f377de..7e7c86d 100644 --- a/sqlx_test.go +++ b/sqlx_test.go @@ -33,7 +33,7 @@ func TestSqlx(t *testing.T) { return 0, nil }, Migrations: []migrate.SqlxMigration{ - migrate.SqlxQueryMigration("001_create_courses", createCoursesSql), + migrate.SqlxQueryMigration("001_create_courses", createCoursesSql, ""), }, } err := migrator.Migrate(db, "sqlite3") @@ -54,7 +54,7 @@ func TestSqlx(t *testing.T) { return 0, nil }, Migrations: []migrate.SqlxMigration{ - migrate.SqlxQueryMigration("001_create_courses", createCoursesSql), + migrate.SqlxQueryMigration("001_create_courses", createCoursesSql, ""), }, } err := migrator.Migrate(db, "sqlite3") @@ -73,8 +73,8 @@ func TestSqlx(t *testing.T) { return 0, nil }, Migrations: []migrate.SqlxMigration{ - migrate.SqlxQueryMigration("001_create_courses", createCoursesSql), - migrate.SqlxQueryMigration("002_create_users", createUsersSql), + migrate.SqlxQueryMigration("001_create_courses", createCoursesSql, ""), + migrate.SqlxQueryMigration("002_create_users", createUsersSql, ""), }, } err = migrator.Migrate(db, "sqlite3") @@ -95,7 +95,7 @@ func TestSqlx(t *testing.T) { return 0, nil }, Migrations: []migrate.SqlxMigration{ - migrate.SqlxFileMigration("001_create_widgets", "testdata/widgets.sql"), + migrate.SqlxFileMigration("001_create_widgets", "testdata/widgets.sql", ""), }, } err := migrator.Migrate(db, "sqlite3") @@ -107,16 +107,68 @@ func TestSqlx(t *testing.T) { t.Fatalf("db.Exec() err = %v; want nil", err) } }) + + t.Run("rollback", func(t *testing.T) { + db := sqliteInMem(t) + migrator := migrate.Sqlx{ + Printf: func(format string, args ...interface{}) (int, error) { + t.Logf(format, args...) + return 0, nil + }, + Migrations: []migrate.SqlxMigration{ + migrate.SqlxQueryMigration("001_create_courses", createCoursesSql, dropCoursesSql), + }, + } + err := migrator.Migrate(db, "sqlite3") + if err != nil { + t.Fatalf("Migrate() err = %v; want nil", err) + } + _, err = db.Exec("INSERT INTO courses (name) VALUES ($1) ", "cor_test") + if err != nil { + t.Fatalf("db.Exec() err = %v; want nil", err) + } + err = migrator.Rollback(db, "sqlite3") + if err != nil { + t.Fatalf("Rollback() err = %v; want nil", err) + } + var count int + err = db.QueryRow("SELECT COUNT(id) FROM courses;").Scan(&count) + if err == nil { + // Want an error here + t.Fatalf("db.QueryRow() err = nil; want table missing error") + } + // Don't want to test inner workings of lib, so let's just migrate again and verify we have a table now + err = migrator.Migrate(db, "sqlite3") + if err != nil { + t.Fatalf("Migrate() err = %v; want nil", err) + } + _, err = db.Exec("INSERT INTO courses (name) VALUES ($1) ", "cor_test") + if err != nil { + t.Fatalf("db.Exec() err = %v; want nil", err) + } + err = db.QueryRow("SELECT COUNT(*) FROM courses;").Scan(&count) + if err != nil { + // Want an error here + t.Fatalf("db.QueryRow() err = %v; want nil", err) + } + if count != 1 { + t.Fatalf("count = %d; want %d", count, 1) + } + }) } -var createCoursesSql = ` +var ( + createCoursesSql = ` CREATE TABLE courses ( id serial PRIMARY KEY, name text );` + dropCoursesSql = `DROP TABLE courses;` -var createUsersSql = ` + createUsersSql = ` CREATE TABLE users ( id serial PRIMARY KEY, email text UNIQUE NOT NULL );` + dropUsersSql = `DROP TABLE users;` +) diff --git a/testdata/widgets.down.sql b/testdata/widgets.down.sql new file mode 100644 index 0000000..b498cb2 --- /dev/null +++ b/testdata/widgets.down.sql @@ -0,0 +1,2 @@ +DROP TABLE widgets; +