浏览代码

marshal: handle nil collection correctly in udts

When a UDT has a missing value or a nil value for a collection
we should correctly encode it as a collection with 0 elements not
as a null value.

Fixes #518
Chris Bannister 10 年之前
父节点
当前提交
e59ea7f720
共有 3 个文件被更改,包括 67 次插入11 次删除
  1. 9 0
      helpers.go
  2. 24 11
      marshal.go
  3. 34 0
      udt_test.go

+ 9 - 0
helpers.go

@@ -106,6 +106,15 @@ func getApacheCassandraType(class string) Type {
 	}
 }
 
+func typeCanBeNull(typ TypeInfo) bool {
+	switch typ.(type) {
+	case CollectionType, UDTTypeInfo, TupleTypeInfo:
+		return false
+	}
+
+	return true
+}
+
 func (r *RowData) rowMap(m map[string]interface{}) {
 	for i, column := range r.Columns {
 		val := dereference(r.Values[i])

+ 24 - 11
marshal.go

@@ -1285,9 +1285,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 				return nil, err
 			}
 
-			n := len(data)
-			buf = appendInt(buf, int32(n))
-			buf = append(buf, data...)
+			if data == nil && typeCanBeNull(e.Type) {
+				buf = appendInt(buf, -1)
+			} else {
+				buf = appendInt(buf, int32(len(data)))
+				buf = append(buf, data...)
+			}
 		}
 
 		return buf, nil
@@ -1304,9 +1307,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 				return nil, err
 			}
 
-			n := len(data)
-			buf = appendInt(buf, int32(n))
-			buf = append(buf, data...)
+			if data == nil && typeCanBeNull(e.Type) {
+				buf = appendInt(buf, -1)
+			} else {
+				buf = appendInt(buf, int32(len(data)))
+				buf = append(buf, data...)
+			}
 		}
 
 		return buf, nil
@@ -1342,8 +1348,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 
 		if !f.IsValid() {
-			buf = appendInt(buf, -1)
-			continue
+			if _, ok := e.Type.(CollectionType); ok {
+				f = reflect.Zero(goType(e.Type))
+			} else {
+				buf = appendInt(buf, -1)
+				continue
+			}
 		} else if f.Kind() == reflect.Ptr {
 			if f.IsNil() {
 				buf = appendInt(buf, -1)
@@ -1358,9 +1368,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 			return nil, err
 		}
 
-		n := len(data)
-		buf = appendInt(buf, int32(n))
-		buf = append(buf, data...)
+		if data == nil && typeCanBeNull(e.Type) {
+			buf = appendInt(buf, -1)
+		} else {
+			buf = appendInt(buf, int32(len(data)))
+			buf = append(buf, data...)
+		}
 	}
 
 	return buf, nil

+ 34 - 0
udt_test.go

@@ -390,3 +390,37 @@ func TestUDT_MissingField(t *testing.T) {
 		t.Errorf("expected %q: got %q", writeCol.Name, readCol.Name)
 	}
 }
+
+func TestUDT_EmptyCollections(t *testing.T) {
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE gocql_test.nil_collections(
+		a list<text>,
+		b map<text, text>,
+		c set<text>
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE gocql_test.nil_collections(
+		id uuid,
+		udt_col frozen<nil_collections>,
+
+		primary key(id)
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	id := TimeUUID()
+	err = session.Query("INSERT INTO nil_collections(id, udt_col) VALUES(?, ?)", id, &struct{}{}).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+}