Pārlūkot izejas kodu

support customize reflection

Tao Wen 9 gadi atpakaļ
vecāks
revīzija
9873b4d32c
3 mainītis faili ar 120 papildinājumiem un 14 dzēšanām
  1. 45 0
      json_customize_test.go
  2. 5 0
      jsoniter_adapter.go
  3. 70 14
      jsoniter_reflect.go

+ 45 - 0
json_customize_test.go

@@ -0,0 +1,45 @@
+package jsoniter
+
+import (
+	"testing"
+	"time"
+	"unsafe"
+	"strconv"
+)
+
+func Test_customize_type_decoder(t *testing.T) {
+	RegisterTypeDecoder("time.Time", func(ptr unsafe.Pointer, iter *Iterator) {
+		t, err := time.ParseInLocation("2006-01-02 15:04:05", iter.ReadString(), time.UTC)
+		if err != nil {
+			iter.Error = err
+			return
+		}
+		*((*time.Time)(ptr)) = t
+	})
+	defer ClearDecoders()
+	val := time.Time{}
+	err := Unmarshal([]byte(`"2016-12-05 08:43:28"`), &val)
+	if err != nil {
+		t.Fatal(err)
+	}
+	year, month, day := val.Date()
+	if year != 2016 || month != 12 || day != 5 {
+		t.Fatal(val)
+	}
+}
+
+type Tom struct {
+	field1 string
+}
+
+func Test_customize_field_decoder(t *testing.T) {
+	RegisterFieldDecoder("jsoniter.Tom", "field1", func(ptr unsafe.Pointer, iter *Iterator) {
+		*((*string)(ptr)) = strconv.Itoa(iter.ReadInt())
+	})
+	defer ClearDecoders()
+	tom := Tom{}
+	err := Unmarshal([]byte(`{"field1": 100}`), &tom)
+	if err != nil {
+		t.Fatal(err)
+	}
+}

+ 5 - 0
jsoniter_adapter.go

@@ -1,9 +1,14 @@
 package jsoniter
 
+import "io"
+
 // adapt to json/encoding api
 
 func Unmarshal(data []byte, v interface{}) error {
 	iter := ParseBytes(data)
 	iter.Read(v)
+	if iter.Error == io.EOF {
+		return nil
+	}
 	return iter.Error
 }

+ 70 - 14
jsoniter_reflect.go

