migrate.go 5.4 KB


  1. package migrate
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/xormplus/xorm"
  6. )
  7. // MigrateFunc is the func signature for migrating.
  8. type MigrateFunc func(*xorm.Engine) error
  9. // RollbackFunc is the func signature for rollbacking.
  10. type RollbackFunc func(*xorm.Engine) error
  11. // InitSchemaFunc is the func signature for initializing the schema.
  12. type InitSchemaFunc func(*xorm.Engine) error
  13. // Options define options for all migrations.
  14. type Options struct {
  15. // TableName is the migration table.
  16. TableName string
  17. // IDColumnName is the name of column where the migration id will be stored.
  18. IDColumnName string
  19. }
  20. // Migration represents a database migration (a modification to be made on the database).
  21. type Migration struct {
  22. // ID is the migration identifier. Usually a timestamp like "201601021504".
  23. ID string
  24. // Migrate is a function that will br executed while running this migration.
  25. Migrate MigrateFunc
  26. // Rollback will be executed on rollback. Can be nil.
  27. Rollback RollbackFunc
  28. }
  29. // Migrate represents a collection of all migrations of a database schema.
  30. type Migrate struct {
  31. db *xorm.Engine
  32. options *Options
  33. migrations []*Migration
  34. initSchema InitSchemaFunc
  35. }
  36. var (
  37. // DefaultOptions can be used if you don't want to think about options.
  38. DefaultOptions = &Options{
  39. TableName: "migrations",
  40. IDColumnName: "id",
  41. }
  42. // ErrRollbackImpossible is returned when trying to rollback a migration
  43. // that has no rollback function.
  44. ErrRollbackImpossible = errors.New("It's impossible to rollback this migration")
  45. // ErrNoMigrationDefined is returned when no migration is defined.
  46. ErrNoMigrationDefined = errors.New("No migration defined")
  47. // ErrMissingID is returned when the ID od migration is equal to ""
  48. ErrMissingID = errors.New("Missing ID in migration")
  49. // ErrNoRunnedMigration is returned when any runned migration was found while
  50. // running RollbackLast
  51. ErrNoRunnedMigration = errors.New("Could not find last runned migration")
  52. )
  53. // New returns a new Gormigrate.
  54. func New(db *xorm.Engine, options *Options, migrations []*Migration) *Migrate {
  55. return &Migrate{
  56. db: db,
  57. options: options,
  58. migrations: migrations,
  59. }
  60. }
  61. // InitSchema sets a function that is run if no migration is found.
  62. // The idea is preventing to run all migrations when a new clean database
  63. // is being migrating. In this function you should create all tables and
  64. // foreign key necessary to your application.
  65. func (m *Migrate) InitSchema(initSchema InitSchemaFunc) {
  66. m.initSchema = initSchema
  67. }
  68. // Migrate executes all migrations that did not run yet.
  69. func (m *Migrate) Migrate() error {
  70. if err := m.createMigrationTableIfNotExists(); err != nil {
  71. return err
  72. }
  73. if m.initSchema != nil && m.isFirstRun() {
  74. return m.runInitSchema()
  75. }
  76. for _, migration := range m.migrations {
  77. if err := m.runMigration(migration); err != nil {
  78. return err
  79. }
  80. }
  81. return nil
  82. }
  83. // RollbackLast undo the last migration
  84. func (m *Migrate) RollbackLast() error {
  85. if len(m.migrations) == 0 {
  86. return ErrNoMigrationDefined
  87. }
  88. lastRunnedMigration, err := m.getLastRunnedMigration()
  89. if err != nil {
  90. return err
  91. }
  92. if err := m.RollbackMigration(lastRunnedMigration); err != nil {
  93. return err
  94. }
  95. return nil
  96. }
  97. func (m *Migrate) getLastRunnedMigration() (*Migration, error) {
  98. for i := len(m.migrations) - 1; i >= 0; i-- {
  99. migration := m.migrations[i]
  100. run, err := m.migrationDidRun(migration)
  101. if err != nil {
  102. return nil, err
  103. } else if run {
  104. return migration, nil
  105. }
  106. }
  107. return nil, ErrNoRunnedMigration
  108. }
  109. // RollbackMigration undo a migration.
  110. func (m *Migrate) RollbackMigration(mig *Migration) error {
  111. if mig.Rollback == nil {
  112. return ErrRollbackImpossible
  113. }
  114. if err := mig.Rollback(m.db); err != nil {
  115. return err
  116. }
  117. sql := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName)
  118. if _, err := m.db.Exec(sql, mig.ID); err != nil {
  119. return err
  120. }
  121. return nil
  122. }
  123. func (m *Migrate) runInitSchema() error {
  124. if err := m.initSchema(m.db); err != nil {
  125. return err
  126. }
  127. for _, migration := range m.migrations {
  128. if err := m.insertMigration(migration.ID); err != nil {
  129. return err
  130. }
  131. }
  132. return nil
  133. }
  134. func (m *Migrate) runMigration(migration *Migration) error {
  135. if len(migration.ID) == 0 {
  136. return ErrMissingID
  137. }
  138. run, err := m.migrationDidRun(migration)
  139. if err != nil {
  140. return err
  141. }
  142. if !run {
  143. if err := migration.Migrate(m.db); err != nil {
  144. return err
  145. }
  146. if err := m.insertMigration(migration.ID); err != nil {
  147. return err
  148. }
  149. }
  150. return nil
  151. }
  152. func (m *Migrate) createMigrationTableIfNotExists() error {
  153. exists, err := m.db.IsTableExist(m.options.TableName)
  154. if err != nil {
  155. return err
  156. }
  157. if exists {
  158. return nil
  159. }
  160. sql := fmt.Sprintf("CREATE TABLE %s (%s VARCHAR(255) PRIMARY KEY)", m.options.TableName, m.options.IDColumnName)
  161. if _, err := m.db.Exec(sql); err != nil {
  162. return err
  163. }
  164. return nil
  165. }
  166. func (m *Migrate) migrationDidRun(mig *Migration) (bool, error) {
  167. count, err := m.db.SQL(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName), mig.ID).Count()
  168. return count > 0, err
  169. }
  170. func (m *Migrate) isFirstRun() bool {
  171. row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", m.options.TableName))
  172. var count int
  173. row.Scan(&count)
  174. return count == 0
  175. }
  176. func (m *Migrate) insertMigration(id string) error {
  177. sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?)", m.options.TableName, m.options.IDColumnName)
  178. _, err := m.db.Exec(sql, id)
  179. return err
  180. }