瀏覽代碼

fix marshaler support for iface case

Tao Wen 7 年之前
父節點
當前提交
43d9384d67
共有 6 個文件被更改,包括 246 次插入91 次删除
  1. 31 0
      feature_reflect.go
  2. 134 0
      feature_reflect_marshaler.go
  3. 0 83
      feature_reflect_native.go
  4. 12 0
      feature_reflect_optional.go
  5. 42 4
      type_tests/slice_test.go
  6. 27 4
      value_tests/marshaler_test.go

+ 31 - 0
feature_reflect.go

@@ -330,6 +330,13 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
 	if typ.AssignableTo(jsoniterNumberType) {
 	if typ.AssignableTo(jsoniterNumberType) {
 		return &jsoniterNumberCodec{}
 		return &jsoniterNumberCodec{}
 	}
 	}
+	if typ == marshalerType {
+		checkIsEmpty := createCheckIsEmpty(cfg, typ)
+		var encoder ValEncoder = &directMarshalerEncoder{
+			checkIsEmpty:      checkIsEmpty,
+		}
+		return encoder
+	}
 	if typ.Implements(marshalerType) {
 	if typ.Implements(marshalerType) {
 		checkIsEmpty := createCheckIsEmpty(cfg, typ)
 		checkIsEmpty := createCheckIsEmpty(cfg, typ)
 		var encoder ValEncoder = &marshalerEncoder{
 		var encoder ValEncoder = &marshalerEncoder{
@@ -338,6 +345,22 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
 		}
 		}
 		return encoder
 		return encoder
 	}
 	}
+	ptrType := reflect.PtrTo(typ)
+	if ptrType.Implements(marshalerType) {
+		checkIsEmpty := createCheckIsEmpty(cfg, ptrType)
+		var encoder ValEncoder = &marshalerEncoder{
+			valType: reflect2.Type2(ptrType),
+			checkIsEmpty:      checkIsEmpty,
+		}
+		return &referenceEncoder{encoder}
+	}
+	if typ == textMarshalerType {
+		checkIsEmpty := createCheckIsEmpty(cfg, typ)
+		var encoder ValEncoder = &directTextMarshalerEncoder{
+			checkIsEmpty:      checkIsEmpty,
+		}
+		return encoder
+	}
 	if typ.Implements(textMarshalerType) {
 	if typ.Implements(textMarshalerType) {
 		checkIsEmpty := createCheckIsEmpty(cfg, typ)
 		checkIsEmpty := createCheckIsEmpty(cfg, typ)
 		var encoder ValEncoder = &textMarshalerEncoder{
 		var encoder ValEncoder = &textMarshalerEncoder{
@@ -346,6 +369,14 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
 		}
 		}
 		return encoder
 		return encoder
 	}
 	}
+	if ptrType.Implements(textMarshalerType) {
+		checkIsEmpty := createCheckIsEmpty(cfg, ptrType)
+		var encoder ValEncoder = &textMarshalerEncoder{
+			valType: reflect2.Type2(ptrType),
+			checkIsEmpty:      checkIsEmpty,
+		}
+		return &referenceEncoder{encoder}
+	}
 	if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 {
 	if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 {
 		return &base64Codec{}
 		return &base64Codec{}
 	}
 	}

+ 134 - 0
feature_reflect_marshaler.go

