Преглед изворни кода

#27 support json.Marshaler

Tao Wen пре 8 година
родитељ
комит
f6f159e108
4 измењених фајлова са 176 додато и 103 уклоњено
  1. 15 103
      feature_reflect.go
  2. 105 0
      feature_reflect_map.go
  3. 34 0
      feature_reflect_native.go
  4. 22 0
      jsoniter_customize_test.go

+ 15 - 103
feature_reflect.go

@@ -6,6 +6,7 @@ import (
 	"sync/atomic"
 	"unsafe"
 	"errors"
+	"encoding/json"
 )
 
 /*
@@ -73,6 +74,7 @@ var typeEncoders map[string]Encoder
 var fieldEncoders map[string]Encoder
 var extensions []ExtensionFunc
 var anyType reflect.Type
+var marshalerType reflect.Type
 
 func init() {
 	typeDecoders = map[string]Decoder{}
@@ -83,6 +85,7 @@ func init() {
 	atomic.StorePointer(&DECODERS, unsafe.Pointer(&map[string]Decoder{}))
 	atomic.StorePointer(&ENCODERS, unsafe.Pointer(&map[string]Encoder{}))
 	anyType = reflect.TypeOf((*Any)(nil)).Elem()
+	marshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
 }
 
 func addDecoderToCache(cacheKey reflect.Type, decoder Decoder) {
@@ -228,105 +231,6 @@ func (decoder *placeholderDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 	decoder.valueDecoder.decode(ptr, iter)
 }
 
-type mapDecoder struct {
-	mapType      reflect.Type
-	elemType     reflect.Type
-	elemDecoder  Decoder
-	mapInterface emptyInterface
-}
-
-func (decoder *mapDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
-	// dark magic to cast unsafe.Pointer back to interface{} using reflect.Type
-	mapInterface := decoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface).Elem()
-	if realVal.IsNil() {
-		realVal.Set(reflect.MakeMap(realVal.Type()))
-	}
-	for field := iter.ReadObject(); field != ""; field = iter.ReadObject() {
-		elem := reflect.New(decoder.elemType)
-		decoder.elemDecoder.decode(unsafe.Pointer(elem.Pointer()), iter)
-		// to put into map, we have to use reflection
-		realVal.SetMapIndex(reflect.ValueOf(string([]byte(field))), elem.Elem())
-	}
-}
-
-type mapEncoder struct {
-	mapType      reflect.Type
-	elemType     reflect.Type
-	elemEncoder  Encoder
-	mapInterface emptyInterface
-}
-
-func (encoder *mapEncoder) encode(ptr unsafe.Pointer, stream *Stream) {
-	mapInterface := encoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface)
-
-	stream.WriteObjectStart()
-	for i, key := range realVal.MapKeys() {
-		if i != 0 {
-			stream.WriteMore()
-		}
-		stream.WriteObjectField(key.String())
-		val := realVal.MapIndex(key).Interface()
-		encoder.elemEncoder.encodeInterface(val, stream)
-	}
-	stream.WriteObjectEnd()
-}
-
-func (encoder *mapEncoder) encodeInterface(val interface{}, stream *Stream) {
-	writeToStream(val, stream, encoder)
-}
-
-func (encoder *mapEncoder) isEmpty(ptr unsafe.Pointer) bool {
-	mapInterface := encoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface)
-	return realVal.Len() == 0
-}
-
-type mapInterfaceEncoder struct {
-	mapType      reflect.Type
-	elemType     reflect.Type
-	elemEncoder  Encoder
-	mapInterface emptyInterface
-}
-
-func (encoder *mapInterfaceEncoder) encode(ptr unsafe.Pointer, stream *Stream) {
-	mapInterface := encoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface)
-
-	stream.WriteObjectStart()
-	for i, key := range realVal.MapKeys() {
-		if i != 0 {
-			stream.WriteMore()
-		}
-		stream.WriteObjectField(key.String())
-		val := realVal.MapIndex(key).Interface()
-		encoder.elemEncoder.encode(unsafe.Pointer(&val), stream)
-	}
-	stream.WriteObjectEnd()
-}
-
-func (encoder *mapInterfaceEncoder) encodeInterface(val interface{}, stream *Stream) {
-	writeToStream(val, stream, encoder)
-}
-
-func (encoder *mapInterfaceEncoder) isEmpty(ptr unsafe.Pointer) bool {
-	mapInterface := encoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface)
-
-	return realVal.Len() == 0
-}
-
 // emptyInterface is the header for an interface{} value.
 type emptyInterface struct {
 	typ  *struct{}
@@ -460,9 +364,6 @@ func createDecoderOfType(typ reflect.Type) (Decoder, error) {
 }
 
 func encoderOfType(typ reflect.Type) (Encoder, error) {
-	if typ.ConvertibleTo(anyType) {
-		return &anyCodec{}, nil
-	}
 	typeName := typ.String()
 	typeEncoder := typeEncoders[typeName]
 	if typeEncoder != nil {
@@ -482,6 +383,13 @@ func encoderOfType(typ reflect.Type) (Encoder, error) {
 }
 
 func createEncoderOfType(typ reflect.Type) (Encoder, error) {
+	if typ.ConvertibleTo(anyType) {
+		return &anyCodec{}, nil
+	}
+	if typ.ConvertibleTo(marshalerType) {
+		templateInterface := reflect.New(typ).Elem().Interface()
+		return &marshalerEncoder{extractInterface(templateInterface)}, nil
+	}
 	switch typ.Kind() {
 	case reflect.String:
 		return &stringCodec{}, nil
@@ -550,7 +458,11 @@ func decoderOfMap(typ reflect.Type) (Decoder, error) {
 		return nil, err
 	}
 	mapInterface := reflect.New(typ).Interface()
-	return &mapDecoder{typ, typ.Elem(), decoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))}, nil
+	return &mapDecoder{typ, typ.Elem(), decoder, extractInterface(mapInterface)}, nil
+}
+
+func extractInterface(val interface{}) emptyInterface {
+	return *((*emptyInterface)(unsafe.Pointer(&val)))
 }
 
 func encoderOfMap(typ reflect.Type) (Encoder, error) {

+ 105 - 0
feature_reflect_map.go

@@ -0,0 +1,105 @@
+package jsoniter
+
+import (
+	"unsafe"
+	"reflect"
+)
+
+type mapDecoder struct {
+	mapType      reflect.Type
+	elemType     reflect.Type
+	elemDecoder  Decoder
+	mapInterface emptyInterface
+}
+
+func (decoder *mapDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
+	// dark magic to cast unsafe.Pointer back to interface{} using reflect.Type
+	mapInterface := decoder.mapInterface
+	mapInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
+	realVal := reflect.ValueOf(*realInterface).Elem()
+	if realVal.IsNil() {
+		realVal.Set(reflect.MakeMap(realVal.Type()))
+	}
+	for field := iter.ReadObject(); field != ""; field = iter.ReadObject() {
+		elem := reflect.New(decoder.elemType)
+		decoder.elemDecoder.decode(unsafe.Pointer(elem.Pointer()), iter)
+		// to put into map, we have to use reflection
+		realVal.SetMapIndex(reflect.ValueOf(string([]byte(field))), elem.Elem())
+	}
+}
+
+type mapEncoder struct {
+	mapType      reflect.Type
+	elemType     reflect.Type
+	elemEncoder  Encoder
+	mapInterface emptyInterface
+}
+
+func (encoder *mapEncoder) encode(ptr unsafe.Pointer, stream *Stream) {
+	mapInterface := encoder.mapInterface
+	mapInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
+	realVal := reflect.ValueOf(*realInterface)
+
+	stream.WriteObjectStart()
+	for i, key := range realVal.MapKeys() {
+		if i != 0 {
+			stream.WriteMore()
+		}
+		stream.WriteObjectField(key.String())
+		val := realVal.MapIndex(key).Interface()
+		encoder.elemEncoder.encodeInterface(val, stream)
+	}
+	stream.WriteObjectEnd()
+}
+
+func (encoder *mapEncoder) encodeInterface(val interface{}, stream *Stream) {
+	writeToStream(val, stream, encoder)
+}
+
+func (encoder *mapEncoder) isEmpty(ptr unsafe.Pointer) bool {
+	mapInterface := encoder.mapInterface
+	mapInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
+	realVal := reflect.ValueOf(*realInterface)
+	return realVal.Len() == 0
+}
+
+type mapInterfaceEncoder struct {
+	mapType      reflect.Type
+	elemType     reflect.Type
+	elemEncoder  Encoder
+	mapInterface emptyInterface
+}
+
+func (encoder *mapInterfaceEncoder) encode(ptr unsafe.Pointer, stream *Stream) {
+	mapInterface := encoder.mapInterface
+	mapInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
+	realVal := reflect.ValueOf(*realInterface)
+
+	stream.WriteObjectStart()
+	for i, key := range realVal.MapKeys() {
+		if i != 0 {
+			stream.WriteMore()
+		}
+		stream.WriteObjectField(key.String())
+		val := realVal.MapIndex(key).Interface()
+		encoder.elemEncoder.encode(unsafe.Pointer(&val), stream)
+	}
+	stream.WriteObjectEnd()
+}
+
+func (encoder *mapInterfaceEncoder) encodeInterface(val interface{}, stream *Stream) {
+	writeToStream(val, stream, encoder)
+}
+
+func (encoder *mapInterfaceEncoder) isEmpty(ptr unsafe.Pointer) bool {
+	mapInterface := encoder.mapInterface
+	mapInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
+	realVal := reflect.ValueOf(*realInterface)
+
+	return realVal.Len() == 0
+}

+ 34 - 0
feature_reflect_native.go

@@ -2,6 +2,7 @@ package jsoniter
 
 import (
 	"unsafe"
+	"encoding/json"
 )
 
 type stringCodec struct {
@@ -328,4 +329,37 @@ func (decoder *stringNumberDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 		iter.reportError("stringNumberDecoder", `expect "`)
 		return
 	}
+}
+
+type marshalerEncoder struct {
+	templateInterface emptyInterface
+}
+
+func (encoder *marshalerEncoder) encode(ptr unsafe.Pointer, stream *Stream) {
+	templateInterface := encoder.templateInterface
+	templateInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
+	marshaler := (*realInterface).(json.Marshaler)
+	bytes, err := marshaler.MarshalJSON()
+	if err != nil {
+		stream.Error = err
+	} else {
+		stream.Write(bytes)
+	}
+}
+func (encoder *marshalerEncoder) encodeInterface(val interface{}, stream *Stream) {
+	writeToStream(val, stream, encoder)
+}
+
+func (encoder *marshalerEncoder) isEmpty(ptr unsafe.Pointer) bool {
+	templateInterface := encoder.templateInterface
+	templateInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
+	marshaler := (*realInterface).(json.Marshaler)
+	bytes, err := marshaler.MarshalJSON()
+	if err != nil {
+		return true
+	} else {
+		return len(bytes) > 0
+	}
 }

+ 22 - 0
jsoniter_customize_test.go

@@ -7,6 +7,7 @@ import (
 	"time"
 	"unsafe"
 	"github.com/json-iterator/go/require"
+	"encoding/json"
 )
 
 func Test_customize_type_decoder(t *testing.T) {
@@ -127,4 +128,25 @@ func Test_unexported_fields(t *testing.T) {
 	str, err := MarshalToString(obj)
 	should.Nil(err)
 	should.Equal(`{"field1":"world","field-2":"abc"}`, str)
+}
+
+type ObjectImplementedMarshaler int
+
+func (obj *ObjectImplementedMarshaler) MarshalJSON() ([]byte, error) {
+	return []byte(`"hello"`), nil
+}
+
+func Test_marshaler(t *testing.T) {
+	type TestObject struct {
+		Field *ObjectImplementedMarshaler
+	}
+	should := require.New(t)
+	val := ObjectImplementedMarshaler(100)
+	obj := TestObject{&val}
+	bytes, err := json.Marshal(obj)
+	should.Nil(err)
+	should.Equal(`{"Field":"hello"}`, string(bytes))
+	str, err := MarshalToString(obj)
+	should.Nil(err)
+	should.Equal(`{"Field":"hello"}`, str)
 }