Browse Source

marshal: Fix unmarshalUDT for less values than UDT fields

The CQL binary protocol specification states that a serialized UDT is
allowed to have _less_ values than it has fields:

  "A UDT value will generally have one value for each field of the type
   it represents, but it is allowed to have less values than the type has
   fields."

This is also evident in the Cassandra Java driver's UDTCodec class:

    public UDTValue deserialize(ByteBuffer bytes) {
        ByteBuffer input = bytes.duplicate();
        UDTValue value = definition.newValue();

        int i = 0;
        while (input.hasRemaining() && i < value.values.length) {
            int n = input.getInt();
            value.values[i++] = n < 0 ? null : readBytes(input, n);
        }
        return value;
    }

Fix the unmarshalUDT() function to check whether there's more data to
read to avoid accessing a slice out of bounds:

  panic: runtime error: slice bounds out of range

  goroutine 1 [running]:
  github.com/gocql/gocql.unmarshalUDT(0x7ffff7fbc380, 0xc208030c60, 0x0, 0x0, 0x0, 0x694f60, 0xc208094da0, 0x0, 0x0)
      /disk/GOPATH/src/github.com/gocql/gocql/marshal.go:1654 +0x574
  github.com/gocql/gocql.Unmarshal(0x7ffff7fbc380, 0xc208030c60, 0x0, 0x0, 0x0, 0x694f60, 0xc208094da0, 0x0, 0x0)
      /disk/GOPATH/src/github.com/gocql/gocql/marshal.go:157 +0x9b7
  github.com/gocql/gocql.(*Iter).Scan(0xc20806c630, 0xc20800adc0, 0x1, 0x1, 0x7ffff7fa9000)
      /disk/GOPATH/src/github.com/gocql/gocql/session.go:1084 +0x6ff
  github.com/gocql/gocql.(*Query).Scan(0xc208082700, 0xc20800adc0, 0x1, 0x1, 0x0, 0x0)
      /disk/GOPATH/src/github.com/gocql/gocql/session.go:918 +0xab
  main.main()
      /home/a.go:86 +0x4cd

Fixes #757.
Pekka Enberg 9 years ago
parent
commit
acbcce0946
3 changed files with 46 additions and 0 deletions
  1. 1 0
      AUTHORS
  2. 6 0
      marshal.go
  3. 39 0
      udt_test.go

+ 1 - 0
AUTHORS

@@ -73,3 +73,4 @@ Michael Highstead <highstead@gmail.com>
 Sarah Brown <esbie.is@gmail.com>
 Caleb Doxsey <caleb@datadoghq.com>
 Frederic Hemery <frederic.hemery@datadoghq.com>
+Pekka Enberg <penberg@scylladb.com>

+ 6 - 0
marshal.go

@@ -1651,6 +1651,9 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
 		udt := info.(UDTTypeInfo)
 
 		for _, e := range udt.Elements {
+			if len(data) == 0 {
+				return nil
+			}
 			size := readInt(data[:4])
 			data = data[4:]
 
@@ -1689,6 +1692,9 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
 		m := *v
 
 		for _, e := range udt.Elements {
+			if len(data) == 0 {
+				return nil
+			}
 			size := readInt(data[:4])
 			data = data[4:]
 

+ 39 - 0
udt_test.go

@@ -503,3 +503,42 @@ func TestUDT_UpdateField(t *testing.T) {
 		t.Errorf("expected %+v: got %+v", *writeCol, *readCol)
 	}
 }
+
+func TestUDT_ScanNullUDT(t *testing.T) {
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE gocql_test.scan_null_udt_position(
+		lat int,
+		lon int,
+		padding text);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE gocql_test.scan_null_udt_houses(
+		id int,
+		name text,
+		loc frozen<position>,
+		primary key(id)
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = session.Query("INSERT INTO scan_null_udt_houses(id, name) VALUES(?, ?)", 1, "test" ).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	pos := &position{}
+
+	err = session.Query("SELECT loc FROM scan_null_udt_houses WHERE id = ?", 1).Scan(pos)
+	if err != nil {
+		t.Fatal(err)
+	}
+}