Преглед изворни кода

marshal: support nested tuples (#937)

Add support for marshalling/unmarshalling nested tuples via slice, array
and structs.

Also improve tuple handling so they can be used with slices, arrays and
structs for values.
Chris Bannister пре 8 година
родитељ
комит
bb83efe9b6
2 измењених фајлова са 152 додато и 5 уклоњено
  1. 101 3
      marshal.go
  2. 51 2
      tuple_test.go

+ 101 - 3
marshal.go

@@ -1595,12 +1595,11 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
 	case unsetColumn:
 	case unsetColumn:
 		return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples")
 		return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples")
 	case []interface{}:
 	case []interface{}:
-		var buf []byte
-
 		if len(v) != len(tuple.Elems) {
 		if len(v) != len(tuple.Elems) {
 			return nil, unmarshalErrorf("cannont marshal tuple: wrong number of elements")
 			return nil, unmarshalErrorf("cannont marshal tuple: wrong number of elements")
 		}
 		}
 
 
+		var buf []byte
 		for i, elem := range v {
 		for i, elem := range v {
 			data, err := Marshal(tuple.Elems[i], elem)
 			data, err := Marshal(tuple.Elems[i], elem)
 			if err != nil {
 			if err != nil {
@@ -1615,7 +1614,51 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
 		return buf, nil
 		return buf, nil
 	}
 	}
 
 
-	return nil, unmarshalErrorf("cannot marshal %T into %s", value, tuple)
+	rv := reflect.ValueOf(value)
+	t := rv.Type()
+	k := t.Kind()
+
+	switch k {
+	case reflect.Struct:
+		if v := t.NumField(); v != len(tuple.Elems) {
+			return nil, marshalErrorf("can not marshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems))
+		}
+
+		var buf []byte
+		for i, elem := range tuple.Elems {
+			data, err := Marshal(elem, rv.Field(i).Interface())
+			if err != nil {
+				return nil, err
+			}
+
+			n := len(data)
+			buf = appendInt(buf, int32(n))
+			buf = append(buf, data...)
+		}
+
+		return buf, nil
+	case reflect.Slice, reflect.Array:
+		size := rv.Len()
+		if size != len(tuple.Elems) {
+			return nil, marshalErrorf("can not marshal tuple into %v of length %d need %d elements", k, size, len(tuple.Elems))
+		}
+
+		var buf []byte
+		for i, elem := range tuple.Elems {
+			data, err := Marshal(elem, rv.Index(i).Interface())
+			if err != nil {
+				return nil, err
+			}
+
+			n := len(data)
+			buf = appendInt(buf, int32(n))
+			buf = append(buf, data...)
+		}
+
+		return buf, nil
+	}
+
+	return nil, marshalErrorf("cannot marshal %T into %s", value, tuple)
 }
 }
 
 
 // currently only support unmarshal into a list of values, this makes it possible
 // currently only support unmarshal into a list of values, this makes it possible
@@ -1644,6 +1687,61 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
 		return nil
 		return nil
 	}
 	}
 
 
+	rv := reflect.ValueOf(value)
+	if rv.Kind() != reflect.Ptr {
+		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+	}
+
+	rv = rv.Elem()
+	t := rv.Type()
+	k := t.Kind()
+
+	switch k {
+	case reflect.Struct:
+		if v := t.NumField(); v != len(tuple.Elems) {
+			return unmarshalErrorf("can not unmarshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems))
+		}
+
+		for i, elem := range tuple.Elems {
+			m := readInt(data)
+			data = data[4:]
+
+			v := elem.New()
+			if err := Unmarshal(elem, data[:m], v); err != nil {
+				return err
+			}
+			rv.Field(i).Set(reflect.ValueOf(v).Elem())
+
+			data = data[m:]
+		}
+
+		return nil
+	case reflect.Slice, reflect.Array:
+		if k == reflect.Array {
+			size := rv.Len()
+			if size != len(tuple.Elems) {
+				return unmarshalErrorf("can not unmarshal tuple into array of length %d need %d elements", size, len(tuple.Elems))
+			}
+		} else {
+			rv.Set(reflect.MakeSlice(t, len(tuple.Elems), len(tuple.Elems)))
+		}
+
+		for i, elem := range tuple.Elems {
+			m := readInt(data)
+			data = data[4:]
+
+			v := elem.New()
+			if err := Unmarshal(elem, data[:m], v); err != nil {
+				return err
+			}
+			rv.Index(i).Set(reflect.ValueOf(v).Elem())
+
+			data = data[m:]
+		}
+
+		return nil
+	}
+
 	return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
 	return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
 }
 }
 
 

+ 51 - 2
tuple_test.go

@@ -2,7 +2,10 @@
 
 
 package gocql
 package gocql
 
 
-import "testing"
+import (
+	"reflect"
+	"testing"
+)
 
 
 func TestTupleSimple(t *testing.T) {
 func TestTupleSimple(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
@@ -55,7 +58,6 @@ func TestTupleMapScan(t *testing.T) {
 	if session.cfg.ProtoVersion < protoVersion3 {
 	if session.cfg.ProtoVersion < protoVersion3 {
 		t.Skip("tuple types are only available of proto>=3")
 		t.Skip("tuple types are only available of proto>=3")
 	}
 	}
-	defer session.Close()
 
 
 	err := createTable(session, `CREATE TABLE gocql_test.tuple_map_scan(
 	err := createTable(session, `CREATE TABLE gocql_test.tuple_map_scan(
 		id int,
 		id int,
@@ -76,3 +78,50 @@ func TestTupleMapScan(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 }
 }
+
+func TestTuple_NestedCollection(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	if session.cfg.ProtoVersion < protoVersion3 {
+		t.Skip("tuple types are only available of proto>=3")
+	}
+
+	err := createTable(session, `CREATE TABLE gocql_test.nested_tuples(
+		id int,
+		val list<frozen<tuple<int, text>>>,
+
+		primary key(id))`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	type typ struct {
+		A int
+		B string
+	}
+
+	tests := []struct {
+		name string
+		val  interface{}
+	}{
+		{name: "slice", val: [][]interface{}{{1, "2"}, {3, "4"}}},
+		{name: "array", val: [][2]interface{}{{1, "2"}, {3, "4"}}},
+		{name: "struct", val: []typ{{1, "2"}, {3, "4"}}},
+	}
+
+	for i, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			if err := session.Query(`INSERT INTO nested_tuples (id, val) VALUES (?, ?);`, i, test.val).Exec(); err != nil {
+				t.Fatal(err)
+			}
+
+			rv := reflect.ValueOf(test.val)
+			res := reflect.New(rv.Type()).Elem().Addr().Interface()
+
+			err = session.Query(`SELECT val FROM nested_tuples WHERE id=?`, i).Scan(res)
+			if err != nil {
+				t.Fatal(err)
+			}
+		})
+	}
+}