浏览代码

fix bugs

* fix error when get null var
* add support get for null var
* fix bugs
* add test for SQL get
* fix tests
Unknown 6 年之前
父节点
当前提交
b8f0029aaa
共有 4 个文件被更改,包括 272 次插入1 次删除
  1. 2 0
      session_find.go
  2. 118 0
      session_get.go
  3. 151 0
      session_get_test.go
  4. 1 1
      xorm.go

+ 2 - 0
session_find.go

@@ -84,6 +84,8 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
 }
 
 func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
+	defer session.resetStatement()
+
 	if session.statement.lastError != nil {
 		return session.statement.lastError
 	}

+ 118 - 0
session_get.go

@@ -24,6 +24,8 @@ func (session *Session) Get(bean interface{}) (bool, error) {
 }
 
 func (session *Session) get(bean interface{}) (bool, error) {
+	defer session.resetStatement()
+
 	if session.statement.lastError != nil {
 		return false, session.statement.lastError
 	}
@@ -86,6 +88,8 @@ func (session *Session) get(bean interface{}) (bool, error) {
 	if context != nil {
 		res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
 		if res != nil {
+			session.engine.logger.Debug("hit context cache", sqlStr)
+
 			structValue := reflect.Indirect(reflect.ValueOf(bean))
 			structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
 			session.lastSQL = ""
@@ -93,13 +97,16 @@ func (session *Session) get(bean interface{}) (bool, error) {
 			return true, nil
 		}
 	}
+
 	has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...)
 	if err != nil || !has {
 		return has, err
 	}
+
 	if context != nil {
 		context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean)
 	}
+
 	return true, nil
 }
 
@@ -138,6 +145,114 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
 			vvv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(Value(v)))
 		}
 
+		return true, nil
+	case *string:
+		var res sql.NullString
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*string)) = res.String
+		}
+		return true, nil
+	case *int:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*int)) = int(res.Int64)
+		}
+		return true, nil
+	case *int8:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*int8)) = int8(res.Int64)
+		}
+		return true, nil
+	case *int16:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*int16)) = int16(res.Int64)
+		}
+		return true, nil
+	case *int32:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*int32)) = int32(res.Int64)
+		}
+		return true, nil
+	case *int64:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*int64)) = int64(res.Int64)
+		}
+		return true, nil
+	case *uint:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*uint)) = uint(res.Int64)
+		}
+		return true, nil
+	case *uint8:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*uint8)) = uint8(res.Int64)
+		}
+		return true, nil
+	case *uint16:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*uint16)) = uint16(res.Int64)
+		}
+		return true, nil
+	case *uint32:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*uint32)) = uint32(res.Int64)
+		}
+		return true, nil
+	case *uint64:
+		var res sql.NullInt64
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*uint64)) = uint64(res.Int64)
+		}
+		return true, nil
+	case *bool:
+		var res sql.NullBool
+		if err := rows.Scan(&res); err != nil {
+			return true, err
+		}
+		if res.Valid {
+			*(bean.(*bool)) = res.Bool
+		}
 		return true, nil
 	}
 
@@ -167,6 +282,9 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
 		err = rows.ScanSlice(bean)
 	case reflect.Map:
 		err = rows.ScanMap(bean)
+	case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		err = rows.Scan(&bean)
 	default:
 		err = rows.Scan(bean)
 	}

+ 151 - 0
session_get_test.go

@@ -47,6 +47,12 @@ func TestGetVar(t *testing.T) {
 	assert.Equal(t, true, has)
 	assert.Equal(t, 28, age)
 
+	var ageMax int
+	has, err = testEngine.SQL("SELECT max(age) FROM "+testEngine.TableName("get_var", true)+" WHERE `id` = ?", data.Id).Get(&ageMax)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.Equal(t, 28, ageMax)
+
 	var age2 int64
 	has, err = testEngine.Table("get_var").Cols("age").
 		Where("age > ?", 20).
@@ -56,6 +62,69 @@ func TestGetVar(t *testing.T) {
 	assert.Equal(t, true, has)
 	assert.EqualValues(t, 28, age2)
 
+	var age3 int8
+	has, err = testEngine.Table("get_var").Cols("age").Get(&age3)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.EqualValues(t, 28, age3)
+
+	var age4 int16
+	has, err = testEngine.Table("get_var").Cols("age").
+		Where("age > ?", 20).
+		And("age < ?", 30).
+		Get(&age4)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.EqualValues(t, 28, age4)
+
+	var age5 int32
+	has, err = testEngine.Table("get_var").Cols("age").
+		Where("age > ?", 20).
+		And("age < ?", 30).
+		Get(&age5)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.EqualValues(t, 28, age5)
+
+	var age6 int
+	has, err = testEngine.Table("get_var").Cols("age").Get(&age6)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.EqualValues(t, 28, age6)
+
+	var age7 int64
+	has, err = testEngine.Table("get_var").Cols("age").
+		Where("age > ?", 20).
+		And("age < ?", 30).
+		Get(&age7)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.EqualValues(t, 28, age7)
+
+	var age8 int8
+	has, err = testEngine.Table("get_var").Cols("age").Get(&age8)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.EqualValues(t, 28, age8)
+
+	var age9 int16
+	has, err = testEngine.Table("get_var").Cols("age").
+		Where("age > ?", 20).
+		And("age < ?", 30).
+		Get(&age9)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.EqualValues(t, 28, age9)
+
+	var age10 int32
+	has, err = testEngine.Table("get_var").Cols("age").
+		Where("age > ?", 20).
+		And("age < ?", 30).
+		Get(&age10)
+	assert.NoError(t, err)
+	assert.Equal(t, true, has)
+	assert.EqualValues(t, 28, age10)
+
 	var id sql.NullInt64
 	has, err = testEngine.Table("get_var").Cols("id").Get(&id)
 	assert.NoError(t, err)
@@ -433,3 +502,85 @@ func TestGetCustomTableInterface(t *testing.T) {
 	assert.NoError(t, err)
 	assert.True(t, has)
 }
+
+func TestGetNullVar(t *testing.T) {
+	type TestGetNullVarStruct struct {
+		Id   int64
+		Name string
+		Age  int
+	}
+
+	assert.NoError(t, prepareEngine())
+	assertSync(t, new(TestGetNullVarStruct))
+
+	affected, err := testEngine.Exec("insert into " + testEngine.TableName(new(TestGetNullVarStruct), true) + " (name,age) values (null,null)")
+	assert.NoError(t, err)
+	a, _ := affected.RowsAffected()
+	assert.EqualValues(t, 1, a)
+
+	var name string
+	has, err := testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("name").Get(&name)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, "", name)
+
+	var age int
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age)
+
+	var age2 int8
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age2)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age2)
+
+	var age3 int16
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age3)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age3)
+
+	var age4 int32
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age4)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age4)
+
+	var age5 int64
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age5)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age5)
+
+	var age6 uint
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age6)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age6)
+
+	var age7 uint8
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age7)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age7)
+
+	var age8 int16
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age8)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age8)
+
+	var age9 int32
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age9)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age9)
+
+	var age10 int64
+	has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age10)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 0, age10)
+}

+ 1 - 1
xorm.go

@@ -20,7 +20,7 @@ import (
 
 const (
 	// Version show the xorm's version
-	Version string = "0.7.4.0724"
+	Version string = "0.7.5.0803"
 )
 
 func regDrvsNDialects() bool {