소스 검색

Merge pull request #521 from Zariel/nil-collections

marshal: handle nil collection correctly in udts
Chris Bannister 10 년 전
부모
커밋
c41eb5218c
4개의 변경된 파일81개의 추가작업 그리고 24개의 파일을 삭제
  1. 9 0
      helpers.go
  2. 34 20
      marshal.go
  3. 4 4
      marshal_test.go
  4. 34 0
      udt_test.go

+ 9 - 0
helpers.go

@@ -106,6 +106,15 @@ func getApacheCassandraType(class string) Type {
 	}
 }
 
+func typeCanBeNull(typ TypeInfo) bool {
+	switch typ.(type) {
+	case CollectionType, UDTTypeInfo, TupleTypeInfo:
+		return false
+	}
+
+	return true
+}
+
 func (r *RowData) rowMap(m map[string]interface{}) {
 	for i, column := range r.Columns {
 		val := dereference(r.Values[i])

+ 34 - 20
marshal.go

@@ -43,9 +43,6 @@ type Unmarshaler interface {
 // Marshal returns the CQL encoding of the value for the Cassandra
 // internal type described by the info parameter.
 func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
-	if value == nil {
-		return nil, nil
-	}
 	if info.Version() < protoVersion1 {
 		panic("protocol version not set")
 	}
@@ -290,7 +287,16 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 		return encInt(int32(i)), nil
 	}
+
+	if value == nil {
+		return nil, nil
+	}
+
 	rv := reflect.ValueOf(value)
+	if rv.IsNil() {
+		return nil, nil
+	}
+
 	switch rv.Type().Kind() {
 	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
 		v := rv.Int()
@@ -881,9 +887,6 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) {
 	k := t.Kind()
 	switch k {
 	case reflect.Slice, reflect.Array:
-		if k == reflect.Slice && rv.IsNil() {
-			return nil, nil
-		}
 		buf := &bytes.Buffer{}
 		n := rv.Len()
 
@@ -989,9 +992,7 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) {
 	if t.Kind() != reflect.Map {
 		return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	}
-	if rv.IsNil() {
-		return nil, nil
-	}
+
 	buf := &bytes.Buffer{}
 	n := rv.Len()
 
@@ -1285,9 +1286,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 				return nil, err
 			}
 
-			n := len(data)
-			buf = appendInt(buf, int32(n))
-			buf = append(buf, data...)
+			if data == nil && typeCanBeNull(e.Type) {
+				buf = appendInt(buf, -1)
+			} else {
+				buf = appendInt(buf, int32(len(data)))
+				buf = append(buf, data...)
+			}
 		}
 
 		return buf, nil
@@ -1304,9 +1308,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 				return nil, err
 			}
 
-			n := len(data)
-			buf = appendInt(buf, int32(n))
-			buf = append(buf, data...)
+			if data == nil && typeCanBeNull(e.Type) {
+				buf = appendInt(buf, -1)
+			} else {
+				buf = appendInt(buf, int32(len(data)))
+				buf = append(buf, data...)
+			}
 		}
 
 		return buf, nil
@@ -1342,8 +1349,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 
 		if !f.IsValid() {
-			buf = appendInt(buf, -1)
-			continue
+			if _, ok := e.Type.(CollectionType); ok {
+				f = reflect.Zero(goType(e.Type))
+			} else {
+				buf = appendInt(buf, -1)
+				continue
+			}
 		} else if f.Kind() == reflect.Ptr {
 			if f.IsNil() {
 				buf = appendInt(buf, -1)
@@ -1358,9 +1369,12 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 			return nil, err
 		}
 
-		n := len(data)
-		buf = appendInt(buf, int32(n))
-		buf = append(buf, data...)
+		if data == nil && typeCanBeNull(e.Type) {
+			buf = appendInt(buf, -1)
+		} else {
+			buf = appendInt(buf, int32(len(data)))
+			buf = append(buf, data...)
+		}
 	}
 
 	return buf, nil

+ 4 - 4
marshal_test.go

@@ -257,8 +257,8 @@ var marshalTests = []struct {
 			NativeType: NativeType{proto: 2, typ: TypeSet},
 			Elem:       NativeType{proto: 2, typ: TypeInt},
 		},
-		[]byte(nil),
-		[]int(nil),
+		[]byte{0, 0}, // encoding of a list should always include the size of the collection
+		[]int{},
 	},
 	{
 		CollectionType{
@@ -275,8 +275,8 @@ var marshalTests = []struct {
 			Key:        NativeType{proto: 2, typ: TypeVarchar},
 			Elem:       NativeType{proto: 2, typ: TypeInt},
 		},
-		[]byte(nil),
-		map[string]int(nil),
+		[]byte{0, 0},
+		map[string]int{},
 	},
 	{
 		CollectionType{

+ 34 - 0
udt_test.go

@@ -390,3 +390,37 @@ func TestUDT_MissingField(t *testing.T) {
 		t.Errorf("expected %q: got %q", writeCol.Name, readCol.Name)
 	}
 }
+
+func TestUDT_EmptyCollections(t *testing.T) {
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE gocql_test.nil_collections(
+		a list<text>,
+		b map<text, text>,
+		c set<text>
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE gocql_test.nil_collections(
+		id uuid,
+		udt_col frozen<nil_collections>,
+
+		primary key(id)
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	id := TimeUUID()
+	err = session.Query("INSERT INTO nil_collections(id, udt_col) VALUES(?, ?)", id, &struct{}{}).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+}