Browse Source

Add SetSchema for engine

fix postgres with schema

fix test

fix tablename

refactor tableName

fix schema support

improve the interface of EngineInterface

add test for resolve

add test for (id) replace

Add test for count with orderby and limit
xormplus 7 years ago
parent
commit
b5723783b6

+ 18 - 3
dialect_postgres.go

@@ -910,7 +910,16 @@ func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
 
 func (db *postgres) CreateIndexSql(tableName string, index *core.Index) string {
 	quote := db.Quote
-	return fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tableName, index.Name)),
+	idxName := index.Name
+
+	tableName = strings.Replace(tableName, `"`, "", -1)
+	tableName = strings.Replace(tableName, `.`, "_", -1)
+
+	if db.Uri.Schema != "" {
+		idxName = db.Uri.Schema + "." + idxName
+	}
+
+	return fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tableName, idxName)),
 		quote(tableName), quote(strings.Join(index.Cols, quote(","))))
 }
 
@@ -918,6 +927,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
 	quote := db.Quote
 	idxName := index.Name
 
+	tableName = strings.Replace(tableName, `"`, "", -1)
+	tableName = strings.Replace(tableName, `.`, "_", -1)
+
 	if !strings.HasPrefix(idxName, "UQE_") &&
 		!strings.HasPrefix(idxName, "IDX_") {
 		if index.Type == core.UniqueType {
@@ -926,6 +938,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
 			idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
 		}
 	}
+	if db.Uri.Schema != "" {
+		idxName = db.Uri.Schema + "." + idxName
+	}
 	return fmt.Sprintf("DROP INDEX %v", quote(idxName))
 }
 
@@ -966,7 +981,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
 	var f string
 	if len(db.Schema) != 0 {
 		args = append(args, db.Schema)
-		f = "AND s.table_schema = $2"
+		f = " AND s.table_schema = $2"
 	}
 	s = fmt.Sprintf(s, f)
 
@@ -1091,11 +1106,11 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
 func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
 	args := []interface{}{tableName}
 	s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
-	db.LogSQL(s, args)
 	if len(db.Schema) != 0 {
 		args = append(args, db.Schema)
 		s = s + " AND schemaname=$2"
 	}
+	db.LogSQL(s, args)
 
 	rows, err := db.DB().Query(s, args...)
 	if err != nil {

+ 20 - 67
engine.go

@@ -545,46 +545,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
 	return nil
 }
 
-func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
-	v := rValue(beanOrTableName)
-	if v.Type().Kind() == reflect.String {
-		return beanOrTableName.(string), nil
-	} else if v.Type().Kind() == reflect.Struct {
-		return engine.tbName(v), nil
-	}
-	return "", errors.New("bean should be a struct or struct's point")
-}
-
-func (engine *Engine) tbSchemaName(v string) string {
-	// Add schema name as prefix of table name.
-	// Only for postgres database.
-	if engine.dialect.DBType() == core.POSTGRES &&
-		engine.dialect.URI().Schema != "" &&
-		engine.dialect.URI().Schema != postgresPublicSchema &&
-		strings.Index(v, ".") == -1 {
-		return engine.dialect.URI().Schema + "." + v
-	}
-	return v
-}
-
-func (engine *Engine) tbName(v reflect.Value) string {
-	if tb, ok := v.Interface().(TableName); ok {
-		return engine.tbSchemaName(tb.TableName())
-
-	}
-
-	if v.Type().Kind() == reflect.Ptr {
-		if tb, ok := reflect.Indirect(v).Interface().(TableName); ok {
-			return engine.tbSchemaName(tb.TableName())
-		}
-	} else if v.CanAddr() {
-		if tb, ok := v.Addr().Interface().(TableName); ok {
-			return engine.tbSchemaName(tb.TableName())
-		}
-	}
-	return engine.tbSchemaName(engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name()))
-}
-
 // Cascade use cascade or not
 func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
 	session := engine.NewSession()
@@ -868,7 +828,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
 	if err != nil {
 		engine.logger.Error(err)
 	}
