|
|
@@ -1,72 +1,73 @@
|
|
|
package xorm
|
|
|
|
|
|
import (
|
|
|
- "errors"
|
|
|
"flag"
|
|
|
+ "fmt"
|
|
|
"os"
|
|
|
"testing"
|
|
|
|
|
|
+ _ "github.com/go-sql-driver/mysql"
|
|
|
+ _ "github.com/lib/pq"
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
|
)
|
|
|
|
|
|
var (
|
|
|
testEngine *Engine
|
|
|
- dbType string
|
|
|
- connStr string
|
|
|
-)
|
|
|
+ connString string
|
|
|
|
|
|
-func prepareSqlite3Engine() error {
|
|
|
- //if testEngine == nil {
|
|
|
- os.Remove("./test.db")
|
|
|
- var err error
|
|
|
- testEngine, err = NewEngine("sqlite3", "./test.db")
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- testEngine.ShowSQL(*showSQL)
|
|
|
- //}
|
|
|
- return nil
|
|
|
-}
|
|
|
+ db = flag.String("db", "sqlite3", "the tested database")
|
|
|
+ showSQL = flag.Bool("show_sql", true, "show generated SQLs")
|
|
|
+ ptrConnStr = flag.String("conn_str", "", "test database connection string")
|
|
|
+ mapType = flag.String("map_type", "snake", "indicate the name mapping")
|
|
|
+ cache = flag.Bool("cache", false, "if enable cache")
|
|
|
+)
|
|
|
|
|
|
-func prepareMysqlEngine() error {
|
|
|
+func createEngine(dbType, connStr string) error {
|
|
|
if testEngine == nil {
|
|
|
var err error
|
|
|
- testEngine, err = NewEngine("mysql", connStr)
|
|
|
+ testEngine, err = NewEngine(dbType, connStr)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+
|
|
|
testEngine.ShowSQL(*showSQL)
|
|
|
- _, err = testEngine.Exec("DROP DATABASE")
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
}
|
|
|
- return nil
|
|
|
-}
|
|
|
|
|
|
-func prepareEngine() error {
|
|
|
- if dbType == "sqlite" {
|
|
|
- return prepareSqlite3Engine()
|
|
|
- } else if dbType == "mysql" {
|
|
|
- return prepareMysqlEngine()
|
|
|
+ tables, err := testEngine.DBMetas()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
- return errors.New("Unknown test database driver")
|
|
|
+ var tableNames = make([]interface{}, 0, len(tables))
|
|
|
+ for _, table := range tables {
|
|
|
+ tableNames = append(tableNames, table.Name)
|
|
|
+ }
|
|
|
+ return testEngine.DropTables(tableNames...)
|
|
|
}
|
|
|
|
|
|
-var (
|
|
|
- db = flag.String("db", "sqlite", "the tested database")
|
|
|
- showSQL = flag.Bool("show_sql", true, "show generated SQLs")
|
|
|
-)
|
|
|
+func prepareEngine() error {
|
|
|
+ return createEngine(*db, connString)
|
|
|
+}
|
|
|
|
|
|
func TestMain(m *testing.M) {
|
|
|
flag.Parse()
|
|
|
|
|
|
- if db != nil {
|
|
|
- dbType = *db
|
|
|
+ if *db == "sqlite3" {
|
|
|
+ if ptrConnStr == nil {
|
|
|
+ connString = "./test.db"
|
|
|
+ } else {
|
|
|
+ connString = *ptrConnStr
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if ptrConnStr == nil {
|
|
|
+ fmt.Println("you should indicate conn string")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ connString = *ptrConnStr
|
|
|
}
|
|
|
|
|
|
if err := prepareEngine(); err != nil {
|
|
|
- panic(err)
|
|
|
+ fmt.Println(err)
|
|
|
+ return
|
|
|
}
|
|
|
os.Exit(m.Run())
|
|
|
}
|