@@ -0,0 +1,134 @@
+package jsoniter
+
+import (
+	"github.com/v2pro/plz/reflect2"
+	"unsafe"
+	"encoding"
+	"encoding/json"
+)
+
+type marshalerEncoder struct {
+	checkIsEmpty checkIsEmpty
+	valType      reflect2.Type
+}
+
+func (encoder *marshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
+	obj := encoder.valType.UnsafeIndirect(ptr)
+	if encoder.valType.IsNullable() && reflect2.IsNil(obj) {
+		stream.WriteNil()
+		return
+	}
+	marshaler := obj.(json.Marshaler)
+	bytes, err := marshaler.MarshalJSON()
+	if err != nil {
+		stream.Error = err
+	} else {
+		stream.Write(bytes)
+	}
+}
+
+func (encoder *marshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
+	return encoder.checkIsEmpty.IsEmpty(ptr)
+}
+
+type directMarshalerEncoder struct {
+	checkIsEmpty checkIsEmpty
+}
+
+func (encoder *directMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
+	marshaler := *(*json.Marshaler)(ptr)
+	if marshaler == nil {
+		stream.WriteNil()
+		return
+	}
+	bytes, err := marshaler.MarshalJSON()
+	if err != nil {
+		stream.Error = err
+	} else {
+		stream.Write(bytes)
+	}
+}
+
+func (encoder *directMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
+	return encoder.checkIsEmpty.IsEmpty(ptr)
+}
+
+type textMarshalerEncoder struct {
+	valType			reflect2.Type
+	checkIsEmpty      checkIsEmpty
+}
+
+func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
+	obj := encoder.valType.UnsafeIndirect(ptr)
+	if encoder.valType.IsNullable() && reflect2.IsNil(obj) {
+		stream.WriteNil()
+		return
+	}
+	marshaler := (obj).(encoding.TextMarshaler)
+	bytes, err := marshaler.MarshalText()
+	if err != nil {
+		stream.Error = err
+	} else {
+		stream.WriteString(string(bytes))
+	}
+}
+
+func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
+	return encoder.checkIsEmpty.IsEmpty(ptr)
+}
+
+type directTextMarshalerEncoder struct {
+	checkIsEmpty checkIsEmpty
+}
+
+func (encoder *directTextMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
+	marshaler := *(*encoding.TextMarshaler)(ptr)
+	if marshaler == nil {
+		stream.WriteNil()
+		return
+	}
+	bytes, err := marshaler.MarshalText()
+	if err != nil {
+		stream.Error = err
+	} else {
+		stream.WriteString(string(bytes))
+	}
+}
+
+func (encoder *directTextMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
+	return encoder.checkIsEmpty.IsEmpty(ptr)
+}
+
+type unmarshalerDecoder struct {
+	templateInterface emptyInterface
+}
+
+func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
+	templateInterface := decoder.templateInterface
+	templateInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
+	unmarshaler := (*realInterface).(json.Unmarshaler)
+	iter.nextToken()
+	iter.unreadByte() // skip spaces
+	bytes := iter.SkipAndReturnBytes()
+	err := unmarshaler.UnmarshalJSON(bytes)
+	if err != nil {
+		iter.ReportError("unmarshalerDecoder", err.Error())
+	}
+}
+
+type textUnmarshalerDecoder struct {
+	templateInterface emptyInterface
+}
+
+func (decoder *textUnmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
+	templateInterface := decoder.templateInterface
+	templateInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
+	unmarshaler := (*realInterface).(encoding.TextUnmarshaler)
+	str := iter.ReadString()
+	err := unmarshaler.UnmarshalText([]byte(str))
+	if err != nil {
+		iter.ReportError("textUnmarshalerDecoder", err.Error())
+	}
+}

+ 0 - 83
feature_reflect_native.go

@@ -1,7 +1,6 @@
 package jsoniter
 package jsoniter
 
 
 import (
 import (
-	"encoding"
 	"encoding/base64"
 	"encoding/base64"
 	"encoding/json"
 	"encoding/json"
 	"reflect"
 	"reflect"
@@ -591,85 +590,3 @@ func (encoder *stringModeStringEncoder) Encode(ptr unsafe.Pointer, stream *Strea
 func (encoder *stringModeStringEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 func (encoder *stringModeStringEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 	return encoder.elemEncoder.IsEmpty(ptr)
 	return encoder.elemEncoder.IsEmpty(ptr)
 }
 }
-
-type marshalerEncoder struct {
-	checkIsEmpty checkIsEmpty
-	valType      reflect2.Type
-}
-
-func (encoder *marshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
-	obj := encoder.valType.UnsafeIndirect(ptr)
-	if reflect2.IsNil(obj) {
-		stream.WriteNil()
-		return
-	}
-	marshaler := obj.(json.Marshaler)
-	bytes, err := marshaler.MarshalJSON()
-	if err != nil {
-		stream.Error = err
-	} else {
-		stream.Write(bytes)
-	}
-}
-
-func (encoder *marshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
-	return encoder.checkIsEmpty.IsEmpty(ptr)
-}
-
-type textMarshalerEncoder struct {
-	valType			reflect2.Type
-	checkIsEmpty      checkIsEmpty
-}
-
-func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
-	obj := encoder.valType.UnsafeIndirect(ptr)
-	if reflect2.IsNil(obj) {
-		stream.WriteNil()
-		return
-	}
-	marshaler := (obj).(encoding.TextMarshaler)
-	bytes, err := marshaler.MarshalText()
-	if err != nil {
-		stream.Error = err
-	} else {
-		stream.WriteString(string(bytes))
-	}
-}
-
-func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
-	return encoder.checkIsEmpty.IsEmpty(ptr)
-}
-
-type unmarshalerDecoder struct {
-	templateInterface emptyInterface
-}
-
-func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	templateInterface := decoder.templateInterface
-	templateInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
-	unmarshaler := (*realInterface).(json.Unmarshaler)
-	iter.nextToken()
-	iter.unreadByte() // skip spaces
-	bytes := iter.SkipAndReturnBytes()
-	err := unmarshaler.UnmarshalJSON(bytes)
-	if err != nil {
-		iter.ReportError("unmarshalerDecoder", err.Error())
-	}
-}
-
-type textUnmarshalerDecoder struct {
-	templateInterface emptyInterface
-}
-
-func (decoder *textUnmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	templateInterface := decoder.templateInterface
-	templateInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
-	unmarshaler := (*realInterface).(encoding.TextUnmarshaler)
-	str := iter.ReadString()
-	err := unmarshaler.UnmarshalText([]byte(str))
-	if err != nil {
-		iter.ReportError("textUnmarshalerDecoder", err.Error())
-	}
-}

