Kaynağa Gözat

#27 support json.Unmarshaler

Tao Wen 8 yıl önce
ebeveyn
işleme
7d2ae80c37

+ 11 - 0
feature_iter_skip.go

@@ -30,6 +30,17 @@ func (iter *Iterator) ReadBool() (ret bool) {
 }
 
 
+func (iter *Iterator) SkipAndReturnBytes() []byte {
+	if iter.reader != nil {
+		panic("reader input does not support this api")
+	}
+	before := iter.head
+	iter.Skip()
+	after := iter.head
+	return iter.buf[before:after]
+}
+
+
 // Skip skips a json object and positions to relatively the next json object
 func (iter *Iterator) Skip() {
 	c := iter.nextToken()

+ 9 - 3
feature_reflect.go

@@ -75,6 +75,7 @@ var fieldEncoders map[string]Encoder
 var extensions []ExtensionFunc
 var anyType reflect.Type
 var marshalerType reflect.Type
+var unmarshalerType reflect.Type
 
 func init() {
 	typeDecoders = map[string]Decoder{}
@@ -86,6 +87,7 @@ func init() {
 	atomic.StorePointer(&ENCODERS, unsafe.Pointer(&map[string]Encoder{}))
 	anyType = reflect.TypeOf((*Any)(nil)).Elem()
 	marshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
+	unmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
 }
 
 func addDecoderToCache(cacheKey reflect.Type, decoder Decoder) {
@@ -293,9 +295,6 @@ func (p prefix) addToEncoder(encoder Encoder, err error) (Encoder, error) {
 }
 
 func decoderOfType(typ reflect.Type) (Decoder, error) {
-	if typ.ConvertibleTo(anyType) {
-		return &anyCodec{}, nil
-	}
 	typeName := typ.String()
 	typeDecoder := typeDecoders[typeName]
 	if typeDecoder != nil {
@@ -315,6 +314,13 @@ func decoderOfType(typ reflect.Type) (Decoder, error) {
 }
 
 func createDecoderOfType(typ reflect.Type) (Decoder, error) {
+	if typ.ConvertibleTo(anyType) {
+		return &anyCodec{}, nil
+	}
+	if typ.ConvertibleTo(unmarshalerType) {
+		templateInterface := reflect.New(typ).Elem().Interface()
+		return &optionalDecoder{typ, &unmarshalerDecoder{extractInterface(templateInterface)}}, nil
+	}
 	switch typ.Kind() {
 	case reflect.String:
 		return &stringCodec{}, nil

+ 16 - 0
feature_reflect_native.go

@@ -362,4 +362,20 @@ func (encoder *marshalerEncoder) isEmpty(ptr unsafe.Pointer) bool {
 	} else {
 		return len(bytes) > 0
 	}
+}
+
+type unmarshalerDecoder struct {
+	templateInterface emptyInterface
+}
+
+func (decoder *unmarshalerDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
+	templateInterface := decoder.templateInterface
+	templateInterface.word = ptr
+	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
+	unmarshaler := (*realInterface).(json.Unmarshaler)
+	bytes := iter.SkipAndReturnBytes()
+	err := unmarshaler.UnmarshalJSON(bytes)
+	if err != nil {
+		iter.reportError("unmarshaler", err.Error())
+	}
 }

+ 25 - 1
jsoniter_customize_test.go

@@ -149,4 +149,28 @@ func Test_marshaler(t *testing.T) {
 	str, err := MarshalToString(obj)
 	should.Nil(err)
 	should.Equal(`{"Field":"hello"}`, str)
-}
+}
+
+type ObjectImplementedUnmarshaler int
+
+func (obj *ObjectImplementedUnmarshaler) UnmarshalJSON([]byte) error {
+	*obj = 100
+	return nil
+}
+
+func Test_unmarshaler(t *testing.T) {
+	type TestObject struct {
+		Field *ObjectImplementedUnmarshaler
+		Field2 string
+	}
+	should := require.New(t)
+	obj := TestObject{}
+	val := ObjectImplementedUnmarshaler(0)
+	obj.Field = &val
+	err := json.Unmarshal([]byte(`{"Field":"hello"}`), &obj)
+	should.Nil(err)
+	should.Equal(100, int(*obj.Field))
+	err = Unmarshal([]byte(`{"Field":"hello"}`), &obj)
+	should.Nil(err)
+	should.Equal(100, int(*obj.Field))
+}