Browse Source

more Find supports

xormplus 8 years ago
parent
commit
35acd228da
2 changed files with 54 additions and 41 deletions
  1. 2 2
      session.go
  2. 52 39
      session_find.go

+ 2 - 2
session.go

@@ -445,10 +445,10 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c
 type Cell *interface{}
 
 func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount int,
-	table *core.Table, newElemFunc func() reflect.Value,
+	table *core.Table, newElemFunc func([]string) reflect.Value,
 	sliceValueSetFunc func(*reflect.Value, core.PK) error) error {
 	for rows.Next() {
-		var newValue = newElemFunc()
+		var newValue = newElemFunc(fields)
 		bean := newValue.Interface()
 		dataStruct := rValue(bean)
 		pk, err := session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, table)

+ 52 - 39
session_find.go

@@ -193,31 +193,43 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
 		return err
 	}
 
-	var newElemFunc func() reflect.Value
+	var newElemFunc func(fields []string) reflect.Value
 	elemType := containerValue.Type().Elem()
+	var isPointer bool
 	if elemType.Kind() == reflect.Ptr {
-		newElemFunc = func() reflect.Value {
-			return reflect.New(elemType.Elem())
-		}
-	} else {
-		newElemFunc = func() reflect.Value {
-			return reflect.New(elemType)
+		isPointer = true
+		elemType = elemType.Elem()
+	}
+	if elemType.Kind() == reflect.Ptr {
+		return errors.New("pointer to pointer is not supported")
+	}
+
+	newElemFunc = func(fields []string) reflect.Value {
+		switch elemType.Kind() {
+		case reflect.Slice:
+			slice := reflect.MakeSlice(elemType, len(fields), len(fields))
+			x := reflect.New(slice.Type())
+			x.Elem().Set(slice)
+			return x
+		case reflect.Map:
+			mp := reflect.MakeMap(elemType)
+			x := reflect.New(mp.Type())
+			x.Elem().Set(mp)
+			return x
 		}
+		return reflect.New(elemType)
 	}
 
 	var containerValueSetFunc func(*reflect.Value, core.PK) error
 
 	if containerValue.Kind() == reflect.Slice {
-		if elemType.Kind() == reflect.Ptr {
-			containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
-				containerValue.Set(reflect.Append(containerValue, reflect.ValueOf(newValue.Interface())))
-				return nil
-			}
-		} else {
-			containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
-				containerValue.Set(reflect.Append(containerValue, reflect.Indirect(reflect.ValueOf(newValue.Interface()))))
-				return nil
+		containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
+			if isPointer {
+				containerValue.Set(reflect.Append(containerValue, newValue.Elem().Addr()))
+			} else {
+				containerValue.Set(reflect.Append(containerValue, newValue.Elem()))
 			}
+			return nil
 		}
 	} else {
 		keyType := containerValue.Type().Key()
@@ -228,40 +240,41 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
 			return errors.New("don't support multiple primary key's map has non-slice key type")
 		}
 
-		if elemType.Kind() == reflect.Ptr {
-			containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
-				keyValue := reflect.New(keyType)
-				err := convertPKToValue(table, keyValue.Interface(), pk)
-				if err != nil {
-					return err
-				}
-				containerValue.SetMapIndex(keyValue.Elem(), reflect.ValueOf(newValue.Interface()))
-				return nil
+		containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
+			keyValue := reflect.New(keyType)
+			err := convertPKToValue(table, keyValue.Interface(), pk)
+			if err != nil {
+				return err
 			}
-		} else {
-			containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
-				keyValue := reflect.New(keyType)
-				err := convertPKToValue(table, keyValue.Interface(), pk)
-				if err != nil {
-					return err
-				}
-				containerValue.SetMapIndex(keyValue.Elem(), reflect.Indirect(reflect.ValueOf(newValue.Interface())))
-				return nil
+			if isPointer {
+				containerValue.SetMapIndex(keyValue.Elem(), newValue.Elem().Addr())
+			} else {
+				containerValue.SetMapIndex(keyValue.Elem(), newValue.Elem())
 			}
+			return nil
 		}
 	}
 
-	var newValue = newElemFunc()
-	dataStruct := rValue(newValue.Interface())
-	if dataStruct.Kind() == reflect.Struct {
+	if elemType.Kind() == reflect.Struct {
+		var newValue = newElemFunc(fields)
+		dataStruct := rValue(newValue.Interface())
 		return session.rows2Beans(rawRows, fields, len(fields), session.Engine.autoMapType(dataStruct), newElemFunc, containerValueSetFunc)
 	}
 
 	for rawRows.Next() {
-		var newValue = newElemFunc()
+		var newValue = newElemFunc(fields)
 		bean := newValue.Interface()
 
-		if err := rawRows.Scan(bean); err != nil {
+		switch elemType.Kind() {
+		case reflect.Slice:
+			err = rawRows.ScanSlice(bean)
+		case reflect.Map:
+			err = rawRows.ScanMap(bean)
+		default:
+			err = rawRows.Scan(bean)
+		}
+
+		if err != nil {
 			return err
 		}