Browse Source

use reflect2 to implement map decoder

Tao Wen 7 years ago
parent
commit
08218647c3
5 changed files with 202 additions and 134 deletions
  1. 10 30
      feature_reflect.go
  2. 112 71
      feature_reflect_map.go
  3. 30 10
      feature_reflect_marshaler.go
  4. 8 0
      feature_reflect_optional.go
  5. 42 23
      value_tests/marshaler_test.go

+ 10 - 30
feature_reflect.go

@@ -82,19 +82,22 @@ func (stream *Stream) WriteVal(val interface{}) {
 	encoder.Encode(reflect2.PtrOf(val), stream)
 }
 
-func decoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder {
+func (cfg *frozenConfig) DecoderOf(typ reflect.Type) ValDecoder {
 	cacheKey := typ
 	decoder := cfg.getDecoderFromCache(cacheKey)
 	if decoder != nil {
 		return decoder
 	}
-	decoder = getTypeDecoderFromExtension(cfg, typ)
+	decoder = decoderOfType(cfg, "", typ)
+	cfg.addDecoderToCache(cacheKey, decoder)
+	return decoder
+}
+
+func decoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder {
+	decoder := getTypeDecoderFromExtension(cfg, typ)
 	if decoder != nil {
-		cfg.addDecoderToCache(cacheKey, decoder)
 		return decoder
 	}
-	decoder = &placeholderDecoder{cfg: cfg, cacheKey: cacheKey}
-	cfg.addDecoderToCache(cacheKey, decoder)
 	decoder = createDecoderOfType(cfg, prefix, typ)
 	for _, extension := range extensions {
 		decoder = extension.DecorateDecoder(typ, decoder)
@@ -102,7 +105,6 @@ func decoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecode
 	for _, extension := range cfg.extensions {
 		decoder = extension.DecorateDecoder(typ, decoder)
 	}
-	cfg.addDecoderToCache(cacheKey, decoder)
 	return decoder
 }
 
@@ -120,30 +122,8 @@ func createDecoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
 	if typ.AssignableTo(jsoniterNumberType) {
 		return &jsoniterNumberCodec{}
 	}
-	if typ.Implements(unmarshalerType) {
-		templateInterface := reflect.New(typ).Elem().Interface()
-		var decoder ValDecoder = &unmarshalerDecoder{extractInterface(templateInterface)}
-		if typ.Kind() == reflect.Ptr {
-			decoder = &OptionalDecoder{typ.Elem(), decoder}
-		}
-		return decoder
-	}
-	if reflect.PtrTo(typ).Implements(unmarshalerType) {
-		templateInterface := reflect.New(typ).Interface()
-		var decoder ValDecoder = &unmarshalerDecoder{extractInterface(templateInterface)}
-		return decoder
-	}
-	if typ.Implements(textUnmarshalerType) {
-		templateInterface := reflect.New(typ).Elem().Interface()
-		var decoder ValDecoder = &textUnmarshalerDecoder{extractInterface(templateInterface)}
-		if typ.Kind() == reflect.Ptr {
-			decoder = &OptionalDecoder{typ.Elem(), decoder}
-		}
-		return decoder
-	}
-	if reflect.PtrTo(typ).Implements(textUnmarshalerType) {
-		templateInterface := reflect.New(typ).Interface()
-		var decoder ValDecoder = &textUnmarshalerDecoder{extractInterface(templateInterface)}
+	decoder := createDecoderOfMarshaler(cfg, prefix, typ)
+	if decoder != nil {
 		return decoder
 	}
 	if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 {

+ 112 - 71
feature_reflect_map.go

@@ -1,19 +1,24 @@
 package jsoniter
 
 import (
-	"encoding"
 	"reflect"
 	"sort"
-	"strconv"
 	"unsafe"
 	"github.com/v2pro/plz/reflect2"
 	"fmt"
 )
 
 func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder {
-	decoder := decoderOfType(cfg, prefix+"[map]->", typ.Elem())
-	mapInterface := reflect.New(typ).Interface()
-	return &mapDecoder{typ, typ.Key(), typ.Elem(), decoder, extractInterface(mapInterface)}
+	keyDecoder := decoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key())
+	elemDecoder := decoderOfType(cfg, prefix+" [mapElem]", typ.Elem())
+	mapType := reflect2.Type2(typ).(*reflect2.UnsafeMapType)
+	return &mapDecoder{
+		mapType:     mapType,
+		keyType:     mapType.Key(),
+		elemType:    mapType.Elem(),
+		keyDecoder:  keyDecoder,
+		elemDecoder: elemDecoder,
+	}
 }
 
 func encoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
@@ -31,6 +36,38 @@ func encoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder
 	}
 }
 
+func decoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder {
+	switch typ.Kind() {
+	case reflect.String:
+		return decoderOfType(cfg, prefix, reflect2.DefaultTypeOfKind(reflect.String).Type1())
+	case reflect.Bool,
+		reflect.Uint8, reflect.Int8,
+		reflect.Uint16, reflect.Int16,
+		reflect.Uint32, reflect.Int32,
+		reflect.Uint64, reflect.Int64,
+		reflect.Uint, reflect.Int,
+		reflect.Float32, reflect.Float64,
+		reflect.Uintptr:
+		typ = reflect2.DefaultTypeOfKind(typ.Kind()).Type1()
+		return &numericMapKeyDecoder{decoderOfType(cfg, prefix, typ)}
+	default:
+		ptrType := reflect.PtrTo(typ)
+		if ptrType.Implements(textMarshalerType) {
+			return &referenceDecoder{
+				&textUnmarshalerDecoder{
+					valType: reflect2.Type2(ptrType),
+				},
+			}
+		}
+		if typ.Implements(textMarshalerType) {
+			return &textUnmarshalerDecoder{
+				valType:       reflect2.Type2(typ),
+			}
+		}
+		return &lazyErrorDecoder{err: fmt.Errorf("unsupported map key type: %v", typ)}
+	}
+}
+
 func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
 	switch typ.Kind() {
 	case reflect.String:
@@ -53,7 +90,7 @@ func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEnco
 		}
 		if typ.Implements(textMarshalerType) {
 			return &textMarshalerEncoder{
-				valType: reflect2.Type2(typ),
+				valType:       reflect2.Type2(typ),
 				stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
 			}
 		}
@@ -62,77 +99,81 @@ func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEnco
 }
 
 type mapDecoder struct {
-	mapType      reflect.Type
-	keyType      reflect.Type
-	elemType     reflect.Type
-	elemDecoder  ValDecoder
-	mapInterface emptyInterface
+	mapType     *reflect2.UnsafeMapType
+	keyType     reflect2.Type
+	elemType    reflect2.Type
+	keyDecoder  ValDecoder
+	elemDecoder ValDecoder
 }
 
 func (decoder *mapDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	// dark magic to cast unsafe.Pointer back to interface{} using reflect.Type
-	mapInterface := decoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface).Elem()
-	if iter.ReadNil() {
-		realVal.Set(reflect.Zero(decoder.mapType))
+	mapType := decoder.mapType
+	c := iter.nextToken()
+	if c == 'n' {
+		iter.skipThreeBytes('u', 'l', 'l')
+		*(*unsafe.Pointer)(ptr) = nil
+		mapType.UnsafeSet(ptr, mapType.UnsafeNew())
 		return
 	}
-	if realVal.IsNil() {
-		realVal.Set(reflect.MakeMap(realVal.Type()))
-	}
-	iter.ReadMapCB(func(iter *Iterator, keyStr string) bool {
-		elem := reflect.New(decoder.elemType)
-		decoder.elemDecoder.Decode(extractInterface(elem.Interface()).word, iter)
-		// to put into map, we have to use reflection
-		keyType := decoder.keyType
-		// TODO: remove this from loop
-		switch {
-		case keyType.Kind() == reflect.String:
-			realVal.SetMapIndex(reflect.ValueOf(keyStr).Convert(keyType), elem.Elem())
-			return true
-		case keyType.Implements(textUnmarshalerType):
-			textUnmarshaler := reflect.New(keyType.Elem()).Interface().(encoding.TextUnmarshaler)
-			err := textUnmarshaler.UnmarshalText([]byte(keyStr))
-			if err != nil {
-				iter.ReportError("read map key as TextUnmarshaler", err.Error())
-				return false
-			}
-			realVal.SetMapIndex(reflect.ValueOf(textUnmarshaler), elem.Elem())
-			return true
-		case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
-			textUnmarshaler := reflect.New(keyType).Interface().(encoding.TextUnmarshaler)
-			err := textUnmarshaler.UnmarshalText([]byte(keyStr))
-			if err != nil {
-				iter.ReportError("read map key as TextUnmarshaler", err.Error())
-				return false
-			}
-			realVal.SetMapIndex(reflect.ValueOf(textUnmarshaler).Elem(), elem.Elem())
-			return true
-		default:
-			switch keyType.Kind() {
-			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-				n, err := strconv.ParseInt(keyStr, 10, 64)
-				if err != nil || reflect.Zero(keyType).OverflowInt(n) {
-					iter.ReportError("read map key as int64", "read int64 failed")
-					return false
-				}
-				realVal.SetMapIndex(reflect.ValueOf(n).Convert(keyType), elem.Elem())
-				return true
-			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
-				n, err := strconv.ParseUint(keyStr, 10, 64)
-				if err != nil || reflect.Zero(keyType).OverflowUint(n) {
-					iter.ReportError("read map key as uint64", "read uint64 failed")
-					return false
-				}
-				realVal.SetMapIndex(reflect.ValueOf(n).Convert(keyType), elem.Elem())
-				return true
-			}
+	if mapType.UnsafeIsNil(ptr) {
+		mapType.UnsafeSet(ptr, mapType.UnsafeMakeMap(0))
+	}
+	if c != '{' {
+		iter.ReportError("ReadMapCB", `expect { or n, but found `+string([]byte{c}))
+		return
+	}
+	c = iter.nextToken()
+	if c == '}' {
+		return
+	}
+	if c != '"' {
+		iter.ReportError("ReadMapCB", `expect " after }, but found `+string([]byte{c}))
+		return
+	}
+	iter.unreadByte()
+	key := decoder.keyType.UnsafeNew()
+	decoder.keyDecoder.Decode(key, iter)
+	c = iter.nextToken()
+	if c != ':' {
+		iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c}))
+		return
+	}
+	elem := decoder.elemType.UnsafeNew()
+	decoder.elemDecoder.Decode(elem, iter)
+	decoder.mapType.UnsafeSetIndex(ptr, key, elem)
+	for c = iter.nextToken(); c == ','; c = iter.nextToken() {
+		key := decoder.keyType.UnsafeNew()
+		decoder.keyDecoder.Decode(key, iter)
+		c = iter.nextToken()
+		if c != ':' {
+			iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c}))
+			return
 		}
