Sfoglia il codice sorgente

use reflect2 for json.Marshaler

Tao Wen 7 anni fa
parent
commit
a7a34507ab

+ 2 - 14
feature_reflect.go

@@ -7,6 +7,7 @@ import (
 	"reflect"
 	"time"
 	"unsafe"
+	"github.com/v2pro/plz/reflect2"
 )
 
 // ValDecoder is an internal type registered to cache as needed.
@@ -331,21 +332,8 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
 	}
 	if typ.Implements(marshalerType) {
 		checkIsEmpty := createCheckIsEmpty(cfg, typ)
-		templateInterface := reflect.New(typ).Elem().Interface()
-		var encoder ValEncoder = &marshalerEncoder{
-			templateInterface: extractInterface(templateInterface),
-			checkIsEmpty:      checkIsEmpty,
-		}
-		if typ.Kind() == reflect.Ptr {
-			encoder = &OptionalEncoder{encoder}
-		}
-		return encoder
-	}
-	if reflect.PtrTo(typ).Implements(marshalerType) {
-		checkIsEmpty := createCheckIsEmpty(cfg, reflect.PtrTo(typ))
-		templateInterface := reflect.New(typ).Interface()
 		var encoder ValEncoder = &marshalerEncoder{
-			templateInterface: extractInterface(templateInterface),
+			valType: reflect2.Type2(typ),
 			checkIsEmpty:      checkIsEmpty,
 		}
 		return encoder

+ 4 - 0
feature_reflect_map.go

@@ -176,6 +176,10 @@ type sortKeysMapEncoder struct {
 
 func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
 	ptr = *(*unsafe.Pointer)(ptr)
+	if ptr == nil {
+		stream.WriteNil()
+		return
+	}
 	mapInterface := encoder.mapInterface
 	mapInterface.word = ptr
 	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))

+ 8 - 10
feature_reflect_native.go

@@ -6,6 +6,7 @@ import (
 	"encoding/json"
 	"reflect"
 	"unsafe"
+	"github.com/v2pro/plz/reflect2"
 )
 
 type stringCodec struct {
@@ -473,7 +474,7 @@ func (codec *base64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
 	case StringValue:
 		encoding := base64.StdEncoding
 		src := iter.SkipAndReturnBytes()
-		src = src[1 : len(src)-1]
+		src = src[1: len(src)-1]
 		decodedLen := encoding.DecodedLen(len(src))
 		dst := make([]byte, decodedLen)
 		len, err := encoding.Decode(dst, src)
@@ -578,20 +579,17 @@ func (encoder *stringModeStringEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 }
 
 type marshalerEncoder struct {
-	templateInterface emptyInterface
-	checkIsEmpty      checkIsEmpty
+	checkIsEmpty checkIsEmpty
+	valType      reflect2.Type
 }
 
 func (encoder *marshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
-	templateInterface := encoder.templateInterface
-	templateInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
-	marshaler, ok := (*realInterface).(json.Marshaler)
-	if !ok {
-		stream.WriteVal(nil)
+	obj := encoder.valType.UnsafeIndirect(ptr)
+	if obj == nil {
+		stream.WriteNil()
 		return
 	}
-
+	marshaler := obj.(json.Marshaler)
 	bytes, err := marshaler.MarshalJSON()
 	if err != nil {
 		stream.Error = err

+ 2 - 5
value_tests/struct_test.go

@@ -10,6 +10,7 @@ func init() {
 	var pString = func(val string) *string {
 		return &val
 	}
+	epoch := time.Unix(0, 0)
 	unmarshalCases = append(unmarshalCases, unmarshalCase{
 		ptr: (*struct {
 			Field interface{}
@@ -83,13 +84,9 @@ func init() {
 		struct {
 			F *float64
 		}{},
-		// TODO: fix this
-		//struct {
-		//	*time.Time
-		//}{},
 		struct {
 			*time.Time
-		}{&time.Time{}},
+		}{&epoch},
 		struct {
 			*StructVarious
 		}{&StructVarious{}},

+ 13 - 1
value_tests/value_test.go

@@ -6,6 +6,7 @@ import (
 	"encoding/json"
 	"github.com/stretchr/testify/require"
 	"github.com/json-iterator/go"
+	"fmt"
 )
 
 type unmarshalCase struct {
@@ -19,6 +20,10 @@ var marshalCases = []interface{}{
 	nil,
 }
 
+type selectedMarshalCase struct  {
+	marshalCase interface{}
+}
+
 func Test_unmarshal(t *testing.T) {
 	should := require.New(t)
 	for _, testCase := range unmarshalCases {
@@ -35,9 +40,16 @@ func Test_unmarshal(t *testing.T) {
 
 func Test_marshal(t *testing.T) {
 	for _, testCase := range marshalCases {
+		selectedMarshalCase, found := testCase.(selectedMarshalCase)
+		if found {
+			marshalCases = []interface{}{selectedMarshalCase.marshalCase}
+			break
+		}
+	}
+	for i, testCase := range marshalCases {
 		var name string
 		if testCase != nil {
-			name = reflect.TypeOf(testCase).String()
+			name = fmt.Sprintf("[%v]%v/%s", i, testCase, reflect.TypeOf(testCase).String())
 		}
 		t.Run(name, func(t *testing.T) {
 			should := require.New(t)