Parcourir la source

support struct

Tao Wen il y a 9 ans
Parent
commit
dd431da523
2 fichiers modifiés avec 149 ajouts et 7 suppressions
  1. 92 6
      jsoniter_reflect.go
  2. 57 1
      jsoniter_reflect_test.go

+ 92 - 6
jsoniter_reflect.go

@@ -3,24 +3,59 @@ package jsoniter
 import (
 	"reflect"
 	"errors"
+	"fmt"
+	"unsafe"
+	"sync/atomic"
 )
 
 type Decoder interface {
-	decode(iter *Iterator, obj interface{})
+	decode(ptr unsafe.Pointer, iter *Iterator)
 }
 
 type stringDecoder struct {
 }
 
-func (decoder *stringDecoder) decode(iter *Iterator, obj interface{}) {
-	ptr := obj.(*string)
-	*ptr = iter.ReadString()
+func (decoder *stringDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
+	*((*string)(ptr)) = iter.ReadString()
+}
+
+type structDecoder struct {
+	fields map[string]Decoder
+}
+
+func (decoder *structDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
+	for field := iter.ReadObject(); field != ""; field = iter.ReadObject() {
+		fieldDecoder := decoder.fields[field]
+		if fieldDecoder == nil {
+			iter.Skip()
+		} else {
+			fieldDecoder.decode(ptr, iter)
+		}
+	}
+}
+
+type structFieldDecoder struct {
+	offset       uintptr
+	fieldDecoder Decoder
+}
+
+func (decoder *structFieldDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
+	fieldPtr := uintptr(ptr) + decoder.offset
+	decoder.fieldDecoder.decode(unsafe.Pointer(fieldPtr), iter)
 }
 
 var DECODER_STRING *stringDecoder
+var DECODERS_STRUCT unsafe.Pointer
 
 func init() {
 	DECODER_STRING = &stringDecoder{}
+	atomic.StorePointer(&DECODERS_STRUCT, unsafe.Pointer(&map[string]*structDecoder{}))
+}
+
+// emptyInterface is the header for an interface{} value.
+type emptyInterface struct {
+	typ  *struct{}
+	word unsafe.Pointer
 }
 
 func (iter *Iterator) Read(obj interface{}) {
@@ -30,13 +65,23 @@ func (iter *Iterator) Read(obj interface{}) {
 		iter.Error = err
 		return
 	}
-	decoder.decode(iter, obj)
+	e := (*emptyInterface)(unsafe.Pointer(&obj))
+	decoder.decode(e.word, iter)
+}
+
+type prefix string
+
+func (p prefix) addTo(decoder Decoder, err error) (Decoder, error) {
+	if err != nil {
+		return nil, fmt.Errorf("%s: %s", p, err.Error())
+	}
+	return decoder, err
 }
 
 func decoderOfType(type_ reflect.Type) (Decoder, error) {
 	switch type_.Kind() {
 	case reflect.Ptr:
-		return decoderOfPtr(type_.Elem())
+		return prefix("ptr").addTo(decoderOfPtr(type_.Elem()))
 	default:
 		return nil, errors.New("expect ptr")
 	}
@@ -46,8 +91,49 @@ func decoderOfPtr(type_ reflect.Type) (Decoder, error) {
 	switch type_.Kind() {
 	case reflect.String:
 		return DECODER_STRING, nil
+	case reflect.Struct:
+		return decoderOfStruct(type_)
 	default:
 		return nil, errors.New("expect string")
 	}
 }
 
+func decoderOfStruct(type_ reflect.Type) (Decoder, error) {
+	cacheKey := type_.String()
+	cachedDecoder := getStructDecoderFromCache(cacheKey)
+	if cachedDecoder == nil {
+		fields := map[string]Decoder{}
+		for i := 0; i < type_.NumField(); i++ {
+			field := type_.Field(i)
+			decoder, err := decoderOfPtr(field.Type)
+			if err != nil {
+				return prefix(fmt.Sprintf("[%s]", field.Name)).addTo(decoder, err)
+			}
+			fields[field.Name] = &structFieldDecoder{field.Offset, decoder}
+		}
+		cachedDecoder = &structDecoder{fields}
+		addStructDecoderToCache(cacheKey, cachedDecoder)
+	}
+	return cachedDecoder, nil
+}
+
+func addStructDecoderToCache(cacheKey string, decoder *structDecoder) {
+	retry := true
+	for retry {
+		ptr := atomic.LoadPointer(&DECODERS_STRUCT)
+		cache := *(*map[string]*structDecoder)(ptr)
+		copy := map[string]*structDecoder{}
+		for k, v := range cache {
+			copy[k] = v
+		}
+		copy[cacheKey] = decoder
+		retry = !atomic.CompareAndSwapPointer(&DECODERS_STRUCT, ptr, unsafe.Pointer(&copy))
+	}
+}
+
+func getStructDecoderFromCache(cacheKey string) *structDecoder {
+	ptr := atomic.LoadPointer(&DECODERS_STRUCT)
+	cache := *(*map[string]*structDecoder)(ptr)
+	return cache[cacheKey]
+}
+

+ 57 - 1
jsoniter_reflect_test.go

@@ -2,6 +2,8 @@ package jsoniter
 
 import (
 	"testing"
+	"fmt"
+	"encoding/json"
 )
 
 func Test_reflect_str(t *testing.T) {
@@ -9,6 +11,60 @@ func Test_reflect_str(t *testing.T) {
 	str := ""
 	iter.Read(&str)
 	if str != "hello" {
-		t.FailNow()
+		t.Fatal(str)
 	}
 }
+
+type StructOfString struct {
+	field1 string
+	field2 string
+}
+
+func Test_reflect_struct(t *testing.T) {
+	iter := ParseString(`{"field1": "hello", "field2": "world"}`)
+	struct_ := StructOfString{}
+	iter.Read(&struct_)
+	if struct_.field1 != "hello" {
+		fmt.Println(iter.Error)
+		t.Fatal(struct_.field1)
+	}
+	if struct_.field2 != "world" {
+		fmt.Println(iter.Error)
+		t.Fatal(struct_.field1)
+	}
+}
+
+func Benchmark_jsoniter_reflect(b *testing.B) {
+	b.ReportAllocs()
+	for n := 0; n < b.N; n++ {
+		iter := ParseString(`{"field1": "hello", "field2": "world"}`)
+		struct_ := StructOfString{}
+		iter.Read(&struct_)
+	}
+}
+
+func Benchmark_jsoniter_direct(b *testing.B) {
+	b.ReportAllocs()
+	for n := 0; n < b.N; n++ {
+		iter := ParseString(`{"field1": "hello", "field2": "world"}`)
+		struct_ := StructOfString{}
+		for field := iter.ReadObject(); field != ""; field = iter.ReadObject() {
+			switch field {
+			case "field1":
+				struct_.field1 = iter.ReadString()
+			case "field2":
+				struct_.field2 = iter.ReadString()
+			default:
+				iter.Skip()
+			}
+		}
+	}
+}
+
+func Benchmark_json_reflect(b *testing.B) {
+	b.ReportAllocs()
+	for n := 0; n < b.N; n++ {
+		struct_ := StructOfString{}
+		json.Unmarshal([]byte(`{"field1": "hello", "field2": "world"}`), &struct_)
+	}
+}