Selaa lähdekoodia

Merge pull request #283 from probkiizokna/marshal_nil

Marshal pointer values
Phillip Couto 11 vuotta sitten
vanhempi
commit
6c91b07893
3 muutettua tiedostoa jossa 175 lisäystä ja 30 poistoa
  1. 1 1
      AUTHORS
  2. 11 22
      marshal.go
  3. 163 7
      marshal_test.go

+ 1 - 1
AUTHORS

@@ -37,4 +37,4 @@ Jeremy Schlatter <jeremy.schlatter@gmail.com>
 Matthias Kadenbach <matthias.kadenbach@gmail.com>
 Matthias Kadenbach <matthias.kadenbach@gmail.com>
 Dean Elbaz <elbaz.dean@gmail.com>
 Dean Elbaz <elbaz.dean@gmail.com>
 Mike Berman <evencode@gmail.com>
 Mike Berman <evencode@gmail.com>
-
+Dmitriy Fedorenko <c0va23@gmail.com>

+ 11 - 22
marshal.go

@@ -40,6 +40,14 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
+	if valueRef := reflect.ValueOf(value); valueRef.Kind() == reflect.Ptr {
+		if valueRef.IsNil() {
+			return nil, nil
+		} else {
+			return Marshal(info, valueRef.Elem().Interface())
+		}
+	}
+
 	if v, ok := value.(Marshaler); ok {
 	if v, ok := value.(Marshaler); ok {
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
 	}
 	}
@@ -133,8 +141,6 @@ func marshalVarchar(info *TypeInfo, value interface{}) ([]byte, error) {
 		return []byte(rv.String()), nil
 		return []byte(rv.String()), nil
 	case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8:
 	case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8:
 		return rv.Bytes(), nil
 		return rv.Bytes(), nil
-	case k == reflect.Ptr:
-		return marshalVarchar(info, rv.Elem().Interface())
 	}
 	}
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
@@ -232,8 +238,6 @@ func marshalInt(info *TypeInfo, value interface{}) ([]byte, error) {
 			return nil, marshalErrorf("marshal int: value %d out of range", v)
 			return nil, marshalErrorf("marshal int: value %d out of range", v)
 		}
 		}
 		return encInt(int32(v)), nil
 		return encInt(int32(v)), nil
