Преглед изворни кода

Added feature for storing lastSQL query on session

xormplus пре 10 година
родитељ
комит
a85984d0c5
6 измењених фајлова са 369 додато и 299 уклоњено
  1. 4 0
      processors.go
  2. 1 1
      rows.go
  3. 223 237
      session.go
  4. 92 44
      sessionplus.go
  5. 47 15
      statement.go
  6. 2 2
      test/xorm_test.go

+ 4 - 0
processors.go

@@ -23,6 +23,10 @@ type BeforeSetProcessor interface {
 	BeforeSet(string, Cell)
 }
 
+type AfterSetProcessor interface {
+	AfterSet(string, Cell)
+}
+
 // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations
 //// Executed before an object is validated
 //type BeforeValidateProcessor interface {

+ 1 - 1
rows.go

@@ -45,7 +45,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
 		sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable)
 	}
 
-	rows.session.Engine.logSQL(sqlStr, args)
+	rows.session.saveLastSQL(sqlStr, args)
 	var err error
 	rows.stmt, err = rows.session.DB().Prepare(sqlStr)
 	if err != nil {

+ 223 - 237
session.go

@@ -6,6 +6,7 @@ package xorm
 
 import (
 	"database/sql"
+	"database/sql/driver"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -46,6 +47,10 @@ type Session struct {
 
 	stmtCache   map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr))
 	cascadeDeep int
+
+	// !evalphobia! stored the last executed query on this session
+	lastSQL     string
+	lastSQLArgs []interface{}
 }
 
 // Method Init reset the session as the init status.
@@ -63,6 +68,9 @@ func (session *Session) Init() {
 	session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0)
 	session.beforeClosures = make([]func(interface{}), 0)
 	session.afterClosures = make([]func(interface{}), 0)
+
+	session.lastSQL = ""
+	session.lastSQLArgs = []interface{}{}
 }
 
 // Method Close release the connection from pool
@@ -331,8 +339,7 @@ func (session *Session) Begin() error {
 		session.IsAutoCommit = false
 		session.IsCommitedOrRollbacked = false
 		session.Tx = tx
-
-		session.Engine.logSQL("BEGIN TRANSACTION")
+		session.saveLastSQL("BEGIN TRANSACTION")
 	}
 	return nil
 }
