Browse Source

marshal: correctly marshal nil values

When marshalling nil values for types which are nullable in
Cassandra we should return a nil slice which will be encoded
as -1 length on the wire which Cassandra will interpret as null.

Fixes #523
Chris Bannister 10 năm trước cách đây
mục cha
commit
965ebcb166
2 tập tin đã thay đổi với 71 bổ sung0 xóa
  1. 42 0
      marshal.go
  2. 29 0
      marshal_test.go

+ 42 - 0
marshal.go

@@ -190,6 +190,11 @@ func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) {
 	case []byte:
 		return v, nil
 	}
+
+	if value == nil {
+		return nil, nil
+	}
+
 	rv := reflect.ValueOf(value)
 	t := rv.Type()
 	k := t.Kind()
@@ -364,6 +369,11 @@ func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 		return encBigInt(i), nil
 	}
+
+	if value == nil {
+		return nil, nil
+	}
+
 	rv := reflect.ValueOf(value)
 	switch rv.Type().Kind() {
 	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
@@ -617,6 +627,11 @@ func marshalBool(info TypeInfo, value interface{}) ([]byte, error) {
 	case bool:
 		return encBool(v), nil
 	}
+
+	if value == nil {
+		return nil, nil
+	}
+
 	rv := reflect.ValueOf(value)
 	switch rv.Type().Kind() {
 	case reflect.Bool:
@@ -667,6 +682,11 @@ func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) {
 	case float32:
 		return encInt(int32(math.Float32bits(v))), nil
 	}
+
+	if value == nil {
+		return nil, nil
+	}
+
 	rv := reflect.ValueOf(value)
 	switch rv.Type().Kind() {
 	case reflect.Float32:
@@ -703,6 +723,9 @@ func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) {
 	case float64:
 		return encBigInt(int64(math.Float64bits(v))), nil
 	}
+	if value == nil {
+		return nil, nil
+	}
 	rv := reflect.ValueOf(value)
 	switch rv.Type().Kind() {
 	case reflect.Float64:
@@ -733,6 +756,10 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error {
 }
 
 func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) {
+	if value == nil {
+		return nil, nil
+	}
+
 	switch v := value.(type) {
 	case Marshaler:
 		return v.MarshalCQL(info)
@@ -816,6 +843,11 @@ func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) {
 		x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
 		return encBigInt(x), nil
 	}
+
+	if value == nil {
+		return nil, nil
+	}
+
 	rv := reflect.ValueOf(value)
 	switch rv.Type().Kind() {
 	case reflect.Int64:
@@ -1088,6 +1120,11 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 		return b[:], nil
 	}
+
+	if value == nil {
+		return nil, nil
+	}
+
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 
@@ -1166,6 +1203,11 @@ func marshalInet(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 		return nil, marshalErrorf("cannot marshal. invalid ip string %s", val)
 	}
+
+	if value == nil {
+		return nil, nil
+	}
+
 	return nil, marshalErrorf("cannot marshal %T into %s", value, info)
 }
 

+ 29 - 0
marshal_test.go

@@ -897,3 +897,32 @@ func TestMarshalTuple(t *testing.T) {
 		t.Errorf("unmarshalTest: expected [foo, bar], got [%s, %s]", s1, s2)
 	}
 }
+
+func TestMarshalNil(t *testing.T) {
+	types := []Type{
+		TypeAscii,
+		TypeBlob,
+		TypeBoolean,
+		TypeBigInt,
+		TypeCounter,
+		TypeDecimal,
+		TypeDouble,
+		TypeFloat,
+		TypeInt,
+		TypeTimestamp,
+		TypeUUID,
+		TypeVarchar,
+		TypeVarint,
+		TypeTimeUUID,
+		TypeInet,
+	}
+
+	for _, typ := range types {
+		data, err := Marshal(NativeType{proto: 3, typ: typ}, nil)
+		if err != nil {
+			t.Errorf("unable to marshal nil %v: %v\n", typ, err)
+		} else if data != nil {
+			t.Errorf("expected to get nil byte for nil %v got % X", typ, data)
+		}
+	}
+}