浏览代码

marshal: handle UDTs which are not complete

If a UDT is updated then it type description will include all fields but
the data will not include all of them.

fixes #595
Chris Bannister 10 年之前
父节点
当前提交
779476a151
共有 2 个文件被更改,包括 63 次插入1 次删除
  1. 6 1
      marshal.go
  2. 57 0
      udt_test.go

+ 6 - 1
marshal.go

@@ -1521,6 +1521,11 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
 
 	udt := info.(UDTTypeInfo)
 	for _, e := range udt.Elements {
+		if len(data) < 4 {
+			// UDT def does not match the column value
+			return nil
+		}
+
 		size := readInt(data[:4])
 		data = data[4:]
 
@@ -1532,7 +1537,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
 			}
 
 			if !f.IsValid() || !f.CanAddr() {
-				return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
+				return unmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name)
 			}
 
 			fk := f.Addr().Interface()

+ 57 - 0
udt_test.go

@@ -424,3 +424,60 @@ func TestUDT_EmptyCollections(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+func TestUDT_UpdateField(t *testing.T) {
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE gocql_test.update_field_udt(
+		name text,
+		owner text);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE gocql_test.update_field(
+		id uuid,
+		udt_col frozen<update_field_udt>,
+
+		primary key(id)
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	type col struct {
+		Name  string `cql:"name"`
+		Owner string `cql:"owner"`
+		Data  string `cql:"data"`
+	}
+
+	writeCol := &col{
+		Name:  "test-name",
+		Owner: "test-owner",
+	}
+
+	id := TimeUUID()
+	err = session.Query("INSERT INTO update_field(id, udt_col) VALUES(?, ?)", id, writeCol).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if err := createTable(session, `ALTER TYPE gocql_test.update_field_udt ADD data text;`); err != nil {
+		t.Fatal(err)
+	}
+
+	readCol := &col{}
+	err = session.Query("SELECT udt_col FROM update_field WHERE id = ?", id).Scan(readCol)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if *readCol != *writeCol {
+		t.Errorf("expected %+v: got %+v", *writeCol, *readCol)
+	}
+}