@@ -340,7 +347,7 @@ func (session *Session) Begin() error {
 // When using transaction, you can rollback if any error
 func (session *Session) Rollback() error {
 	if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
-		session.Engine.logSQL(session.Engine.dialect.RollBackStr())
+		session.saveLastSQL(session.Engine.dialect.RollBackStr())
 		session.IsCommitedOrRollbacked = true
 		return session.Tx.Rollback()
 	}
@@ -350,7 +357,7 @@ func (session *Session) Rollback() error {
 // When using transaction, Commit will commit all operations.
 func (session *Session) Commit() error {
 	if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
-		session.Engine.logSQL("COMMIT")
+		session.saveLastSQL("COMMIT")
 		session.IsCommitedOrRollbacked = true
 		var err error
 		if err = session.Tx.Commit(); err == nil {
@@ -471,7 +478,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
 		sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
 	}
 
-	session.Engine.logSQL(sqlStr, args...)
+	session.saveLastSQL(sqlStr, args...)
 
 	return session.Engine.LogSQLExecutionTime(sqlStr, args, func() (sql.Result, error) {
 		if session.IsAutoCommit {
@@ -614,11 +621,15 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
 	return nil
 }
 
-func (statement *Statement) JoinColumns(cols []*core.Column) string {
+func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string {
 	var colnames = make([]string, len(cols))
 	for i, col := range cols {
-		colnames[i] = statement.Engine.Quote(statement.TableName()) +
-			"." + statement.Engine.Quote(col.Name)
+		if includeTableName {
+			colnames[i] = statement.Engine.Quote(statement.TableName()) +
+				"." + statement.Engine.Quote(col.Name)
+		} else {
+			colnames[i] = statement.Engine.Quote(col.Name)
+		}
 	}
 	return strings.Join(colnames, ", ")
 }
@@ -630,11 +641,14 @@ func (statement *Statement) convertIdSql(sqlStr string) string {
 			return ""
 		}
 
-		colstrs := statement.JoinColumns(cols)
-		sqls := splitNNoCase(sqlStr, "from", 2)
+		colstrs := statement.JoinColumns(cols, false)
+		sqls := splitNNoCase(sqlStr, " from ", 2)
 		if len(sqls) != 2 {
 			return ""
 		}
+		if statement.Engine.dialect.DBType() == "ql" {
+			return fmt.Sprintf("SELECT id() FROM %v", sqls[1])
+		}
 		return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1])
 	}
 	return ""
@@ -644,7 +658,7 @@ func (session *Session) canCache() bool {
 	if session.Statement.RefTable == nil ||
 		session.Statement.JoinStr != "" ||
 		session.Statement.RawSQL != "" ||
-		session.Tx != nil || 
+		session.Tx != nil ||
 		len(session.Statement.selectStr) > 0 {
 		return false
 	}
@@ -751,7 +765,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
 }
 
 func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) {
-	if !session.canCache() || 
+	if !session.canCache() ||
 		indexNoCase(sqlStr, "having") != -1 ||
 		indexNoCase(sqlStr, "group by") != -1 {
 		return ErrCacheFailed
@@ -1052,115 +1066,6 @@ func (session *Session) Get(bean interface{}) (bool, error) {
 	return false, nil
 }
 
-// get retrieve one record from database, bean's non-empty fields
-// will be as conditions
-//func (session *Session) GetWithDateFormat(dateFormat string, bean interface{}) (bool, error) {
-//	defer session.resetStatement()
-//	if session.IsAutoClose {
-//		defer session.Close()
-//	}
-
-//	session.Statement.Limit(1)
-//	var sqlStr string
-//	var args []interface{}
-
-//	if session.Statement.RefTable == nil {
-//		session.Statement.RefTable = session.Engine.TableInfo(bean)
-//	}
-
-//	if session.Statement.RawSQL == "" {
-//		sqlStr, args = session.Statement.genGetSql(bean)
-//	} else {
-//		sqlStr = session.Statement.RawSQL
-//		args = session.Statement.RawParams
-//	}
-
-//	if session.Statement.JoinStr == "" {
-//		if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil &&
-//			session.Statement.UseCache &&
-//			!session.Statement.unscoped {
-//			has, err := session.cacheGet(bean, sqlStr, args...)
-//			if err != ErrCacheFailed {
-//				return has, err
-//			}
-//		}
-//	}
-
-//	var rawRows *core.Rows
-//	var err error
-//	session.queryPreprocess(&sqlStr, args...)
-//	if session.IsAutoCommit {
-//		stmt, errPrepare := session.doPrepare(sqlStr)
-//		if errPrepare != nil {
-//			return false, errPrepare
-//		}
-//		// defer stmt.Close() // !nashtsai! don't close due to stmt is cached and bounded to this session
-//		rawRows, err = stmt.Query(args...)
-//	} else {
-//		rawRows, err = session.Tx.Query(sqlStr, args...)
-//	}
-//	if err != nil {
-//		return false, err
-//	}
-
-//	defer rawRows.Close()
-
-//	if rawRows.Next() {
-//		if fields, err := rawRows.Columns(); err == nil {
-//			err = session.row2BeanWithDateFormat(dateFormat, rawRows, fields, len(fields), bean)
-//		}
-//		return true, err
-//	}
-//	return false, nil
-//}
-
-//func (session *Session) GetToJsonStringWithDateFormat(dateFormat string, bean interface{}) (bool, string, error) {
-////	has, err, data := session.GetToMap(dateFormat, bean)
-//	has, err:=session.Get(bean)
-//	fmt.Println("数据库查询bean:%#v", bean)
-////	fmt.Println("数据库查询data:%#v", data)
-//	if !has || err != nil {
-//		return false, "", err
-//	}
-
-////		tmpBeanMap := Struct2MapWithDateFormat(dateFormat, bean)
-
-////	result, err1 := JSONString(data, true)
-////result, err1 := JSONString(tmpBeanMap, true)
-//result, err1 := JSONString(bean, true)
-//	if err1 != nil {
-//		return false, "", err
-//	}
-//	return true, result, nil
-//}
-
-func Struct2Map(obj interface{}) map[string]interface{} {
-	t := reflect.TypeOf(obj)
-	v := reflect.ValueOf(obj)
-
-	var data = make(map[string]interface{})
-	for i := 0; i < t.NumField(); i++ {
-		data[t.Field(i).Name] = v.Field(i).Interface()
-	}
-	return data
-}
-
-//func Struct2MapWithDateFormat(dateFormat string, obj interface{}) map[string]interface{} {
-//	t := reflect.TypeOf(obj)
-//	v := reflect.ValueOf(obj)
-
-//	var data = make(map[string]interface{})
-//	for i := 0; i < t.NumField(); i++ {
-//		if t.Field(i).Type == core.TimeType {
-//			data[t.Field(i).Name] = v.Field(i).Interface().(time.Time).Format(dateFormat)
-//		} else {
-//			data[t.Field(i).Name] = v.Field(i).Interface()
-//		}
-
-//	}
-//	return data
-//}
-
 // Count counts the records. bean's non-empty fields
 // are conditions.
 func (session *Session) Count(bean interface{}) (int64, error) {
@@ -1303,7 +1208,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 		var addedTableName = (len(session.Statement.JoinStr) > 0)
 		colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true,
 			false, true, session.Statement.allUseBool, session.Statement.useAllCols,
-			session.Statement.unscoped, session.Statement.mustColumnMap, 
+			session.Statement.unscoped, session.Statement.mustColumnMap,
 			session.Statement.TableName(), addedTableName)
 		session.Statement.ConditionStr = strings.Join(colNames, " AND ")
 		session.Statement.BeanArgs = args
@@ -1430,7 +1335,6 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 		}
 
 		table := session.Engine.autoMapType(dataStruct)
-
 		return session.rows2Beans(rawRows, fields, fieldsCount, table, newElemFunc, sliceValueSetFunc)
 	} else {
 		resultsSlice, err := session.query(sqlStr, args...)
@@ -1573,7 +1477,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
 	var total int64
 	sql := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName))
 	err := session.DB().QueryRow(sql).Scan(&total)
-	session.Engine.logSQL(sql)
+	session.saveLastSQL(sql)
 	if err != nil {
 		return true, err
 	}
@@ -1742,6 +1646,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 		}
 	}
 
