Sfoglia il codice sorgente

support TextMarshaler as map key

Tao Wen 7 anni fa
parent
commit
d8e64aa825

+ 4 - 1
feature_reflect.go

@@ -358,6 +358,7 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
 		checkIsEmpty := createCheckIsEmpty(cfg, typ)
 		var encoder ValEncoder = &directTextMarshalerEncoder{
 			checkIsEmpty:      checkIsEmpty,
+			stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
 		}
 		return encoder
 	}
@@ -365,14 +366,16 @@ func createEncoderOfType(cfg *frozenConfig, prefix string, typ reflect.Type) Val
 		checkIsEmpty := createCheckIsEmpty(cfg, typ)
 		var encoder ValEncoder = &textMarshalerEncoder{
 			valType: reflect2.Type2(typ),
+			stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
 			checkIsEmpty:      checkIsEmpty,
 		}
 		return encoder
 	}
-	if ptrType.Implements(textMarshalerType) {
+	if typ.Kind() == reflect.Map && ptrType.Implements(textMarshalerType) {
 		checkIsEmpty := createCheckIsEmpty(cfg, ptrType)
 		var encoder ValEncoder = &textMarshalerEncoder{
 			valType: reflect2.Type2(ptrType),
+			stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
 			checkIsEmpty:      checkIsEmpty,
 		}
 		return &referenceEncoder{encoder}

+ 107 - 122
feature_reflect_map.go

@@ -2,11 +2,12 @@ package jsoniter
 
 import (
 	"encoding"
-	"encoding/json"
 	"reflect"
 	"sort"
 	"strconv"
 	"unsafe"
+	"github.com/v2pro/plz/reflect2"
+	"fmt"
 )
 
 func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder {
@@ -16,13 +17,48 @@ func decoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValDecoder
 }
 
 func encoderOfMap(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
-	elemType := typ.Elem()
-	encoder := &emptyInterfaceCodec{}
-	mapInterface := reflect.New(typ).Elem().Interface()
 	if cfg.sortMapKeys {
-		return &sortKeysMapEncoder{typ, elemType, encoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))}
+		return &sortKeysMapEncoder{
+			mapType:     reflect2.Type2(typ).(*reflect2.UnsafeMapType),
+			keyEncoder:  encoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()),
+			elemEncoder: encoderOfType(cfg, prefix+" [mapElem]", typ.Elem()),
+		}
+	}
+	return &mapEncoder{
+		mapType:     reflect2.Type2(typ).(*reflect2.UnsafeMapType),
+		keyEncoder:  encoderOfMapKey(cfg, prefix+" [mapKey]", typ.Key()),
+		elemEncoder: encoderOfType(cfg, prefix+" [mapElem]", typ.Elem()),
+	}
+}
+
+func encoderOfMapKey(cfg *frozenConfig, prefix string, typ reflect.Type) ValEncoder {
+	switch typ.Kind() {
+	case reflect.String:
+		return encoderOfType(cfg, prefix, reflect2.DefaultTypeOfKind(reflect.String).Type1())
+	case reflect.Bool,
+		reflect.Uint8, reflect.Int8,
+		reflect.Uint16, reflect.Int16,
+		reflect.Uint32, reflect.Int32,
+		reflect.Uint64, reflect.Int64,
+		reflect.Uint, reflect.Int,
+		reflect.Float32, reflect.Float64,
+		reflect.Uintptr:
+		typ = reflect2.DefaultTypeOfKind(typ.Kind()).Type1()
+		return &numericMapKeyEncoder{encoderOfType(cfg, prefix, typ)}
+	default:
+		if typ == textMarshalerType {
+			return &directTextMarshalerEncoder{
+				stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
+			}
+		}
+		if typ.Implements(textMarshalerType) {
+			return &textMarshalerEncoder{
+				valType: reflect2.Type2(typ),
+				stringEncoder: cfg.EncoderOf(reflect.TypeOf("")),
+			}
+		}
+		return &lazyErrorEncoder{err: fmt.Errorf("unsupported map key type: %v", typ)}
 	}
-	return &mapEncoder{typ, elemType, encoder, *((*emptyInterface)(unsafe.Pointer(&mapInterface)))}
 }
 
 type mapDecoder struct {
@@ -99,159 +135,108 @@ func (decoder *mapDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
 	})
 }
 
+type numericMapKeyEncoder struct {
+	encoder ValEncoder
+}
+
+func (encoder *numericMapKeyEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
+	stream.writeByte('"')
+	encoder.encoder.Encode(ptr, stream)
+	stream.writeByte('"')
+}
+
+func (encoder *numericMapKeyEncoder) IsEmpty(ptr unsafe.Pointer) bool {
+	return false
+}
+
 type mapEncoder struct {
-	mapType      reflect.Type
-	elemType     reflect.Type
-	elemEncoder  ValEncoder
-	mapInterface emptyInterface
+	mapType     *reflect2.UnsafeMapType
+	keyEncoder  ValEncoder
+	elemEncoder ValEncoder
 }
 
 func (encoder *mapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
-	mapInterface := encoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface)
 	stream.WriteObjectStart()