@@ -7,6 +7,7 @@ import (
 	"unsafe"
 	"sync/atomic"
 	"strings"
+	"io"
 )
 
 type Decoder interface {
@@ -117,12 +118,21 @@ type stringNumberDecoder struct {
 
 func (decoder *stringNumberDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 	c := iter.readByte()
+	if iter.Error != nil {
+		return
+	}
 	if c != '"' {
 		iter.ReportError("stringNumberDecoder", `expect "`)
 		return
 	}
 	decoder.elemDecoder.decode(ptr, iter)
+	if iter.Error != nil {
+		return
+	}
 	c = iter.readByte()
+	if iter.Error != nil {
+		return
+	}
 	if c != '"' {
 		iter.ReportError("stringNumberDecoder", `expect "`)
 		return
@@ -130,7 +140,7 @@ func (decoder *stringNumberDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 }
 
 type optionalDecoder struct {
-	valueType reflect.Type
+	valueType    reflect.Type
 	valueDecoder Decoder
 }
 
@@ -145,11 +155,12 @@ func (decoder *optionalDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 }
 
 type structDecoder struct {
+	type_ reflect.Type
 	fields map[string]Decoder
 }
 
 func (decoder *structDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
-	for field := iter.ReadObject(); field != ""; field = iter.ReadObject() {
+	for field := iter.ReadObject(); field != "" && iter.Error == nil; field = iter.ReadObject() {
 		fieldDecoder := decoder.fields[field]
 		if fieldDecoder == nil {
 			iter.Skip()
@@ -157,20 +168,26 @@ func (decoder *structDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 			fieldDecoder.decode(ptr, iter)
 		}
 	}
+	if iter.Error != nil && iter.Error != io.EOF {
+		iter.Error = fmt.Errorf("%v: %s", decoder.type_, iter.Error.Error())
+	}
 }
 
 type structFieldDecoder struct {
-	offset       uintptr
+	field       *reflect.StructField
 	fieldDecoder Decoder
 }
 
 func (decoder *structFieldDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
-	fieldPtr := uintptr(ptr) + decoder.offset
+	fieldPtr := uintptr(ptr) + decoder.field.Offset
 	decoder.fieldDecoder.decode(unsafe.Pointer(fieldPtr), iter)
+	if iter.Error != nil && iter.Error != io.EOF {
+		iter.Error = fmt.Errorf("%s: %s", decoder.field.Name, iter.Error.Error())
+	}
 }
 
 type sliceDecoder struct {
-	sliceType    reflect.Type
+	sliceType   reflect.Type
 	elemType    reflect.Type
 	elemDecoder Decoder
 }
@@ -185,12 +202,15 @@ type sliceHeader struct {
 func (decoder *sliceDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 	slice := (*sliceHeader)(ptr)
 	slice.Len = 0
-	for iter.ReadArray() {
+	for iter.ReadArray() && iter.Error == nil {
 		offset := uintptr(slice.Len) * decoder.elemType.Size()
 		growOne(slice, decoder.sliceType, decoder.elemType)
 		dataPtr := uintptr(slice.Data) + offset
 		decoder.elemDecoder.decode(unsafe.Pointer(dataPtr), iter)
 	}
+	if iter.Error != nil && iter.Error != io.EOF {
+		iter.Error = fmt.Errorf("%v: %s", decoder.sliceType, iter.Error.Error())
+	}
 }
 
 // grow grows the slice s so that it can hold extra more values, allocating
@@ -215,8 +235,8 @@ func growOne(slice *sliceHeader, sliceType reflect.Type, elementType reflect.Typ
 	}
 	dst := unsafe.Pointer(reflect.MakeSlice(sliceType, newLen, newCap).Pointer())
 	originalBytesCount := uintptr(slice.Len) * elementType.Size()
-	srcPtr := (*[1<<30]byte)(slice.Data)
-	dstPtr := (*[1<<30]byte)(dst)
+	srcPtr := (*[1 << 30]byte)(slice.Data)
+	dstPtr := (*[1 << 30]byte)(dst)
 	for i := uintptr(0); i < originalBytesCount; i++ {
 		dstPtr[i] = srcPtr[i]
 	}
@@ -247,10 +267,38 @@ func getDecoderFromCache(cacheKey string) Decoder {
 	return cache[cacheKey]
 }
 
+var typeDecoders map[string]Decoder
+var fieldDecoders map[string]Decoder
+
 func init() {
+	typeDecoders = map[string]Decoder{}
+	fieldDecoders = map[string]Decoder{}
 	atomic.StorePointer(&DECODERS, unsafe.Pointer(&map[string]Decoder{}))
 }
 
+type DecoderFunc func(ptr unsafe.Pointer, iter *Iterator)
+
+type funcDecoder struct {
+	func_ DecoderFunc
+}
+
+func (decoder *funcDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
+	decoder.func_(ptr, iter)
+}
+
+func RegisterTypeDecoder(type_ string, func_ DecoderFunc) {
+	typeDecoders[type_] = &funcDecoder{func_}
+}
+
+func RegisterFieldDecoder(type_ string, field string, func_ DecoderFunc) {
+	fieldDecoders[fmt.Sprintf("%s/%s", type_, field)] = &funcDecoder{func_}
+}
+
+func ClearDecoders() {
+	typeDecoders = map[string]Decoder{}
+	fieldDecoders = map[string]Decoder{}
+}
+
 // emptyInterface is the header for an interface{} value.
 type emptyInterface struct {
 	typ  *struct{}
@@ -293,6 +341,10 @@ func decoderOfType(type_ reflect.Type) (Decoder, error) {
 }
 
 func decoderOfPtr(type_ reflect.Type) (Decoder, error) {
+	typeDecoder := typeDecoders[type_.String()]
+	if typeDecoder != nil {
+		return typeDecoder, nil
+	}
 	switch type_.Kind() {
 	case reflect.String:
 		return &stringDecoder{}, nil
@@ -341,28 +393,32 @@ func decoderOfOptional(type_ reflect.Type) (Decoder, error) {
 	return &optionalDecoder{type_, decoder}, nil
 }
 
-
 func decoderOfStruct(type_ reflect.Type) (Decoder, error) {
 	fields := map[string]Decoder{}
 	for i := 0; i < type_.NumField(); i++ {
 		field := type_.Field(i)
+		fieldDecoderKey := fmt.Sprintf("%s/%s", type_.String(), field.Name)
+		decoder := fieldDecoders[fieldDecoderKey]
 		tagParts := strings.Split(field.Tag.Get("json"), ",")
 		jsonFieldName := tagParts[0]
 		if jsonFieldName == "" {
 			jsonFieldName = field.Name
 		}
-		decoder, err := decoderOfPtr(field.Type)
-		if err != nil {
-			return prefix(fmt.Sprintf("{%s}", field.Name)).addTo(decoder, err)
+		if decoder == nil {
+			var err error
+			decoder, err = decoderOfPtr(field.Type)
+			if err != nil {
+				return prefix(fmt.Sprintf("{%s}", field.Name)).addTo(decoder, err)
+			}
 		}
 		if len(tagParts) > 1 && tagParts[1] == "string" {
 			decoder = &stringNumberDecoder{decoder}
 		}
 		if jsonFieldName != "-" {
-			fields[jsonFieldName] = &structFieldDecoder{field.Offset, decoder}
+			fields[jsonFieldName] = &structFieldDecoder{&field, decoder}
 		}
 	}
-	return &structDecoder{fields}, nil
+	return &structDecoder{type_, fields}, nil
 }
 
 func decoderOfSlice(type_ reflect.Type) (Decoder, error) {