-	case reflect.Ptr:
-		return marshalInt(info, rv.Elem().Interface())
 	}
 	}
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
@@ -279,8 +283,8 @@ func marshalBigInt(info *TypeInfo, value interface{}) ([]byte, error) {
 		return encBigInt(int64(v)), nil
 		return encBigInt(int64(v)), nil
 	case uint8:
 	case uint8:
 		return encBigInt(int64(v)), nil
 		return encBigInt(int64(v)), nil
-	case *big.Int:
-		return encBigInt2C(v), nil
+	case big.Int:
+		return encBigInt2C(&v), nil
 	}
 	}
 	rv := reflect.ValueOf(value)
 	rv := reflect.ValueOf(value)
 	switch rv.Type().Kind() {
 	switch rv.Type().Kind() {
@@ -293,8 +297,6 @@ func marshalBigInt(info *TypeInfo, value interface{}) ([]byte, error) {
 			return nil, marshalErrorf("marshal bigint: value %d out of range", v)
 			return nil, marshalErrorf("marshal bigint: value %d out of range", v)
 		}
 		}
 		return encBigInt(int64(v)), nil
 		return encBigInt(int64(v)), nil
-	case reflect.Ptr:
-		return marshalBigInt(info, rv.Elem().Interface())
 	}
 	}
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
@@ -545,8 +547,6 @@ func marshalBool(info *TypeInfo, value interface{}) ([]byte, error) {
 	switch rv.Type().Kind() {
 	switch rv.Type().Kind() {
 	case reflect.Bool:
 	case reflect.Bool:
 		return encBool(rv.Bool()), nil
 		return encBool(rv.Bool()), nil
-	case reflect.Ptr:
-		return marshalBool(info, rv.Elem().Interface())
 	}
 	}
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
@@ -597,8 +597,6 @@ func marshalFloat(info *TypeInfo, value interface{}) ([]byte, error) {
 	switch rv.Type().Kind() {
 	switch rv.Type().Kind() {
 	case reflect.Float32:
 	case reflect.Float32:
 		return encInt(int32(math.Float32bits(float32(rv.Float())))), nil
 		return encInt(int32(math.Float32bits(float32(rv.Float())))), nil
-	case reflect.Ptr:
-		return marshalFloat(info, rv.Elem().Interface())
 	}
 	}
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
@@ -635,8 +633,6 @@ func marshalDouble(info *TypeInfo, value interface{}) ([]byte, error) {
 	switch rv.Type().Kind() {
 	switch rv.Type().Kind() {
 	case reflect.Float64:
 	case reflect.Float64:
 		return encBigInt(int64(math.Float64bits(rv.Float()))), nil
 		return encBigInt(int64(math.Float64bits(rv.Float()))), nil
-	case reflect.Ptr:
-		return marshalDouble(info, rv.Elem().Interface())
 	}
 	}
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
@@ -666,12 +662,7 @@ func marshalDecimal(info *TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Marshaler:
 	case Marshaler:
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
-	case *inf.Dec:
-
-		if v == nil {
-			return nil, nil
-		}
-
+	case inf.Dec:
 		unscaled := encBigInt2C(v.UnscaledBig())
 		unscaled := encBigInt2C(v.UnscaledBig())
 		if unscaled == nil {
 		if unscaled == nil {
 			return nil, marshalErrorf("can not marshal %T into %s", value, info)
 			return nil, marshalErrorf("can not marshal %T into %s", value, info)
@@ -759,8 +750,6 @@ func marshalTimestamp(info *TypeInfo, value interface{}) ([]byte, error) {
 	switch rv.Type().Kind() {
 	switch rv.Type().Kind() {
 	case reflect.Int64:
 	case reflect.Int64:
 		return encBigInt(rv.Int()), nil
 		return encBigInt(rv.Int()), nil
-	case reflect.Ptr:
-		return marshalTimestamp(info, rv.Elem().Interface())
 	}
 	}
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }

+ 163 - 7
marshal_test.go

@@ -317,6 +317,159 @@ var marshalTests = []struct {
 	},
 	},
 }
 }
 
 
+var marshalNilTests = []struct {
+	Info  *TypeInfo
+	Data  []byte
+	Value interface{}
+}{
+	{
+		&TypeInfo{Type: TypeInt},
+		[]byte(nil),
+		nil,
+	},
+	{
+		&TypeInfo{Type: TypeVarchar},
+		[]byte("nullable string"),
+		func() *string {
+			value := "nullable string"
+			return &value
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeVarchar},
+		[]byte{},
+		(*string)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeInt},
+		[]byte("\x7f\xff\xff\xff"),
+		func() *int {
+			var value int = math.MaxInt32
+			return &value
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeInt},
+		[]byte(nil),
+		(*int)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeTimeUUID},
+		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
+		&UUID{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
+	},
+	{
+		&TypeInfo{Type: TypeTimeUUID},
+		[]byte{},
+		(*UUID)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeTimestamp},
+		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
+		func() *time.Time {
+			t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC)
+			return &t
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeTimestamp},
+		[]byte(nil),
+		(*time.Time)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeBoolean},
+		[]byte("\x00"),
+		func() *bool {
+			b := false
+			return &b
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeBoolean},
+		[]byte("\x01"),
+		func() *bool {
+			b := true
+			return &b
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeBoolean},
+		[]byte(nil),
+		(*bool)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeFloat},
+		[]byte("\x40\x49\x0f\xdb"),
+		func() *float32 {
+			f := float32(3.14159265)
+			return &f
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeFloat},
+		[]byte(nil),
+		(*float32)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeDouble},
+		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
+		func() *float64 {
+			d := float64(3.14159265)
+			return &d
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeDouble},
+		[]byte(nil),
+		(*float64)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x7F\x00\x00\x01"),
+		func() *net.IP {
+			ip := net.ParseIP("127.0.0.1").To4()
+			return &ip
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte(nil),
+		(*net.IP)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
+		func() *[]int {
+			l := []int{1, 2}
+			return &l
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		[]byte(nil),
+		(*[]int)(nil),
+	},
+	{
+		&TypeInfo{Type: TypeMap,
+			Key:  &TypeInfo{Type: TypeVarchar},
+			Elem: &TypeInfo{Type: TypeInt},
+		},
+		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
+		func() *map[string]int {
+			m := map[string]int{"foo": 1}
+			return &m
+		}(),
+	},
+	{
+		&TypeInfo{Type: TypeMap,
+			Key:  &TypeInfo{Type: TypeVarchar},
+			Elem: &TypeInfo{Type: TypeInt},
+		},
+		[]byte(nil),
+		(*map[string]int)(nil),
+	},
+}
+
 func decimalize(s string) *inf.Dec {
 func decimalize(s string) *inf.Dec {
 	i, _ := new(inf.Dec).SetString(s)
 	i, _ := new(inf.Dec).SetString(s)
 	return i
 	return i
@@ -335,18 +488,21 @@ func TestMarshal(t *testing.T) {
 			continue
 			continue
 		}
 		}
 		if !bytes.Equal(data, test.Data) {
 		if !bytes.Equal(data, test.Data) {
-			t.Errorf("marshalTest[%d]: expected %q, got %q.", i, test.Data, data)
+			t.Errorf("marshalTest[%d]: expected %q, got %q (%#v)", i, test.Data, data, test.Value)
 		}
 		}
 	}
 	}
 }
 }
 
 
 func TestMarshalNil(t *testing.T) {
 func TestMarshalNil(t *testing.T) {
-	data, err := Marshal(&TypeInfo{Type: TypeInt}, nil)
-	if err != nil {
-		t.Errorf("failed to marshal nil with err: %v", err)
-	}
-	if data != nil {
-		t.Errorf("expected nil, got %v", data)
+	for i, test := range marshalNilTests {
+		data, err := Marshal(test.Info, test.Value)
+		if err != nil {
+			t.Errorf("marshalNilTest[%d]: %v", i, err)
+			continue
+		}
+		if !bytes.Equal(data, test.Data) {
+			t.Errorf("marshalNilTest[%d]: expected %q, got %q.(%#v)", i, test.Data, data, test.Value)
+		}
 	}
 	}
 }
 }