Explorar o código

Add migrate package for schema versioned migrations

xormplus %!s(int64=8) %!d(string=hai) anos
pai
achega
6e3eebadbd
Modificáronse 3 ficheiros con 385 adicións e 6 borrados
  1. 214 0
      migrate/migrate.go
  2. 150 0
      migrate/migrate_test.go
  3. 21 6
      test/xorm_test.go

+ 214 - 0
migrate/migrate.go

@@ -0,0 +1,214 @@
+package migrate
+
+import (
+	"errors"
+	"fmt"
+
+	"github.com/xormplus/xorm"
+)
+
+// MigrateFunc is the func signature for migrating.
+type MigrateFunc func(*xorm.Engine) error
+
+// RollbackFunc is the func signature for rollbacking.
+type RollbackFunc func(*xorm.Engine) error
+
+// InitSchemaFunc is the func signature for initializing the schema.
+type InitSchemaFunc func(*xorm.Engine) error
+
+// Options define options for all migrations.
+type Options struct {
+	// TableName is the migration table.
+	TableName string
+	// IDColumnName is the name of column where the migration id will be stored.
+	IDColumnName string
+}
+
+// Migration represents a database migration (a modification to be made on the database).
+type Migration struct {
+	// ID is the migration identifier. Usually a timestamp like "201601021504".
+	ID string
+	// Migrate is a function that will br executed while running this migration.
+	Migrate MigrateFunc
+	// Rollback will be executed on rollback. Can be nil.
+	Rollback RollbackFunc
+}
+
+// Migrate represents a collection of all migrations of a database schema.
+type Migrate struct {
+	db         *xorm.Engine
+	options    *Options
+	migrations []*Migration
+	initSchema InitSchemaFunc
+}
+
+var (
+	// DefaultOptions can be used if you don't want to think about options.
+	DefaultOptions = &Options{
+		TableName:    "migrations",
+		IDColumnName: "id",
+	}
+
+	// ErrRollbackImpossible is returned when trying to rollback a migration
+	// that has no rollback function.
+	ErrRollbackImpossible = errors.New("It's impossible to rollback this migration")
+
+	// ErrNoMigrationDefined is returned when no migration is defined.
+	ErrNoMigrationDefined = errors.New("No migration defined")
+
+	// ErrMissingID is returned when the ID od migration is equal to ""
+	ErrMissingID = errors.New("Missing ID in migration")
+
+	// ErrNoRunnedMigration is returned when any runned migration was found while
+	// running RollbackLast
+	ErrNoRunnedMigration = errors.New("Could not find last runned migration")
+)
+
+// New returns a new Gormigrate.
+func New(db *xorm.Engine, options *Options, migrations []*Migration) *Migrate {
+	return &Migrate{
+		db:         db,
+		options:    options,
+		migrations: migrations,
+	}
+}
+
+// InitSchema sets a function that is run if no migration is found.
+// The idea is preventing to run all migrations when a new clean database
+// is being migrating. In this function you should create all tables and
+// foreign key necessary to your application.
+func (m *Migrate) InitSchema(initSchema InitSchemaFunc) {
+	m.initSchema = initSchema
+}
+
+// Migrate executes all migrations that did not run yet.
+func (m *Migrate) Migrate() error {
+	if err := m.createMigrationTableIfNotExists(); err != nil {
+		return err
+	}
+
+	if m.initSchema != nil && m.isFirstRun() {
+		if err := m.runInitSchema(); err != nil {
+			return err
+		}
+		return nil
+	}
+
+	for _, migration := range m.migrations {
+		if err := m.runMigration(migration); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+// RollbackLast undo the last migration
+func (m *Migrate) RollbackLast() error {
+	if len(m.migrations) == 0 {
+		return ErrNoMigrationDefined
+	}
+
+	lastRunnedMigration, err := m.getLastRunnedMigration()
+	if err != nil {
+		return err
+	}
+
+	if err := m.RollbackMigration(lastRunnedMigration); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (m *Migrate) getLastRunnedMigration() (*Migration, error) {
+	for i := len(m.migrations) - 1; i >= 0; i-- {
+		migration := m.migrations[i]
+		if m.migrationDidRun(migration) {
+			return migration, nil
+		}
+	}
+	return nil, ErrNoRunnedMigration
+}
+
+// RollbackMigration undo a migration.
+func (m *Migrate) RollbackMigration(mig *Migration) error {
+	if mig.Rollback == nil {
+		return ErrRollbackImpossible
+	}
+
+	if err := mig.Rollback(m.db); err != nil {
+		return err
+	}
+
+	sql := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName)
+	if _, err := m.db.Exec(sql, mig.ID); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (m *Migrate) runInitSchema() error {
+	if err := m.initSchema(m.db); err != nil {
+		return err
+	}
+
+	for _, migration := range m.migrations {
+		if err := m.insertMigration(migration.ID); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (m *Migrate) runMigration(migration *Migration) error {
+	if len(migration.ID) == 0 {
+		return ErrMissingID
+	}
+
+	if !m.migrationDidRun(migration) {
+		if err := migration.Migrate(m.db); err != nil {
+			return err
+		}
+
+		if err := m.insertMigration(migration.ID); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (m *Migrate) createMigrationTableIfNotExists() error {
+	exists, err := m.db.IsTableExist(m.options.TableName)
+	if err != nil {
+		return err
+	}
+	if exists {
+		return nil
+	}
+
+	sql := fmt.Sprintf("CREATE TABLE %s (%s VARCHAR(255) PRIMARY KEY)", m.options.TableName, m.options.IDColumnName)
+	if _, err := m.db.Exec(sql); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (m *Migrate) migrationDidRun(mig *Migration) bool {
+	row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName), mig.ID)
+	var count int
+	row.Scan(&count)
+	return count > 0
+}
+
+func (m *Migrate) isFirstRun() bool {
+	row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", m.options.TableName))
+	var count int
+	row.Scan(&count)
+	return count == 0
+}
+
+func (m *Migrate) insertMigration(id string) error {
+	sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?)", m.options.TableName, m.options.IDColumnName)
+	_, err := m.db.Exec(sql, id)
+	return err
+}

+ 150 - 0
migrate/migrate_test.go

@@ -0,0 +1,150 @@
+package migrate
+
+import (
+	"fmt"
+	"log"
+	"os"
+	"testing"
+
+	_ "github.com/mattn/go-sqlite3"
+	"github.com/xormplus/xorm"
+	"gopkg.in/stretchr/testify.v1/assert"
+)
+
+type Person struct {
+	ID   int64
+	Name string
+}
+
+type Pet struct {
+	ID       int64
+	Name     string
+	PersonID int
+}
+
+const (
+	dbName = "testdb.sqlite3"
+)
+
+var (
+	migrations = []*Migration{
+		{
+			ID: "201608301400",
+			Migrate: func(tx *xorm.Engine) error {
+				return tx.Sync2(&Person{})
+			},
+			Rollback: func(tx *xorm.Engine) error {
+				return tx.DropTables(&Person{})
+			},
+		},
+		{
+			ID: "201608301430",
+			Migrate: func(tx *xorm.Engine) error {
+				return tx.Sync2(&Pet{})
+			},
+			Rollback: func(tx *xorm.Engine) error {
+				return tx.DropTables(&Pet{})
+			},
+		},
+	}
+)
+
+func TestMigration(t *testing.T) {
+	_ = os.Remove(dbName)
+
+	db, err := xorm.NewEngine("sqlite3", dbName)
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer db.Close()
+
+	if err = db.DB().Ping(); err != nil {
+		log.Fatal(err)
+	}
+
+	m := New(db, DefaultOptions, migrations)
+
+	err = m.Migrate()
+	assert.NoError(t, err)
+	exists, _ := db.IsTableExist(&Person{})
+	assert.True(t, exists)
+	exists, _ = db.IsTableExist(&Pet{})
+	assert.True(t, exists)
+	assert.Equal(t, 2, tableCount(db, "migrations"))
+
+	err = m.RollbackLast()
+	assert.NoError(t, err)
+	exists, _ = db.IsTableExist(&Person{})
+	assert.True(t, exists)
+	exists, _ = db.IsTableExist(&Pet{})
+	assert.False(t, exists)
+	assert.Equal(t, 1, tableCount(db, "migrations"))
+
+	err = m.RollbackLast()
+	assert.NoError(t, err)
+	exists, _ = db.IsTableExist(&Person{})
+	assert.False(t, exists)
+	exists, _ = db.IsTableExist(&Pet{})
+	assert.False(t, exists)
+	assert.Equal(t, 0, tableCount(db, "migrations"))
+}
+
+func TestInitSchema(t *testing.T) {
+	os.Remove(dbName)
+
+	db, err := xorm.NewEngine("sqlite3", dbName)
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer db.Close()
+	if err = db.DB().Ping(); err != nil {
+		log.Fatal(err)
+	}
+
+	m := New(db, DefaultOptions, migrations)
+	m.InitSchema(func(tx *xorm.Engine) error {
+		if err := tx.Sync2(&Person{}); err != nil {
+			return err
+		}
+		if err := tx.Sync2(&Pet{}); err != nil {
+			return err
+		}
+		return nil
+	})
+
+	err = m.Migrate()
+	assert.NoError(t, err)
+	exists, _ := db.IsTableExist(&Person{})
+	assert.True(t, exists)
+	exists, _ = db.IsTableExist(&Pet{})
+	assert.True(t, exists)
+	assert.Equal(t, 2, tableCount(db, "migrations"))
+}
+
+func TestMissingID(t *testing.T) {
+	os.Remove(dbName)
+
+	db, err := xorm.NewEngine("sqlite3", dbName)
+	assert.NoError(t, err)
+	if db != nil {
+		defer db.Close()
+	}
+	assert.NoError(t, db.DB().Ping())
+
+	migrationsMissingID := []*Migration{
+		{
+			Migrate: func(tx *xorm.Engine) error {
+				return nil
+			},
+		},
+	}
+
+	m := New(db, DefaultOptions, migrationsMissingID)
+	assert.Equal(t, ErrMissingID, m.Migrate())
+}
+
+func tableCount(db *xorm.Engine, tableName string) (count int) {
+	row := db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
+	row.Scan(&count)
+	return
+}

+ 21 - 6
test/xorm_test.go

@@ -73,6 +73,7 @@ func Test_InitDB(t *testing.T) {
 	}
 
 	db.ShowSQL(true)
+
 	log.Println(db)
 	//	db.NewSession().SqlMapClient().Execute()
 	log.Println(db.GetSqlMap("json_category-16-17"))
@@ -303,7 +304,7 @@ func Test_QueryByParamMapWithDateFormat_XmlIndent(t *testing.T) {
 
 func Test_SqlMapClient_QueryByParamMap_Json(t *testing.T) {
 	paramMap := map[string]interface{}{"1": 2, "2": 5}
-	rows, err := db.SqlMapClient("selectAllArticle", &paramMap).Query().Json()
+	rows, err := db.SqlMapClient("json_selectAllArticle", &paramMap).Query().Json()
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -312,7 +313,7 @@ func Test_SqlMapClient_QueryByParamMap_Json(t *testing.T) {
 
 func Test_SqlMapClient_QueryByParamMapWithDateFormat_Json(t *testing.T) {
 	paramMap := map[string]interface{}{"1": 2, "2": 5}
-	rows, err := db.SqlMapClient("selectAllArticle", &paramMap).QueryWithDateFormat("2006-01-02 15:04").Json()
+	rows, err := db.SqlMapClient("json_selectAllArticle", &paramMap).QueryWithDateFormat("2006-01-02 15:04").Json()
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -321,7 +322,7 @@ func Test_SqlMapClient_QueryByParamMapWithDateFormat_Json(t *testing.T) {
 
 func Test_SqlMapClient_QueryByParamMap_Xml(t *testing.T) {
 	paramMap := map[string]interface{}{"1": 2, "2": 5}
-	rows, err := db.SqlMapClient("selectAllArticle", &paramMap).Query().Xml()
+	rows, err := db.SqlMapClient("json_selectAllArticle", &paramMap).Query().Xml()
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -330,7 +331,7 @@ func Test_SqlMapClient_QueryByParamMap_Xml(t *testing.T) {
 
 func Test_SqlMapClient_QueryByParamMapWithDateFormat_Xml(t *testing.T) {
 	paramMap := map[string]interface{}{"1": 2, "2": 5}
-	rows, err := db.SqlMapClient("selectAllArticle", &paramMap).QueryWithDateFormat("2006-01-02 15:04").Xml()
+	rows, err := db.SqlMapClient("json_selectAllArticle", &paramMap).QueryWithDateFormat("2006-01-02 15:04").Xml()
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -339,7 +340,7 @@ func Test_SqlMapClient_QueryByParamMapWithDateFormat_Xml(t *testing.T) {
 
 func Test_SqlMapClient_QueryByParamMap_XmlIndent(t *testing.T) {
 	paramMap := map[string]interface{}{"1": 2, "2": 5}
-	rows, err := db.SqlMapClient("selectAllArticle", &paramMap).Query().XmlIndent("", "  ", "article")
+	rows, err := db.SqlMapClient("json_selectAllArticle", &paramMap).Query().XmlIndent("", "  ", "article")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -348,7 +349,7 @@ func Test_SqlMapClient_QueryByParamMap_XmlIndent(t *testing.T) {
 
 func Test_SqlMapClient_QueryByParamMapWithDateFormat_XmlIndent(t *testing.T) {
 	paramMap := map[string]interface{}{"1": 2, "2": 5}
-	rows, err := db.SqlMapClient("selectAllArticle", &paramMap).QueryWithDateFormat("2006-01-02 15:04").XmlIndent("", "  ", "article")
+	rows, err := db.SqlMapClient("json_selectAllArticle", &paramMap).QueryWithDateFormat("2006-01-02 15:04").XmlIndent("", "  ", "article")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -660,3 +661,17 @@ func Test_GetSqlTemplates(t *testing.T) {
 	}
 	t.Log("[Test_GetSqlTemplates]->Test_GetSqlTemplates_2->strSqlTemplate:\n", strSqlTemplate)
 }
+
+func Test_Limit_Func(t *testing.T) {
+	res, _ := db.Sql("SELECT b.id,a.name,b.title FROM category a,article b where a.id=b.categorysubid ORDER BY b.id").Limit(10, 3).Query().List()
+	t.Log(res)
+}
+
+func Test_Find(t *testing.T) {
+	var category []Category
+	err := db.Find(&category)
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Log(category)
+}