Explorar o código

Merge pull request #172 from olegshaldybin/more-stdlib-compat

Improve stdlib compatibility
Tao Wen %!s(int64=8) %!d(string=hai) anos
pai
achega
3c298d8a76
Modificáronse 3 ficheiros con 210 adicións e 57 borrados
  1. 57 56
      feature_reflect_native.go
  2. 121 0
      jsoniter_interface_test.go
  3. 32 1
      jsoniter_null_test.go

+ 57 - 56
feature_reflect_native.go

@@ -32,11 +32,9 @@ type intCodec struct {
 }
 
 func (codec *intCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*int)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*int)(ptr)) = iter.ReadInt()
 	}
-	*((*int)(ptr)) = iter.ReadInt()
 }
 
 func (codec *intCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -55,11 +53,9 @@ type uintptrCodec struct {
 }
 
 func (codec *uintptrCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*uintptr)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*uintptr)(ptr)) = uintptr(iter.ReadUint64())
 	}
-	*((*uintptr)(ptr)) = uintptr(iter.ReadUint64())
 }
 
 func (codec *uintptrCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -78,11 +74,9 @@ type int8Codec struct {
 }
 
 func (codec *int8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*uint8)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*int8)(ptr)) = iter.ReadInt8()
 	}
-	*((*int8)(ptr)) = iter.ReadInt8()
 }
 
 func (codec *int8Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -101,11 +95,9 @@ type int16Codec struct {
 }
 
 func (codec *int16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*int16)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*int16)(ptr)) = iter.ReadInt16()
 	}
-	*((*int16)(ptr)) = iter.ReadInt16()
 }
 
 func (codec *int16Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -124,11 +116,9 @@ type int32Codec struct {
 }
 
 func (codec *int32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*int32)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*int32)(ptr)) = iter.ReadInt32()
 	}
-	*((*int32)(ptr)) = iter.ReadInt32()
 }
 
 func (codec *int32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -147,11 +137,9 @@ type int64Codec struct {
 }
 
 func (codec *int64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*int64)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*int64)(ptr)) = iter.ReadInt64()
 	}
-	*((*int64)(ptr)) = iter.ReadInt64()
 }
 
 func (codec *int64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -170,11 +158,10 @@ type uintCodec struct {
 }
 
 func (codec *uintCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*uint)(ptr)) = 0
+	if !iter.ReadNil() {
+		*((*uint)(ptr)) = iter.ReadUint()
 		return
 	}
-	*((*uint)(ptr)) = iter.ReadUint()
 }
 
 func (codec *uintCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -193,11 +180,9 @@ type uint8Codec struct {
 }
 
 func (codec *uint8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*uint8)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*uint8)(ptr)) = iter.ReadUint8()
 	}
-	*((*uint8)(ptr)) = iter.ReadUint8()
 }
 
 func (codec *uint8Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -216,11 +201,9 @@ type uint16Codec struct {
 }
 
 func (codec *uint16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*uint16)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*uint16)(ptr)) = iter.ReadUint16()
 	}
-	*((*uint16)(ptr)) = iter.ReadUint16()
 }
 
 func (codec *uint16Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -239,11 +222,9 @@ type uint32Codec struct {
 }
 
 func (codec *uint32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*uint32)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*uint32)(ptr)) = iter.ReadUint32()
 	}
-	*((*uint32)(ptr)) = iter.ReadUint32()
 }
 
 func (codec *uint32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -262,11 +243,9 @@ type uint64Codec struct {
 }
 
 func (codec *uint64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*uint64)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*uint64)(ptr)) = iter.ReadUint64()
 	}
-	*((*uint64)(ptr)) = iter.ReadUint64()
 }
 
 func (codec *uint64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -285,11 +264,9 @@ type float32Codec struct {
 }
 
 func (codec *float32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*float32)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*float32)(ptr)) = iter.ReadFloat32()
 	}
-	*((*float32)(ptr)) = iter.ReadFloat32()
 }
 
 func (codec *float32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -308,11 +285,9 @@ type float64Codec struct {
 }
 
 func (codec *float64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*float64)(ptr)) = 0
-		return
+	if !iter.ReadNil() {
+		*((*float64)(ptr)) = iter.ReadFloat64()
 	}