+	defer func() {
+		if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
+			for ii, key := range fields {
+				b.AfterSet(key, Cell(scanResults[ii].(*interface{})))
+			}
+		}
+	}()
+
 	var tempMap = make(map[string]int)
 	for ii, key := range fields {
 		var idx int
@@ -1791,7 +1703,6 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 			hasAssigned := false
 
 			switch fieldType.Kind() {
-
 			case reflect.Complex64, reflect.Complex128:
 				if rawValueType.Kind() == reflect.String {
 					hasAssigned = true
@@ -1802,6 +1713,15 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 						return err
 					}
 					fieldValue.Set(x.Elem())
+				} else if rawValueType.Kind() == reflect.Slice {
+					hasAssigned = true
+					x := reflect.New(fieldType)
+					err := json.Unmarshal(vv.Bytes(), x.Interface())
+					if err != nil {
+						session.Engine.LogError(err)
+						return err
+					}
+					fieldValue.Set(x.Elem())
 				}
 			case reflect.Slice, reflect.Array:
 				switch rawValueType.Kind() {
@@ -1846,6 +1766,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 					fieldValue.SetUint(uint64(vv.Int()))
 				}
 			case reflect.Struct:
+				col := table.GetColumn(key)
 				if fieldType.ConvertibleTo(core.TimeType) {
 					if rawValueType == core.TimeType {
 						hasAssigned = true
@@ -1853,7 +1774,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 						t := vv.Convert(core.TimeType).Interface().(time.Time)
 						z, _ := t.Zone()
 						if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
-							session.Engine.LogDebug("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
+							session.Engine.LogDebugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
 							t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
 								t.Minute(), t.Second(), t.Nanosecond(), time.Local)
 						}
@@ -1871,13 +1792,42 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 						vv = reflect.ValueOf(t)
 						fieldValue.Set(vv)
 					}
+				} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
+					// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString
+					hasAssigned = true
+					if err := nulVal.Scan(vv.Interface()); err != nil {
+						//fmt.Println("sql.Sanner error:", err.Error())
+						session.Engine.LogError("sql.Sanner error:", err.Error())
+						hasAssigned = false
+					}
+				} else if col.SQLType.IsJson() {
+					if rawValueType.Kind() == reflect.String {
+						hasAssigned = true
+						x := reflect.New(fieldType)
+						err := json.Unmarshal([]byte(vv.String()), x.Interface())
+						if err != nil {
+							session.Engine.LogError(err)
+							return err
+						}
+						fieldValue.Set(x.Elem())
+					} else if rawValueType.Kind() == reflect.Slice {
+						hasAssigned = true
+						x := reflect.New(fieldType)
+						err := json.Unmarshal(vv.Bytes(), x.Interface())
+						if err != nil {
+							session.Engine.LogError(err)
+							return err
+						}
+						fieldValue.Set(x.Elem())
+					}
 				} else if session.Statement.UseCascade {
 					table := session.Engine.autoMapType(*fieldValue)
 					if table != nil {
-						if len(table.PrimaryKeys) > 1 {
-							panic("unsupported composited primary key cascade")
+						if len(table.PrimaryKeys) != 1 {
+							panic("unsupported non or composited primary key cascade")
 						}
 						var pk = make(core.PK, len(table.PrimaryKeys))
+
 						switch rawValueType.Kind() {
 						case reflect.Int64:
 							pk[0] = vv.Int()
@@ -2065,7 +2015,7 @@ func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{})
 		*sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable)
 	}
 
-	session.Engine.logSQL(*sqlStr, paramStr...)
+	session.saveLastSQL(*sqlStr, paramStr...)
 }
 
 func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
