Jelajahi Sumber

Add SetSchema for engine

fix postgres with schema

fix test

fix tablename

refactor tableName

fix schema support

improve the interface of EngineInterface

add test for resolve

add test for (id) replace

Add test for count with orderby and limit
xormplus 7 tahun lalu
induk
melakukan
b5723783b6

+ 18 - 3
dialect_postgres.go

@@ -910,7 +910,16 @@ func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
 
 
 func (db *postgres) CreateIndexSql(tableName string, index *core.Index) string {
 func (db *postgres) CreateIndexSql(tableName string, index *core.Index) string {
 	quote := db.Quote
 	quote := db.Quote
-	return fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tableName, index.Name)),
+	idxName := index.Name
+
+	tableName = strings.Replace(tableName, `"`, "", -1)
+	tableName = strings.Replace(tableName, `.`, "_", -1)
+
+	if db.Uri.Schema != "" {
+		idxName = db.Uri.Schema + "." + idxName
+	}
+
+	return fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tableName, idxName)),
 		quote(tableName), quote(strings.Join(index.Cols, quote(","))))
 		quote(tableName), quote(strings.Join(index.Cols, quote(","))))
 }
 }
 
 
@@ -918,6 +927,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
 	quote := db.Quote
 	quote := db.Quote
 	idxName := index.Name
 	idxName := index.Name
 
 
+	tableName = strings.Replace(tableName, `"`, "", -1)
+	tableName = strings.Replace(tableName, `.`, "_", -1)
+
 	if !strings.HasPrefix(idxName, "UQE_") &&
 	if !strings.HasPrefix(idxName, "UQE_") &&
 		!strings.HasPrefix(idxName, "IDX_") {
 		!strings.HasPrefix(idxName, "IDX_") {
 		if index.Type == core.UniqueType {
 		if index.Type == core.UniqueType {
@@ -926,6 +938,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
 			idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
 			idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
 		}
 		}
 	}
 	}
+	if db.Uri.Schema != "" {
+		idxName = db.Uri.Schema + "." + idxName
+	}
 	return fmt.Sprintf("DROP INDEX %v", quote(idxName))
 	return fmt.Sprintf("DROP INDEX %v", quote(idxName))
 }
 }
 
 
@@ -966,7 +981,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
 	var f string
 	var f string
 	if len(db.Schema) != 0 {
 	if len(db.Schema) != 0 {
 		args = append(args, db.Schema)
 		args = append(args, db.Schema)
-		f = "AND s.table_schema = $2"
+		f = " AND s.table_schema = $2"
 	}
 	}
 	s = fmt.Sprintf(s, f)
 	s = fmt.Sprintf(s, f)
 
 
@@ -1091,11 +1106,11 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
 func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
 func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
 	args := []interface{}{tableName}
 	args := []interface{}{tableName}
 	s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
 	s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
-	db.LogSQL(s, args)
 	if len(db.Schema) != 0 {
 	if len(db.Schema) != 0 {
 		args = append(args, db.Schema)
 		args = append(args, db.Schema)
 		s = s + " AND schemaname=$2"
 		s = s + " AND schemaname=$2"
 	}
 	}
+	db.LogSQL(s, args)
 
 
 	rows, err := db.DB().Query(s, args...)
 	rows, err := db.DB().Query(s, args...)
 	if err != nil {
 	if err != nil {

+ 20 - 67
engine.go

@@ -545,46 +545,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
 	return nil
 	return nil
 }
 }
 
 
-func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
-	v := rValue(beanOrTableName)
-	if v.Type().Kind() == reflect.String {
-		return beanOrTableName.(string), nil
-	} else if v.Type().Kind() == reflect.Struct {
-		return engine.tbName(v), nil
-	}
-	return "", errors.New("bean should be a struct or struct's point")
-}
-
-func (engine *Engine) tbSchemaName(v string) string {
-	// Add schema name as prefix of table name.
-	// Only for postgres database.
-	if engine.dialect.DBType() == core.POSTGRES &&
-		engine.dialect.URI().Schema != "" &&
-		engine.dialect.URI().Schema != postgresPublicSchema &&
-		strings.Index(v, ".") == -1 {
-		return engine.dialect.URI().Schema + "." + v
-	}
-	return v
-}
-
-func (engine *Engine) tbName(v reflect.Value) string {
-	if tb, ok := v.Interface().(TableName); ok {
-		return engine.tbSchemaName(tb.TableName())
-
-	}
-
-	if v.Type().Kind() == reflect.Ptr {
-		if tb, ok := reflect.Indirect(v).Interface().(TableName); ok {
-			return engine.tbSchemaName(tb.TableName())
-		}
-	} else if v.CanAddr() {
-		if tb, ok := v.Addr().Interface().(TableName); ok {
-			return engine.tbSchemaName(tb.TableName())
-		}
-	}
-	return engine.tbSchemaName(engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name()))
-}
-
 // Cascade use cascade or not
 // Cascade use cascade or not
 func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
 func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
 	session := engine.NewSession()
 	session := engine.NewSession()
@@ -868,7 +828,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
 	if err != nil {
 	if err != nil {
 		engine.logger.Error(err)
 		engine.logger.Error(err)
 	}
 	}