-	return &Table{tb, engine.tbName(v)}
+	return &Table{tb, engine.TableName(bean)}
 }
 
 func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@@ -904,20 +864,8 @@ var (
 func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
 	t := v.Type()
 	table := engine.newTable()
-	if tb, ok := v.Interface().(TableName); ok {
-		table.Name = tb.TableName()
-	} else {
-		if v.CanAddr() {
-			if tb, ok = v.Addr().Interface().(TableName); ok {
-				table.Name = tb.TableName()
-			}
-		}
-		if table.Name == "" {
-			table.Name = engine.TableMapper.Obj2Table(t.Name())
-		}
-	}
-
 	table.Type = t
+	table.Name = engine.tbNameForMap(v)
 
 	var idFieldColName string
 	var hasCacheTag, hasNoCacheTag bool
@@ -1195,7 +1143,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
 	if t.Kind() != reflect.Struct {
 		return errors.New("error params")
 	}
-	tableName := engine.tbName(v)
+	tableName := engine.TableName(bean)
 	table, err := engine.autoMapType(v)
 	if err != nil {
 		return err
@@ -1219,7 +1167,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
 		if t.Kind() != reflect.Struct {
 			return errors.New("error params")
 		}
-		tableName := engine.tbName(v)
+		tableName := engine.TableName(bean)
 		table, err := engine.autoMapType(v)
 		if err != nil {
 			return err
@@ -1246,13 +1194,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 
 	for _, bean := range beans {
 		v := rValue(bean)
-		tableName := engine.tbName(v)
+		tableNameNoSchema := engine.tbNameNoSchema(v.Interface())
 		table, err := engine.autoMapType(v)
 		if err != nil {
 			return err
 		}
 
-		isExist, err := session.Table(bean).isTableExist(tableName)
+		isExist, err := session.Table(bean).isTableExist(tableNameNoSchema)
 		if err != nil {
 			return err
 		}
@@ -1278,12 +1226,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 			}
 		} else {
 			for _, col := range table.Columns() {
-				isExist, err := engine.dialect.IsColumnExist(tableName, col.Name)
+				isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name)
 				if err != nil {
 					return err
 				}
 				if !isExist {
-					if err := session.statement.setRefValue(v); err != nil {
+					if err := session.statement.setRefBean(bean); err != nil {
 						return err
 					}
 					err = session.addColumn(col.Name)
@@ -1294,35 +1242,35 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 			}
 
 			for name, index := range table.Indexes {
-				if err := session.statement.setRefValue(v); err != nil {
+				if err := session.statement.setRefBean(bean); err != nil {
 					return err
 				}
 				if index.Type == core.UniqueType {
-					isExist, err := session.isIndexExist2(tableName, index.Cols, true)
+					isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
 					if err != nil {
 						return err
 					}
 					if !isExist {
-						if err := session.statement.setRefValue(v); err != nil {
+						if err := session.statement.setRefBean(bean); err != nil {
 							return err
 						}
 
-						err = session.addUnique(tableName, name)
+						err = session.addUnique(tableNameNoSchema, name)
 						if err != nil {
 							return err
 						}
 					}
 				} else if index.Type == core.IndexType {
-					isExist, err := session.isIndexExist2(tableName, index.Cols, false)
+					isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
 					if err != nil {
 						return err
 					}
 					if !isExist {
-						if err := session.statement.setRefValue(v); err != nil {
+						if err := session.statement.setRefBean(bean); err != nil {
 							return err
 						}
 
-						err = session.addIndex(tableName, name)
+						err = session.addIndex(tableNameNoSchema, name)
 						if err != nil {
 							return err
 						}
@@ -1661,6 +1609,11 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) {
 	engine.DatabaseTZ = tz
 }
 
+// SetSchema sets the schema of database
+func (engine *Engine) SetSchema(schema string) {
+	engine.dialect.URI().Schema = schema
+}
+
 // Unscoped always disable struct tag "deleted"
 func (engine *Engine) Unscoped() *Session {
 	session := engine.NewSession()

+ 109 - 0
engine_table.go

@@ -0,0 +1,109 @@
+// Copyright 2018 The Xorm Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xorm
+
+import (
+	"fmt"
+	"reflect"
+	"strings"
+
+	"github.com/xormplus/core"
+)
+
+// TableNameWithSchema will automatically add schema prefix on table name
+func (engine *Engine) tbNameWithSchema(v string) string {
+	// Add schema name as prefix of table name.
+	// Only for postgres database.
+	if engine.dialect.DBType() == core.POSTGRES &&
+		engine.dialect.URI().Schema != "" &&
+		engine.dialect.URI().Schema != postgresPublicSchema &&
+		strings.Index(v, ".") == -1 {
+		return engine.dialect.URI().Schema + "." + v
+	}
+	return v
+}
+
+// TableName returns table name with schema prefix if has
+func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
+	tbName := engine.tbNameNoSchema(bean)
+	if len(includeSchema) > 0 && includeSchema[0] {
+		tbName = engine.tbNameWithSchema(tbName)
+	}
+
+	return tbName
+}
+
+// tbName get some table's table name
+func (session *Session) tbNameNoSchema(table *core.Table) string {
+	if len(session.statement.AltTableName) > 0 {
+		return session.statement.AltTableName
+	}
+
+	return table.Name
+}
+
+func (engine *Engine) tbNameForMap(v reflect.Value) string {
+	t := v.Type()
+	if tb, ok := v.Interface().(TableName); ok {
+		return tb.TableName()
+	}
+	if v.CanAddr() {
+		if tb, ok := v.Addr().Interface().(TableName); ok {
+			return tb.TableName()
+		}
+	}
+	return engine.TableMapper.Obj2Table(t.Name())
+}
+
+func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
+	switch tablename.(type) {
+	case []string:
+		t := tablename.([]string)
+		if len(t) > 1 {
+			return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
+		} else if len(t) == 1 {
+			return engine.Quote(t[0])
+		}
+	case []interface{}:
+		t := tablename.([]interface{})
+		l := len(t)
+		var table string
+		if l > 0 {
+			f := t[0]
+			switch f.(type) {
+			case string:
+				table = f.(string)
+			case TableName:
+				table = f.(TableName).TableName()
+			default:
+				v := rValue(f)
+				t := v.Type()
+				if t.Kind() == reflect.Struct {
+					table = engine.tbNameForMap(v)
+				} else {
+					table = engine.Quote(fmt.Sprintf("%v", f))
+				}
+			}
+		}
+		if l > 1 {
+			return fmt.Sprintf("%v AS %v", engine.Quote(table),
+				engine.Quote(fmt.Sprintf("%v", t[1])))
+		} else if l == 1 {
+			return engine.Quote(table)
+		}
+	case TableName:
+		return tablename.(TableName).TableName()
+	case string:
+		return tablename.(string)
+	default:
+		v := rValue(tablename)
+		t := v.Type()
+		if t.Kind() == reflect.Struct {
+			return engine.tbNameForMap(v)
+		}
+		return engine.Quote(fmt.Sprintf("%v", tablename))
+	}
+	return ""
+}

+ 2 - 0
interface.go

@@ -87,6 +87,7 @@ type EngineInterface interface {
 	SetDefaultCacher(core.Cacher)
 	SetLogLevel(core.LogLevel)
 	SetMapper(core.IMapper)
+	SetSchema(string)
 	SetTZDatabase(tz *time.Location)
 	SetTZLocation(tz *time.Location)
 	ShowSQL(show ...bool)
@@ -94,6 +95,7 @@ type EngineInterface interface {
 	Sync2(...interface{}) error
 	StoreEngine(storeEngine string) *Session
 	TableInfo(bean interface{}) *Table
+	TableName(interface{}, ...bool) string
 	UnMapType(reflect.Type)
 }
 

+ 3 - 3
rows.go

@@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
 	var args []interface{}
 	var err error
 
-	if err = rows.session.statement.setRefValue(rValue(bean)); err != nil {
+	if err = rows.session.statement.setRefBean(bean); err != nil {
 		return nil, err
 	}
 
@@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error {
 		return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
 	}
 
-	dataStruct := rValue(bean)
-	if err := rows.session.statement.setRefValue(dataStruct); err != nil {
+	if err := rows.session.statement.setRefBean(bean); err != nil {
 		return err
 	}
 
@@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error {
 		return err
 	}
 
+	dataStruct := rValue(bean)
 	_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
 	if err != nil {
 		return err

+ 2 - 1
rows_test.go

@@ -54,7 +54,8 @@ func TestRows(t *testing.T) {
 	}
 	assert.EqualValues(t, 1, cnt)
 
-	rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows))
+	var tbName = testEngine.Quote(testEngine.TableName(user, true))
+	rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows))
 	assert.NoError(t, err)
 	defer rows2.Close()
 

+ 0 - 9
session.go

@@ -834,15 +834,6 @@ func (session *Session) LastSQL() (string, []interface{}) {
 	return session.lastSQL, session.lastSQLArgs
 }
 
-// tbName get some table's table name
-func (session *Session) tbNameNoSchema(table *core.Table) string {
-	if len(session.statement.AltTableName) > 0 {
-		return session.statement.AltTableName
-	}
-
-	return table.Name
-}
-
 // Unscoped always disable struct tag "deleted"
 func (session *Session) Unscoped() *Session {
 	session.statement.Unscoped()

+ 22 - 90
session_cond_test.go

@@ -122,18 +122,11 @@ func TestIn(t *testing.T) {
 	assert.NoError(t, err)
 	assert.EqualValues(t, 3, cnt)
 
+	department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
 	var usrs []Userinfo
-	err = testEngine.Limit(3).Find(&usrs)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-
-	if len(usrs) != 3 {
-		err = errors.New("there are not 3 records")
-		t.Error(err)
-		panic(err)
-	}
+	err = testEngine.Where(department+" = ?", "dev").Limit(3).Find(&usrs)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 3, len(usrs))
 
 	var ids []int64
 	var idsStr string
@@ -145,35 +138,20 @@ func TestIn(t *testing.T) {
 
 	users := make([]Userinfo, 0)
 	err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
-	if len(users) != 3 {
-		err = errors.New("in uses should be " + idsStr + " total 3")
-		t.Error(err)
-		panic(err)
-	}
+	assert.EqualValues(t, 3, len(users))
 
 	users = make([]Userinfo, 0)
 	err = testEngine.In("(id)", ids).Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
-	if len(users) != 3 {
-		err = errors.New("in uses should be " + idsStr + " total 3")
-		t.Error(err)
-		panic(err)
-	}
+	assert.EqualValues(t, 3, len(users))
 
 	for _, user := range users {
 		if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
 			err = errors.New("in uses should be " + idsStr + " total 3")
-			t.Error(err)
-			panic(err)
+			assert.NoError(t, err)
 		}
 	}
 
@@ -183,87 +161,41 @@ func TestIn(t *testing.T) {
 		idsInterface = append(idsInterface, id)
 	}
 
-	department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
 	err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
-
-	if len(users) != 3 {
-		err = errors.New("in uses should be " + idsStr + " total 3")
-		t.Error(err)
-		panic(err)
-	}
+	assert.EqualValues(t, 3, len(users))
 
 	for _, user := range users {
 		if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
 			err = errors.New("in uses should be " + idsStr + " total 3")
-			t.Error(err)
-			panic(err)
+			assert.NoError(t, err)
 		}
 	}
 
 	dev := testEngine.GetColumnMapper().Obj2Table("Dev")
 
 	err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users)
-
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 
 	cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if cnt != 1 {
-		err = errors.New("update records not 1")
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 
 	user := new(Userinfo)
 	has, err := testEngine.ID(ids[0]).Get(user)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if !has {
-		err = errors.New("get record not 1")
-		t.Error(err)
-		panic(err)
-	}
-	if user.Departname != "dev-" {
-		err = errors.New("update not success")
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, "dev-", user.Departname)
 
 	cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if cnt != 1 {
-		err = errors.New("update records not 1")
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 
 	cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if cnt != 1 {
-		err = errors.New("deleted records not 1")
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 }
 
 func TestFindAndCount(t *testing.T) {

+ 1 - 1
session_delete.go

@@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 		defer session.Close()
 	}
 
-	if err := session.statement.setRefValue(rValue(bean)); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return 0, err
 	}
 

+ 1 - 1
session_exist.go

@@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
 			}
 
 			if beanValue.Elem().Kind() == reflect.Struct {
-				if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
+				if err := session.statement.setRefBean(bean[0]); err != nil {
 					return false, err
 				}
 			}

+ 2 - 2
session_exist_test.go

@@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) {
 	assert.NoError(t, err)
 	assert.False(t, has)
 
-	has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
+	has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist()
 	assert.NoError(t, err)
 	assert.True(t, has)
 
-	has, err = testEngine.SQL("select * from record_exist where name = ?", "test2").Exist()
+	has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist()
 	assert.NoError(t, err)
 	assert.False(t, has)
 

+ 1 - 1
session_find.go

@@ -182,7 +182,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 		}
 
 		args = append(session.statement.joinArgs, condArgs...)
-		sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true)
+		sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true)
 		if err != nil {
 			return err
 		}

+ 13 - 42
session_find_test.go

@@ -96,21 +96,15 @@ func TestFind(t *testing.T) {
 	users := make([]Userinfo, 0)
 
 	err := testEngine.Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	for _, user := range users {
 		fmt.Println(user)
 	}
 
 	users2 := make([]Userinfo, 0)
-	userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo")
-	err = testEngine.SQL("select * from " + testEngine.Quote(userinfo)).Find(&users2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true))
+	err = testEngine.SQL("select * from " + tbName).Find(&users2)
+	assert.NoError(t, err)
 }
 
 func TestFind2(t *testing.T) {
@@ -238,14 +232,8 @@ func TestDistinct(t *testing.T) {
 	users := make([]Userinfo, 0)
 	departname := testEngine.GetTableMapper().Obj2Table("Departname")
 	err = testEngine.Distinct(departname).Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if len(users) != 1 {
-		t.Error(err)
-		panic(errors.New("should be one record"))
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, len(users))
 
 	fmt.Println(users)
 
@@ -255,11 +243,9 @@ func TestDistinct(t *testing.T) {
 
 	users2 := make([]Depart, 0)
 	err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	if len(users2) != 1 {
+		fmt.Println(len(users2))
 		t.Error(err)
 		panic(errors.New("should be one record"))
 	}
@@ -272,18 +258,12 @@ func TestOrder(t *testing.T) {
 
 	users := make([]Userinfo, 0)
 	err := testEngine.OrderBy("id desc").Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 
 	users2 := make([]Userinfo, 0)
 	err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users2)
 }
 
@@ -293,10 +273,7 @@ func TestHaving(t *testing.T) {
 
 	users := make([]Userinfo, 0)
 	err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 
 	/*users = make([]Userinfo, 0)
@@ -324,18 +301,12 @@ func TestOrderSameMapper(t *testing.T) {
 
 	users := make([]Userinfo, 0)
 	err := testEngine.OrderBy("(id) desc").Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users)
 
 	users2 := make([]Userinfo, 0)
 	err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	fmt.Println(users2)
 }
 

+ 1 - 1
session_get.go

@@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
 	}
 
 	if beanValue.Elem().Kind() == reflect.Struct {
-		if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
+		if err := session.statement.setRefBean(bean); err != nil {
 			return false, err
 		}
 	}

+ 32 - 1
session_get_test.go

@@ -84,11 +84,16 @@ func TestGetVar(t *testing.T) {
 	assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
 
 	var money2 float64
-	has, err = testEngine.SQL("SELECT money FROM get_var LIMIT 1").Get(&money2)
+	has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
 	assert.NoError(t, err)
 	assert.Equal(t, true, has)
 	assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))
 
+	var money3 float64
+	has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " WHERE money > 20").Get(&money3)
+	assert.NoError(t, err)
+	assert.Equal(t, false, has)
+
 	var valuesString = make(map[string]string)
 	has, err = testEngine.Table("get_var").Get(&valuesString)
 	assert.NoError(t, err)
@@ -279,3 +284,29 @@ func TestGetActionMapping(t *testing.T) {
 		ID(1).Get(&valuesSlice)
 	assert.NoError(t, err)
 }
+
+func TestGetStructId(t *testing.T) {
+	type TestGetStruct struct {
+		Id int64
+	}
+
+	assert.NoError(t, prepareEngine())
+	assertSync(t, new(TestGetStruct))
+
+	_, err := testEngine.Insert(&TestGetStruct{})
+	assert.NoError(t, err)
+	_, err = testEngine.Insert(&TestGetStruct{})
+	assert.NoError(t, err)
+
+	type maxidst struct {
+		Id int64
+	}
+
+	//var id int64
+	var maxid maxidst
+	sql := "select max(id) as id from " + testEngine.TableName(&TestGetStruct{}, true)
+	has, err := testEngine.SQL(sql).Get(&maxid)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 2, maxid.Id)
+}

+ 1 - 1
session_insert.go

@@ -298,7 +298,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
 }
 
 func (session *Session) innerInsert(bean interface{}) (int64, error) {
-	if err := session.statement.setRefValue(rValue(bean)); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return 0, err
 	}
 	if len(session.statement.TableName()) <= 0 {

+ 2 - 1
session_insert_test.go

@@ -732,8 +732,9 @@ func (MyUserinfo2) TableName() string {
 func TestInsertMulti4(t *testing.T) {
 	assert.NoError(t, prepareEngine())
 
-	testEngine.ShowSQL(true)
+	testEngine.ShowSQL(false)
 	assertSync(t, new(MyUserinfo2))
+	testEngine.ShowSQL(true)
 
 	users := []MyUserinfo2{
 		{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},

+ 19 - 4
session_pk_test.go

@@ -1118,13 +1118,28 @@ func TestCompositePK(t *testing.T) {
 	}
 
 	assert.NoError(t, prepareEngine())
-	assertSync(t, new(TaskSolution))
 
+	tables1, err := testEngine.DBMetas()
+	assert.NoError(t, err)
+
+	assertSync(t, new(TaskSolution))
 	assert.NoError(t, testEngine.Sync2(new(TaskSolution)))
-	tables, err := testEngine.DBMetas()
+
+	tables2, err := testEngine.DBMetas()
 	assert.NoError(t, err)
-	assert.EqualValues(t, 1, len(tables))
-	pkCols := tables[0].PKColumns()
+	assert.EqualValues(t, 1+len(tables1), len(tables2))
+
+	var table *core.Table
+	for _, t := range tables2 {
+		if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") {
+			table = t
+			break
+		}
+	}
+
+	assert.NotEqual(t, nil, table)
+
+	pkCols := table.PKColumns()
 	assert.EqualValues(t, 2, len(pkCols))
 	assert.EqualValues(t, "uid", pkCols[0].Name)
 	assert.EqualValues(t, "tid", pkCols[1].Name)

+ 1 - 1
session_query.go

@@ -90,7 +90,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
 	}
 
 	args := append(session.statement.joinArgs, condArgs...)
-	sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true)
+	sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true)
 	if err != nil {
 		return "", nil, err
 	}

+ 5 - 5
session_query_test.go

@@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) {
 	_, err := testEngine.InsertOne(data)
 	assert.NoError(t, err)
 
-	records, err := testEngine.QueryString("select * from get_var2")
+	records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var2", true))
 	assert.NoError(t, err)
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 5, len(records[0]))
@@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) {
 	_, err := testEngine.Insert(data)
 	assert.NoError(t, err)
 
-	records, err := testEngine.QueryString("select * from get_var3")
+	records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var3", true))
 	assert.NoError(t, err)
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 2, len(records[0]))
@@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) {
 	_, err := testEngine.InsertOne(data)
 	assert.NoError(t, err)
 
-	records, err := testEngine.QueryInterface("select * from get_var_interface")
+	records, err := testEngine.QueryInterface("select * from " + testEngine.TableName("get_var_interface", true))
 	assert.NoError(t, err)
 	assert.Equal(t, 1, len(records))
 	assert.Equal(t, 5, len(records[0]))
@@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) {
 	assert.NoError(t, err)
 	assertResult(t, results)
 
-	results, err = testEngine.SQL("select * from query_no_params").Query()
+	results, err = testEngine.SQL("select * from " + testEngine.TableName("query_no_params", true)).Query()
 	assert.NoError(t, err)
 	assertResult(t, results)
 }
@@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) {
 		assert.EqualValues(t, 3000, money)
 	}
 
-	results, err := testEngine.Query(builder.Select("*").From("query_with_builder"))
+	results, err := testEngine.Query(builder.Select("*").From(testEngine.TableName("query_with_builder", true)))
 	assert.NoError(t, err)
 	assertResult(t, results)
 }

+ 2 - 2
session_raw_test.go

@@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) {
 
 	assert.NoError(t, testEngine.Sync2(new(UserinfoQuery)))
 
-	res, err := testEngine.Exec("INSERT INTO `userinfo_query` (uid, name) VALUES (?, ?)", 1, "user")
+	res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (uid, name) VALUES (?, ?)", 1, "user")
 	assert.NoError(t, err)
 	cnt, err := res.RowsAffected()
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, cnt)
 
-	results, err := testEngine.Query("select * from userinfo_query")
+	results, err := testEngine.Query("select * from " + testEngine.TableName("userinfo_query", true))
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, len(results))
 	id, err := strconv.Atoi(string(results[0]["uid"]))

+ 30 - 51
session_schema.go

@@ -6,9 +6,7 @@ package xorm
 
 import (
 	"database/sql"
-	"errors"
 	"fmt"
-	"reflect"
 	"strings"
 
 	"github.com/xormplus/core"
@@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error {
 }
 
 func (session *Session) createTable(bean interface{}) error {
-	v := rValue(bean)
-	if err := session.statement.setRefValue(v); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return err
 	}
 
@@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error {
 }
 
 func (session *Session) createIndexes(bean interface{}) error {
-	v := rValue(bean)
-	if err := session.statement.setRefValue(v); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return err
 	}
 
@@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error {
 }
 
 func (session *Session) createUniques(bean interface{}) error {
-	v := rValue(bean)
-	if err := session.statement.setRefValue(v); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return err
 	}
 
@@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error {
 }
 
 func (session *Session) dropIndexes(bean interface{}) error {
-	v := rValue(bean)
-	if err := session.statement.setRefValue(v); err != nil {
+	if err := session.statement.setRefBean(bean); err != nil {
 		return err
 	}
 
@@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
 }
 
 func (session *Session) dropTable(beanOrTableName interface{}) error {
-	tableName, err := session.engine.tableName(beanOrTableName)
-	if err != nil {
-		return err
-	}
-
+	tableName := session.engine.tbNameNoSchema(beanOrTableName)
 	var needDrop = true
 	if !session.engine.dialect.SupportDropIfExists() {
 		sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
@@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
 	}
 
 	if needDrop {
-		sqlStr := session.engine.Dialect().DropTableSql(tableName)
-		_, err = session.exec(sqlStr)
+		sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
+		_, err := session.exec(sqlStr)
 		return err
 	}
 	return nil
@@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
 		defer session.Close()
 	}
 
-	tableName, err := session.engine.tableName(beanOrTableName)
-	if err != nil {
-		return false, err
-	}
+	tableName := session.engine.tbNameNoSchema(beanOrTableName)
 
 	return session.isTableExist(tableName)
 }
@@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) {
 
 // IsTableEmpty if table have any records
 func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
-	v := rValue(bean)
-	t := v.Type()
-
-	if t.Kind() == reflect.String {
-		if session.isAutoClose {
-			defer session.Close()
-		}
-		return session.isTableEmpty(bean.(string))
-	} else if t.Kind() == reflect.Struct {
-		rows, err := session.Count(bean)
-		return rows == 0, err
+	if session.isAutoClose {
+		defer session.Close()
 	}
-	return false, errors.New("bean should be a struct or struct's point")
+	return session.isTableEmpty(session.engine.tbNameNoSchema(bean))
 }
 
 func (session *Session) isTableEmpty(tableName string) (bool, error) {
 	var total int64
-	sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
+	sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
 	err := session.queryRow(sqlStr).Scan(&total)
 	if err != nil {
 		if err == sql.ErrNoRows {
@@ -270,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
 			return err
 		}
 		structTables = append(structTables, table)
-		var tbName = session.tbNameNoSchema(table)
+		tbName := session.tbNameNoSchema(table)
+		tbNameWithSchema := engine.TableName(tbName, true)
 
 		var oriTable *core.Table
 		for _, tb := range tables {
@@ -315,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error {
 							if engine.dialect.DBType() == core.MYSQL ||
 								engine.dialect.DBType() == core.POSTGRES {
 								engine.logger.Infof("Table %s column %s change type from %s to %s\n",
-									tbName, col.Name, curType, expectedType)
-								_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
+									tbNameWithSchema, col.Name, curType, expectedType)
+								_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
 							} else {
 								engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
-									tbName, col.Name, curType, expectedType)
+									tbNameWithSchema, col.Name, curType, expectedType)
 							}
 						} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
 							if engine.dialect.DBType() == core.MYSQL {
 								if oriCol.Length < col.Length {
 									engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
-										tbName, col.Name, oriCol.Length, col.Length)
-									_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
+										tbNameWithSchema, col.Name, oriCol.Length, col.Length)
+									_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
 								}
 							}
 						} else {
 							if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
 								engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
-									tbName, col.Name, curType, expectedType)
+									tbNameWithSchema, col.Name, curType, expectedType)
 							}
 						}
 					} else if expectedType == core.Varchar {
 						if engine.dialect.DBType() == core.MYSQL {
 							if oriCol.Length < col.Length {
 								engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
-									tbName, col.Name, oriCol.Length, col.Length)
-								_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
+									tbNameWithSchema, col.Name, oriCol.Length, col.Length)
+								_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
 							}
 						}
 					}
@@ -354,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
 					}
 				} else {
 					session.statement.RefTable = table
-					session.statement.tableName = tbName
+					session.statement.tableName = tbNameWithSchema
 					err = session.addColumn(col.Name)
 				}
 				if err != nil {
@@ -378,7 +357,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
 				if oriIndex != nil {
 					if oriIndex.Type != index.Type {
 
-						sql := engine.dialect.DropIndexSql(tbName, oriIndex)
+						sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
 
 						if sql != "" {
 							_, err = session.exec(sql)
@@ -398,7 +377,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
 			for name2, index2 := range oriTable.Indexes {
 				if _, ok := foundIndexNames[name2]; !ok {
 
-					sql := engine.dialect.DropIndexSql(tbName, index2)
+					sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
 
 					if sql != "" {
 						_, err = session.exec(sql)
@@ -412,12 +391,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
 			for name, index := range addedNames {
 				if index.Type == core.UniqueType {
 					session.statement.RefTable = table
-					session.statement.tableName = tbName
-					err = session.addUnique(tbName, name)
+					session.statement.tableName = tbNameWithSchema
+					err = session.addUnique(tbNameWithSchema, name)
 				} else if index.Type == core.IndexType {
 					session.statement.RefTable = table
-					session.statement.tableName = tbName
-					err = session.addIndex(tbName, name)
+					session.statement.tableName = tbNameWithSchema
+					err = session.addIndex(tbNameWithSchema, name)
 				}
 				if err != nil {
 					return err
@@ -442,7 +421,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
 
 		for _, colName := range table.ColumnsSeq() {
 			if oriTable.GetColumn(colName) == nil {
-				engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName)
+				engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName)
 			}
 		}
 	}

+ 26 - 1
session_stats_test.go

@@ -153,8 +153,33 @@ func TestSQLCount(t *testing.T) {
 
 	assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
 
-	total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2").
+	total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)).
 		Count()
 	assert.NoError(t, err)
 	assert.EqualValues(t, 0, total)
 }
+
+func TestCountWithOthers(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	type CountWithOthers struct {
+		Id   int64
+		Name string
+	}
+
+	assertSync(t, new(CountWithOthers))
+
+	_, err := testEngine.Insert(&CountWithOthers{
+		Name: "orderby",
+	})
+	assert.NoError(t, err)
+
+	_, err = testEngine.Insert(&CountWithOthers{
+		Name: "limit",
+	})
+	assert.NoError(t, err)
+
+	total, err := testEngine.OrderBy("id desc").Limit(1).Count(new(CountWithOthers))
+	assert.NoError(t, err)
+	assert.EqualValues(t, 2, total)
+}

+ 19 - 81
session_tx_test.go

@@ -32,45 +32,21 @@ func TestTransaction(t *testing.T) {
 	defer session.Close()
 
 	err := session.Begin()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 	user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	_, err = session.Insert(&user1)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 	user2 := Userinfo{Username: "yyy"}
 	_, err = session.Where("(id) = ?", 0).Update(&user2)
-	if err != nil {
-		session.Rollback()
-		fmt.Println(err)
-		//t.Error(err)
-		return
-	}
+	assert.NoError(t, err)
 
 	_, err = session.Delete(&user2)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 	err = session.Commit()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-		return
-	}
-	// panic(err) !nashtsai! should remove this
+	assert.NoError(t, err)
 }
 
 func TestCombineTransaction(t *testing.T) {
@@ -91,38 +67,21 @@ func TestCombineTransaction(t *testing.T) {
 	defer session.Close()
 
 	err := session.Begin()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	_, err = session.Insert(&user1)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+
 	user2 := Userinfo{Username: "zzz"}
 	_, err = session.Where("id = ?", 0).Update(&user2)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
-	_, err = session.Exec("delete from userinfo where username = ?", user2.Username)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-	}
+	_, err = session.Exec("delete from "+testEngine.TableName("userinfo", true)+" where username = ?", user2.Username)
+	assert.NoError(t, err)
 
 	err = session.Commit()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 }
 
 func TestCombineTransactionSameMapper(t *testing.T) {
@@ -148,45 +107,24 @@ func TestCombineTransactionSameMapper(t *testing.T) {
 
 	counter()
 	defer counter()
+
 	session := testEngine.NewSession()
 	defer session.Close()
 
 	err := session.Begin()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 	user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
 	_, err = session.Insert(&user1)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
 	user2 := Userinfo{Username: "zzz"}
 	_, err = session.Where("(id) = ?", 0).Update(&user2)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
 
-	_, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username)
-	if err != nil {
-		session.Rollback()
-		t.Error(err)
-		panic(err)
-		return
-	}
+	_, err = session.Exec("delete from  "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username)
+	assert.NoError(t, err)
 
 	err = session.Commit()
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 }

+ 1 - 1
session_update.go

@@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 	var isMap = t.Kind() == reflect.Map
 	var isStruct = t.Kind() == reflect.Struct
 	if isStruct {
-		if err := session.statement.setRefValue(v); err != nil {
+		if err := session.statement.setRefBean(bean); err != nil {
 			return 0, err
 		}
 

+ 27 - 109
session_update_test.go

@@ -462,30 +462,18 @@ func TestUpdate1(t *testing.T) {
 
 	col1 := &UpdateAllCols{Ptr: &s}
 	err = testEngine.Sync(col1)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	_, err = testEngine.Insert(col1)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	col2 := &UpdateAllCols{col1.Id, true, "", nil}
 	_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	col3 := &UpdateAllCols{}
 	has, err = testEngine.ID(col2.Id).Get(col3)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	if !has {
 		err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@@ -759,7 +747,7 @@ func TestUpdateUpdated(t *testing.T) {
 func TestUpdateSameMapper(t *testing.T) {
 	assert.NoError(t, prepareEngine())
 
-	oldMapper := testEngine.GetColumnMapper()
+	oldMapper := testEngine.GetTableMapper()
 	testEngine.UnMapType(rValue(new(Userinfo)).Type())
 	testEngine.UnMapType(rValue(new(Condi)).Type())
 	testEngine.UnMapType(rValue(new(Article)).Type())
@@ -786,81 +774,38 @@ func TestUpdateSameMapper(t *testing.T) {
 
 	var ori Userinfo
 	has, err := testEngine.Get(&ori)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if !has {
-		t.Error(errors.New("not exist"))
-		panic(errors.New("not exist"))
-	}
+	assert.NoError(t, err)
+	assert.True(t, has)
+
 	// update by id
 	user := Userinfo{Username: "xxx", Height: 1.2}
 	cnt, err := testEngine.ID(ori.Uid).Update(&user)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-	if cnt != 1 {
-		err = errors.New("update not returned 1")
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 
 	condi := Condi{"Username": "zzz", "Departname": ""}
 	cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-
-	if cnt != 1 {
-		err = errors.New("update not returned 1")
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
 
 	cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	total, err := testEngine.Count(&user)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
-
-	if cnt != total {
-		err = errors.New("insert not returned 1")
-		t.Error(err)
-		panic(err)
-		return
-	}
+	assert.NoError(t, err)
+	assert.EqualValues(t, cnt, total)
 
 	err = testEngine.Sync(&Article{})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	defer func() {
 		err = testEngine.DropTables(&Article{})
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 	}()
 
 	a := &Article{0, "1", "2", "3", "4", "5", 2}
 	cnt, err = testEngine.Insert(a)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	if cnt != 1 {
 		err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@@ -875,10 +820,7 @@ func TestUpdateSameMapper(t *testing.T) {
 	}
 
 	cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	if cnt != 1 {
 		err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@@ -889,30 +831,18 @@ func TestUpdateSameMapper(t *testing.T) {
 
 	col1 := &UpdateAllCols{}
 	err = testEngine.Sync(col1)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	_, err = testEngine.Insert(col1)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	col2 := &UpdateAllCols{col1.Id, true, "", nil}
 	_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	col3 := &UpdateAllCols{}
 	has, err = testEngine.ID(col2.Id).Get(col3)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	if !has {
 		err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@@ -931,32 +861,20 @@ func TestUpdateSameMapper(t *testing.T) {
 	{
 		col1 := &UpdateMustCols{}
 		err = testEngine.Sync(col1)
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 
 		_, err = testEngine.Insert(col1)
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 
 		col2 := &UpdateMustCols{col1.Id, true, ""}
 		boolStr := testEngine.GetColumnMapper().Obj2Table("Bool")
 		stringStr := testEngine.GetColumnMapper().Obj2Table("String")
 		_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 
 		col3 := &UpdateMustCols{}
 		has, err := testEngine.ID(col2.Id).Get(col3)
-		if err != nil {
-			t.Error(err)
-			panic(err)
-		}
+		assert.NoError(t, err)
 
 		if !has {
 			err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))

+ 43 - 78
statement.go

@@ -221,26 +221,18 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
 	if err != nil {
 		return err
 	}
-	statement.tableName = statement.Engine.tbName(v)
+	statement.tableName = statement.Engine.TableName(v.Interface(), true)
 	return nil
 }
 
-// Table tempororily set table name, the parameter could be a string or a pointer of struct
-func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
-	v := rValue(tableNameOrBean)
-	t := v.Type()
-	if t.Kind() == reflect.String {
-		statement.AltTableName = tableNameOrBean.(string)
-	} else if t.Kind() == reflect.Struct {
-		var err error
-		statement.RefTable, err = statement.Engine.autoMapType(v)
-		if err != nil {
-			statement.Engine.logger.Error(err)
-			return statement
-		}
-		statement.AltTableName = statement.Engine.tbName(v)
+func (statement *Statement) setRefBean(bean interface{}) error {
+	var err error
+	statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
+	if err != nil {
+		return err
 	}
-	return statement
+	statement.tableName = statement.Engine.TableName(bean, true)
+	return nil
 }
 
 // Auto generating update columnes and values according a struct
@@ -743,6 +735,23 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
 	return statement
 }
 
+// Table tempororily set table name, the parameter could be a string or a pointer of struct
+func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
+	v := rValue(tableNameOrBean)
+	t := v.Type()
+	if t.Kind() == reflect.Struct {
+		var err error
+		statement.RefTable, err = statement.Engine.autoMapType(v)
+		if err != nil {
+			statement.Engine.logger.Error(err)
+			return statement
+		}
+	}
+
+	statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
+	return statement
+}
+
 // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
 func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
 	var buf bytes.Buffer
@@ -752,56 +761,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 		fmt.Fprintf(&buf, "%v JOIN ", joinOP)
 	}
 
-	switch tablename.(type) {
-	case []string:
-		t := tablename.([]string)
-		if len(t) > 1 {
-			fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
-		} else if len(t) == 1 {
-			fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
-		}
-	case []interface{}:
-		t := tablename.([]interface{})
-		l := len(t)
-		var table string
-		if l > 0 {
-			f := t[0]
-			switch f.(type) {
-			case string:
-				table = f.(string)
-			case TableName:
-				table = f.(TableName).TableName()
-			default:
-				v := rValue(f)
-				t := v.Type()
-				if t.Kind() == reflect.Struct {
-					fmt.Fprintf(&buf, statement.Engine.tbName(v))
-				} else {
-					fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f)))
-				}
-			}
-		}
-		if l > 1 {
-			fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
-				statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
-		} else if l == 1 {
-			fmt.Fprintf(&buf, statement.Engine.Quote(table))
-		}
-	case TableName:
-		fmt.Fprintf(&buf, tablename.(TableName).TableName())
-	case string:
-		fmt.Fprintf(&buf, tablename.(string))
-	default:
-		v := rValue(tablename)
-		t := v.Type()
-		if t.Kind() == reflect.Struct {
-			fmt.Fprintf(&buf, statement.Engine.tbName(v))
-		} else {
-			fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
-		}
-	}
+	tbName := statement.Engine.TableName(tablename, true)
 
-	fmt.Fprintf(&buf, " ON %v", condition)
+	fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
 	statement.JoinStr = buf.String()
 	statement.joinArgs = append(statement.joinArgs, args...)
 	return statement
@@ -906,16 +868,18 @@ func (statement *Statement) genUniqueSQL() []string {
 func (statement *Statement) genDelIndexSQL() []string {
 	var sqls []string
 	tbName := statement.TableName()
+	idxPrefixName := strings.Replace(tbName, `"`, "", -1)
+	idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
 	for idxName, index := range statement.RefTable.Indexes {
 		var rIdxName string
 		if index.Type == core.UniqueType {
-			rIdxName = uniqueName(tbName, idxName)
+			rIdxName = uniqueName(idxPrefixName, idxName)
 		} else if index.Type == core.IndexType {
-			rIdxName = indexName(tbName, idxName)
+			rIdxName = indexName(idxPrefixName, idxName)
 		}
-		sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
+		sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
 		if statement.Engine.dialect.IndexOnTable() {
-			sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
+			sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
 		}
 		sqls = append(sqls, sql)
 	}
@@ -966,7 +930,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
 	v := rValue(bean)
 	isStruct := v.Kind() == reflect.Struct
 	if isStruct {
-		statement.setRefValue(v)
+		statement.setRefBean(bean)
 	}
 
 	var columnStr = statement.ColumnStr
@@ -1005,7 +969,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
 		return "", nil, err
 	}
 
-	sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true)
+	sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
 	if err != nil {
 		return "", nil, err
 	}
@@ -1018,7 +982,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
 	var condArgs []interface{}
 	var err error
 	if len(beans) > 0 {
-		statement.setRefValue(rValue(beans[0]))
+		statement.setRefBean(beans[0])
 		condSQL, condArgs, err = statement.genConds(beans[0])
 	} else {
 		condSQL, condArgs, err = builder.ToSQL(statement.cond)
@@ -1035,7 +999,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
 			selectSQL = "count(*)"
 		}
 	}
-	sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false)
+	sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
 	if err != nil {
 		return "", nil, err
 	}
@@ -1044,7 +1008,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
 }
 
 func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
-	statement.setRefValue(rValue(bean))
+	statement.setRefBean(bean)
 
 	var sumStrs = make([]string, 0, len(columns))
 	for _, colName := range columns {
@@ -1060,7 +1024,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
 		return "", nil, err
 	}
 
-	sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true)
+	sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
 	if err != nil {
 		return "", nil, err
 	}
@@ -1068,7 +1032,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
 	return sqlStr, append(statement.joinArgs, condArgs...), nil
 }
 
-func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) {
+func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (a string, err error) {
 	var distinct string
 	if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
 		distinct = "DISTINCT "
@@ -1135,9 +1099,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bo
 			}
 
 			var orderStr string
-			if len(statement.OrderStr) > 0 {
+			if needOrderBy && len(statement.OrderStr) > 0 {
 				orderStr = " ORDER BY " + statement.OrderStr
 			}
+
 			var groupStr string
 			if len(statement.GroupByStr) > 0 {
 				groupStr = " GROUP BY " + statement.GroupByStr
@@ -1163,7 +1128,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bo
 	if statement.HavingStr != "" {
 		a = fmt.Sprintf("%v %v", a, statement.HavingStr)
 	}
-	if statement.OrderStr != "" {
+	if needOrderBy && statement.OrderStr != "" {
 		a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
 	}
 	if needLimit {

+ 27 - 36
tag_extends_test.go

@@ -202,17 +202,14 @@ func TestExtends(t *testing.T) {
 
 	var info UserAndDetail
 	qt := testEngine.Quote
-	ui := testEngine.GetTableMapper().Obj2Table("Userinfo")
-	ud := testEngine.GetTableMapper().Obj2Table("Userdetail")
-	uiid := testEngine.GetTableMapper().Obj2Table("Id")
+	ui := testEngine.TableName(new(Userinfo), true)
+	ud := testEngine.TableName(&detail, true)
+	uiid := testEngine.GetColumnMapper().Obj2Table("Id")
 	udid := "detail_id"
 	sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s",
 		qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
 	b, err := testEngine.SQL(sql).NoCascade().Get(&info)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 	if !b {
 		err = errors.New("should has lest one record")
 		t.Error(err)
@@ -341,19 +338,17 @@ func TestExtends2(t *testing.T) {
 	}
 
 	var mapper = testEngine.GetTableMapper().Obj2Table
-	userTableName := mapper("MessageUser")
-	typeTableName := mapper("MessageType")
-	msgTableName := mapper("Message")
+	var quote = testEngine.Quote
+	userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
+	typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
+	msgTableName := quote(testEngine.TableName(mapper("Message"), true))
 
 	list := make([]Message, 0)
-	err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
-		Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`").
-		Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
+	err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
+		Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
+		Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
 		Find(&list)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	if len(list) != 1 {
 		err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@@ -406,25 +401,20 @@ func TestExtends3(t *testing.T) {
 		assert.NoError(t, err)
 	}
 	_, err = testEngine.Insert(&msg)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	var mapper = testEngine.GetTableMapper().Obj2Table
-	userTableName := mapper("MessageUser")
-	typeTableName := mapper("MessageType")
-	msgTableName := mapper("Message")
+	var quote = testEngine.Quote
+	userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
+	typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
+	msgTableName := quote(testEngine.TableName(mapper("Message"), true))
 
 	list := make([]MessageExtend3, 0)
-	err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
-		Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`").
-		Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
+	err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
+		Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
+		Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
 		Find(&list)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
 
 	if len(list) != 1 {
 		err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@@ -499,13 +489,14 @@ func TestExtends4(t *testing.T) {
 	}
 
 	var mapper = testEngine.GetTableMapper().Obj2Table
-	userTableName := mapper("MessageUser")
-	typeTableName := mapper("MessageType")
-	msgTableName := mapper("Message")
+	var quote = testEngine.Quote
+	userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
+	typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
+	msgTableName := quote(testEngine.TableName(mapper("Message"), true))
 
 	list := make([]MessageExtend4, 0)
-	err = testEngine.Table(msgTableName).Join("LEFT", userTableName, "`"+userTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
-		Join("LEFT", typeTableName, "`"+typeTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
+	err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
+		Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
 		Find(&list)
 	if err != nil {
 		t.Error(err)

+ 1 - 1
test/xorm_test.go

@@ -635,7 +635,7 @@ func Test_SqlTemplateClient_Search_Json(t *testing.T) {
 
 func Test_Query(t *testing.T) {
 
-	result, err := db.Query("select * from category where id =25")
+	result, err := db.QueryInterface("select * from category where id =25")
 	if err != nil {
 		t.Fatal(err)
 	}

+ 4 - 3
types_test.go

@@ -301,10 +301,11 @@ type UserCus struct {
 func TestCustomType2(t *testing.T) {
 	assert.NoError(t, prepareEngine())
 
-	err := testEngine.CreateTables(&UserCus{})
+	var uc UserCus
+	err := testEngine.CreateTables(&uc)
 	assert.NoError(t, err)
 
-	tableName := testEngine.TableMapper.Obj2Table("UserCus")
+	tableName := testEngine.TableName(&uc, true)
 	_, err = testEngine.Exec("delete from " + testEngine.Quote(tableName))
 	assert.NoError(t, err)
 
@@ -327,7 +328,7 @@ func TestCustomType2(t *testing.T) {
 	fmt.Println(user)
 
 	users := make([]UserCus, 0)
-	err = testEngine.Where("`"+testEngine.ColumnMapper.Obj2Table("Status")+"` = ?", "Registed").Find(&users)
+	err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Status")+"` = ?", "Registed").Find(&users)
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, len(users))
 

+ 4 - 1
xorm_test.go

@@ -27,6 +27,7 @@ var (
 	cache      = flag.Bool("cache", false, "if enable cache")
 	cluster    = flag.Bool("cluster", false, "if this is a cluster")
 	splitter   = flag.String("splitter", ";", "the splitter on connstr for cluster")
+	schema     = flag.String("schema", "", "specify the schema")
 )
 
 func createEngine(dbType, connStr string) error {
@@ -35,7 +36,6 @@ func createEngine(dbType, connStr string) error {
 
 		if !*cluster {
 			testEngine, err = NewEngine(dbType, connStr)
-
 		} else {
 			testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter))
 		}
@@ -43,6 +43,9 @@ func createEngine(dbType, connStr string) error {
 			return err
 		}
 
+		if *schema != "" {
+			testEngine.SetSchema(*schema)
+		}
 		testEngine.ShowSQL(*showSQL)
 		testEngine.SetLogLevel(core.LOG_DEBUG)
 		if *cache {