-	for i, key := range realVal.MapKeys() {
+	iter := encoder.mapType.UnsafeIterate(ptr)
+	for i := 0; iter.HasNext(); i++ {
 		if i != 0 {
 			stream.WriteMore()
 		}
-		encodeMapKey(key, stream)
+		key, elem := iter.UnsafeNext()
+		encoder.keyEncoder.Encode(key, stream)
 		if stream.indention > 0 {
 			stream.writeTwoBytes(byte(':'), byte(' '))
 		} else {
 			stream.writeByte(':')
 		}
-		val := realVal.MapIndex(key).Interface()
-		encoder.elemEncoder.Encode(unsafe.Pointer(&val), stream)
+		encoder.elemEncoder.Encode(elem, stream)
 	}
 	stream.WriteObjectEnd()
 }
 
-func encodeMapKey(key reflect.Value, stream *Stream) {
-	if key.Kind() == reflect.String {
-		stream.WriteString(key.String())
-		return
-	}
-	if tm, ok := key.Interface().(encoding.TextMarshaler); ok {
-		buf, err := tm.MarshalText()
-		if err != nil {
-			stream.Error = err
-			return
-		}
-		stream.writeByte('"')
-		stream.Write(buf)
-		stream.writeByte('"')
-		return
-	}
-	switch key.Kind() {
-	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-		stream.writeByte('"')
-		stream.WriteInt64(key.Int())
-		stream.writeByte('"')
-		return
-	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
-		stream.writeByte('"')
-		stream.WriteUint64(key.Uint())
-		stream.writeByte('"')
-		return
-	}
-	stream.Error = &json.UnsupportedTypeError{Type: key.Type()}
-}
-
 func (encoder *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
-	mapInterface := encoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface)
-	return realVal.Len() == 0
+	iter := encoder.mapType.UnsafeIterate(ptr)
+	return !iter.HasNext()
 }
 
 type sortKeysMapEncoder struct {
-	mapType      reflect.Type
-	elemType     reflect.Type
-	elemEncoder  ValEncoder
-	mapInterface emptyInterface
+	mapType     *reflect2.UnsafeMapType
+	keyEncoder  ValEncoder
+	elemEncoder ValEncoder
 }
 
 func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