-	return &Table{tb, engine.tbName(v)}
+	return &Table{tb, engine.TableName(bean)}
 }
 }
 
 
 func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
 func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@@ -904,20 +864,8 @@ var (
 func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
 func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
 	t := v.Type()
 	t := v.Type()
 	table := engine.newTable()
 	table := engine.newTable()
-	if tb, ok := v.Interface().(TableName); ok {
-		table.Name = tb.TableName()
-	} else {
-		if v.CanAddr() {
-			if tb, ok = v.Addr().Interface().(TableName); ok {
-				table.Name = tb.TableName()
-			}
-		}
-		if table.Name == "" {
-			table.Name = engine.TableMapper.Obj2Table(t.Name())
-		}
-	}
-
 	table.Type = t
 	table.Type = t
+	table.Name = engine.tbNameForMap(v)
 
 
 	var idFieldColName string
 	var idFieldColName string
 	var hasCacheTag, hasNoCacheTag bool
 	var hasCacheTag, hasNoCacheTag bool
@@ -1195,7 +1143,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
 	if t.Kind() != reflect.Struct {
 	if t.Kind() != reflect.Struct {
 		return errors.New("error params")
 		return errors.New("error params")
 	}
 	}
-	tableName := engine.tbName(v)
+	tableName := engine.TableName(bean)
 	table, err := engine.autoMapType(v)
 	table, err := engine.autoMapType(v)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -1219,7 +1167,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
 		if t.Kind() != reflect.Struct {
 		if t.Kind() != reflect.Struct {
 			return errors.New("error params")
 			return errors.New("error params")
 		}
 		}
-		tableName := engine.tbName(v)
+		tableName := engine.TableName(bean)
 		table, err := engine.autoMapType(v)
 		table, err := engine.autoMapType(v)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -1246,13 +1194,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 
 
 	for _, bean := range beans {
 	for _, bean := range beans {
 		v := rValue(bean)
 		v := rValue(bean)
-		tableName := engine.tbName(v)
+		tableNameNoSchema := engine.tbNameNoSchema(v.Interface())
 		table, err := engine.autoMapType(v)
 		table, err := engine.autoMapType(v)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 
 
-		isExist, err := session.Table(bean).isTableExist(tableName)
+		isExist, err := session.Table(bean).isTableExist(tableNameNoSchema)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -1278,12 +1226,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 			}
 			}
 		} else {
 		} else {
 			for _, col := range table.Columns() {
 			for _, col := range table.Columns() {
-				isExist, err := engine.dialect.IsColumnExist(tableName, col.Name)
+				isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name)
 				if err != nil {
 				if err != nil {
 					return err
 					return err
 				}
 				}
 				if !isExist {
 				if !isExist {
-					if err := session.statement.setRefValue(v); err != nil {
+					if err := session.statement.setRefBean(bean); err != nil {
 						return err
 						return err
 					}
 					}
 					err = session.addColumn(col.Name)
 					err = session.addColumn(col.Name)
@@ -1294,35 +1242,35 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 			}
 			}
 
 
 			for name, index := range table.Indexes {
 			for name, index := range table.Indexes {
-				if err := session.statement.setRefValue(v); err != nil {
+				if err := session.statement.setRefBean(bean); err != nil {
 					return err
 					return err
 				}
 				}
 				if index.Type == core.UniqueType {
 				if index.Type == core.UniqueType {
-					isExist, err := session.isIndexExist2(tableName, index.Cols, true)
+					isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
 					if err != nil {
 					if err != nil {
 						return err
 						return err
 					}
 					}
 					if !isExist {
 					if !isExist {
-						if err := session.statement.setRefValue(v); err != nil {
+						if err := session.statement.setRefBean(bean); err != nil {
 							return err
 							return err
 						}
 						}
 
 
-						err = session.addUnique(tableName, name)
+						err = session.addUnique(tableNameNoSchema, name)
 						if err != nil {
 						if err != nil {
 							return err
 							return err
 						}
 						}
 					}
 					}
 				} else if index.Type == core.IndexType {
 				} else if index.Type == core.IndexType {
-					isExist, err := session.isIndexExist2(tableName, index.Cols, false)
+					isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
 					if err != nil {
 					if err != nil {
 						return err
 						return err
 					}
 					}
 					if !isExist {
 					if !isExist {
-						if err := session.statement.setRefValue(v); err != nil {
+						if err := session.statement.setRefBean(bean); err != nil {
 							return err
 							return err
 						}
 						}
 
 
-						err = session.addIndex(tableName, name)
+						err = session.addIndex(tableNameNoSchema, name)
 						if err != nil {
 						if err != nil {
 							return err
 							return err
 						}
 						}
@@ -1661,6 +1609,11 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) {
 	engine.DatabaseTZ = tz
 	engine.DatabaseTZ = tz
 }
 }
 
 
+// SetSchema sets the schema of database
+func (engine *Engine) SetSchema(schema string) {
+	engine.dialect.URI().Schema = schema
+}
+
 // Unscoped always disable struct tag "deleted"
 // Unscoped always disable struct tag "deleted"
 func (engine *Engine) Unscoped() *Session {
 func (engine *Engine) Unscoped() *Session {
 	session := engine.NewSession()
 	session := engine.NewSession()

+ 109 - 0
engine_table.go

@@ -0,0 +1,109 @@
+// Copyright 2018 The Xorm Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xorm
+
+import (
+	"fmt"
+	"reflect"
+	"strings"
+
+	"github.com/xormplus/core"
+)
+
+// TableNameWithSchema will automatically add schema prefix on table name
+func (engine *Engine) tbNameWithSchema(v string) string {
+	// Add schema name as prefix of table name.
+	// Only for postgres database.
+	if engine.dialect.DBType() == core.POSTGRES &&
+		engine.dialect.URI().Schema != "" &&
+		engine.dialect.URI().Schema != postgresPublicSchema &&
+		strings.Index(v, ".") == -1 {
+		return engine.dialect.URI().Schema + "." + v
+	}
+	return v
+}
+
+// TableName returns table name with schema prefix if has
+func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
+	tbName := engine.tbNameNoSchema(bean)
+	if len(includeSchema) > 0 && includeSchema[0] {
+		tbName = engine.tbNameWithSchema(tbName)
+	}
+
+	return tbName
+}
+
+// tbName get some table's table name
+func (session *Session) tbNameNoSchema(table *core.Table) string {
+	if len(session.statement.AltTableName) > 0 {
+		return session.statement.AltTableName
+	}
+
+	return table.Name
+}
+
+func (engine *Engine) tbNameForMap(v reflect.Value) string {
+	t := v.Type()
+	if tb, ok := v.Interface().(TableName); ok {
+		return tb.TableName()
+	}
+	if v.CanAddr() {
+		if tb, ok := v.Addr().Interface().(TableName); ok {
+			return tb.TableName()
+		}
+	}
+	return engine.TableMapper.Obj2Table(t.Name())
+}
+
+func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
+	switch tablename.(type) {
+	case []string:
+		t := tablename.([]string)
+		if len(t) > 1 {
+			return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
+		} else if len(t) == 1 {
+			return engine.Quote(t[0])
+		}
+	case []interface{}:
+		t := tablename.([]interface{})
+		l := len(t)
+		var table string
+		if l > 0 {
+			f := t[0]
+			switch f.(type) {
+			case string:
+				table = f.(string)
+			case TableName:
+				table = f.(TableName).TableName()
+			default:
+				v := rValue(f)
+				t := v.Type()
+				if t.Kind() == reflect.Struct {
+					table = engine.tbNameForMap(v)
+				} else {
+					table = engine.Quote(fmt.Sprintf("%v", f))
+				}
+			}
+		}
+		if l > 1 {
+			return fmt.Sprintf("%v AS %v", engine.Quote(table),
+				engine.Quote(fmt.Sprintf("%v", t[1])))
+		} else if l == 1 {
+			return engine.Quote(table)
+		}
+	case TableName:
+		return tablename.(TableName).TableName()
+	case string:
+		return tablename.(string)
+	default:
+		v := rValue(tablename)
+		t := v.Type()
+		if t.Kind() == reflect.Struct {
+			return engine.tbNameForMap(v)
+		}
+		return engine.Quote(fmt.Sprintf("%v", tablename))
+	}
+	return ""
+}

+ 2 - 0
interface.go

@@ -87,6 +87,7 @@ type EngineInterface interface {
 	SetDefaultCacher(core.Cacher)
 	SetDefaultCacher(core.Cacher)
 	SetLogLevel(core.LogLevel)
 	SetLogLevel(core.LogLevel)
 	SetMapper(core.IMapper)
 	SetMapper(core.IMapper)
+	SetSchema(string)
 	SetTZDatabase(tz *time.Location)
 	SetTZDatabase(tz *time.Location)
 	SetTZLocation(tz *time.Location)
 	SetTZLocation(tz *time.Location)
 	ShowSQL(show ...bool)
 	ShowSQL(show ...bool)
@@ -94,6 +95,7 @@ type EngineInterface interface {
 	Sync2(...interface{}) error
 	Sync2(...interface{}) error
 	StoreEngine(storeEngine string) *Session
 	StoreEngine(storeEngine string) *Session
 	TableInfo(bean interface{}) *Table
 	TableInfo(bean interface{}) *Table
+	TableName(interface{}, ...bool) string
 	UnMapType(reflect.Type)
 	UnMapType(reflect.Type)
 }
 }
 
 

+ 3 - 3
rows.go

@@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
 	var args []interface{}
 	var args []interface{}
 	var err error
 	var err error
 
 
-	if err = rows.session.statement.setRefValue(rValue(bean)); err != nil {
+	if err = rows.session.statement.setRefBean(bean); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
@@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error {
 		return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
 		return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
 	}
 	}
 
 
-	dataStruct := rValue(bean)
-	if err := rows.session.statement.setRefValue(dataStruct); err != nil {
+	if err := rows.session.statement.setRefBean(bean); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error {
 		return err
 		return err
 	}
 	}
 
 
+	dataStruct := rValue(bean)
 	_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
 	_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
 	if err != nil {
 	if err != nil {
 		return err
 		return err

+ 2 - 1
rows_test.go

@@ -54,7 +54,8 @@ func TestRows(t *testing.T) {
 	}
 	}
 	assert.EqualValues(t, 1, cnt)
 	assert.EqualValues(t, 1, cnt)
 
 
-	rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows))
+	var tbName = testEngine.Quote(testEngine.TableName(user, true))
+	rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows))
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	defer rows2.Close()
 	defer rows2.Close()
 
 

+ 0 - 9
session.go

@@ -834,15 +834,6 @@ func (session *Session) LastSQL() (string, []interface{}) {
 	return session.lastSQL, session.lastSQLArgs
 	return session.lastSQL, session.lastSQLArgs
 }
 }
 
 
-// tbName get some table's table name
-func (session *Session) tbNameNoSchema(table *core.Table) string {
-	if len(session.statement.AltTableName) > 0 {
-		return session.statement.AltTableName
-	}
-
-	return table.Name
-}
-
 // Unscoped always disable struct tag "deleted"
 // Unscoped always disable struct tag "deleted"
 func (session *Session) Unscoped() *Session {
 func (session *Session) Unscoped() *Session {
 	session.statement.Unscoped()
 	session.statement.Unscoped()

+ 22 - 90
session_cond_test.go

@@ -122,18 +122,11 @@ func TestIn(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.EqualValues(t, 3, cnt)
 	assert.EqualValues(t, 3, cnt)
 
 
+	department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
 	var usrs []Userinfo
 	var usrs []Userinfo
-	err = testEngine.Limit(3).Find(&usrs)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-
-	if len(usrs) != 3 {
-		err = errors.New("there are not 3 records")
-		t.Error(err)
-		panic(err)
-	}
+	err = testEngine.Where(department+" = ?", "dev").Limit(3).Find(&usrs)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 3, len(usrs))
 
 
 	var ids []int64
 	var ids []int64
 	var idsStr string
 	var idsStr string
@@ -145,35 +138,20 @@ func TestIn(t *testing.T) {
 
 
 	users := make([]Userinfo, 0)
 	users := make([]Userinfo, 0)
 	err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users)
 	err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 	fmt.Println(users)
-	if len(users) != 3 {
-		err = errors.New("in uses should be " + idsStr + " total 3")
-		t.Error(err)
-		panic(err)
-	}
+	assert.EqualValues(t, 3, len(users))
 
 
 	users = make([]Userinfo, 0)
 	users = make([]Userinfo, 0)
 	err = testEngine.In("(id)", ids).Find(&users)
 	err = testEngine.In("(id)", ids).Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 	fmt.Println(users)
-	if len(users) != 3 {
-		err = errors.New("in uses should be " + idsStr + " total 3")
-		t.Error(err)
-		panic(err)
-	}
+	assert.EqualValues(t, 3, len(users))
 
 
 	for _, user := range users {
 	for _, user := range users {
 		if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
 		if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
 			err = errors.New("in uses should be " + idsStr + " total 3")
 			err = errors.New("in uses should be " + idsStr + " total 3")
-			t.Error(err)
-			panic(err)
+			assert.NoError(t, err)
 		}
 		}
 	}
 	}
 
 
@@ -183,87 +161,41 @@ func TestIn(t *testing.T) {
 		idsInterface = append(idsInterface, id)
 		idsInterface = append(idsInterface, id)
 	}
 	}
 
 
-	department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
 	err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users)
 	err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 	fmt.Println(users)
