Browse Source

bug fixed

xormplus 9 years ago
parent
commit
9779f2cd44
9 changed files with 242 additions and 210 deletions
  1. 59 67
      engine.go
  2. 15 0
      helpers.go
  3. 7 2
      rows.go
  4. 117 100
      session.go
  5. 4 2
      sessionplus.go
  6. 30 35
      statement.go
  7. 8 2
      test/sql/oracle/studygolang.xml
  8. 1 1
      test/xorm_test.go
  9. 1 1
      xorm.go

+ 59 - 67
engine.go

@@ -289,46 +289,6 @@ func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, exe
 	}
 }
 
-// LogError logging error
-/*func (engine *Engine) LogError(contents ...interface{}) {
-	engine.logger.Err(contents...)
-}
-
-// LogErrorf logging errorf
-func (engine *Engine) LogErrorf(format string, contents ...interface{}) {
-	engine.logger.Errf(format, contents...)
-}
-
-// LogInfo logging info
-func (engine *Engine) LogInfo(contents ...interface{}) {
-	engine.logger.Info(contents...)
-}
-
-// LogInfof logging infof
-func (engine *Engine) LogInfof(format string, contents ...interface{}) {
-	engine.logger.Infof(format, contents...)
-}
-
-// LogDebug logging debug
-func (engine *Engine) LogDebug(contents ...interface{}) {
-	engine.logger.Debug(contents...)
-}
-
-// LogDebugf logging debugf
-func (engine *Engine) LogDebugf(format string, contents ...interface{}) {
-	engine.logger.Debugf(format, contents...)
-}
-
-// LogWarn logging warn
-func (engine *Engine) LogWarn(contents ...interface{}) {
-	engine.logger.Warning(contents...)
-}
-
-// LogWarnf logging warnf
-func (engine *Engine) LogWarnf(format string, contents ...interface{}) {
-	engine.logger.Warningf(format, contents...)
-}*/
-
 // Sql method let's you manualy write raw sql and operate
 // For example:
 //
@@ -425,8 +385,26 @@ func (engine *Engine) DumpTables(tables []*core.Table, w io.Writer, tp ...core.D
 	return engine.dumpTables(tables, w, tp...)
 }
 
-func (engine *Engine) tbName(tb *core.Table) string {
-	return tb.Name
+func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
+	v := rValue(beanOrTableName)
+	if v.Type().Kind() == reflect.String {
+		return beanOrTableName.(string), nil
+	} else if v.Type().Kind() == reflect.Struct {
+		return engine.tbName(v), nil
+	}
+	return "", errors.New("bean should be a struct or struct's point")
+}
+
+func (engine *Engine) tbName(v reflect.Value) string {
+	if tb, ok := v.Interface().(TableName); ok {
+		return tb.TableName()
+	}
+	if v.CanAddr() {
+		if tb, ok := v.Addr().Interface().(TableName); ok {
+			return tb.TableName()
+		}
+	}
+	return engine.TableMapper.Obj2Table(v.Type().Name())
 }
 
 // DumpAll dump database all table structs and data to w with specify db type
@@ -465,16 +443,17 @@ func (engine *Engine) dumpAll(w io.Writer, tp ...core.DbType) error {
 			return err
 		}
 		for _, index := range table.Indexes {
-			_, err = io.WriteString(w, dialect.CreateIndexSql(engine.tbName(table), index)+";\n")
+			_, err = io.WriteString(w, dialect.CreateIndexSql(table.Name, index)+";\n")
 			if err != nil {
 				return err
 			}
 		}
 
-		rows, err := engine.DB().Query("SELECT * FROM " + engine.Quote(engine.tbName(table)))
+		rows, err := engine.DB().Query("SELECT * FROM " + engine.Quote(table.Name))
 		if err != nil {
 			return err
 		}
+		defer rows.Close()
 
 		cols, err := rows.Columns()
 		if err != nil {
@@ -490,7 +469,7 @@ func (engine *Engine) dumpAll(w io.Writer, tp ...core.DbType) error {
 				return err
 			}
 
-			_, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(engine.tbName(table))+" ("+dialect.Quote(strings.Join(cols, dialect.Quote(", ")))+") VALUES (")
+			_, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+dialect.Quote(strings.Join(cols, dialect.Quote(", ")))+") VALUES (")
 			if err != nil {
 				return err
 			}
@@ -565,16 +544,17 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
 			return err
 		}
 		for _, index := range table.Indexes {
-			_, err = io.WriteString(w, dialect.CreateIndexSql(engine.tbName(table), index)+";\n")
+			_, err = io.WriteString(w, dialect.CreateIndexSql(table.Name, index)+";\n")
 			if err != nil {
 				return err
 			}
 		}
 
-		rows, err := engine.DB().Query("SELECT * FROM " + engine.Quote(engine.tbName(table)))
+		rows, err := engine.DB().Query("SELECT * FROM " + engine.Quote(table.Name))
 		if err != nil {
 			return err
 		}
+		defer rows.Close()
 
 		cols, err := rows.Columns()
 		if err != nil {
@@ -590,7 +570,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
 				return err
 			}
 
