Explorar o código

Merge pull request #678 from Zariel/marshal-nil-udt-collections

marshal: fix udt nil collection marshalling
Chris Bannister %!s(int64=9) %!d(string=hai) anos
pai
achega
f34daf3479
Modificáronse 3 ficheiros con 52 adicións e 38 borrados
  1. 9 0
      frame.go
  2. 20 37
      marshal.go
  3. 23 1
      udt_test.go

+ 9 - 0
frame.go

@@ -1646,6 +1646,15 @@ func (f *framer) writeByte(b byte) {
 	f.wbuf = append(f.wbuf, b)
 }
 
+func appendBytes(p []byte, d []byte) []byte {
+	if d == nil {
+		return appendInt(p, -1)
+	}
+	p = appendInt(p, int32(len(d)))
+	p = append(p, d...)
+	return p
+}
+
 func appendShort(p []byte, n uint16) []byte {
 	return append(p,
 		byte(n>>8),

+ 20 - 37
marshal.go

@@ -929,6 +929,10 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) {
 	rv := reflect.ValueOf(value)
 	t := rv.Type()
 	k := t.Kind()
+	if k == reflect.Slice && rv.IsNil() {
+		return nil, nil
+	}
+
 	switch k {
 	case reflect.Slice, reflect.Array:
 		buf := &bytes.Buffer{}
@@ -994,6 +998,9 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
 			if k == reflect.Array {
 				return unmarshalErrorf("unmarshal list: can not store nil in array value")
 			}
+			if rv.IsNil() {
+				return nil
+			}
 			rv.Set(reflect.Zero(t))
 			return nil
 		}
@@ -1032,6 +1039,10 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) {
 	}
 
 	rv := reflect.ValueOf(value)
+	if rv.IsNil() {
+		return nil, nil
+	}
+
 	t := rv.Type()
 	if t.Kind() != reflect.Map {
 		return nil, marshalErrorf("can not marshal %T into %s", value, info)
@@ -1344,12 +1355,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 				return nil, err
 			}
 
-			if data == nil && typeCanBeNull(e.Type) {
-				buf = appendInt(buf, -1)
-			} else {
-				buf = appendInt(buf, int32(len(data)))
-				buf = append(buf, data...)
-			}
+			buf = appendBytes(buf, data)
 		}
 
 		return buf, nil
@@ -1366,12 +1372,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 				return nil, err
 			}
 
-			if data == nil && typeCanBeNull(e.Type) {
-				buf = appendInt(buf, -1)
-			} else {
-				buf = appendInt(buf, int32(len(data)))
-				buf = append(buf, data...)
-			}
+			buf = appendBytes(buf, data)
 		}
 
 		return buf, nil
@@ -1406,37 +1407,19 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
 			f = k.FieldByName(e.Name)
 		}
 
-		if !f.IsValid() {
-			if _, ok := e.Type.(CollectionType); ok {
-				f = reflect.Zero(goType(e.Type))
-			} else {
-				buf = appendInt(buf, -1)
-				continue
-			}
-		} else if f.Kind() == reflect.Ptr {
-			if f.IsNil() {
-				buf = appendInt(buf, -1)
-				continue
-			} else {
-				f = f.Elem()
+		var data []byte
+		if f.IsValid() && f.CanInterface() {
+			var err error
+			data, err = Marshal(e.Type, f.Interface())
+			if err != nil {
+				return nil, err
 			}
 		}
 
-		data, err := Marshal(e.Type, f.Interface())
-		if err != nil {
-			return nil, err
-		}
-
-		if data == nil && typeCanBeNull(e.Type) {
-			buf = appendInt(buf, -1)
-		} else {
-			buf = appendInt(buf, int32(len(data)))
-			buf = append(buf, data...)
-		}
+		buf = appendBytes(buf, data)
 	}
 
 	return buf, nil
-
 }
 
 func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {

+ 23 - 1
udt_test.go

@@ -418,11 +418,33 @@ func TestUDT_EmptyCollections(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	type udt struct {
+		A []string          `cql:"a"`
+		B map[string]string `cql:"b"`
+		C []string          `cql:"c"`
+	}
+
 	id := TimeUUID()
-	err = session.Query("INSERT INTO nil_collections(id, udt_col) VALUES(?, ?)", id, &struct{}{}).Exec()
+	err = session.Query("INSERT INTO nil_collections(id, udt_col) VALUES(?, ?)", id, &udt{}).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var val udt
+	err = session.Query("SELECT udt_col FROM nil_collections WHERE id=?", id).Scan(&val)
 	if err != nil {
 		t.Fatal(err)
 	}
+
+	if val.A != nil {
+		t.Errorf("expected to get nil got %#+v", val.A)
+	}
+	if val.B != nil {
+		t.Errorf("expected to get nil got %#+v", val.B)
+	}
+	if val.C != nil {
+		t.Errorf("expected to get nil got %#+v", val.C)
+	}
 }
 
 func TestUDT_UpdateField(t *testing.T) {