-
-	if len(users) != 3 {
-		err = errors.New("in uses should be " + idsStr + " total 3")
-		t.Error(err)
-		panic(err)
-	}
+	assert.EqualValues(t, 3, len(users))
 
 
 	for _, user := range users {
 	for _, user := range users {
 		if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
 		if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
 			err = errors.New("in uses should be " + idsStr + " total 3")
 			err = errors.New("in uses should be " + idsStr + " total 3")
-			t.Error(err)
-			panic(err)
+			assert.NoError(t, err)
 		}
 		}
 	}
 	}
 
 
 	dev := testEngine.GetColumnMapper().Obj2Table("Dev")
 	dev := testEngine.GetColumnMapper().Obj2Table("Dev")
 
 
 	err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users)
 	err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users)
-
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 	fmt.Println(users)
 
 
 	cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"})
 	cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if cnt != 1 {
-		err = errors.New("update records not 1")
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 
 
 	user := new(Userinfo)
 	user := new(Userinfo)
 	has, err := testEngine.ID(ids[0]).Get(user)
 	has, err := testEngine.ID(ids[0]).Get(user)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if !has {
-		err = errors.New("get record not 1")
-		t.Error(err)
-		panic(err)
-	}
-	if user.Departname != "dev-" {
-		err = errors.New("update not success")
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, "dev-", user.Departname)
 
 
 	cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"})
 	cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if cnt != 1 {
-		err = errors.New("update records not 1")
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 
 
 	cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{})
 	cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if cnt != 1 {
-		err = errors.New("deleted records not 1")
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 }
 }
 
 
 func TestFindAndCount(t *testing.T) {
 func TestFindAndCount(t *testing.T) {

+ 1 - 1
session_delete.go

@@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 		defer session.Close()
 		defer session.Close()
 	}
 	}
 
 
-	if err := session.statement.setRefValue(rValue(bean)); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
 
 

+ 1 - 1
session_exist.go

@@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
 			}
 			}
 
 
 			if beanValue.Elem().Kind() == reflect.Struct {
 			if beanValue.Elem().Kind() == reflect.Struct {
-				if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
+				if err := session.statement.setRefBean(bean[0]); err != nil {
 					return false, err
 					return false, err
 				}
 				}
 			}
 			}

+ 2 - 2
session_exist_test.go

@@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.False(t, has)
 	assert.False(t, has)
 
 
-	has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
+	has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.True(t, has)
 	assert.True(t, has)
 
 
-	has, err = testEngine.SQL("select * from record_exist where name = ?", "test2").Exist()
+	has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.False(t, has)
 	assert.False(t, has)
 
 

+ 1 - 1
session_find.go

@@ -182,7 +182,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 		}
 		}
 
 
 		args = append(session.statement.joinArgs, condArgs...)
 		args = append(session.statement.joinArgs, condArgs...)
-		sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true)
+		sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}

+ 13 - 42
session_find_test.go

