Explorar o código

Merge pull request #284 from probkiizokna/unmarshal_nil

Unmarshal null values into double pointer values.
Ben Hood %!s(int64=11) %!d(string=hai) anos
pai
achega
243807fe32
Modificáronse 2 ficheiros con 47 adicións e 48 borrados
  1. 33 20
      marshal.go
  2. 14 28
      marshal_test.go

+ 33 - 20
marshal.go

@@ -90,6 +90,10 @@ func Unmarshal(info *TypeInfo, data []byte, value interface{}) error {
 	if v, ok := value.(Unmarshaler); ok {
 		return v.UnmarshalCQL(info, data)
 	}
+	if isNullableValue(value) {
+		return unmarshalNullable(info, data, value)
+	}
+
 	switch info.Type {
 	case TypeVarchar, TypeAscii, TypeBlob:
 		return unmarshalVarchar(info, data, value)
@@ -124,6 +128,29 @@ func Unmarshal(info *TypeInfo, data []byte, value interface{}) error {
 	return fmt.Errorf("can not unmarshal %s into %T", info, value)
 }
 
+func isNullableValue(value interface{}) bool {
+	v := reflect.ValueOf(value)
+	return v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Ptr
+}
+
+func isNullData(info *TypeInfo, data []byte) bool {
+	return len(data) == 0
+}
+
+func unmarshalNullable(info *TypeInfo, data []byte, value interface{}) error {
+	valueRef := reflect.ValueOf(value)
+
+	if isNullData(info, data) {
+		nilValue := reflect.Zero(valueRef.Type().Elem())
+		valueRef.Elem().Set(nilValue)
+		return nil
+	} else {
+		newValue := reflect.New(valueRef.Type().Elem().Elem())
+		valueRef.Elem().Set(newValue)
+		return Unmarshal(info, data, newValue.Interface())
+	}
+}
+
 func marshalVarchar(info *TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	case Marshaler:
@@ -323,7 +350,7 @@ func unmarshalInt(info *TypeInfo, data []byte, value interface{}) error {
 
 func unmarshalVarint(info *TypeInfo, data []byte, value interface{}) error {
 	switch value.(type) {
-	case *big.Int, **big.Int:
+	case *big.Int:
 		return unmarshalIntlike(info, 0, data, value)
 	}
 
@@ -449,13 +476,6 @@ func unmarshalIntlike(info *TypeInfo, int64Val int64, data []byte, value interfa
 	case *big.Int:
 		decBigInt2C(data, v)
 		return nil
-	case **big.Int:
-		if len(data) == 0 {
-			*v = nil
-		} else {
-			*v = decBigInt2C(data, nil)
-		}
-		return nil
 	}
 
 	rv := reflect.ValueOf(value)
@@ -680,18 +700,11 @@ func unmarshalDecimal(info *TypeInfo, data []byte, value interface{}) error {
 	switch v := value.(type) {
 	case Unmarshaler:
 		return v.UnmarshalCQL(info, data)
-	case **inf.Dec:
-		if len(data) > 4 {
-			scale := decInt(data[0:4])
-			unscaled := decBigInt2C(data[4:], nil)
-			*v = inf.NewDecBig(unscaled, inf.Scale(scale))
-			return nil
-		} else if len(data) == 0 {
-			*v = nil
-			return nil
-		} else {
-			return unmarshalErrorf("can not unmarshal %s into %T", info, value)
-		}
+	case *inf.Dec:
+		scale := decInt(data[0:4])
+		unscaled := decBigInt2C(data[4:], nil)
+		*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
+		return nil
 	}
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }

+ 14 - 28
marshal_test.go

@@ -315,13 +315,6 @@ var marshalTests = []struct {
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		net.ParseIP("fe80::202:b3ff:fe1e:8329"),
 	},
-}
-
-var marshalNilTests = []struct {
-	Info  *TypeInfo
-	Data  []byte
-	Value interface{}
-}{
 	{
 		&TypeInfo{Type: TypeInt},
 		[]byte(nil),
@@ -493,29 +486,22 @@ func TestMarshal(t *testing.T) {
 	}
 }
 
-func TestMarshalNil(t *testing.T) {
-	for i, test := range marshalNilTests {
-		data, err := Marshal(test.Info, test.Value)
-		if err != nil {
-			t.Errorf("marshalNilTest[%d]: %v", i, err)
-			continue
-		}
-		if !bytes.Equal(data, test.Data) {
-			t.Errorf("marshalNilTest[%d]: expected %q, got %q.(%#v)", i, test.Data, data, test.Value)
-		}
-	}
-}
-
 func TestUnmarshal(t *testing.T) {
 	for i, test := range marshalTests {
-		v := reflect.New(reflect.TypeOf(test.Value))
-		err := Unmarshal(test.Info, test.Data, v.Interface())
-		if err != nil {
-			t.Errorf("marshalTest[%d]: %v", i, err)
-			continue
-		}
-		if !reflect.DeepEqual(v.Elem().Interface(), test.Value) {
-			t.Errorf("marshalTest[%d]: expected %#v, got %#v.", i, test.Value, v.Elem().Interface())
+		if test.Value != nil {
+			v := reflect.New(reflect.TypeOf(test.Value))
+			err := Unmarshal(test.Info, test.Data, v.Interface())
+			if err != nil {
+				t.Errorf("unmarshalTest[%d]: %v", i, err)
+				continue
+			}
+			if !reflect.DeepEqual(v.Elem().Interface(), test.Value) {
+				t.Errorf("unmarshalTest[%d]: expected %#v, got %#v.", i, test.Value, v.Elem().Interface())
+			}
+		} else {
+			if err := Unmarshal(test.Info, test.Data, test.Value); nil == err {
+				t.Errorf("unmarshalTest[%d]: %#v not return error.", i, test.Value)
+			}
 		}
 	}
 }