Explorar o código

marshall: correctly handle null tuple elements (#985)

UDT and Tuple values are [byte] which can return nil. UDT handled this
and lots of special handling for the slicing the data. Create readBytes
which correctly reads the bytes and handles nil. Use this in tuple and
UDT. Add testcase for tuples.
Chris Bannister %!s(int64=8) %!d(string=hai) anos
pai
achega
ce5020aaba
Modificáronse 2 ficheiros con 76 adicións e 48 borrados
  1. 35 44
      marshal.go
  2. 41 4
      tuple_test.go

+ 35 - 44
marshal.go

@@ -1667,6 +1667,16 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) {
 	return nil, marshalErrorf("cannot marshal %T into %s", value, tuple)
 }
 
+func readBytes(p []byte) ([]byte, []byte) {
+	// TODO: really should use a framer
+	size := readInt(p)
+	p = p[4:]
+	if size < 0 {
+		return nil, p
+	}
+	return p[:size], p[size:]
+}
+
 // currently only support unmarshal into a list of values, this makes it possible
 // to support tuples without changing the query API. In the future this can be extend
 // to allow unmarshalling into custom tuple types.
@@ -1680,14 +1690,13 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
 	case []interface{}:
 		for i, elem := range tuple.Elems {
 			// each element inside data is a [bytes]
-			size := readInt(data)
-			data = data[4:]
+			var p []byte
+			p, data = readBytes(data)
 
-			err := Unmarshal(elem, data[:size], v[i])
+			err := Unmarshal(elem, p, v[i])
 			if err != nil {
 				return err
 			}
-			data = data[size:]
 		}
 
 		return nil
@@ -1864,18 +1873,11 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
 			if len(data) == 0 {
 				return nil
 			}
-			size := readInt(data[:4])
-			data = data[4:]
 
-			var err error
-			if size < 0 {
-				err = v.UnmarshalUDT(e.Name, e.Type, nil)
-			} else {
-				err = v.UnmarshalUDT(e.Name, e.Type, data[:size])
-				data = data[size:]
-			}
+			var p []byte
+			p, data = readBytes(data)
 
-			if err != nil {
+			if err := v.UnmarshalUDT(e.Name, e.Type, p); err != nil {
 				return err
 			}
 		}
@@ -1905,20 +1907,13 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
 			if len(data) == 0 {
 				return nil
 			}
-			size := readInt(data[:4])
-			data = data[4:]
 
 			val := reflect.New(goType(e.Type))
 
-			var err error
-			if size < 0 {
-				err = Unmarshal(e.Type, nil, val.Interface())
-			} else {
-				err = Unmarshal(e.Type, data[:size], val.Interface())
-				data = data[size:]
-			}
+			var p []byte
+			p, data = readBytes(data)
 
-			if err != nil {
+			if err := Unmarshal(e.Type, p, val.Interface()); err != nil {
 				return err
 			}
 
@@ -1958,30 +1953,26 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
 			return nil
 		}
 
-		size := readInt(data[:4])
-		data = data[4:]
+		var p []byte
+		p, data = readBytes(data)
 
-		if size >= 0 {
-			f, ok := fields[e.Name]
-			if !ok {
-				f = k.FieldByName(e.Name)
-				if f == emptyValue {
-					// skip fields which exist in the UDT but not in
-					// the struct passed in
-					data = data[size:] // Skip over this data to go to next
-					continue
-				}
+		f, ok := fields[e.Name]
+		if !ok {
+			f = k.FieldByName(e.Name)
+			if f == emptyValue {
+				// skip fields which exist in the UDT but not in
+				// the struct passed in
+				continue
 			}
+		}
 
-			if !f.IsValid() || !f.CanAddr() {
-				return unmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name)
-			}
+		if !f.IsValid() || !f.CanAddr() {
+			return unmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name)
+		}
 
-			fk := f.Addr().Interface()
-			if err := Unmarshal(e.Type, data[:size], fk); err != nil {
-				return err
-			}
-			data = data[size:]
+		fk := f.Addr().Interface()
+		if err := Unmarshal(e.Type, p, fk); err != nil {
+			return err
 		}
 	}
 

+ 41 - 4
tuple_test.go

@@ -43,13 +43,50 @@ func TestTupleSimple(t *testing.T) {
 
 	if id != 1 {
 		t.Errorf("expected to get id=1 got: %v", id)
-	}
-	if coord.x != 100 {
+	} else if coord.x != 100 {
 		t.Errorf("expected to get coord.x=100 got: %v", coord.x)
-	}
-	if coord.y != -100 {
+	} else if coord.y != -100 {
 		t.Errorf("expected to get coord.y=-100 got: %v", coord.y)
 	}
+
+}
+
+func TestTuple_NullTuple(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	if session.cfg.ProtoVersion < protoVersion3 {
+		t.Skip("tuple types are only available of proto>=3")
+	}
+
+	err := createTable(session, `CREATE TABLE gocql_test.tuple_nil_test(
+		id int,
+		coord frozen<tuple<int, int>>,
+
+		primary key(id))`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	const id = 1
+
+	err = session.Query("INSERT INTO tuple_nil_test(id, coord) VALUES(?, (?, ?))", id, nil, nil).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	x := new(int)
+	y := new(int)
+	iter := session.Query("SELECT coord FROM tuple_nil_test WHERE id=?", id)
+	if err := iter.Scan(&x, &y); err != nil {
+		t.Fatal(err)
+	}
+
+	if x != nil {
+		t.Fatalf("should be nil got %+#v", x)
+	} else if y != nil {
+		t.Fatalf("should be nil got %+#v", y)
+	}
+
 }
 
 func TestTupleMapScan(t *testing.T) {