@@ -96,21 +96,15 @@ func TestFind(t *testing.T) {
 	users := make([]Userinfo, 0)
 	users := make([]Userinfo, 0)
 
 
 	err := testEngine.Find(&users)
 	err := testEngine.Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	for _, user := range users {
 	for _, user := range users {
 		fmt.Println(user)
 		fmt.Println(user)
 	}
 	}
 
 
 	users2 := make([]Userinfo, 0)
 	users2 := make([]Userinfo, 0)
-	userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo")
-	err = testEngine.SQL("select * from " + testEngine.Quote(userinfo)).Find(&users2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true))
+	err = testEngine.SQL("select * from " + tbName).Find(&users2)
+	assert.NoError(t, err)
 }
 }
 
 
 func TestFind2(t *testing.T) {
 func TestFind2(t *testing.T) {
@@ -238,14 +232,8 @@ func TestDistinct(t *testing.T) {
 	users := make([]Userinfo, 0)
 	users := make([]Userinfo, 0)
 	departname := testEngine.GetTableMapper().Obj2Table("Departname")
 	departname := testEngine.GetTableMapper().Obj2Table("Departname")
 	err = testEngine.Distinct(departname).Find(&users)
 	err = testEngine.Distinct(departname).Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if len(users) != 1 {
-		t.Error(err)
-		panic(errors.New("should be one record"))
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, len(users))
 
 
 	fmt.Println(users)
 	fmt.Println(users)
 
 
@@ -255,11 +243,9 @@ func TestDistinct(t *testing.T) {
 
 
 	users2 := make([]Depart, 0)
 	users2 := make([]Depart, 0)
 	err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2)
 	err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	if len(users2) != 1 {
 	if len(users2) != 1 {
+		fmt.Println(len(users2))
 		t.Error(err)
 		t.Error(err)
 		panic(errors.New("should be one record"))
 		panic(errors.New("should be one record"))
 	}
 	}
@@ -272,18 +258,12 @@ func TestOrder(t *testing.T) {
 
 
 	users := make([]Userinfo, 0)
 	users := make([]Userinfo, 0)
 	err := testEngine.OrderBy("id desc").Find(&users)
 	err := testEngine.OrderBy("id desc").Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 	fmt.Println(users)
 
 
 	users2 := make([]Userinfo, 0)
 	users2 := make([]Userinfo, 0)
 	err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
 	err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users2)
 	fmt.Println(users2)
 }
 }
 
 
@@ -293,10 +273,7 @@ func TestHaving(t *testing.T) {
 
 
 	users := make([]Userinfo, 0)
 	users := make([]Userinfo, 0)
 	err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users)
 	err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 	fmt.Println(users)
 
 
 	/*users = make([]Userinfo, 0)
 	/*users = make([]Userinfo, 0)
@@ -324,18 +301,12 @@ func TestOrderSameMapper(t *testing.T) {
 
 
 	users := make([]Userinfo, 0)
 	users := make([]Userinfo, 0)
 	err := testEngine.OrderBy("(id) desc").Find(&users)
 	err := testEngine.OrderBy("(id) desc").Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 	fmt.Println(users)
 
 
 	users2 := make([]Userinfo, 0)
 	users2 := make([]Userinfo, 0)
 	err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2)
 	err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users2)
 	fmt.Println(users2)
 }
 }
 
 

+ 1 - 1
session_get.go

@@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
 	}
 	}
 
 
 	if beanValue.Elem().Kind() == reflect.Struct {
 	if beanValue.Elem().Kind() == reflect.Struct {
-		if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
+		if err := session.statement.setRefBean(bean); err != nil {
 			return false, err
 			return false, err
 		}
 		}
 	}
 	}

+ 32 - 1
session_get_test.go

@@ -84,11 +84,16 @@ func TestGetVar(t *testing.T) {
 	assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
 	assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
 
 
 	var money2 float64
 	var money2 float64
-	has, err = testEngine.SQL("SELECT money FROM get_var LIMIT 1").Get(&money2)
+	has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.Equal(t, true, has)
 	assert.Equal(t, true, has)
 	assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))
 	assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))
 
 
+	var money3 float64
+	has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " WHERE money > 20").Get(&money3)
+	assert.NoError(t, err)
+	assert.Equal(t, false, has)
+
 	var valuesString = make(map[string]string)
 	var valuesString = make(map[string]string)
 	has, err = testEngine.Table("get_var").Get(&valuesString)
 	has, err = testEngine.Table("get_var").Get(&valuesString)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
@@ -279,3 +284,29 @@ func TestGetActionMapping(t *testing.T) {
 		ID(1).Get(&valuesSlice)
 		ID(1).Get(&valuesSlice)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 }
 }
+
+func TestGetStructId(t *testing.T) {
+	type TestGetStruct struct {
+		Id int64
+	}
+
+	assert.NoError(t, prepareEngine())
+	assertSync(t, new(TestGetStruct))
+
+	_, err := testEngine.Insert(&TestGetStruct{})
+	assert.NoError(t, err)
+	_, err = testEngine.Insert(&TestGetStruct{})
+	assert.NoError(t, err)
+
+	type maxidst struct {
+		Id int64
+	}
+
+	//var id int64
+	var maxid maxidst
+	sql := "select max(id) as id from " + testEngine.TableName(&TestGetStruct{}, true)
+	has, err := testEngine.SQL(sql).Get(&maxid)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 2, maxid.Id)
+}

+ 1 - 1
session_insert.go

@@ -298,7 +298,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
 }
 }
 
 
 func (session *Session) innerInsert(bean interface{}) (int64, error) {
 func (session *Session) innerInsert(bean interface{}) (int64, error) {
-	if err := session.statement.setRefValue(rValue(bean)); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
 	if len(session.statement.TableName()) <= 0 {
 	if len(session.statement.TableName()) <= 0 {

+ 2 - 1
session_insert_test.go

@@ -732,8 +732,9 @@ func (MyUserinfo2) TableName() string {
 func TestInsertMulti4(t *testing.T) {
 func TestInsertMulti4(t *testing.T) {
 	assert.NoError(t, prepareEngine())
 	assert.NoError(t, prepareEngine())
 
 
-	testEngine.ShowSQL(true)
+	testEngine.ShowSQL(false)
 	assertSync(t, new(MyUserinfo2))
 	assertSync(t, new(MyUserinfo2))
+	testEngine.ShowSQL(true)
 
 
 	users := []MyUserinfo2{
 	users := []MyUserinfo2{
 		{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},
 		{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},

+ 19 - 4
session_pk_test.go

@@ -1118,13 +1118,28 @@ func TestCompositePK(t *testing.T) {
 	}
 	}
 
 
 	assert.NoError(t, prepareEngine())
 	assert.NoError(t, prepareEngine())
-	assertSync(t, new(TaskSolution))
 
 
+	tables1, err := testEngine.DBMetas()
+	assert.NoError(t, err)
+
+	assertSync(t, new(TaskSolution))
 	assert.NoError(t, testEngine.Sync2(new(TaskSolution)))
 	assert.NoError(t, testEngine.Sync2(new(TaskSolution)))
-	tables, err := testEngine.DBMetas()
+
+	tables2, err := testEngine.DBMetas()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
-	assert.EqualValues(t, 1, len(tables))
-	pkCols := tables[0].PKColumns()
+	assert.EqualValues(t, 1+len(tables1), len(tables2))
+
+	var table *core.Table
+	for _, t := range tables2 {
+		if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") {
+			table = t
+			break
+		}
+	}
+
+	assert.NotEqual(t, nil, table)
+
+	pkCols := table.PKColumns()
 	assert.EqualValues(t, 2, len(pkCols))
 	assert.EqualValues(t, 2, len(pkCols))
 	assert.EqualValues(t, "uid", pkCols[0].Name)
 	assert.EqualValues(t, "uid", pkCols[0].Name)
 	assert.EqualValues(t, "tid", pkCols[1].Name)
 	assert.EqualValues(t, "tid", pkCols[1].Name)

+ 1 - 1
session_query.go

@@ -90,7 +90,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
 	}
 	}
 
 
 	args := append(session.statement.joinArgs, condArgs...)
 	args := append(session.statement.joinArgs, condArgs...)
-	sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true)
+	sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true)
 	if err != nil {
 	if err != nil {
 		return "", nil, err
 		return "", nil, err
 	}
 	}

+ 5 - 5
session_query_test.go

@@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) {
 	_, err := testEngine.InsertOne(data)
 	_, err := testEngine.InsertOne(data)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	records, err := testEngine.QueryString("select * from get_var2")
+	records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var2", true))
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 5, len(records[0]))
 	assert.Equal(t, 5, len(records[0]))
@@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) {
 	_, err := testEngine.Insert(data)
 	_, err := testEngine.Insert(data)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	records, err := testEngine.QueryString("select * from get_var3")
+	records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var3", true))
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 2, len(records[0]))
 	assert.Equal(t, 2, len(records[0]))
@@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) {
 	_, err := testEngine.InsertOne(data)
 	_, err := testEngine.InsertOne(data)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	records, err := testEngine.QueryInterface("select * from get_var_interface")
+	records, err := testEngine.QueryInterface("select * from " + testEngine.TableName("get_var_interface", true))
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 5, len(records[0]))
 	assert.Equal(t, 5, len(records[0]))
@@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assertResult(t, results)
 	assertResult(t, results)
 
 
-	results, err = testEngine.SQL("select * from query_no_params").Query()
+	results, err = testEngine.SQL("select * from " + testEngine.TableName("query_no_params", true)).Query()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assertResult(t, results)
 	assertResult(t, results)
 }
 }
@@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) {
 		assert.EqualValues(t, 3000, money)
 		assert.EqualValues(t, 3000, money)
 	}
 	}
 
 
-	results, err := testEngine.Query(builder.Select("*").From("query_with_builder"))
+	results, err := testEngine.Query(builder.Select("*").From(testEngine.TableName("query_with_builder", true)))
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assertResult(t, results)
 	assertResult(t, results)
 }
 }

+ 2 - 2
session_raw_test.go

@@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) {
 
 
 	assert.NoError(t, testEngine.Sync2(new(UserinfoQuery)))
 	assert.NoError(t, testEngine.Sync2(new(UserinfoQuery)))
 
 
-	res, err := testEngine.Exec("INSERT INTO `userinfo_query` (uid, name) VALUES (?, ?)", 1, "user")
+	res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (uid, name) VALUES (?, ?)", 1, "user")
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	cnt, err := res.RowsAffected()
 	cnt, err := res.RowsAffected()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, cnt)
 	assert.EqualValues(t, 1, cnt)
 
 
-	results, err := testEngine.Query("select * from userinfo_query")
+	results, err := testEngine.Query("select * from " + testEngine.TableName("userinfo_query", true))
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, len(results))
 	assert.EqualValues(t, 1, len(results))
 	id, err := strconv.Atoi(string(results[0]["uid"]))
 	id, err := strconv.Atoi(string(results[0]["uid"]))

+ 30 - 51
session_schema.go

@@ -6,9 +6,7 @@ package xorm
 
 
 import (
 import (
 	"database/sql"
 	"database/sql"
-	"errors"
 	"fmt"
 	"fmt"
-	"reflect"
 	"strings"
 	"strings"
 
 
 	"github.com/xormplus/core"
 	"github.com/xormplus/core"
@@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error {
 }
 }
 
 
 func (session *Session) createTable(bean interface{}) error {
 func (session *Session) createTable(bean interface{}) error {
-	v := rValue(bean)
-	if err := session.statement.setRefValue(v); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error {
 }
 }
 
 
 func (session *Session) createIndexes(bean interface{}) error {
 func (session *Session) createIndexes(bean interface{}) error {
-	v := rValue(bean)
-	if err := session.statement.setRefValue(v); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error {
 }
 }
 
 
 func (session *Session) createUniques(bean interface{}) error {
 func (session *Session) createUniques(bean interface{}) error {
-	v := rValue(bean)
-	if err := session.statement.setRefValue(v); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error {
 }
 }
 
 
 func (session *Session) dropIndexes(bean interface{}) error {
 func (session *Session) dropIndexes(bean interface{}) error {
-	v := rValue(bean)
-	if err := session.statement.setRefValue(v); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
 }
 }
 
 
 func (session *Session) dropTable(beanOrTableName interface{}) error {
 func (session *Session) dropTable(beanOrTableName interface{}) error {
-	tableName, err := session.engine.tableName(beanOrTableName)
-	if err != nil {
-		return err
-	}
-
+	tableName := session.engine.tbNameNoSchema(beanOrTableName)
 	var needDrop = true
 	var needDrop = true
 	if !session.engine.dialect.SupportDropIfExists() {
 	if !session.engine.dialect.SupportDropIfExists() {
 		sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
 		sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
@@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
 	}
 	}
 
 
 	if needDrop {
 	if needDrop {
-		sqlStr := session.engine.Dialect().DropTableSql(tableName)
-		_, err = session.exec(sqlStr)
+		sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
+		_, err := session.exec(sqlStr)
 		return err
 		return err
 	}
 	}
 	return nil
 	return nil
@@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
 		defer session.Close()
 		defer session.Close()
 	}
 	}
 
 
-	tableName, err := session.engine.tableName(beanOrTableName)
-	if err != nil {
-		return false, err
-	}
+	tableName := session.engine.tbNameNoSchema(beanOrTableName)
 
 
 	return session.isTableExist(tableName)
 	return session.isTableExist(tableName)
 }
 }
@@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) {
 
 
 // IsTableEmpty if table have any records
 // IsTableEmpty if table have any records
 func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
 func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
-	v := rValue(bean)
-	t := v.Type()
-
-	if t.Kind() == reflect.String {
-		if session.isAutoClose {
-			defer session.Close()
-		}
-		return session.isTableEmpty(bean.(string))
-	} else if t.Kind() == reflect.Struct {
-		rows, err := session.Count(bean)
-		return rows == 0, err
+	if session.isAutoClose {
+		defer session.Close()
 	}
 	}
-	return false, errors.New("bean should be a struct or struct's point")
+	return session.isTableEmpty(session.engine.tbNameNoSchema(bean))
 }
 }
 
 
 func (session *Session) isTableEmpty(tableName string) (bool, error) {
 func (session *Session) isTableEmpty(tableName string) (bool, error) {
 	var total int64
 	var total int64
-	sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
+	sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
 	err := session.queryRow(sqlStr).Scan(&total)
 	err := session.queryRow(sqlStr).Scan(&total)
 	if err != nil {
 	if err != nil {
 		if err == sql.ErrNoRows {
 		if err == sql.ErrNoRows {
@@ -270,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
 			return err
 			return err
 		}
 		}
 		structTables = append(structTables, table)
 		structTables = append(structTables, table)
-		var tbName = session.tbNameNoSchema(table)
+		tbName := session.tbNameNoSchema(table)
+		tbNameWithSchema := engine.TableName(tbName, true)
 
 
 		var oriTable *core.Table
 		var oriTable *core.Table
 		for _, tb := range tables {
 		for _, tb := range tables {
@@ -315,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error {
 							if engine.dialect.DBType() == core.MYSQL ||
 							if engine.dialect.DBType() == core.MYSQL ||
 								engine.dialect.DBType() == core.POSTGRES {
 								engine.dialect.DBType() == core.POSTGRES {
 								engine.logger.Infof("Table %s column %s change type from %s to %s\n",
 								engine.logger.Infof("Table %s column %s change type from %s to %s\n",
-									tbName, col.Name, curType, expectedType)
-								_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
+									tbNameWithSchema, col.Name, curType, expectedType)
+								_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
 							} else {
 							} else {
 								engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
 								engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
-									tbName, col.Name, curType, expectedType)
+									tbNameWithSchema, col.Name, curType, expectedType)
 							}
 							}
 						} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
 						} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
 							if engine.dialect.DBType() == core.MYSQL {
 							if engine.dialect.DBType() == core.MYSQL {
 								if oriCol.Length < col.Length {
 								if oriCol.Length < col.Length {
 									engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
 									engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
-										tbName, col.Name, oriCol.Length, col.Length)
-									_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
+										tbNameWithSchema, col.Name, oriCol.Length, col.Length)
+									_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
 								}
 								}
 							}
 							}
 						} else {
 						} else {
 							if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
 							if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
 								engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
 								engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
-									tbName, col.Name, curType, expectedType)
+									tbNameWithSchema, col.Name, curType, expectedType)
 							}
 							}
 						}
 						}
 					} else if expectedType == core.Varchar {
 					} else if expectedType == core.Varchar {
 						if engine.dialect.DBType() == core.MYSQL {
 						if engine.dialect.DBType() == core.MYSQL {
 							if oriCol.Length < col.Length {
 							if oriCol.Length < col.Length {
 								engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
 								engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
-									tbName, col.Name, oriCol.Length, col.Length)
-								_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
+									tbNameWithSchema, col.Name, oriCol.Length, col.Length)
+								_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
 							}
 							}
 						}
 						}
 					}
 					}
@@ -354,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
 					}
 					}
 				} else {
 				} else {
 					session.statement.RefTable = table
 					session.statement.RefTable = table
-					session.statement.tableName = tbName
+					session.statement.tableName = tbNameWithSchema
 					err = session.addColumn(col.Name)
 					err = session.addColumn(col.Name)
 				}
 				}
 				if err != nil {
 				if err != nil {
@@ -378,7 +357,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
 				if oriIndex != nil {
 				if oriIndex != nil {
 					if oriIndex.Type != index.Type {
 					if oriIndex.Type != index.Type {
 
 
-						sql := engine.dialect.DropIndexSql(tbName, oriIndex)
+						sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
 
 
 						if sql != "" {
 						if sql != "" {
 							_, err = session.exec(sql)
 							_, err = session.exec(sql)
@@ -398,7 +377,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
 			for name2, index2 := range oriTable.Indexes {
 			for name2, index2 := range oriTable.Indexes {
 				if _, ok := foundIndexNames[name2]; !ok {
 				if _, ok := foundIndexNames[name2]; !ok {
 
 
-					sql := engine.dialect.DropIndexSql(tbName, index2)
+					sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
 
 
 					if sql != "" {
 					if sql != "" {
 						_, err = session.exec(sql)
 						_, err = session.exec(sql)
@@ -412,12 +391,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
 			for name, index := range addedNames {
 			for name, index := range addedNames {
 				if index.Type == core.UniqueType {
 				if index.Type == core.UniqueType {
 					session.statement.RefTable = table
 					session.statement.RefTable = table
-					session.statement.tableName = tbName
-					err = session.addUnique(tbName, name)
+					session.statement.tableName = tbNameWithSchema
+					err = session.addUnique(tbNameWithSchema, name)
 				} else if index.Type == core.IndexType {
 				} else if index.Type == core.IndexType {
 					session.statement.RefTable = table
 					session.statement.RefTable = table
-					session.statement.tableName = tbName
-					err = session.addIndex(tbName, name)
+					session.statement.tableName = tbNameWithSchema
+					err = session.addIndex(tbNameWithSchema, name)
 				}
 				}
 				if err != nil {
 				if err != nil {
 					return err
 					return err
@@ -442,7 +421,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
 
 
 		for _, colName := range table.ColumnsSeq() {
 		for _, colName := range table.ColumnsSeq() {
 			if oriTable.GetColumn(colName) == nil {
 			if oriTable.GetColumn(colName) == nil {
-				engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName)
+				engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName)
 			}
 			}
 		}
 		}
 	}
 	}

+ 26 - 1
session_stats_test.go

@@ -153,8 +153,33 @@ func TestSQLCount(t *testing.T) {
 
 
 	assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
 	assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
 
 
-	total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2").
+	total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)).
 		Count()
 		Count()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.EqualValues(t, 0, total)
 	assert.EqualValues(t, 0, total)
 }
 }
+
+func TestCountWithOthers(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	type CountWithOthers struct {
+		Id   int64
+		Name string
+	}
+
+	assertSync(t, new(CountWithOthers))
+
+	_, err := testEngine.Insert(&CountWithOthers{
+		Name: "orderby",
+	})
+	assert.NoError(t, err)
+
+	_, err = testEngine.Insert(&CountWithOthers{
+		Name: "limit",
+	})
+	assert.NoError(t, err)
+
+	total, err := testEngine.OrderBy("id desc").Limit(1).Count(new(CountWithOthers))
+	assert.NoError(t, err)
+	assert.EqualValues(t, 2, total)
+}

+ 19 - 81
session_tx_test.go

@@ -32,45 +32,21 @@ func TestTransaction(t *testing.T) {
 	defer session.Close()
 	defer session.Close()
 
 
 	err := session.Begin()
 	err := session.Begin()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 
 	user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	_, err = session.Insert(&user1)
 	_, err = session.Insert(&user1)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 
 	user2 := Userinfo{Username: "yyy"}
 	user2 := Userinfo{Username: "yyy"}
 	_, err = session.Where("(id) = ?", 0).Update(&user2)
 	_, err = session.Where("(id) = ?", 0).Update(&user2)
-	if err != nil {
-		session.Rollback()
-		fmt.Println(err)
-		//t.Error(err)
-		return
-	}
+	assert.NoError(t, err)
 
 
 	_, err = session.Delete(&user2)
 	_, err = session.Delete(&user2)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 
 	err = session.Commit()
 	err = session.Commit()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-		return
-	}
-	// panic(err) !nashtsai! should remove this
+	assert.NoError(t, err)
 }
 }
 
 
 func TestCombineTransaction(t *testing.T) {
 func TestCombineTransaction(t *testing.T) {
@@ -91,38 +67,21 @@ func TestCombineTransaction(t *testing.T) {
 	defer session.Close()
 	defer session.Close()
 
 
 	err := session.Begin()
 	err := session.Begin()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	_, err = session.Insert(&user1)
 	_, err = session.Insert(&user1)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+
 	user2 := Userinfo{Username: "zzz"}
 	user2 := Userinfo{Username: "zzz"}
 	_, err = session.Where("id = ?", 0).Update(&user2)
 	_, err = session.Where("id = ?", 0).Update(&user2)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
-	_, err = session.Exec("delete from userinfo where username = ?", user2.Username)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-	}
+	_, err = session.Exec("delete from "+testEngine.TableName("userinfo", true)+" where username = ?", user2.Username)
+	assert.NoError(t, err)
 
 
 	err = session.Commit()
 	err = session.Commit()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 }
 }
 
 
 func TestCombineTransactionSameMapper(t *testing.T) {
 func TestCombineTransactionSameMapper(t *testing.T) {
@@ -148,45 +107,24 @@ func TestCombineTransactionSameMapper(t *testing.T) {
 
 
 	counter()
 	counter()
 	defer counter()
 	defer counter()
+
 	session := testEngine.NewSession()
 	session := testEngine.NewSession()
 	defer session.Close()
 	defer session.Close()
 
 
 	err := session.Begin()
 	err := session.Begin()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 
 	user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	_, err = session.Insert(&user1)
 	_, err = session.Insert(&user1)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 
 	user2 := Userinfo{Username: "zzz"}
 	user2 := Userinfo{Username: "zzz"}
 	_, err = session.Where("(id) = ?", 0).Update(&user2)
 	_, err = session.Where("(id) = ?", 0).Update(&user2)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 
-	_, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	_, err = session.Exec("delete from  "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username)
+	assert.NoError(t, err)
 
 
 	err = session.Commit()
 	err = session.Commit()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 }
 }

+ 1 - 1
session_update.go

@@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 	var isMap = t.Kind() == reflect.Map
 	var isMap = t.Kind() == reflect.Map
 	var isStruct = t.Kind() == reflect.Struct
 	var isStruct = t.Kind() == reflect.Struct
 	if isStruct {
 	if isStruct {
-		if err := session.statement.setRefValue(v); err != nil {
+		if err := session.statement.setRefBean(bean); err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
 
 

+ 27 - 109
session_update_test.go

@@ -462,30 +462,18 @@ func TestUpdate1(t *testing.T) {
 
 
 	col1 := &UpdateAllCols{Ptr: &s}
 	col1 := &UpdateAllCols{Ptr: &s}
 	err = testEngine.Sync(col1)
 	err = testEngine.Sync(col1)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	_, err = testEngine.Insert(col1)
 	_, err = testEngine.Insert(col1)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	col2 := &UpdateAllCols{col1.Id, true, "", nil}
 	col2 := &UpdateAllCols{col1.Id, true, "", nil}
 	_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
 	_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	col3 := &UpdateAllCols{}
 	col3 := &UpdateAllCols{}
 	has, err = testEngine.ID(col2.Id).Get(col3)
 	has, err = testEngine.ID(col2.Id).Get(col3)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	if !has {
 	if !has {
 		err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
 		err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@@ -759,7 +747,7 @@ func TestUpdateUpdated(t *testing.T) {
 func TestUpdateSameMapper(t *testing.T) {
 func TestUpdateSameMapper(t *testing.T) {
 	assert.NoError(t, prepareEngine())
 	assert.NoError(t, prepareEngine())
 
 
-	oldMapper := testEngine.GetColumnMapper()
+	oldMapper := testEngine.GetTableMapper()
 	testEngine.UnMapType(rValue(new(Userinfo)).Type())
 	testEngine.UnMapType(rValue(new(Userinfo)).Type())
 	testEngine.UnMapType(rValue(new(Condi)).Type())
 	testEngine.UnMapType(rValue(new(Condi)).Type())
 	testEngine.UnMapType(rValue(new(Article)).Type())
 	testEngine.UnMapType(rValue(new(Article)).Type())
@@ -786,81 +774,38 @@ func TestUpdateSameMapper(t *testing.T) {
 
 
 	var ori Userinfo
 	var ori Userinfo
 	has, err := testEngine.Get(&ori)
 	has, err := testEngine.Get(&ori)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if !has {
-		t.Error(errors.New("not exist"))
-		panic(errors.New("not exist"))
-	}
+	assert.NoError(t, err)
+	assert.True(t, has)
+
 	// update by id
 	// update by id
 	user := Userinfo{Username: "xxx", Height: 1.2}
 	user := Userinfo{Username: "xxx", Height: 1.2}
 	cnt, err := testEngine.ID(ori.Uid).Update(&user)
 	cnt, err := testEngine.ID(ori.Uid).Update(&user)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if cnt != 1 {
-		err = errors.New("update not returned 1")
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 
 
 	condi := Condi{"Username": "zzz", "Departname": ""}
 	condi := Condi{"Username": "zzz", "Departname": ""}
 	cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
 	cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-
-	if cnt != 1 {
-		err = errors.New("update not returned 1")
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 
 
 	cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user)
 	cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	total, err := testEngine.Count(&user)
 	total, err := testEngine.Count(&user)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-
-	if cnt != total {
-		err = errors.New("insert not returned 1")
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, cnt, total)
 
 
 	err = testEngine.Sync(&Article{})
 	err = testEngine.Sync(&Article{})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	defer func() {
 	defer func() {
 		err = testEngine.DropTables(&Article{})
 		err = testEngine.DropTables(&Article{})
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 	}()
 	}()
 
 
 	a := &Article{0, "1", "2", "3", "4", "5", 2}
 	a := &Article{0, "1", "2", "3", "4", "5", 2}
 	cnt, err = testEngine.Insert(a)
 	cnt, err = testEngine.Insert(a)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	if cnt != 1 {
 	if cnt != 1 {
 		err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
 		err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@@ -875,10 +820,7 @@ func TestUpdateSameMapper(t *testing.T) {
 	}
 	}
 
 
 	cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
 	cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	if cnt != 1 {
 	if cnt != 1 {
 		err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
 		err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@@ -889,30 +831,18 @@ func TestUpdateSameMapper(t *testing.T) {
 
 
 	col1 := &UpdateAllCols{}
 	col1 := &UpdateAllCols{}
 	err = testEngine.Sync(col1)
 	err = testEngine.Sync(col1)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	_, err = testEngine.Insert(col1)
 	_, err = testEngine.Insert(col1)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	col2 := &UpdateAllCols{col1.Id, true, "", nil}
 	col2 := &UpdateAllCols{col1.Id, true, "", nil}
 	_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
 	_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	col3 := &UpdateAllCols{}
 	col3 := &UpdateAllCols{}
 	has, err = testEngine.ID(col2.Id).Get(col3)
 	has, err = testEngine.ID(col2.Id).Get(col3)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	if !has {
 	if !has {
 		err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
 		err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@@ -931,32 +861,20 @@ func TestUpdateSameMapper(t *testing.T) {
 	{
 	{
 		col1 := &UpdateMustCols{}
 		col1 := &UpdateMustCols{}
 		err = testEngine.Sync(col1)
 		err = testEngine.Sync(col1)
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 
 
 		_, err = testEngine.Insert(col1)
 		_, err = testEngine.Insert(col1)
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 
 
 		col2 := &UpdateMustCols{col1.Id, true, ""}
 		col2 := &UpdateMustCols{col1.Id, true, ""}
 		boolStr := testEngine.GetColumnMapper().Obj2Table("Bool")
 		boolStr := testEngine.GetColumnMapper().Obj2Table("Bool")
 		stringStr := testEngine.GetColumnMapper().Obj2Table("String")
 		stringStr := testEngine.GetColumnMapper().Obj2Table("String")
 		_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
 		_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 
 
 		col3 := &UpdateMustCols{}
 		col3 := &UpdateMustCols{}
 		has, err := testEngine.ID(col2.Id).Get(col3)
 		has, err := testEngine.ID(col2.Id).Get(col3)
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 
 
 		if !has {
 		if !has {
 			err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
 			err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))

+ 43 - 78
statement.go

@@ -221,26 +221,18 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	statement.tableName = statement.Engine.tbName(v)
+	statement.tableName = statement.Engine.TableName(v.Interface(), true)
 	return nil
 	return nil
 }
 }
 
 
-// Table tempororily set table name, the parameter could be a string or a pointer of struct
-func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
-	v := rValue(tableNameOrBean)
-	t := v.Type()
-	if t.Kind() == reflect.String {
-		statement.AltTableName = tableNameOrBean.(string)
-	} else if t.Kind() == reflect.Struct {
-		var err error
-		statement.RefTable, err = statement.Engine.autoMapType(v)
-		if err != nil {
-			statement.Engine.logger.Error(err)
-			return statement
-		}
-		statement.AltTableName = statement.Engine.tbName(v)
+func (statement *Statement) setRefBean(bean interface{}) error {
+	var err error
+	statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
+	if err != nil {
+		return err
 	}
 	}
-	return statement
+	statement.tableName = statement.Engine.TableName(bean, true)
+	return nil
 }
 }
 
 
 // Auto generating update columnes and values according a struct
 // Auto generating update columnes and values according a struct
@@ -743,6 +735,23 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
 	return statement
 	return statement
 }
 }
 
 
+// Table tempororily set table name, the parameter could be a string or a pointer of struct
+func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
+	v := rValue(tableNameOrBean)
+	t := v.Type()
+	if t.Kind() == reflect.Struct {
+		var err error
+		statement.RefTable, err = statement.Engine.autoMapType(v)
+		if err != nil {
+			statement.Engine.logger.Error(err)
+			return statement
+		}
+	}
+
+	statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
+	return statement
+}
+
 // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
 // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
 func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
 func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
 	var buf bytes.Buffer
 	var buf bytes.Buffer
@@ -752,56 +761,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 		fmt.Fprintf(&buf, "%v JOIN ", joinOP)
 		fmt.Fprintf(&buf, "%v JOIN ", joinOP)
 	}
 	}
 
 
-	switch tablename.(type) {
-	case []string:
-		t := tablename.([]string)
-		if len(t) > 1 {
-			fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
-		} else if len(t) == 1 {
-			fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
-		}
-	case []interface{}:
-		t := tablename.([]interface{})
-		l := len(t)
-		var table string
-		if l > 0 {
-			f := t[0]
-			switch f.(type) {
-			case string:
-				table = f.(string)
-			case TableName:
-				table = f.(TableName).TableName()
-			default:
-				v := rValue(f)
-				t := v.Type()
-				if t.Kind() == reflect.Struct {
-					fmt.Fprintf(&buf, statement.Engine.tbName(v))
-				} else {
-					fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f)))
-				}
-			}
-		}
-		if l > 1 {
-			fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
-				statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
-		} else if l == 1 {
-			fmt.Fprintf(&buf, statement.Engine.Quote(table))
-		}
-	case TableName:
-		fmt.Fprintf(&buf, tablename.(TableName).TableName())
-	case string:
-		fmt.Fprintf(&buf, tablename.(string))
-	default:
-		v := rValue(tablename)
-		t := v.Type()
-		if t.Kind() == reflect.Struct {
-			fmt.Fprintf(&buf, statement.Engine.tbName(v))
-		} else {
-			fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
-		}
-	}
+	tbName := statement.Engine.TableName(tablename, true)
 
 
-	fmt.Fprintf(&buf, " ON %v", condition)
+	fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
 	statement.JoinStr = buf.String()
 	statement.JoinStr = buf.String()
 	statement.joinArgs = append(statement.joinArgs, args...)
 	statement.joinArgs = append(statement.joinArgs, args...)
 	return statement
 	return statement
@@ -906,16 +868,18 @@ func (statement *Statement) genUniqueSQL() []string {
 func (statement *Statement) genDelIndexSQL() []string {
 func (statement *Statement) genDelIndexSQL() []string {
 	var sqls []string
 	var sqls []string
 	tbName := statement.TableName()
 	tbName := statement.TableName()
+	idxPrefixName := strings.Replace(tbName, `"`, "", -1)
+	idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
 	for idxName, index := range statement.RefTable.Indexes {
 	for idxName, index := range statement.RefTable.Indexes {
 		var rIdxName string
 		var rIdxName string
 		if index.Type == core.UniqueType {
 		if index.Type == core.UniqueType {
-			rIdxName = uniqueName(tbName, idxName)
+			rIdxName = uniqueName(idxPrefixName, idxName)
 		} else if index.Type == core.IndexType {
 		} else if index.Type == core.IndexType {
-			rIdxName = indexName(tbName, idxName)
+			rIdxName = indexName(idxPrefixName, idxName)
 		}
 		}
-		sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
+		sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
 		if statement.Engine.dialect.IndexOnTable() {
 		if statement.Engine.dialect.IndexOnTable() {
-			sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
+			sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
 		}
 		}
 		sqls = append(sqls, sql)
 		sqls = append(sqls, sql)
 	}
 	}
@@ -966,7 +930,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
 	v := rValue(bean)
 	v := rValue(bean)
 	isStruct := v.Kind() == reflect.Struct
 	isStruct := v.Kind() == reflect.Struct
 	if isStruct {
 	if isStruct {
-		statement.setRefValue(v)
+		statement.setRefBean(bean)
 	}
 	}
 
 
 	var columnStr = statement.ColumnStr
 	var columnStr = statement.ColumnStr
@@ -1005,7 +969,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
 		return "", nil, err
 		return "", nil, err
 	}
 	}
 
 
-	sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true)
+	sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
 	if err != nil {
 	if err != nil {
 		return "", nil, err
 		return "", nil, err
 	}
 	}
@@ -1018,7 +982,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
 	var condArgs []interface{}
 	var condArgs []interface{}
 	var err error
 	var err error
 	if len(beans) > 0 {
 	if len(beans) > 0 {
-		statement.setRefValue(rValue(beans[0]))
+		statement.setRefBean(beans[0])
 		condSQL, condArgs, err = statement.genConds(beans[0])
 		condSQL, condArgs, err = statement.genConds(beans[0])
 	} else {
 	} else {
 		condSQL, condArgs, err = builder.ToSQL(statement.cond)
 		condSQL, condArgs, err = builder.ToSQL(statement.cond)
@@ -1035,7 +999,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
 			selectSQL = "count(*)"
 			selectSQL = "count(*)"
 		}
 		}
 	}
 	}
-	sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false)
+	sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
 	if err != nil {
 	if err != nil {
 		return "", nil, err
 		return "", nil, err
 	}
 	}
@@ -1044,7 +1008,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
 }
 }
 
 
 func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
 func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
-	statement.setRefValue(rValue(bean))
+	statement.setRefBean(bean)
 
 
 	var sumStrs = make([]string, 0, len(columns))
 	var sumStrs = make([]string, 0, len(columns))
 	for _, colName := range columns {
 	for _, colName := range columns {
@@ -1060,7 +1024,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
 		return "", nil, err
 		return "", nil, err
 	}
 	}
 
 
-	sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true)
+	sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
 	if err != nil {
 	if err != nil {
 		return "", nil, err
 		return "", nil, err
 	}
 	}
@@ -1068,7 +1032,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
 	return sqlStr, append(statement.joinArgs, condArgs...), nil
 	return sqlStr, append(statement.joinArgs, condArgs...), nil
 }
 }
 
 
-func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) {
+func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (a string, err error) {
 	var distinct string
 	var distinct string
 	if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
 	if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
 		distinct = "DISTINCT "
 		distinct = "DISTINCT "
@@ -1135,9 +1099,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bo
 			}
 			}
 
 
 			var orderStr string
 			var orderStr string
-			if len(statement.OrderStr) > 0 {
+			if needOrderBy && len(statement.OrderStr) > 0 {
 				orderStr = " ORDER BY " + statement.OrderStr
 				orderStr = " ORDER BY " + statement.OrderStr
 			}
 			}
+
 			var groupStr string
 			var groupStr string
 			if len(statement.GroupByStr) > 0 {
 			if len(statement.GroupByStr) > 0 {
 				groupStr = " GROUP BY " + statement.GroupByStr
 				groupStr = " GROUP BY " + statement.GroupByStr
@@ -1163,7 +1128,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bo
 	if statement.HavingStr != "" {
 	if statement.HavingStr != "" {
 		a = fmt.Sprintf("%v %v", a, statement.HavingStr)
 		a = fmt.Sprintf("%v %v", a, statement.HavingStr)
 	}
 	}
-	if statement.OrderStr != "" {
+	if needOrderBy && statement.OrderStr != "" {
 		a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
 		a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
 	}
 	}
 	if needLimit {
 	if needLimit {

+ 27 - 36
tag_extends_test.go

@@ -202,17 +202,14 @@ func TestExtends(t *testing.T) {
 
 
 	var info UserAndDetail
 	var info UserAndDetail
 	qt := testEngine.Quote
 	qt := testEngine.Quote
-	ui := testEngine.GetTableMapper().Obj2Table("Userinfo")
-	ud := testEngine.GetTableMapper().Obj2Table("Userdetail")
-	uiid := testEngine.GetTableMapper().Obj2Table("Id")
+	ui := testEngine.TableName(new(Userinfo), true)
+	ud := testEngine.TableName(&detail, true)
+	uiid := testEngine.GetColumnMapper().Obj2Table("Id")
 	udid := "detail_id"
 	udid := "detail_id"
 	sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s",
 	sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s",
 		qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
 		qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
 	b, err := testEngine.SQL(sql).NoCascade().Get(&info)
 	b, err := testEngine.SQL(sql).NoCascade().Get(&info)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	if !b {
 	if !b {
 		err = errors.New("should has lest one record")
 		err = errors.New("should has lest one record")
 		t.Error(err)
 		t.Error(err)
@@ -341,19 +338,17 @@ func TestExtends2(t *testing.T) {
 	}
 	}
 
 
 	var mapper = testEngine.GetTableMapper().Obj2Table
 	var mapper = testEngine.GetTableMapper().Obj2Table
-	userTableName := mapper("MessageUser")
-	typeTableName := mapper("MessageType")
-	msgTableName := mapper("Message")
+	var quote = testEngine.Quote
+	userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
+	typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
+	msgTableName := quote(testEngine.TableName(mapper("Message"), true))
 
 
 	list := make([]Message, 0)
 	list := make([]Message, 0)
-	err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
-		Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`").
-		Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
+	err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
+		Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
+		Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
 		Find(&list)
 		Find(&list)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	if len(list) != 1 {
 	if len(list) != 1 {
 		err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
 		err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@@ -406,25 +401,20 @@ func TestExtends3(t *testing.T) {
 		assert.NoError(t, err)
 		assert.NoError(t, err)
 	}
 	}
 	_, err = testEngine.Insert(&msg)
 	_, err = testEngine.Insert(&msg)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	var mapper = testEngine.GetTableMapper().Obj2Table
 	var mapper = testEngine.GetTableMapper().Obj2Table
-	userTableName := mapper("MessageUser")
-	typeTableName := mapper("MessageType")
-	msgTableName := mapper("Message")
+	var quote = testEngine.Quote
+	userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
+	typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
+	msgTableName := quote(testEngine.TableName(mapper("Message"), true))
 
 
 	list := make([]MessageExtend3, 0)
 	list := make([]MessageExtend3, 0)
-	err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
-		Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`").
-		Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
+	err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
+		Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
+		Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
 		Find(&list)
 		Find(&list)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 
 	if len(list) != 1 {
 	if len(list) != 1 {
 		err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
 		err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@@ -499,13 +489,14 @@ func TestExtends4(t *testing.T) {
 	}
 	}
 
 
 	var mapper = testEngine.GetTableMapper().Obj2Table
 	var mapper = testEngine.GetTableMapper().Obj2Table
-	userTableName := mapper("MessageUser")
-	typeTableName := mapper("MessageType")
-	msgTableName := mapper("Message")
+	var quote = testEngine.Quote
+	userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
+	typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
+	msgTableName := quote(testEngine.TableName(mapper("Message"), true))
 
 
 	list := make([]MessageExtend4, 0)
 	list := make([]MessageExtend4, 0)
-	err = testEngine.Table(msgTableName).Join("LEFT", userTableName, "`"+userTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
-		Join("LEFT", typeTableName, "`"+typeTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
+	err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
+		Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
 		Find(&list)
 		Find(&list)
 	if err != nil {
 	if err != nil {
 		t.Error(err)
 		t.Error(err)

+ 1 - 1
test/xorm_test.go

@@ -635,7 +635,7 @@ func Test_SqlTemplateClient_Search_Json(t *testing.T) {
 
 
 func Test_Query(t *testing.T) {
 func Test_Query(t *testing.T) {
 
 
-	result, err := db.Query("select * from category where id =25")
+	result, err := db.QueryInterface("select * from category where id =25")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}

+ 4 - 3
types_test.go

@@ -301,10 +301,11 @@ type UserCus struct {
 func TestCustomType2(t *testing.T) {
 func TestCustomType2(t *testing.T) {
 	assert.NoError(t, prepareEngine())
 	assert.NoError(t, prepareEngine())
 
 
-	err := testEngine.CreateTables(&UserCus{})
+	var uc UserCus
+	err := testEngine.CreateTables(&uc)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	tableName := testEngine.TableMapper.Obj2Table("UserCus")
+	tableName := testEngine.TableName(&uc, true)
 	_, err = testEngine.Exec("delete from " + testEngine.Quote(tableName))
 	_, err = testEngine.Exec("delete from " + testEngine.Quote(tableName))
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
@@ -327,7 +328,7 @@ func TestCustomType2(t *testing.T) {
 	fmt.Println(user)
 	fmt.Println(user)
 
 
 	users := make([]UserCus, 0)
 	users := make([]UserCus, 0)
-	err = testEngine.Where("`"+testEngine.ColumnMapper.Obj2Table("Status")+"` = ?", "Registed").Find(&users)
+	err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Status")+"` = ?", "Registed").Find(&users)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, len(users))
 	assert.EqualValues(t, 1, len(users))
 
 

+ 4 - 1
xorm_test.go

@@ -27,6 +27,7 @@ var (
 	cache      = flag.Bool("cache", false, "if enable cache")
 	cache      = flag.Bool("cache", false, "if enable cache")
 	cluster    = flag.Bool("cluster", false, "if this is a cluster")
 	cluster    = flag.Bool("cluster", false, "if this is a cluster")
 	splitter   = flag.String("splitter", ";", "the splitter on connstr for cluster")
 	splitter   = flag.String("splitter", ";", "the splitter on connstr for cluster")
+	schema     = flag.String("schema", "", "specify the schema")
 )
 )
 
 
 func createEngine(dbType, connStr string) error {
 func createEngine(dbType, connStr string) error {
@@ -35,7 +36,6 @@ func createEngine(dbType, connStr string) error {
 
 
 		if !*cluster {
 		if !*cluster {
 			testEngine, err = NewEngine(dbType, connStr)
 			testEngine, err = NewEngine(dbType, connStr)
-
 		} else {
 		} else {
 			testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter))
 			testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter))
 		}
 		}
@@ -43,6 +43,9 @@ func createEngine(dbType, connStr string) error {
 			return err
 			return err
 		}
 		}
 
 
+		if *schema != "" {
+			testEngine.SetSchema(*schema)
+		}
 		testEngine.ShowSQL(*showSQL)
 		testEngine.ShowSQL(*showSQL)
 		testEngine.SetLogLevel(core.LOG_DEBUG)
 		testEngine.SetLogLevel(core.LOG_DEBUG)
 		if *cache {
 		if *cache {