-	*((*float64)(ptr)) = iter.ReadFloat64()
 }
 
 func (codec *float64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -352,13 +327,39 @@ type emptyInterfaceCodec struct {
 }
 
 func (codec *emptyInterfaceCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	if iter.ReadNil() {
-		*((*interface{})(ptr)) = nil
+	existing := *((*interface{})(ptr))
+
+	// Checking for both typed and untyped nil pointers.
+	if existing != nil &&
+		reflect.TypeOf(existing).Kind() == reflect.Ptr &&
+		!reflect.ValueOf(existing).IsNil() {
+
+		var ptrToExisting interface{}
+		for {
+			elem := reflect.ValueOf(existing).Elem()
+			if elem.Kind() != reflect.Ptr || elem.IsNil() {
+				break
+			}
+			ptrToExisting = existing
+			existing = elem.Interface()
+		}
+
+		if iter.ReadNil() {
+			if ptrToExisting != nil {
+				nilPtr := reflect.Zero(reflect.TypeOf(ptrToExisting).Elem())
+				reflect.ValueOf(ptrToExisting).Elem().Set(nilPtr)
+			} else {
+				*((*interface{})(ptr)) = nil
+			}
+		} else {
+			iter.ReadVal(existing)
+		}
+
 		return
 	}
-	existing := *((*interface{})(ptr))
-	if existing != nil && reflect.TypeOf(existing).Kind() == reflect.Ptr {
-		iter.ReadVal(existing)
+
+	if iter.ReadNil() {
+		*((*interface{})(ptr)) = nil
 	} else {
 		*((*interface{})(ptr)) = iter.Read()
 	}

+ 121 - 0
jsoniter_interface_test.go

@@ -437,3 +437,124 @@ func Test_marshal_nil_nonempty_interface(t *testing.T) {
 	should.NoError(err)
 	should.Equal(nil, obj.Field)
 }
+
+func Test_overwrite_interface_ptr_value_with_nil(t *testing.T) {
+	type Wrapper struct {
+		Payload interface{} `json:"payload,omitempty"`
+	}
+	type Payload struct {
+		Value int `json:"val,omitempty"`
+	}
+
+	should := require.New(t)
+
+	payload := &Payload{}
+	wrapper := &Wrapper{
+		Payload: &payload,
+	}
+
+	err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
+	should.Equal(nil, err)
+	should.Equal(&payload, wrapper.Payload)
+	should.Equal(42, (*(wrapper.Payload.(**Payload))).Value)
+
+	err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
+	should.Equal(nil, err)
+	should.Equal(&payload, wrapper.Payload)
+	should.Equal((*Payload)(nil), payload)
+
+	payload = &Payload{}
+	wrapper = &Wrapper{
+		Payload: &payload,
+	}
+
+	err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
+	should.Equal(nil, err)
+	should.Equal(&payload, wrapper.Payload)
+	should.Equal(42, (*(wrapper.Payload.(**Payload))).Value)
+
+	err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
+	should.Equal(nil, err)
+	should.Equal(&payload, wrapper.Payload)
+	should.Equal((*Payload)(nil), payload)
+}
+
+func Test_overwrite_interface_value_with_nil(t *testing.T) {
+	type Wrapper struct {
+		Payload interface{} `json:"payload,omitempty"`
+	}
+	type Payload struct {
+		Value int `json:"val,omitempty"`
+	}
+
+	should := require.New(t)
+
+	payload := &Payload{}
+	wrapper := &Wrapper{
+		Payload: payload,
+	}
+
+	err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
+	should.Equal(nil, err)
+	should.Equal(42, (*(wrapper.Payload.(*Payload))).Value)
+
+	err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
+	should.Equal(nil, err)
+	should.Equal(nil, wrapper.Payload)
+	should.Equal(42, payload.Value)
+
+	payload = &Payload{}
+	wrapper = &Wrapper{
+		Payload: payload,
+	}
+
+	err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
+	should.Equal(nil, err)
+	should.Equal(42, (*(wrapper.Payload.(*Payload))).Value)
+
+	err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
+	should.Equal(nil, err)
+	should.Equal(nil, wrapper.Payload)
+	should.Equal(42, payload.Value)
+}
+
+func Test_unmarshal_into_nil(t *testing.T) {
+	type Payload struct {
+		Value int `json:"val,omitempty"`
+	}
+	type Wrapper struct {
+		Payload interface{} `json:"payload,omitempty"`
+	}
+
+	should := require.New(t)
+
+	var payload *Payload
+	wrapper := &Wrapper{
+		Payload: payload,
+	}
+
+	err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
+	should.Nil(err)
+	should.NotNil(wrapper.Payload)
+	should.Nil(payload)
+
+	err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
+	should.Nil(err)
+	should.Nil(wrapper.Payload)
+	should.Nil(payload)
+
+	payload = nil
+	wrapper = &Wrapper{
+		Payload: payload,
+	}
+
+	err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
+	should.Nil(err)
+	should.NotNil(wrapper.Payload)
+	should.Nil(payload)
+
+	err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
+	should.Nil(err)
+	should.Nil(wrapper.Payload)
+	should.Nil(payload)
+}

+ 32 - 1
jsoniter_null_test.go

@@ -3,9 +3,10 @@ package jsoniter
 import (
 	"bytes"
 	"encoding/json"
-	"github.com/stretchr/testify/require"
 	"io"
 	"testing"
+
+	"github.com/stretchr/testify/require"
 )
 
 func Test_read_null(t *testing.T) {
@@ -135,3 +136,33 @@ func Test_encode_nil_array(t *testing.T) {
 	should.Nil(err)
 	should.Equal("null", string(output))
 }
+
+func Test_decode_nil_num(t *testing.T) {
+	type TestData struct {
+		Field int `json:"field"`
+	}
+	should := require.New(t)
+
+	data1 := []byte(`{"field": 42}`)
+	data2 := []byte(`{"field": null}`)
+
+	// Checking stdlib behavior as well
+	obj2 := TestData{}
+	err := json.Unmarshal(data1, &obj2)
+	should.Equal(nil, err)
+	should.Equal(42, obj2.Field)
+
+	err = json.Unmarshal(data2, &obj2)
+	should.Equal(nil, err)
+	should.Equal(42, obj2.Field)
+
+	obj := TestData{}
+
+	err = Unmarshal(data1, &obj)
+	should.Equal(nil, err)
+	should.Equal(42, obj.Field)
+
+	err = Unmarshal(data2, &obj)
+	should.Equal(nil, err)
+	should.Equal(42, obj.Field)
+}