@@ -2532,108 +2482,115 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
 		fieldValue.SetUint(x)
 	//Currently only support Time type
 	case reflect.Struct:
-		if fieldType.ConvertibleTo(core.TimeType) {
-			x, err := session.byte2Time(col, data)
-			if err != nil {
-				return err
+		// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString
+		if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
+			if err := nulVal.Scan(data); err != nil {
+				return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error())
 			}
-			v = x
-			fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
-		} else if session.Statement.UseCascade {
-			table := session.Engine.autoMapType(*fieldValue)
-			if table != nil {
-				if len(table.PrimaryKeys) > 1 {
-					panic("unsupported composited primary key cascade")
-				}
-				var pk = make(core.PK, len(table.PrimaryKeys))
-				rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
-				switch rawValueType.Kind() {
-				case reflect.Int64:
-					x, err := strconv.ParseInt(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = x
-				case reflect.Int:
-					x, err := strconv.ParseInt(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = int(x)
-				case reflect.Int32:
-					x, err := strconv.ParseInt(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = int32(x)
-				case reflect.Int16:
-					x, err := strconv.ParseInt(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = int16(x)
-				case reflect.Int8:
-					x, err := strconv.ParseInt(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = int8(x)
-				case reflect.Uint64:
-					x, err := strconv.ParseUint(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = x
-				case reflect.Uint:
-					x, err := strconv.ParseUint(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = uint(x)
-				case reflect.Uint32:
-					x, err := strconv.ParseUint(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = uint32(x)
-				case reflect.Uint16:
-					x, err := strconv.ParseUint(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
-					}
-					pk[0] = uint16(x)
-				case reflect.Uint8:
-					x, err := strconv.ParseUint(string(data), 10, 64)
-					if err != nil {
-						return fmt.Errorf("arg %v as int: %s", key, err.Error())
+		} else {
+			if fieldType.ConvertibleTo(core.TimeType) {
+				x, err := session.byte2Time(col, data)
+				if err != nil {
+					return err
+				}
+				v = x
+				fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
+			} else if session.Statement.UseCascade {
+				table := session.Engine.autoMapType(*fieldValue)
+				if table != nil {
+					if len(table.PrimaryKeys) > 1 {
+						panic("unsupported composited primary key cascade")
 					}
-					pk[0] = uint8(x)
-				case reflect.String:
-					pk[0] = string(data)
-				default:
-					panic("unsupported primary key type cascade")
-				}
-
-				if !isPKZero(pk) {
-					// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
-					// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
-					// property to be fetched lazily
-					structInter := reflect.New(fieldValue.Type())
-					newsession := session.Engine.NewSession()
-					defer newsession.Close()
-					has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
-					if err != nil {
-						return err
+					var pk = make(core.PK, len(table.PrimaryKeys))
+					rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
+					switch rawValueType.Kind() {
+					case reflect.Int64:
+						x, err := strconv.ParseInt(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = x
+					case reflect.Int:
+						x, err := strconv.ParseInt(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = int(x)
+					case reflect.Int32:
+						x, err := strconv.ParseInt(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = int32(x)
+					case reflect.Int16:
+						x, err := strconv.ParseInt(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = int16(x)
+					case reflect.Int8:
+						x, err := strconv.ParseInt(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = int8(x)
+					case reflect.Uint64:
+						x, err := strconv.ParseUint(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = x
+					case reflect.Uint:
+						x, err := strconv.ParseUint(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = uint(x)
+					case reflect.Uint32:
+						x, err := strconv.ParseUint(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = uint32(x)
+					case reflect.Uint16:
+						x, err := strconv.ParseUint(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = uint16(x)
+					case reflect.Uint8:
+						x, err := strconv.ParseUint(string(data), 10, 64)
+						if err != nil {
+							return fmt.Errorf("arg %v as int: %s", key, err.Error())
+						}
+						pk[0] = uint8(x)
+					case reflect.String:
+						pk[0] = string(data)
+					default:
+						panic("unsupported primary key type cascade")
 					}
-					if has {
-						v = structInter.Elem().Interface()
-						fieldValue.Set(reflect.ValueOf(v))
-					} else {
-						return errors.New("cascade obj is not exist!")
+
+					if !isPKZero(pk) {
+						// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
+						// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
+						// property to be fetched lazily
+						structInter := reflect.New(fieldValue.Type())
+						newsession := session.Engine.NewSession()
+						defer newsession.Close()
+						has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
+						if err != nil {
+							return err
+						}
+						if has {
+							v = structInter.Elem().Interface()
+							fieldValue.Set(reflect.ValueOf(v))
+						} else {
+							return errors.New("cascade obj is not exist!")
+						}
 					}
+				} else {
+					return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
 				}
-			} else {
-				return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
 			}
 		}
 	case reflect.Ptr:
@@ -3047,16 +3004,36 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
 			tf := session.Engine.FormatTime(col.SQLType.Name, t)
 			return tf, nil
 		}
+
 		if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
 			if len(fieldTable.PrimaryKeys) == 1 {
 				pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName)
 				return pkField.Interface(), nil
-			} else {
-				return 0, fmt.Errorf("no primary key for col %v", col.Name)
 			}
-		} else {
-			return 0, fmt.Errorf("Unsupported type %v", fieldValue.Type())
+			return 0, fmt.Errorf("no primary key for col %v", col.Name)
 		}
+		// !<winxxp>! 增加支持driver.Valuer接口的结构,如sql.NullString
+		if v, ok := fieldValue.Interface().(driver.Valuer); ok {
+			return v.Value()
+		}
+
+		if col.SQLType.IsText() {
+			bytes, err := json.Marshal(fieldValue.Interface())
+			if err != nil {
+				session.Engine.LogError(err)
+				return 0, err
+			}
+			return string(bytes), nil
+		} else if col.SQLType.IsBlob() {
+			bytes, err := json.Marshal(fieldValue.Interface())
+			if err != nil {
+				session.Engine.LogError(err)
+				return 0, err
+			}
+			return bytes, nil
+		}
+
+		return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type())
 	case reflect.Complex64, reflect.Complex128:
 		bytes, err := json.Marshal(fieldValue.Interface())
 		if err != nil {
@@ -3090,9 +3067,8 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
 				}
 			}
 			return bytes, nil
-		} else {
-			return nil, ErrUnSupportedType
 		}
+		return nil, ErrUnSupportedType
 	case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
 		return int64(fieldValue.Uint()), nil
 	default:
@@ -3114,12 +3090,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 		processor.BeforeInsert()
 	}
 	// --
-
 	colNames, args, err := genCols(table, session, bean, false, false)
 	if err != nil {
 		return 0, err
 	}
-
 	// insert expr columns, override if exists
 	exprColumns := session.Statement.getExpr()
 	exprColVals := make([]string, 0, len(exprColumns))
@@ -3320,7 +3294,7 @@ func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) {
 		return "", ""
 	}
 
-	colstrs := statement.JoinColumns(statement.RefTable.PKColumns())
+	colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true)
 	sqls := splitNNoCase(sqlStr, "where", 2)
 	if len(sqls) != 2 {
 		if len(sqls) == 1 {
@@ -3530,7 +3504,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 		if session.Statement.ColumnStr == "" {
 			colNames, args = buildUpdates(session.Engine, table, bean, false, false,
 				false, false, session.Statement.allUseBool, session.Statement.useAllCols,
-				session.Statement.mustColumnMap, session.Statement.nullableMap, 
+				session.Statement.mustColumnMap, session.Statement.nullableMap,
 				session.Statement.columnMap, true)
 		} else {
 			colNames, args, err = genCols(table, session, bean, true, true)
@@ -3812,7 +3786,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 	session.Statement.RefTable = table
 	colNames, args := buildConditions(session.Engine, table, bean, true, true,
 		false, true, session.Statement.allUseBool, session.Statement.useAllCols,
-		session.Statement.unscoped, session.Statement.mustColumnMap, 
+		session.Statement.unscoped, session.Statement.mustColumnMap,
 		session.Statement.TableName(), false)
 
 	var condition = ""
@@ -3917,6 +3891,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 	return res.RowsAffected()
 }
 
+// saveLastSQL stores executed query information
+func (session *Session) saveLastSQL(sql string, args ...interface{}) {
+	session.lastSQL = sql
+	session.lastSQLArgs = args
+	session.Engine.logSQL(sql, args...)
+}
+
+// LastSQL returns last query information
+func (session *Session) LastSQL() (string, []interface{}) {
+	return session.lastSQL, session.lastSQLArgs
+}
+
 func (s *Session) Sync2(beans ...interface{}) error {
 	engine := s.Engine
 

+ 92 - 44
sessionplus.go

@@ -5,101 +5,103 @@
 package xorm
 
 import (
+	"database/sql"
 	"encoding/json"
 	"errors"
-//	"fmt"
+	"fmt"
+	//	"fmt"
 	"reflect"
 	"regexp"
 	"strings"
 	"time"
 
-	"github.com/xormplus/core"
 	"github.com/Chronokeeper/anyxml"
+	"github.com/xormplus/core"
 )
 
 type ResultBean struct {
-	Has bool
+	Has    bool
 	Result interface{}
-	Error    error
+	Error  error
 }
 
-func (resultBean ResultBean) Json() (bool,string, error) {
+func (resultBean ResultBean) Json() (bool, string, error) {
 	if resultBean.Error != nil {
-		return resultBean.Has,"", resultBean.Error
+		return resultBean.Has, "", resultBean.Error
 	}
-	if !resultBean.Has{
-		return resultBean.Has,"", nil
+	if !resultBean.Has {
+		return resultBean.Has, "", nil
 	}
-	result,err:= JSONString(resultBean.Result, true)
-	return resultBean.Has,result,err
+	result, err := JSONString(resultBean.Result, true)
+	return resultBean.Has, result, err
 }
 
 func (session *Session) GetFirst(bean interface{}) ResultBean {
 	has, err := session.Get(bean)
-	r := ResultBean{Has: has,Result:bean, Error: err}
+	r := ResultBean{Has: has, Result: bean, Error: err}
 	return r
 }
 
-func (resultBean ResultBean) Xml() (bool,string, error) {
-	
+func (resultBean ResultBean) Xml() (bool, string, error) {
+
 	if resultBean.Error != nil {
-		return false,"", resultBean.Error
+		return false, "", resultBean.Error
 	}
-	if !resultBean.Has{
-		return resultBean.Has,"", nil
+	if !resultBean.Has {
+		return resultBean.Has, "", nil
 	}
-	has,result,err:=resultBean.Json()
+	has, result, err := resultBean.Json()
 	if err != nil {
-		return false,"", err
+		return false, "", err
 	}
-	if !has{
-		return has,"", nil
+	if !has {
+		return has, "", nil
 	}
 	var anydata = []byte(result)
 	var i interface{}
 	err = json.Unmarshal(anydata, &i)
 	if err != nil {
-		return false,"", err
+		return false, "", err
 	}
 	resultByte, err := anyxml.Xml(i)
 	if err != nil {
-		return false,"", err
+		return false, "", err
 	}
 
-	return resultBean.Has,string(resultByte),err
+	return resultBean.Has, string(resultByte), err
 }
 
-func (resultBean ResultBean) XmlIndent(prefix string, indent string, recordTag string) (bool,string, error) {
+func (resultBean ResultBean) XmlIndent(prefix string, indent string, recordTag string) (bool, string, error) {
 	if resultBean.Error != nil {
-		return false,"", resultBean.Error
+		return false, "", resultBean.Error
 	}
-	if !resultBean.Has{
-		return resultBean.Has,"", nil
+	if !resultBean.Has {
+		return resultBean.Has, "", nil
 	}
-	has,result,err:=resultBean.Json()
+	has, result, err := resultBean.Json()
 	if err != nil {
-		return false,"", err
+		return false, "", err
 	}
-	if !has{
-		return has,"", nil
+	if !has {
+		return has, "", nil
 	}
 	var anydata = []byte(result)
 	var i interface{}
 	err = json.Unmarshal(anydata, &i)
 	if err != nil {
-		return false,"", err
+		return false, "", err
 	}
-	resultByte, err := anyxml.XmlIndent(i,prefix,indent,recordTag)
+	resultByte, err := anyxml.XmlIndent(i, prefix, indent, recordTag)
 	if err != nil {
-		return false,"", err
+		return false, "", err
 	}
 
-	return resultBean.Has,string(resultByte),err
+	return resultBean.Has, string(resultByte), err
 }
 
 type ResultMap struct {
 	Result []map[string]interface{}
-	Error    error
+	Error  error
 }
 
 func (resultMap ResultMap) Json() (string, error) {
@@ -135,7 +137,7 @@ func (resultMap ResultMap) XmlIndent(prefix string, indent string, recordTag str
 
 type ResultStructs struct {
 	Result interface{}
-	Error    error
+	Error  error
 }
 
 func (resultStructs ResultStructs) Json() (string, error) {
@@ -151,7 +153,7 @@ func (resultStructs ResultStructs) Xml() (string, error) {
 		return "", resultStructs.Error
 	}
 
-	result,err:=resultStructs.Json()
+	result, err := resultStructs.Json()
 	if err != nil {
 		return "", err
 	}
@@ -175,7 +177,7 @@ func (resultStructs ResultStructs) XmlIndent(prefix string, indent string, recor
 		return "", resultStructs.Error
 	}
 
-	result,err:=resultStructs.Json()
+	result, err := resultStructs.Json()
 	if err != nil {
 		return "", err
 	}
@@ -194,7 +196,7 @@ func (resultStructs ResultStructs) XmlIndent(prefix string, indent string, recor
 	return string(resultByte), nil
 }
 
-func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) ResultStructs{
+func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) ResultStructs {
 	err := session.find(rowsSlicePtr, condiBean...)
 	r := ResultStructs{Result: rowsSlicePtr, Error: err}
 	return r
@@ -385,6 +387,14 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 		}
 	}
 
+	defer func() {
+		if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
+			for ii, key := range fields {
+				b.AfterSet(key, Cell(scanResults[ii].(*interface{})))
+			}
+		}
+	}()
+
 	var tempMap = make(map[string]int)
 	for ii, key := range fields {
 		var idx int
@@ -434,7 +444,6 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 			hasAssigned := false
 
 			switch fieldType.Kind() {
-
 			case reflect.Complex64, reflect.Complex128:
 				if rawValueType.Kind() == reflect.String {
 					hasAssigned = true
@@ -445,6 +454,15 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 						return err
 					}
 					fieldValue.Set(x.Elem())
+				} else if rawValueType.Kind() == reflect.Slice {
+					hasAssigned = true
+					x := reflect.New(fieldType)
+					err := json.Unmarshal(vv.Bytes(), x.Interface())
+					if err != nil {
+						session.Engine.LogError(err)
+						return err
+					}
+					fieldValue.Set(x.Elem())
 				}
 			case reflect.Slice, reflect.Array:
 				switch rawValueType.Kind() {
@@ -489,6 +507,7 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 					fieldValue.SetUint(uint64(vv.Int()))
 				}
 			case reflect.Struct:
+				col := table.GetColumn(key)
 				if fieldType.ConvertibleTo(core.TimeType) {
 					if rawValueType == core.TimeType {
 						hasAssigned = true
@@ -496,7 +515,7 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 						t := vv.Convert(core.TimeType).Interface().(time.Time)
 						z, _ := t.Zone()
 						if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
-							session.Engine.LogDebug("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
+							session.Engine.LogDebugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
 							t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
 								t.Minute(), t.Second(), t.Nanosecond(), time.Local)
 						}
@@ -518,13 +537,42 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 						vv = reflect.ValueOf(t)
 						fieldValue.Set(vv)
 					}
+				} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
+					// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString
+					hasAssigned = true
+					if err := nulVal.Scan(vv.Interface()); err != nil {
+						//fmt.Println("sql.Sanner error:", err.Error())
+						session.Engine.LogError("sql.Sanner error:", err.Error())
+						hasAssigned = false
+					}
+				} else if col.SQLType.IsJson() {
+					if rawValueType.Kind() == reflect.String {
+						hasAssigned = true
+						x := reflect.New(fieldType)
+						err := json.Unmarshal([]byte(vv.String()), x.Interface())
+						if err != nil {
+							session.Engine.LogError(err)
+							return err
+						}
+						fieldValue.Set(x.Elem())
+					} else if rawValueType.Kind() == reflect.Slice {
+						hasAssigned = true
+						x := reflect.New(fieldType)
+						err := json.Unmarshal(vv.Bytes(), x.Interface())
+						if err != nil {
+							session.Engine.LogError(err)
+							return err
+						}
+						fieldValue.Set(x.Elem())
+					}
 				} else if session.Statement.UseCascade {
 					table := session.Engine.autoMapType(*fieldValue)
 					if table != nil {
-						if len(table.PrimaryKeys) > 1 {
-							panic("unsupported composited primary key cascade")
+						if len(table.PrimaryKeys) != 1 {
+							panic("unsupported non or composited primary key cascade")
 						}
 						var pk = make(core.PK, len(table.PrimaryKeys))
+
 						switch rawValueType.Kind() {
 						case reflect.Int64:
 							pk[0] = vv.Int()

+ 47 - 15
statement.go

@@ -219,6 +219,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
 		requiredField := useAllCols
 		includeNil := useAllCols
 		lColName := strings.ToLower(col.Name)
+
 		if b, ok := mustColumnMap[lColName]; ok {
 			if b {
 				requiredField = true
@@ -320,6 +321,8 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
 					continue
 				}
 				val = engine.FormatTime(col.SQLType.Name, t)
+			} else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
+				val, _ = nulType.Value()
 			} else {
 				engine.autoMapType(fieldValue)
 				if table, ok := engine.Tables[fieldValue.Type()]; ok {
@@ -413,10 +416,13 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 		if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
 			continue
 		}
+		if col.SQLType.IsJson() {
+			continue
+		}
 
 		var colName string
 		if addedTableName {
-			colName = engine.Quote(tableName)+"."+engine.Quote(col.Name)
+			colName = engine.Quote(tableName) + "." + engine.Quote(col.Name)
 		} else {
 			colName = engine.Quote(col.Name)
 		}
@@ -428,7 +434,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 		}
 
 		if col.IsDeleted && !unscoped { // tag "deleted" is enabled
-			colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", 
+			colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')",
 				colName, colName))
 		}
 
@@ -509,24 +515,49 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 				val = engine.FormatTime(col.SQLType.Name, t)
 			} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
 				continue
+			} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
+				val, _ = valNul.Value()
+				if val == nil {
+					continue
+				}
 			} else {
-				engine.autoMapType(fieldValue)
-				if table, ok := engine.Tables[fieldValue.Type()]; ok {
-					if len(table.PrimaryKeys) == 1 {
-						pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
-						// fix non-int pk issues
-						//if pkField.Int() != 0 {
-						if pkField.IsValid() && !isZero(pkField.Interface()) {
-							val = pkField.Interface()
-						} else {
+				if col.SQLType.IsJson() {
+					if col.SQLType.IsText() {
+						bytes, err := json.Marshal(fieldValue.Interface())
+						if err != nil {
+							engine.LogError(err)
 							continue
 						}
-					} else {
-						//TODO: how to handler?
-						panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys))
+						val = string(bytes)
+					} else if col.SQLType.IsBlob() {
+						var bytes []byte
+						var err error
+						bytes, err = json.Marshal(fieldValue.Interface())
+						if err != nil {
+							engine.LogError(err)
+							continue
+						}
+						val = bytes
 					}
 				} else {
-					val = fieldValue.Interface()
+					engine.autoMapType(fieldValue)
+					if table, ok := engine.Tables[fieldValue.Type()]; ok {
+						if len(table.PrimaryKeys) == 1 {
+							pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
+							// fix non-int pk issues
+							//if pkField.Int() != 0 {
+							if pkField.IsValid() && !isZero(pkField.Interface()) {
+								val = pkField.Interface()
+							} else {
+								continue
+							}
+						} else {
+							//TODO: how to handler?
+							panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys))
+						}
+					} else {
+						val = fieldValue.Interface()
+					}
 				}
 			}
 		case reflect.Array, reflect.Slice, reflect.Map:
@@ -786,6 +817,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
 	if strings.Contains(statement.ColumnStr, ".") {
 		statement.ColumnStr = strings.Replace(statement.ColumnStr, ".", statement.Engine.Quote("."), -1)
 	}
+	statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.Quote("*"), "*", -1)
 	return statement
 }
 

+ 2 - 2
test/xorm_test.go

@@ -39,8 +39,8 @@ var db *xorm.Engine
 func Test_InitDB(t *testing.T) {
 	var err error
 	db, err = xorm.NewPostgreSQL("postgres://postgres:root@localhost:5432/mblog?sslmode=disable")
-	db.SqlMap.SqlMapRootDir="./sql/oracle"
-	db.SqlTemplate.SqlTemplateRootDir="./sql/oracle"
+//	db.SqlMap.SqlMapRootDir="./sql/oracle"
+//	db.SqlTemplate.SqlTemplateRootDir="./sql/oracle"
 	if err != nil {
 		t.Fatal(err)
 	}