-			_, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(engine.tbName(table))+" ("+dialect.Quote(strings.Join(cols, dialect.Quote(", ")))+") VALUES (")
+			_, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+dialect.Quote(strings.Join(cols, dialect.Quote(", ")))+") VALUES (")
 			if err != nil {
 				return err
 			}
@@ -872,9 +852,14 @@ func (engine *Engine) GobRegister(v interface{}) *Engine {
 	return engine
 }
 
-func (engine *Engine) TableInfo(bean interface{}) *core.Table {
+type Table struct {
+	*core.Table
+	Name string
+}
+
+func (engine *Engine) TableInfo(bean interface{}) *Table {
 	v := rValue(bean)
-	return engine.autoMapType(v)
+	return &Table{engine.autoMapType(v), engine.tbName(v)}
 }
 
 func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@@ -1251,18 +1236,20 @@ func (engine *Engine) getCacher(v reflect.Value) core.Cacher {
 
 // If enabled cache, clear the cache bean
 func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
-	t := rType(bean)
+	v := rValue(bean)
+	t := v.Type()
 	if t.Kind() != reflect.Struct {
 		return errors.New("error params")
 	}
-	table := engine.TableInfo(bean)
+	tableName := engine.tbName(v)
+	table := engine.autoMapType(v)
 	cacher := table.Cacher
 	if cacher == nil {
 		cacher = engine.Cacher
 	}
 	if cacher != nil {
-		cacher.ClearIds(table.Name)
-		cacher.DelBean(table.Name, id)
+		cacher.ClearIds(tableName)
+		cacher.DelBean(tableName, id)
 	}
 	return nil
 }
@@ -1270,18 +1257,20 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
 // If enabled cache, clear some tables' cache
 func (engine *Engine) ClearCache(beans ...interface{}) error {
 	for _, bean := range beans {
-		t := rType(bean)
+		v := rValue(bean)
+		t := v.Type()
 		if t.Kind() != reflect.Struct {
 			return errors.New("error params")
 		}
-		table := engine.TableInfo(bean)
+		tableName := engine.tbName(v)
+		table := engine.autoMapType(v)
 		cacher := table.Cacher
 		if cacher == nil {
 			cacher = engine.Cacher
 		}
 		if cacher != nil {
-			cacher.ClearIds(table.Name)
-			cacher.ClearBeans(table.Name)
+			cacher.ClearIds(tableName)
+			cacher.ClearBeans(tableName)
 		}
 	}
 	return nil
@@ -1292,11 +1281,13 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
 // If you change some field, you should change the database manually.
 func (engine *Engine) Sync(beans ...interface{}) error {
 	for _, bean := range beans {
-		table := engine.TableInfo(bean)
+		v := rValue(bean)
+		tableName := engine.tbName(v)
+		table := engine.autoMapType(v)
 
 		s := engine.NewSession()
 		defer s.Close()
-		isExist, err := s.Table(bean).isTableExist(table.Name)
+		isExist, err := s.Table(bean).isTableExist(tableName)
 		if err != nil {
 			return err
 		}
@@ -1310,7 +1301,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 		  if err != nil {
 		      return err
 		  }*/
-		var isEmpty bool = false
+		var isEmpty bool
 		if isEmpty {
 			err = engine.DropTables(bean)
 			if err != nil {
@@ -1325,7 +1316,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 				session := engine.NewSession()
 				session.Statement.RefTable = table
 				defer session.Close()
-				isExist, err := session.Engine.dialect.IsColumnExist(table.Name, col.Name)
+				isExist, err := session.Engine.dialect.IsColumnExist(tableName, col.Name)
 				if err != nil {
 					return err
 				}
@@ -1346,7 +1337,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 				defer session.Close()
 				if index.Type == core.UniqueType {
 					//isExist, err := session.isIndexExist(table.Name, name, true)
-					isExist, err := session.isIndexExist2(table.Name, index.Cols, true)
+					isExist, err := session.isIndexExist2(tableName, index.Cols, true)
 					if err != nil {
 						return err
 					}
@@ -1354,13 +1345,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 						session := engine.NewSession()
 						session.Statement.RefTable = table
 						defer session.Close()
-						err = session.addUnique(engine.tbName(table), name)
+						err = session.addUnique(tableName, name)
 						if err != nil {
 							return err
 						}
 					}
 				} else if index.Type == core.IndexType {
-					isExist, err := session.isIndexExist2(table.Name, index.Cols, false)
+					isExist, err := session.isIndexExist2(tableName, index.Cols, false)
 					if err != nil {
 						return err
 					}
@@ -1368,7 +1359,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
 						session := engine.NewSession()
 						session.Statement.RefTable = table
 						defer session.Close()
-						err = session.addIndex(engine.tbName(table), name)
+						err = session.addIndex(tableName, name)
 						if err != nil {
 							return err
 						}
@@ -1420,8 +1411,9 @@ func (engine *Engine) dropAll() error {
 // CreateTables create tabls according bean
 func (engine *Engine) CreateTables(beans ...interface{}) error {
 	session := engine.NewSession()
-	err := session.Begin()
 	defer session.Close()
+
+	err := session.Begin()
 	if err != nil {
 		return err
 	}

+ 15 - 0
helpers.go

@@ -457,6 +457,21 @@ func query2(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []m
 	return rows2Strings(rows)
 }
 
+func setColumnInt(bean interface{}, col *core.Column, t int64) {
+	v, err := col.ValueOf(bean)
+	if err != nil {
+		return
+	}
+	if v.CanSet() {
+		switch v.Type().Kind() {
+		case reflect.Int, reflect.Int64, reflect.Int32:
+			v.SetInt(t)
+		case reflect.Uint, reflect.Uint64, reflect.Uint32:
+			v.SetUint(uint64(t))
+		}
+	}
+}
+
 func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
 	v, err := col.ValueOf(bean)
 	if err != nil {

+ 7 - 2
rows.go

@@ -29,11 +29,16 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
 	rows.session = session
 	rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type()
 
-	defer rows.session.Statement.Init()
+	defer rows.session.resetStatement()
 
 	var sqlStr string
 	var args []interface{}
-	rows.session.Statement.RefTable = rows.session.Engine.TableInfo(bean)
+
+	rows.session.Statement.setRefValue(rValue(bean))
+	if len(session.Statement.TableName()) <= 0 {
+		return nil, ErrTableNotFound
+	}
+
 	if rows.session.Statement.RawSQL == "" {
 		sqlStr, args = rows.session.Statement.genGetSql(bean)
 	} else {

+ 117 - 100
session.go

@@ -446,7 +446,9 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
 	}
 
 	var col *core.Column
-	table := session.Engine.autoMapType(dataStruct)
+	session.Statement.setRefValue(dataStruct)
+	table := session.Statement.RefTable
+	tableName := session.Statement.tableName
 
 	for key, data := range objMap {
 		if col = table.GetColumn(key); col == nil {
@@ -470,7 +472,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
 			fieldValue = dataStruct.FieldByName(fieldName)
 		}
 		if !fieldValue.IsValid() || !fieldValue.CanSet() {
-			session.Engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key)
+			session.Engine.logger.Warnf("table %v's column %v is not valid or cannot set", tableName, key)
 			continue
 		}
 
@@ -536,7 +538,7 @@ func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, er
 // CreateTable create a table according a bean
 func (session *Session) CreateTable(bean interface{}) error {
 	v := rValue(bean)
-	session.Statement.RefTable = session.Engine.mapType(v)
+	session.Statement.setRefValue(v)
 
 	defer session.resetStatement()
 	if session.IsAutoClose {
@@ -549,7 +551,7 @@ func (session *Session) CreateTable(bean interface{}) error {
 // CreateIndexes create indexes
 func (session *Session) CreateIndexes(bean interface{}) error {
 	v := rValue(bean)
-	session.Statement.RefTable = session.Engine.mapType(v)
+	session.Statement.setRefValue(v)
 
 	defer session.resetStatement()
 	if session.IsAutoClose {
@@ -569,7 +571,7 @@ func (session *Session) CreateIndexes(bean interface{}) error {
 // CreateUniques create uniques
 func (session *Session) CreateUniques(bean interface{}) error {
 	v := rValue(bean)
-	session.Statement.RefTable = session.Engine.mapType(v)
+	session.Statement.setRefValue(v)
 
 	defer session.resetStatement()
 	if session.IsAutoClose {
@@ -594,14 +596,15 @@ func (session *Session) createOneTable() error {
 
 // to be deleted
 func (session *Session) createAll() error {
-	defer session.resetStatement()
 	if session.IsAutoClose {
 		defer session.Close()
 	}
 
 	for _, table := range session.Engine.Tables {
 		session.Statement.RefTable = table
+		session.Statement.tableName = table.Name
 		err := session.createOneTable()
+		session.resetStatement()
 		if err != nil {
 			return err
 		}
@@ -611,6 +614,9 @@ func (session *Session) createAll() error {
 
 // drop indexes
 func (session *Session) DropIndexes(bean interface{}) error {
+	v := rValue(bean)
+	session.Statement.setRefValue(v)
+
 	defer session.resetStatement()
 	if session.IsAutoClose {
 		defer session.Close()
@@ -777,9 +783,11 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
 		return ErrCacheFailed
 	}
 
+	tableName := session.Statement.TableName()
+
 	table := session.Statement.RefTable
 	cacher := session.Engine.getCacher2(table)
-	ids, err := core.GetCacheSql(cacher, session.Statement.TableName(), newsql, args)
+	ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
 	if err != nil {
 		rows, err := session.DB().Query(newsql, args...)
 		if err != nil {
@@ -819,8 +827,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
 			ids = append(ids, pk)
 		}
 
-		tableName := session.Statement.TableName()
-
 		session.Engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, newsql, args)
 		err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
 		if err != nil {
@@ -835,7 +841,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
 	ididxes := make(map[string]int)
 	var ides []core.PK = make([]core.PK, 0)
 	var temps []interface{} = make([]interface{}, len(ids))
-	tableName := session.Statement.TableName()
+
 	for idx, id := range ids {
 		sid, err := id.ToString()
 		if err != nil {
@@ -1009,14 +1015,16 @@ func (session *Session) Get(bean interface{}) (bool, error) {
 		defer session.Close()
 	}
 
+	session.Statement.setRefValue(rValue(bean))
+	if len(session.Statement.TableName()) <= 0 {
+		return false, ErrTableNotFound
+	}
+
 	session.Statement.Limit(1)
+
 	var sqlStr string
 	var args []interface{}
 
-	if session.Statement.RefTable == nil {
-		session.Statement.RefTable = session.Engine.TableInfo(bean)
-	}
-
 	if session.Statement.RawSQL == "" {
 		sqlStr, args = session.Statement.genGetSql(bean)
 	} else {
@@ -1201,26 +1209,29 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 	}
 
 	sliceElementType := sliceValue.Type().Elem()
-	var table *core.Table
+
 	if session.Statement.RefTable == nil {
 		if sliceElementType.Kind() == reflect.Ptr {
 			if sliceElementType.Elem().Kind() == reflect.Struct {
 				pv := reflect.New(sliceElementType.Elem())
-				table = session.Engine.autoMapType(pv.Elem())
+				session.Statement.setRefValue(pv.Elem())
 			} else {
 				return errors.New("slice type")
 			}
 		} else if sliceElementType.Kind() == reflect.Struct {
 			pv := reflect.New(sliceElementType)
-			table = session.Engine.autoMapType(pv.Elem())
+			session.Statement.setRefValue(pv.Elem())
 		} else {
 			return errors.New("slice type")
 		}
-		session.Statement.RefTable = table
-	} else {
-		table = session.Statement.RefTable
 	}
 
+	if len(session.Statement.TableName()) <= 0 {
+		return ErrTableNotFound
+	}
+
+	var table = session.Statement.RefTable
+
 	var addedTableName = (len(session.Statement.JoinStr) > 0)
 	if !session.Statement.noAutoCondition && len(condiBean) > 0 {
 		colNames, args := session.Statement.buildConditions(table, condiBean[0], true, true, false, true, addedTableName)
@@ -1433,17 +1444,6 @@ func (session *Session) Ping() error {
 	return session.DB().Ping()
 }
 
-func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
-	v := rValue(beanOrTableName)
-	if v.Type().Kind() == reflect.String {
-		return beanOrTableName.(string), nil
-	} else if v.Type().Kind() == reflect.Struct {
-		table := engine.autoMapType(v)
-		return table.Name, nil
-	}
-	return "", errors.New("bean should be a struct or struct's point")
-}
-
 // IsTableExist if a table is exist
 func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
 	tableName, err := session.Engine.tableName(beanOrTableName)
@@ -1472,7 +1472,6 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
 	if t.Kind() == reflect.String {
 		return session.isTableEmpty(bean.(string))
 	} else if t.Kind() == reflect.Struct {
-		session.Engine.autoMapType(v)
 		rows, err := session.Count(bean)
 		return rows == 0, err
 	}
@@ -1635,8 +1634,9 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
 		return errors.New("Expected a pointer to a struct")
 	}
 
-	table := session.Engine.autoMapType(dataStruct)
-	return session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, table)
+	session.Statement.setRefValue(dataStruct)
+
+	return session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, session.Statement.RefTable)
 }
 
 func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) error {
@@ -1909,6 +1909,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 				} else if session.Statement.UseCascade {
 					table := session.Engine.autoMapType(*fieldValue)
 					if table != nil {
+						hasAssigned = true
 						if len(table.PrimaryKeys) != 1 {
 							panic("unsupported non or composited primary key cascade")
 						}
@@ -2198,7 +2199,7 @@ func (session *Session) query2(sqlStr string, paramStr ...interface{}) (resultsS
 func (session *Session) Insert(beans ...interface{}) (int64, error) {
 	var affected int64
 	var err error
-	defer session.resetStatement()
+
 	if session.IsAutoClose {
 		defer session.Close()
 	}
@@ -2210,6 +2211,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
 			if size > 0 {
 				if session.Engine.SupportInsertMany() {
 					cnt, err := session.innerInsertMulti(bean)
+					session.resetStatement()
 					if err != nil {
 						return affected, err
 					}
@@ -2217,6 +2219,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
 				} else {
 					for i := 0; i < size; i++ {
 						cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
+						session.resetStatement()
 						if err != nil {
 							return affected, err
 						}
@@ -2226,6 +2229,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
 			}
 		} else {
 			cnt, err := session.innerInsert(bean)
+			session.resetStatement()
 			if err != nil {
 				return affected, err
 			}
@@ -2244,23 +2248,24 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 
 	bean := sliceValue.Index(0).Interface()
 	elementValue := rValue(bean)
-	//sliceElementType := elementValue.Type()
-
-	table := session.Engine.autoMapType(elementValue)
-	session.Statement.RefTable = table
+	session.Statement.setRefValue(elementValue)
+	if len(session.Statement.TableName()) <= 0 {
+		return 0, ErrTableNotFound
+	}
 
+	table := session.Statement.RefTable
 	size := sliceValue.Len()
 
-	colNames := make([]string, 0)
-	colMultiPlaces := make([]string, 0)
-	var args = make([]interface{}, 0)
-	cols := make([]*core.Column, 0)
+	var colNames []string
+	var colMultiPlaces []string
+	var args []interface{}
+	var cols []*core.Column
 
 	for i := 0; i < size; i++ {
 		v := sliceValue.Index(i)
 		vv := reflect.Indirect(v)
 		elemValue := v.Interface()
-		colPlaces := make([]string, 0)
+		var colPlaces []string
 
 		// handle BeforeInsertProcessor
 		// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
@@ -2308,6 +2313,13 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 						col := table.GetColumn(colName)
 						setColumnTime(bean, col, t)
 					})
+				} else if col.IsVersion && session.Statement.checkVersion {
+					args = append(args, 1)
+					var colName = col.Name
+					session.afterClosures = append(session.afterClosures, func(bean interface{}) {
+						col := table.GetColumn(colName)
+						setColumnInt(bean, col, 1)
+					})
 				} else {
 					arg, err := session.value2Interface(col, fieldValue)
 					if err != nil {
@@ -2356,6 +2368,13 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 						col := table.GetColumn(colName)
 						setColumnTime(bean, col, t)
 					})
+				} else if col.IsVersion && session.Statement.checkVersion {
+					args = append(args, 1)
+					var colName = col.Name
+					session.afterClosures = append(session.afterClosures, func(bean interface{}) {
+						col := table.GetColumn(colName)
+						setColumnInt(bean, col, 1)
+					})
 				} else {
 					arg, err := session.value2Interface(col, fieldValue)
 					if err != nil {
@@ -2416,24 +2435,29 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 			}
 		}
 	}
+
 	cleanupProcessorsClosures(&session.afterClosures)
 	return res.RowsAffected()
 }
 
 // InsertMulti insert multiple records
 func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
+	defer session.resetStatement()
+	if session.IsAutoClose {
+		defer session.Close()
+	}
+
 	sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
-	if sliceValue.Kind() == reflect.Slice {
-		if sliceValue.Len() > 0 {
-			defer session.resetStatement()
-			if session.IsAutoClose {
-				defer session.Close()
-			}
-			return session.innerInsertMulti(rowsSlicePtr)
-		}
+	if sliceValue.Kind() != reflect.Slice {
+		return 0, ErrParamsType
+
+	}
+
+	if sliceValue.Len() <= 0 {
 		return 0, nil
 	}
-	return 0, ErrParamsType
+
+	return session.innerInsertMulti(rowsSlicePtr)
 }
 
 func (session *Session) str2Time(col *core.Column, data string) (outTime time.Time, outErr error) {
@@ -3084,8 +3108,12 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
 }
 
 func (session *Session) innerInsert(bean interface{}) (int64, error) {
-	table := session.Engine.TableInfo(bean)
-	session.Statement.RefTable = table
+	session.Statement.setRefValue(rValue(bean))
+	if len(session.Statement.TableName()) <= 0 {
+		return 0, ErrTableNotFound
+	}
+
+	table := session.Statement.RefTable
 
 	// handle BeforeInsertProcessor
 	for _, closure := range session.beforeClosures {
@@ -3097,7 +3125,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 		processor.BeforeInsert()
 	}
 	// --
-	colNames, args, err := genCols(table, session, bean, false, false)
+	colNames, args, err := genCols(session.Statement.RefTable, session, bean, false, false)
 	if err != nil {
 		return 0, err
 	}
@@ -3460,11 +3488,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 		defer session.Close()
 	}
 
-	t := rType(bean)
+	v := rValue(bean)
+	t := v.Type()
 
 	var colNames []string
 	var args []interface{}
-	var table *core.Table
 
 	// handle before update processors
 	for _, closure := range session.beforeClosures {
@@ -3480,25 +3508,24 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 	var isMap = t.Kind() == reflect.Map
 	var isStruct = t.Kind() == reflect.Struct
 	if isStruct {
-		table = session.Engine.TableInfo(bean)
-		session.Statement.RefTable = table
+		session.Statement.setRefValue(v)
+
+		if len(session.Statement.TableName()) <= 0 {
+			return 0, ErrTableNotFound
+		}
 
 		if session.Statement.ColumnStr == "" {
-			colNames, args = buildUpdates(session.Engine, table, bean, false, false,
+			colNames, args = buildUpdates(session.Engine, session.Statement.RefTable, bean, false, false,
 				false, false, session.Statement.allUseBool, session.Statement.useAllCols,
 				session.Statement.mustColumnMap, session.Statement.nullableMap,
 				session.Statement.columnMap, true, session.Statement.unscoped)
 		} else {
-			colNames, args, err = genCols(table, session, bean, true, true)
+			colNames, args, err = genCols(session.Statement.RefTable, session, bean, true, true)
 			if err != nil {
 				return 0, err
 			}
 		}
 	} else if isMap {
-		if session.Statement.RefTable == nil {
-			return 0, ErrTableNotFound
-		}
-		table = session.Statement.RefTable
 		colNames = make([]string, 0)
 		args = make([]interface{}, 0)
 		bValue := reflect.Indirect(reflect.ValueOf(bean))
@@ -3511,7 +3538,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 		return 0, ErrParamsType
 	}
 
-	if session.Statement.UseAutoTime && table.Updated != "" {
+	table := session.Statement.RefTable
+
+	if session.Statement.UseAutoTime && table != nil && table.Updated != "" {
 		colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
 		col := table.UpdatedColumn()
 		val, t := session.Engine.NowTime2(col.SQLType.Name)
@@ -3574,7 +3603,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 	var inArgs []interface{}
 	doIncVer := false
 	var verValue *reflect.Value
-	if table.Version != "" && session.Statement.checkVersion {
+	if table != nil && table.Version != "" && session.Statement.checkVersion {
 		if condition != "" {
 			condition = fmt.Sprintf("WHERE (%v) %v %v = ?", condition, session.Engine.Dialect().AndStr(),
 				session.Engine.Quote(table.Version))
@@ -3643,9 +3672,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 		}
 	}
 
-	if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
-		cacher.ClearIds(session.Statement.TableName())
-		cacher.ClearBeans(session.Statement.TableName())
+	if table != nil {
+		if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
+			cacher.ClearIds(session.Statement.TableName())
+			cacher.ClearBeans(session.Statement.TableName())
+		}
 	}
 
 	// handle after update processors
@@ -3712,18 +3743,16 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
 				for _, col := range session.Statement.RefTable.PKColumns() {
 					if v, ok := data[col.Name]; !ok {
 						return errors.New("no id")
-					} else {
-						if col.SQLType.IsText() {
-							pk = append(pk, string(v))
-						} else if col.SQLType.IsNumeric() {
-							id, err = strconv.ParseInt(string(v), 10, 64)
-							if err != nil {
-								return err
-							}
-							pk = append(pk, id)
-						} else {
-							return errors.New("not supported primary key type")
+					} else if col.SQLType.IsText() {
+						pk = append(pk, string(v))
+					} else if col.SQLType.IsNumeric() {
+						id, err = strconv.ParseInt(string(v), 10, 64)
+						if err != nil {
+							return err
 						}
+						pk = append(pk, id)
+					} else {
+						return errors.New("not supported primary key type")
 					}
 				}
 				ids = append(ids, pk)
@@ -3754,6 +3783,9 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 		defer session.Close()
 	}
 
+	session.Statement.setRefValue(rValue(bean))
+	var table = session.Statement.RefTable
+
 	// handle before delete processors
 	for _, closure := range session.beforeClosures {
 		closure(bean)
@@ -3765,8 +3797,6 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 	}
 	// --
 
-	table := session.Engine.TableInfo(bean)
-	session.Statement.RefTable = table
 	var colNames []string
 	var args []interface{}
 
@@ -3946,19 +3976,6 @@ func (session *Session) LastSQL() (string, []interface{}) {
 	return session.lastSQL, session.lastSQLArgs
 }
 
-// tbName get some table's table name
-func (session *Session) tbName(table *core.Table) string {
-	var tbName = table.Name
-	if len(session.Statement.AltTableName) > 0 {
-		tbName = session.Statement.AltTableName
-	}
-
-	/*if len(session.Engine.dialect.URI().Schema) > 0 {
-		return session.Engine.dialect.URI().Schema + "." + tbName
-	}*/
-	return tbName
-}
-
 // tbName get some table's table name
 func (session *Session) tbNameNoSchema(table *core.Table) string {
 	if len(session.Statement.AltTableName) > 0 {
@@ -4029,7 +4046,7 @@ func (s *Session) Sync2(beans ...interface{}) error {
 								engine.dialect.DBType() == core.POSTGRES {
 								engine.logger.Infof("Table %s column %s change type from %s to %s\n",
 									tbName, col.Name, curType, expectedType)
-								_, err = engine.Exec(engine.dialect.ModifyColumnSql(engine.tbName(table), col))
+								_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
 							} else {
 								engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
 									tbName, col.Name, curType, expectedType)
@@ -4039,7 +4056,7 @@ func (s *Session) Sync2(beans ...interface{}) error {
 								if oriCol.Length < col.Length {
 									engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
 										tbName, col.Name, oriCol.Length, col.Length)
-									_, err = engine.Exec(engine.dialect.ModifyColumnSql(engine.tbName(table), col))
+									_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
 								}
 							}
 						} else {
@@ -4053,7 +4070,7 @@ func (s *Session) Sync2(beans ...interface{}) error {
 							if oriCol.Length < col.Length {
 								engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
 									tbName, col.Name, oriCol.Length, col.Length)
-								_, err = engine.Exec(engine.dialect.ModifyColumnSql(engine.tbName(table), col))
+								_, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col))
 							}
 						}
 					}

+ 4 - 2
sessionplus.go

@@ -607,8 +607,9 @@ func (session *Session) row2BeanWithDateFormat(dateFormat string, rows *core.Row
 		return errors.New("Expected a pointer to a struct")
 	}
 
-	table := session.Engine.autoMapType(dataStruct)
-	return session._row2BeanWithDateFormat(dateFormat, rows, fields, fieldsCount, bean, &dataStruct, table)
+	session.Statement.setRefValue(dataStruct)
+
+	return session._row2BeanWithDateFormat(dateFormat, rows, fields, fieldsCount, bean, &dataStruct, session.Statement.RefTable)
 }
 
 func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Rows, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) error {
@@ -882,6 +883,7 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 				} else if session.Statement.UseCascade {
 					table := session.Engine.autoMapType(*fieldValue)
 					if table != nil {
+						hasAssigned = true
 						if len(table.PrimaryKeys) != 1 {
 							panic("unsupported non or composited primary key cascade")
 						}

+ 30 - 35
statement.go

@@ -58,6 +58,7 @@ type Statement struct {
 	OmitStr         string
 	ConditionStr    string
 	AltTableName    string
+	tableName       string
 	RawSQL          string
 	RawParams       []interface{}
 	UseCascade      bool
@@ -100,6 +101,7 @@ func (statement *Statement) Init() {
 	statement.columnMap = make(map[string]bool)
 	statement.ConditionStr = ""
 	statement.AltTableName = ""
+	statement.tableName = ""
 	statement.IdParam = nil
 	statement.RawSQL = ""
 	statement.RawParams = make([]interface{}, 0)
@@ -188,6 +190,11 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme
 	return statement
 }
 
+func (statement *Statement) setRefValue(v reflect.Value) {
+	statement.RefTable = statement.Engine.autoMapType(v)
+	statement.tableName = statement.Engine.tbName(v)
+}
+
 // Table tempororily set table name, the parameter could be a string or a pointer of struct
 func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
 	v := rValue(tableNameOrBean)
@@ -196,6 +203,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
 		statement.AltTableName = tableNameOrBean.(string)
 	} else if t.Kind() == reflect.Struct {
 		statement.RefTable = statement.Engine.autoMapType(v)
+		statement.AltTableName = statement.Engine.tbName(v)
 	}
 	return statement
 }
@@ -678,14 +686,7 @@ func (statement *Statement) TableName() string {
 		return statement.AltTableName
 	}
 
-	if statement.RefTable != nil {
-		/*schema := statement.Engine.dialect.URI().Schema
-		if len(schema) > 0 {
-			return schema + "." + statement.RefTable.Name
-		}*/
-		return statement.RefTable.Name
-	}
-	return ""
+	return statement.tableName
 }
 
 // Id generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?"
@@ -998,8 +999,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 			if t.Kind() == reflect.String {
 				table = f.(string)
 			} else if t.Kind() == reflect.Struct {
-				r := statement.Engine.autoMapType(v)
-				table = r.Name
+				table = statement.Engine.tbName(v)
 			}
 		}
 		if l > 1 {
@@ -1038,7 +1038,7 @@ func (statement *Statement) Unscoped() *Statement {
 
 func (statement *Statement) genColumnStr() string {
 	table := statement.RefTable
-	colNames := make([]string, 0)
+	var colNames []string
 	for _, col := range table.Columns() {
 		if statement.OmitStr != "" {
 			if _, ok := statement.columnMap[strings.ToLower(col.Name)]; ok {
@@ -1075,17 +1075,17 @@ func (statement *Statement) genColumnStr() string {
 }
 
 func (statement *Statement) genCreateTableSQL() string {
-	return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.AltTableName,
+	return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
 		statement.StoreEngine, statement.Charset)
 }
 
 func (s *Statement) genIndexSQL() []string {
-	var sqls []string = make([]string, 0)
+	var sqls []string
 	tbName := s.TableName()
 	quote := s.Engine.Quote
 	for idxName, index := range s.RefTable.Indexes {
 		if index.Type == core.IndexType {
-			sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(s.RefTable.Name, idxName)),
+			sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
 				quote(tbName), quote(strings.Join(index.Cols, quote(","))))
 			sqls = append(sqls, sql)
 		}
@@ -1098,10 +1098,11 @@ func uniqueName(tableName, uqeName string) string {
 }
 
 func (s *Statement) genUniqueSQL() []string {
-	var sqls []string = make([]string, 0)
+	var sqls []string
+	tbName := s.TableName()
 	for _, index := range s.RefTable.Indexes {
 		if index.Type == core.UniqueType {
-			sql := s.Engine.dialect.CreateIndexSql(s.RefTable.Name, index)
+			sql := s.Engine.dialect.CreateIndexSql(tbName, index)
 			sqls = append(sqls, sql)
 		}
 	}
@@ -1109,13 +1110,14 @@ func (s *Statement) genUniqueSQL() []string {
 }
 
 func (s *Statement) genDelIndexSQL() []string {
-	var sqls []string = make([]string, 0)
+	var sqls []string
+	tbName := s.TableName()
 	for idxName, index := range s.RefTable.Indexes {
 		var rIdxName string
 		if index.Type == core.UniqueType {
-			rIdxName = uniqueName(s.RefTable.Name, idxName)
+			rIdxName = uniqueName(tbName, idxName)
 		} else if index.Type == core.IndexType {
-			rIdxName = indexName(s.RefTable.Name, idxName)
+			rIdxName = indexName(tbName, idxName)
 		}
 		sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName))
 		if s.Engine.dialect.IndexOnTable() {
@@ -1127,14 +1129,9 @@ func (s *Statement) genDelIndexSQL() []string {
 }
 
 func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
-	var table *core.Table
-	if statement.RefTable == nil {
-		table = statement.Engine.TableInfo(bean)
-		statement.RefTable = table
-	} else {
-		table = statement.RefTable
-	}
+	statement.setRefValue(rValue(bean))
 
+	var table = statement.RefTable
 	var addedTableName = (len(statement.JoinStr) > 0)
 
 	if !statement.noAutoCondition {
@@ -1144,7 +1141,7 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
 		statement.BeanArgs = args
 	}
 
-	var columnStr string = statement.ColumnStr
+	var columnStr = statement.ColumnStr
 	if len(statement.selectStr) > 0 {
 		columnStr = statement.selectStr
 	} else {
@@ -1199,13 +1196,12 @@ func (statement *Statement) buildConditions(table *core.Table, bean interface{},
 }
 
 func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) {
-	table := statement.Engine.TableInfo(bean)
-	statement.RefTable = table
+	statement.setRefValue(rValue(bean))
 
 	var addedTableName = (len(statement.JoinStr) > 0)
 
 	if !statement.noAutoCondition {
-		colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName)
+		colNames, args := statement.buildConditions(statement.RefTable, bean, true, true, false, true, addedTableName)
 
 		statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ")
 		statement.BeanArgs = args
@@ -1221,13 +1217,12 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}
 }
 
 func (statement *Statement) genSumSql(bean interface{}, columns ...string) (string, []interface{}) {
-	table := statement.Engine.TableInfo(bean)
-	statement.RefTable = table
+	statement.setRefValue(rValue(bean))
 
 	var addedTableName = (len(statement.JoinStr) > 0)
 
 	if !statement.noAutoCondition {
-		colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName)
+		colNames, args := statement.buildConditions(statement.RefTable, bean, true, true, false, true, addedTableName)
 
 		statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ")
 		statement.BeanArgs = args
@@ -1269,7 +1264,7 @@ func (statement *Statement) genSelectSQL(columnStr string) (a string) {
 	}
 	var whereStr = buf.String()
 
-	var fromStr string = " FROM " + quote(statement.TableName())
+	var fromStr = " FROM " + quote(statement.TableName())
 	if statement.TableAlias != "" {
 		if dialect.DBType() == core.ORACLE {
 			fromStr += " " + quote(statement.TableAlias)
@@ -1286,7 +1281,7 @@ func (statement *Statement) genSelectSQL(columnStr string) (a string) {
 			top = fmt.Sprintf(" TOP %d ", statement.LimitN)
 		}
 		if statement.Start > 0 {
-			var column string = "(id)"
+			var column = "(id)"
 			if len(statement.RefTable.PKColumns()) == 0 {
 				for _, index := range statement.RefTable.Indexes {
 					if len(index.Cols) == 1 {

+ 8 - 2
test/sql/oracle/studygolang.xml

@@ -1,6 +1,6 @@
 <sqlMap>
 	<sql id="selectAllArticle">
-		select id,title,createdatetime,content 
+		select id,title,createdatetime,content
 		from Article where id in (?1,?2)
 	</sql>
 	<sql id="selectStudentById1">
@@ -15,6 +15,12 @@
 	<sql id="sql_i_2">
 		INSERT INTO categories VALUES (?id, ?name, ?counts, ?orders, ?pid)
 	</sql>
+	<sql id="category">
+		select * from category
+	</sql>
+	<sql id="category-16-17">
+		select * from category where id in (16,17)
+	</sql>
 	<sql id="create_1">
 		<![CDATA[
 				DROP TABLE IF EXISTS "public"."categories11";
@@ -45,4 +51,4 @@
 		ALTER TABLE "public"."categories11" ADD PRIMARY KEY ("id");
         ]]>
 	</sql>
-</sqlMap>
+</sqlMap>

+ 1 - 1
test/xorm_test.go

@@ -62,7 +62,7 @@ func Test_InitDB(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	err = db.SetSqlTemplateRootDir("./sql/oracle").InitSqlTemplate(xorm.SqlTemplateOptions{Extension: ".xx"})
+	err = db.SetSqlTemplateRootDir("./sql/oracle").InitSqlTemplate(xorm.SqlTemplateOptions{Extension: ".stpl"})
 	if err != nil {
 		t.Fatal(err)
 	}

+ 1 - 1
xorm.go

@@ -17,7 +17,7 @@ import (
 
 const (
 	// Version show the xorm's version
-	Version string = "0.5.5.0707"
+	Version string = "0.5.5.0709"
 )
 
 func regDrvsNDialects() bool {