+ 12 - 0
feature_reflect_optional.go

@@ -106,4 +106,16 @@ func (encoder *dereferenceEncoder) IsEmbeddedPtrNil(ptr unsafe.Pointer) bool {
 	}
 	}
 	fieldPtr := unsafe.Pointer(deReferenced)
 	fieldPtr := unsafe.Pointer(deReferenced)
 	return isEmbeddedPtrNil.IsEmbeddedPtrNil(fieldPtr)
 	return isEmbeddedPtrNil.IsEmbeddedPtrNil(fieldPtr)
+}
+
+type referenceEncoder struct {
+	encoder ValEncoder
+}
+
+func (encoder *referenceEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
+	encoder.encoder.Encode(unsafe.Pointer(&ptr), stream)
+}
+
+func (encoder *referenceEncoder) IsEmpty(ptr unsafe.Pointer) bool {
+	return encoder.encoder.IsEmpty(unsafe.Pointer(&ptr))
 }
 }

+ 42 - 4
type_tests/slice_test.go

@@ -71,19 +71,57 @@ func init() {
 			Map   map[string]string
 			Map   map[string]string
 		})(nil),
 		})(nil),
 		(*[]uint8)(nil),
 		(*[]uint8)(nil),
-		(*[]GeoLocation)(nil),
+		(*[]jsonMarshaler)(nil),
+		(*[]jsonMarshalerMap)(nil),
+		(*[]textMarshaler)(nil),
+		(*[]textMarshalerMap)(nil),
 	)
 	)
 }
 }
 
 
-type GeoLocation struct {
+type jsonMarshaler struct {
 	Id string `json:"id,omitempty" db:"id"`
 	Id string `json:"id,omitempty" db:"id"`
 }
 }
 
 
-func (p *GeoLocation) MarshalJSON() ([]byte, error) {
+func (p *jsonMarshaler) MarshalJSON() ([]byte, error) {
 	return []byte(`{}`), nil
 	return []byte(`{}`), nil
 }
 }
 
 
-func (p *GeoLocation) UnmarshalJSON(input []byte) error {
+func (p *jsonMarshaler) UnmarshalJSON(input []byte) error {
 	p.Id = "hello"
 	p.Id = "hello"
 	return nil
 	return nil
+}
+
+
+type jsonMarshalerMap map[int]int
+
+func (p *jsonMarshalerMap) MarshalJSON() ([]byte, error) {
+	return []byte(`{}`), nil
+}
+
+func (p *jsonMarshalerMap) UnmarshalJSON(input []byte) error {
+	return nil
+}
+
+type textMarshaler struct {
+	Id string `json:"id,omitempty" db:"id"`
+}
+
+func (p *textMarshaler) MarshalText() ([]byte, error) {
+	return []byte(`{}`), nil
+}
+
+func (p *textMarshaler) UnmarshalText(input []byte) error {
+	p.Id = "hello"
+	return nil
+}
+
+type textMarshalerMap map[int]int
+
+
+func (p *textMarshalerMap) MarshalText() ([]byte, error) {
+	return []byte(`{}`), nil
+}
+
+func (p *textMarshalerMap) UnmarshalText(input []byte) error {
+	return nil
 }
 }

+ 27 - 4
value_tests/marshaler_test.go

@@ -1,19 +1,42 @@
 package test
 package test
 
 
+import (
+	"encoding/json"
+	"encoding"
+)
+
 func init() {
 func init() {
+	jsonMarshaler := json.Marshaler(fakeJsonMarshaler{})
+	textMarshaler := encoding.TextMarshaler(fakeTextMarshaler{})
 	marshalCases = append(marshalCases,
 	marshalCases = append(marshalCases,
-		withChan{},
+		fakeJsonMarshaler{},
+		&jsonMarshaler,
+		fakeTextMarshaler{},
+		&textMarshaler,
 	)
 	)
 }
 }
 
 
-type withChan struct {
+type fakeJsonMarshaler struct {
+	F2 chan []byte
+}
+
+func (q fakeJsonMarshaler) MarshalJSON() ([]byte, error) {
+	return []byte(`""`), nil
+}
+
+func (q *fakeJsonMarshaler) UnmarshalJSON(value []byte) error {
+	return nil
+}
+
+
+type fakeTextMarshaler struct {
 	F2 chan []byte
 	F2 chan []byte
 }
 }
 
 
-func (q withChan) MarshalJSON() ([]byte, error) {
+func (q fakeTextMarshaler) MarshalText() ([]byte, error) {
 	return []byte(`""`), nil
 	return []byte(`""`), nil
 }
 }
 
 
-func (q *withChan) UnmarshalJSON(value []byte) error {
+func (q *fakeTextMarshaler) UnmarshalText(value []byte) error {
 	return nil
 	return nil
 }
 }