Gary Burd 12 gadi atpakaļ
vecāks
revīzija
575286126e
2 mainītis faili ar 222 papildinājumiem un 17 dzēšanām
  1. 98 5
      redis/scan.go
  2. 124 12
      redis/scan_test.go

+ 98 - 5
redis/scan.go

@@ -23,6 +23,14 @@ import (
 	"sync"
 )
 
+func ensureLen(d reflect.Value, n int) {
+	if n > d.Cap() {
+		d.Set(reflect.MakeSlice(d.Type(), n, n))
+	} else {
+		d.SetLen(n)
+	}
+}
+
 func cannotConvert(d reflect.Value, s interface{}) error {
 	return fmt.Errorf("redigo: Scan cannot convert from %s to %s",
 		reflect.TypeOf(s), d.Type())
@@ -91,11 +99,7 @@ func convertAssignValues(d reflect.Value, s []interface{}) (err error) {
 	if d.Type().Kind() != reflect.Slice {
 		return cannotConvert(d, s)
 	}
-	if len(s) > d.Cap() {
-		d.Set(reflect.MakeSlice(d.Type(), len(s), len(s)))
-	} else {
-		d.SetLen(len(s))
-	}
+	ensureLen(d, len(s))
 	for i := 0; i < len(s); i++ {
 		switch s := s[i].(type) {
 		case []byte:
@@ -375,6 +379,95 @@ func ScanStruct(src []interface{}, dest interface{}) error {
 	return nil
 }
 
+var (
+	scanSliceValueError = errors.New("redigo: ScanSlice dest must be non-nil pointer to a struct.")
+	scanSliceSrcError   = errors.New("redigo: ScanSlice src element must be bulk or nil.")
+)
+
+// ScanSlice scans multi-bulk src to the slice pointed to by dest. The elements
+// the dest slice must be integer, float, boolean, string, struct or pointer to
+// struct values.
+//
+// Struct fields must be integer, float, boolean or string values. All struct
+// fields are used unless a subset is specified using fieldNames.
+func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error {
+	d := reflect.ValueOf(dest)
+	if d.Kind() != reflect.Ptr || d.IsNil() {
+		return scanSliceValueError
+	}
+	d = d.Elem()
+	if d.Kind() != reflect.Slice {
+		return scanSliceValueError
+	}
+
+	isPtr := false
+	t := d.Type().Elem()
+	if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
+		isPtr = true
+		t = t.Elem()
+	}
+
+	if t.Kind() != reflect.Struct {
+		ensureLen(d, len(src))
+		for i, s := range src {
+			if s == nil {
+				continue
+			}
+			s, ok := s.([]byte)
+			if !ok {
+				return scanSliceSrcError
+			}
+			if err := convertAssignBytes(d.Index(i), s); err != nil {
+				return err
+			}
+		}
+		return nil
+	}
+
+	ss := structSpecForType(t)
+	fss := ss.l
+	if len(fieldNames) > 0 {
+		fss = make([]*fieldSpec, len(fieldNames))
+		for i, name := range fieldNames {
+			fss[i] = ss.m[name]
+			if fss[i] == nil {
+				return errors.New("redigo: bad field name " + name)
+			}
+		}
+	}
+
+	n := len(src) / len(fss)
+	if n*len(fss) != len(src) {
+		return errors.New("redigo: length of ScanSlice not a multiple of struct field count.")
+	}
+
+	ensureLen(d, n)
+	for i := 0; i < n; i++ {
+		d := d.Index(i)
+		if isPtr {
+			if d.IsNil() {
+				d.Set(reflect.New(t))
+			}
+			d = d.Elem()
+		}
+		for j, fs := range fss {
+			s := src[i*len(fss)+j]
+			if s == nil {
+				continue
+			}
+			sb, ok := s.([]byte)
+			if !ok {
+				return scanSliceSrcError
+			}
+			d := d.FieldByIndex(fs.index)
+			if err := convertAssignBytes(d, sb); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
 // Args is a helper for constructing command arguments from structured values.
 type Args []interface{}
 

+ 124 - 12
redis/scan_test.go

@@ -56,18 +56,6 @@ var scanConversionTests = []struct {
 	{[]interface{}{[]byte("1")}, []bool{true}},
 }
 
-var scanConversionErrorTests = []struct {
-	src  interface{}
-	dest interface{}
-}{
-	{[]byte("1234"), byte(0)},
-	{int64(1234), byte(0)},
-	{[]byte("-1"), byte(0)},
-	{int64(-1), byte(0)},
-	{[]byte("junk"), false},
-	{redis.Error("blah"), false},
-}
-
 func TestScanConversion(t *testing.T) {
 	for _, tt := range scanConversionTests {
 		values := []interface{}{tt.src}
@@ -83,6 +71,18 @@ func TestScanConversion(t *testing.T) {
 	}
 }
 
+var scanConversionErrorTests = []struct {
+	src  interface{}
+	dest interface{}
+}{
+	{[]byte("1234"), byte(0)},
+	{int64(1234), byte(0)},
+	{[]byte("-1"), byte(0)},
+	{int64(-1), byte(0)},
+	{[]byte("junk"), false},
+	{redis.Error("blah"), false},
+}
+
 func TestScanConversionError(t *testing.T) {
 	for _, tt := range scanConversionErrorTests {
 		values := []interface{}{tt.src}
@@ -204,6 +204,118 @@ func TestBadScanStructArgs(t *testing.T) {
 	test(&v2)
 }
 
+var scanSliceTests = []struct {
+	src        []interface{}
+	fieldNames []string
+	ok         bool
+	dest       interface{}
+}{
+	{
+		[]interface{}{[]byte("1"), nil, []byte("-1")},
+		nil,
+		true,
+		[]int{1, 0, -1},
+	},
+	{
+		[]interface{}{[]byte("1"), nil, []byte("2")},
+		nil,
+		true,
+		[]uint{1, 0, 2},
+	},
+	{
+		[]interface{}{[]byte("-1")},
+		nil,
+		false,
+		[]uint{1},
+	},
+	{
+		[]interface{}{[]byte("hello"), nil, []byte("world")},
+		nil,
+		true,
+		[][]byte{[]byte("hello"), nil, []byte("world")},
+	},
+	{
+		[]interface{}{[]byte("hello"), nil, []byte("world")},
+		nil,
+		true,
+		[]string{"hello", "", "world"},
+	},
+	{
+		[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
+		nil,
+		true,
+		[]struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
+	},
+	{
+		[]interface{}{[]byte("a1"), []byte("b1")},
+		nil,
+		false,
+		[]struct{ A, B, C string }{{"a1", "b1", ""}},
+	},
+	{
+		[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
+		nil,
+		true,
+		[]*struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
+	},
+	{
+		[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
+		[]string{"A", "B"},
+		true,
+		[]struct{ A, C, B string }{{"a1", "", "b1"}, {"a2", "", "b2"}},
+	},
+}
+
+func TestScanSlice(t *testing.T) {
+	for _, tt := range scanSliceTests {
+
+		typ := reflect.ValueOf(tt.dest).Type()
+		dest := reflect.New(typ)
+
+		err := redis.ScanSlice(tt.src, dest.Interface(), tt.fieldNames...)
+		if tt.ok != (err == nil) {
+			t.Errorf("ScanSlice(%v, []%s, %v) returned error %v", tt.src, typ, tt.fieldNames, err)
+			continue
+		}
+		if tt.ok && !reflect.DeepEqual(dest.Elem().Interface(), tt.dest) {
+			t.Errorf("ScanSlice(src, []%s) returned %#v, want %#v", typ, dest.Elem().Interface(), tt.dest)
+		}
+	}
+}
+
+func ExampleScanSlice() {
+	c, err := dial()
+	if err != nil {
+		panic(err)
+	}
+	defer c.Close()
+
+	c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
+	c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
+	c.Send("HMSET", "album:3", "title", "Beat", "rating", 4)
+	c.Send("LPUSH", "albums", "1")
+	c.Send("LPUSH", "albums", "2")
+	c.Send("LPUSH", "albums", "3")
+	values, err := redis.Values(c.Do("SORT", "albums",
+		"BY", "album:*->rating",
+		"GET", "album:*->title",
+		"GET", "album:*->rating"))
+	if err != nil {
+		panic(err)
+	}
+
+	var albums []struct {
+		Title  string
+		Rating int
+	}
+	if err := redis.ScanSlice(values, &albums); err != nil {
+		panic(err)
+	}
+	fmt.Printf("%v\n", albums)
+	// Output:
+	// [{Earthbound 1} {Beat 4} {Red 5}]
+}
+
 var argsTests = []struct {
 	title    string
 	actual   redis.Args