浏览代码

fix tablename bug

* fix tablename bug

* fix test
xormplus 7 年之前
父节点
当前提交
512ecd4a8a
共有 6 个文件被更改,包括 89 次插入12 次删除
  1. 11 7
      engine_table.go
  2. 2 2
      session_find.go
  3. 73 0
      session_find_test.go
  4. 1 1
      session_insert.go
  5. 1 1
      statement.go
  6. 1 1
      xorm.go

+ 11 - 7
engine_table.go

@@ -45,16 +45,17 @@ func (session *Session) tbNameNoSchema(table *core.Table) string {
 }
 
 func (engine *Engine) tbNameForMap(v reflect.Value) string {
-	t := v.Type()
-	if tb, ok := v.Interface().(TableName); ok {
-		return tb.TableName()
+	if v.Type().Implements(tpTableName) {
+		return v.Interface().(TableName).TableName()
 	}
-	if v.CanAddr() {
-		if tb, ok := v.Addr().Interface().(TableName); ok {
-			return tb.TableName()
+	if v.Kind() == reflect.Ptr {
+		v = v.Elem()
+		if v.Type().Implements(tpTableName) {
+			return v.Interface().(TableName).TableName()
 		}
 	}
-	return engine.TableMapper.Obj2Table(t.Name())
+
+	return engine.TableMapper.Obj2Table(v.Type().Name())
 }
 
 func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
@@ -97,6 +98,9 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
 		return tablename.(TableName).TableName()
 	case string:
 		return tablename.(string)
+	case reflect.Value:
+		v := tablename.(reflect.Value)
+		return engine.tbNameForMap(v)
 	default:
 		v := rValue(tablename)
 		t := v.Type()

+ 2 - 2
session_find.go

@@ -96,7 +96,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 		if sliceElementType.Kind() == reflect.Ptr {
 			if sliceElementType.Elem().Kind() == reflect.Struct {
 				pv := reflect.New(sliceElementType.Elem())
-				if err := session.statement.setRefValue(pv.Elem()); err != nil {
+				if err := session.statement.setRefValue(pv); err != nil {
 					return err
 				}
 			} else {
@@ -104,7 +104,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 			}
 		} else if sliceElementType.Kind() == reflect.Struct {
 			pv := reflect.New(sliceElementType)
-			if err := session.statement.setRefValue(pv.Elem()); err != nil {
+			if err := session.statement.setRefValue(pv); err != nil {
 				return err
 			}
 		} else {

+ 73 - 0
session_find_test.go

@@ -584,3 +584,76 @@ func TestFindAndCountOneFunc(t *testing.T) {
 	assert.EqualValues(t, 1, len(results))
 	assert.EqualValues(t, 1, cnt)
 }
+
+type FindMapDevice struct {
+	Deviceid string `xorm:"pk"`
+	Status   int
+}
+
+func (device *FindMapDevice) TableName() string {
+	return "devices"
+}
+
+func TestFindMapStringId(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+	assertSync(t, new(FindMapDevice))
+
+	cnt, err := testEngine.Insert(&FindMapDevice{
+		Deviceid: "1",
+		Status:   1,
+	})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+
+	deviceIDs := []string{"1"}
+
+	deviceMaps := make(map[string]*FindMapDevice, len(deviceIDs))
+	err = testEngine.
+		Where("status = ?", 1).
+		In("deviceid", deviceIDs).
+		Find(&deviceMaps)
+	assert.NoError(t, err)
+
+	deviceMaps2 := make(map[string]FindMapDevice, len(deviceIDs))
+	err = testEngine.
+		Where("status = ?", 1).
+		In("deviceid", deviceIDs).
+		Find(&deviceMaps2)
+	assert.NoError(t, err)
+
+	devices := make([]*FindMapDevice, 0, len(deviceIDs))
+	err = testEngine.Find(&devices)
+	assert.NoError(t, err)
+
+	devices2 := make([]FindMapDevice, 0, len(deviceIDs))
+	err = testEngine.Find(&devices2)
+	assert.NoError(t, err)
+
+	var device FindMapDevice
+	has, err := testEngine.Get(&device)
+	assert.NoError(t, err)
+	assert.True(t, has)
+
+	has, err = testEngine.Exist(&FindMapDevice{})
+	assert.NoError(t, err)
+	assert.True(t, has)
+
+	cnt, err = testEngine.Count(new(FindMapDevice))
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+
+	cnt, err = testEngine.ID("1").Update(&FindMapDevice{
+		Status: 2,
+	})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+
+	sum, err := testEngine.SumInt(new(FindMapDevice), "status")
+	assert.NoError(t, err)
+	assert.EqualValues(t, 2, sum)
+
+	cnt, err = testEngine.ID("1").Delete(new(FindMapDevice))
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+
+}

+ 1 - 1
session_insert.go

@@ -66,7 +66,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 		return 0, errors.New("could not insert a empty slice")
 	}
 
-	if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
+	if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil {
 		return 0, err
 	}
 

+ 1 - 1
statement.go

@@ -208,7 +208,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
 	if err != nil {
 		return err
 	}
-	statement.tableName = statement.Engine.TableName(v.Interface(), true)
+	statement.tableName = statement.Engine.TableName(v, true)
 	return nil
 }
 

+ 1 - 1
xorm.go

@@ -17,7 +17,7 @@ import (
 
 const (
 	// Version show the xorm's version
-	Version string = "0.6.4.1122"
+	Version string = "0.6.5.0411"
 )
 
 func regDrvsNDialects() bool {