-	ptr = *(*unsafe.Pointer)(ptr)
-	if ptr == nil {
+	if *(*unsafe.Pointer)(ptr) == nil {
 		stream.WriteNil()
 		return
 	}
-	mapInterface := encoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface)
-
-	// Extract and sort the keys.
-	keys := realVal.MapKeys()
-	sv := stringValues(make([]reflectWithString, len(keys)))
-	for i, v := range keys {
-		sv[i].v = v
-		if err := sv[i].resolve(); err != nil {
-			stream.Error = err
-			return
+	stream.WriteObjectStart()
+	mapIter := encoder.mapType.UnsafeIterate(ptr)
+	subStream := stream.cfg.BorrowStream(nil)
+	subIter := stream.cfg.BorrowIterator(nil)
+	keyValues := encodedKeyValues{}
+	for mapIter.HasNext() {
+		subStream.buf = make([]byte, 0, 64)
+		key, elem := mapIter.UnsafeNext()
+		encoder.keyEncoder.Encode(key, subStream)
+		encodedKey := subStream.Buffer()
+		subIter.ResetBytes(encodedKey)
+		decodedKey := subIter.ReadString()
+		if stream.indention > 0 {
+			subStream.writeTwoBytes(byte(':'), byte(' '))
+		} else {
+			subStream.writeByte(':')
 		}
+		encoder.elemEncoder.Encode(elem, subStream)
+		keyValues = append(keyValues, encodedKV{
+			key:      decodedKey,
+			keyValue: subStream.Buffer(),
+		})
 	}
-	sort.Sort(sv)
-
-	stream.WriteObjectStart()
-	for i, key := range sv {
+	sort.Sort(keyValues)
+	for i, keyValue := range keyValues {
 		if i != 0 {
 			stream.WriteMore()
 		}
-		stream.WriteVal(key.s) // might need html escape, so can not WriteString directly
-		if stream.indention > 0 {
-			stream.writeTwoBytes(byte(':'), byte(' '))
-		} else {
-			stream.writeByte(':')
-		}
-		val := realVal.MapIndex(key.v).Interface()
-		encoder.elemEncoder.Encode(unsafe.Pointer(&val), stream)
+		stream.Write(keyValue.keyValue)
 	}
 	stream.WriteObjectEnd()
+	stream.cfg.ReturnStream(subStream)
+	stream.cfg.ReturnIterator(subIter)
 }
 
-// stringValues is a slice of reflect.Value holding *reflect.StringValue.
-// It implements the methods to sort by string.
-type stringValues []reflectWithString
-
-type reflectWithString struct {
-	v reflect.Value
-	s string
-}
-
-func (w *reflectWithString) resolve() error {
-	if w.v.Kind() == reflect.String {
-		w.s = w.v.String()
-		return nil
-	}
-	if tm, ok := w.v.Interface().(encoding.TextMarshaler); ok {
-		buf, err := tm.MarshalText()
-		w.s = string(buf)
-		return err
-	}
-	switch w.v.Kind() {
-	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-		w.s = strconv.FormatInt(w.v.Int(), 10)
-		return nil
-	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
-		w.s = strconv.FormatUint(w.v.Uint(), 10)
-		return nil
-	}
-	return &json.UnsupportedTypeError{Type: w.v.Type()}
+func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
+	iter := encoder.mapType.UnsafeIterate(ptr)
+	return !iter.HasNext()
 }
 
-func (sv stringValues) Len() int           { return len(sv) }
-func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
-func (sv stringValues) Less(i, j int) bool { return sv[i].s < sv[j].s }
+type encodedKeyValues []encodedKV
 
-func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool {
-	mapInterface := encoder.mapInterface
-	mapInterface.word = ptr
-	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
-	realVal := reflect.ValueOf(*realInterface)
-	return realVal.Len() == 0
+type encodedKV struct {
+	key      string
+	keyValue []byte
 }
+
+func (sv encodedKeyValues) Len() int           { return len(sv) }
+func (sv encodedKeyValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
+func (sv encodedKeyValues) Less(i, j int) bool { return sv[i].key < sv[j].key }

+ 6 - 2
feature_reflect_marshaler.go

@@ -55,6 +55,7 @@ func (encoder *directMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 
 type textMarshalerEncoder struct {
 	valType			reflect2.Type
+	stringEncoder   ValEncoder
 	checkIsEmpty      checkIsEmpty
 }
 
@@ -69,7 +70,8 @@ func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream)
 	if err != nil {
 		stream.Error = err
 	} else {
-		stream.WriteString(string(bytes))
+		str := string(bytes)
+		encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
 	}
 }
 
@@ -78,6 +80,7 @@ func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 }
 
 type directTextMarshalerEncoder struct {
+	stringEncoder ValEncoder
 	checkIsEmpty checkIsEmpty
 }
 
@@ -91,7 +94,8 @@ func (encoder *directTextMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *St
 	if err != nil {
 		stream.Error = err
 	} else {
-		stream.WriteString(string(bytes))
+		str := string(bytes)
+		encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
 	}
 }
 

+ 1 - 1
type_tests/slice_test.go

@@ -74,7 +74,7 @@ func init() {
 		(*[]jsonMarshaler)(nil),
 		(*[]jsonMarshalerMap)(nil),
 		(*[]textMarshaler)(nil),
-		(*[]textMarshalerMap)(nil),
+		selectedSymmetricCase{(*[]textMarshalerMap)(nil)},
 	)
 }
 

+ 24 - 0
value_tests/marshaler_test.go

@@ -8,11 +8,23 @@ import (
 func init() {
 	jsonMarshaler := json.Marshaler(fakeJsonMarshaler{})
 	textMarshaler := encoding.TextMarshaler(fakeTextMarshaler{})
+	textMarshaler2 := encoding.TextMarshaler(&fakeTextMarshaler2{})
 	marshalCases = append(marshalCases,
 		fakeJsonMarshaler{},
 		&jsonMarshaler,
 		fakeTextMarshaler{},
 		&textMarshaler,
+		fakeTextMarshaler2{},
+		&textMarshaler2,
+		map[fakeTextMarshaler]int{
+			fakeTextMarshaler{}: 100,
+		},
+		map[*fakeTextMarshaler]int{
+			&fakeTextMarshaler{}: 100,
+		},
+		map[encoding.TextMarshaler]int{
+			textMarshaler: 100,
+		},
 	)
 }
 
@@ -40,3 +52,15 @@ func (q fakeTextMarshaler) MarshalText() ([]byte, error) {
 func (q *fakeTextMarshaler) UnmarshalText(value []byte) error {
 	return nil
 }
+
+type fakeTextMarshaler2 struct {
+	Field2 int
+}
+
+func (q *fakeTextMarshaler2) MarshalText() ([]byte, error) {
+	return []byte(`"abc"`), nil
+}
+
+func (q *fakeTextMarshaler2) UnmarshalText(value []byte) error {
+	return nil
+}

+ 2 - 2
value_tests/value_test.go

@@ -54,9 +54,9 @@ func Test_marshal(t *testing.T) {
 		t.Run(name, func(t *testing.T) {
 			should := require.New(t)
 			output1, err1 := json.Marshal(testCase)
-			should.NoError(err1)
+			should.NoError(err1, "json")
 			output2, err2 := jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(testCase)
-			should.NoError(err2)
+			should.NoError(err2, "jsoniter")
 			should.Equal(string(output1), string(output2))
 		})
 	}