-		iter.ReportError("read map key", "unexpected map key type "+keyType.String())
-		return true
-	})
+		elem := decoder.elemType.UnsafeNew()
+		decoder.elemDecoder.Decode(elem, iter)
+		decoder.mapType.UnsafeSetIndex(ptr, key, elem)
+	}
+	if c != '}' {
+		iter.ReportError("ReadMapCB", `expect }, but found `+string([]byte{c}))
+	}
+}
+
+type numericMapKeyDecoder struct {
+	decoder ValDecoder
+}
+
+func (decoder *numericMapKeyDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
+	c := iter.nextToken()
+	if c != '"' {
+		iter.ReportError("ReadMapCB", `expect ", but found `+string([]byte{c}))
+		return
+	}
+	decoder.decoder.Decode(ptr, iter)
+	c = iter.nextToken()
+	if c != '"' {
+		iter.ReportError("ReadMapCB", `expect ", but found `+string([]byte{c}))
+		return
+	}
 }
 
 type numericMapKeyEncoder struct {

+ 30 - 10
feature_reflect_marshaler.go

@@ -8,6 +8,21 @@ import (
 	"reflect"
 )
 
+func createDecoderOfMarshaler(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder {
+	ptrType := reflect.PtrTo(typ)
+	if ptrType.Implements(unmarshalerType) {
+		return &referenceDecoder{
+			&unmarshalerDecoder{reflect2.Type2(ptrType)},
+		}
+	}
+	if ptrType.Implements(textUnmarshalerType) {
+		return &referenceDecoder{
+			&textUnmarshalerDecoder{reflect2.Type2(ptrType)},
+		}
+	}
+	return nil
+}
+
 func createEncoderOfMarshaler(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
 	if typ == marshalerType {
 		checkIsEmpty := createCheckIsEmpty(cfg, typ)
@@ -160,14 +175,13 @@ func (encoder *directTextMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 }
 
 type unmarshalerDecoder struct {
-	templateInterface emptyInterface
+	valType reflect2.Type
 }
 
 func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	templateInterface := decoder.templateInterface
-	templateInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
-	unmarshaler := (*realInterface).(json.Unmarshaler)
+	valType := decoder.valType
+	obj := valType.UnsafeIndirect(ptr)
+	unmarshaler := obj.(json.Unmarshaler)
 	iter.nextToken()
 	iter.unreadByte() // skip spaces
 	bytes := iter.SkipAndReturnBytes()
@@ -178,14 +192,20 @@ func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
 }
 
 type textUnmarshalerDecoder struct {
-	templateInterface emptyInterface
+	valType reflect2.Type
 }
 
 func (decoder *textUnmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
-	templateInterface := decoder.templateInterface
-	templateInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&templateInterface))
-	unmarshaler := (*realInterface).(encoding.TextUnmarshaler)
+	valType := decoder.valType
+	obj := valType.UnsafeIndirect(ptr)
+	if reflect2.IsNil(obj) {
+		ptrType := valType.(*reflect2.UnsafePtrType)
+		elemType := ptrType.Elem()
+		elem := elemType.UnsafeNew()
+		ptrType.UnsafeSet(ptr, unsafe.Pointer(&elem))
+		obj = valType.UnsafeIndirect(ptr)
+	}
+	unmarshaler := (obj).(encoding.TextUnmarshaler)
 	str := iter.ReadString()
 	err := unmarshaler.UnmarshalText([]byte(str))
 	if err != nil {

+ 8 - 0
feature_reflect_optional.go

@@ -118,4 +118,12 @@ func (encoder *referenceEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
 
 func (encoder *referenceEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 	return encoder.encoder.IsEmpty(unsafe.Pointer(&ptr))
+}
+
+type referenceDecoder struct {
+	decoder ValDecoder
+}
+
+func (decoder *referenceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
+	decoder.decoder.Decode(unsafe.Pointer(&ptr), iter)
 }

+ 42 - 23
value_tests/marshaler_test.go

@@ -6,61 +6,80 @@ import (
 )
 
 func init() {
-	jsonMarshaler := json.Marshaler(fakeJsonMarshaler{})
-	textMarshaler := encoding.TextMarshaler(fakeTextMarshaler{})
-	textMarshaler2 := encoding.TextMarshaler(&fakeTextMarshaler2{})
+	jm := json.Marshaler(jmOfStruct{})
+	tm1 := encoding.TextMarshaler(tmOfStruct{})
+	tm2 := encoding.TextMarshaler(&tmOfStructInt{})
 	marshalCases = append(marshalCases,
-		fakeJsonMarshaler{},
-		&jsonMarshaler,
-		fakeTextMarshaler{},
-		&textMarshaler,
-		fakeTextMarshaler2{},
-		&textMarshaler2,
-		map[fakeTextMarshaler]int{
-			fakeTextMarshaler{}: 100,
+		jmOfStruct{},
+		&jm,
+		tmOfStruct{},
+		&tm1,
+		tmOfStructInt{},
+		&tm2,
+		map[tmOfStruct]int{
+			tmOfStruct{}: 100,
 		},
-		map[*fakeTextMarshaler]int{
-			&fakeTextMarshaler{}: 100,
+		map[*tmOfStruct]int{
+			&tmOfStruct{}: 100,
 		},
 		map[encoding.TextMarshaler]int{
-			textMarshaler: 100,
+			tm1: 100,
 		},
 	)
+	unmarshalCases = append(unmarshalCases, unmarshalCase{
+		ptr: (*tmOfMap)(nil),
+		input: `"{1:2}"`,
+	}, unmarshalCase{
+		ptr: (*tmOfMapPtr)(nil),
+		input: `"{1:2}"`,
+	})
 }
 
-type fakeJsonMarshaler struct {
+type jmOfStruct struct {
 	F2 chan []byte
 }
 
-func (q fakeJsonMarshaler) MarshalJSON() ([]byte, error) {
+func (q jmOfStruct) MarshalJSON() ([]byte, error) {
 	return []byte(`""`), nil
 }
 
-func (q *fakeJsonMarshaler) UnmarshalJSON(value []byte) error {
+func (q *jmOfStruct) UnmarshalJSON(value []byte) error {
 	return nil
 }
 
 
-type fakeTextMarshaler struct {
+type tmOfStruct struct {
 	F2 chan []byte
 }
 
-func (q fakeTextMarshaler) MarshalText() ([]byte, error) {
+func (q tmOfStruct) MarshalText() ([]byte, error) {
 	return []byte(`""`), nil
 }
 
-func (q *fakeTextMarshaler) UnmarshalText(value []byte) error {
+func (q *tmOfStruct) UnmarshalText(value []byte) error {
 	return nil
 }
 
-type fakeTextMarshaler2 struct {
+type tmOfStructInt struct {
 	Field2 int
 }
 
-func (q *fakeTextMarshaler2) MarshalText() ([]byte, error) {
+func (q *tmOfStructInt) MarshalText() ([]byte, error) {
 	return []byte(`"abc"`), nil
 }
 
-func (q *fakeTextMarshaler2) UnmarshalText(value []byte) error {
+func (q *tmOfStructInt) UnmarshalText(value []byte) error {
+	return nil
+}
+
+type tmOfMap map[int]int
+
+func (q tmOfMap) UnmarshalText(value []byte) error {
+	return nil
+}
+
+type tmOfMapPtr map[int]int
+
+func (q *tmOfMapPtr) UnmarshalText(value []byte) error {
 	return nil
 }