Преглед изворни кода

support TextMarshaler as map key

Tao Wen пре 8 година
родитељ
комит
3979955e69
3 измењених фајлова са 58 додато и 32 уклоњено
  1. 3 0
      feature_reflect.go
  2. 36 32
      feature_reflect_map.go
  3. 19 0
      jsoniter_map_test.go

+ 3 - 0
feature_reflect.go

@@ -6,6 +6,7 @@ import (
 	"sync/atomic"
 	"unsafe"
 	"encoding/json"
+	"encoding"
 )
 
 /*
@@ -77,6 +78,7 @@ var jsonRawMessageType reflect.Type
 var anyType reflect.Type
 var marshalerType reflect.Type
 var unmarshalerType reflect.Type
+var textUnmarshalerType reflect.Type
 
 func init() {
 	typeDecoders = map[string]Decoder{}
@@ -91,6 +93,7 @@ func init() {
 	anyType = reflect.TypeOf((*Any)(nil)).Elem()
 	marshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
 	unmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
+	textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
 }
 
 func addDecoderToCache(cacheKey reflect.Type, decoder Decoder) {

+ 36 - 32
feature_reflect_map.go

@@ -25,43 +25,47 @@ func (decoder *mapDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 	if realVal.IsNil() {
 		realVal.Set(reflect.MakeMap(realVal.Type()))
 	}
-	iter.ReadObjectCB(func(iter *Iterator, keyStr string) bool{
+	iter.ReadObjectCB(func(iter *Iterator, keyStr string) bool {
 		elem := reflect.New(decoder.elemType)
 		decoder.elemDecoder.decode(unsafe.Pointer(elem.Pointer()), iter)
 		// to put into map, we have to use reflection
-		realVal.SetMapIndex(decodeMapKey(iter, keyStr, decoder.keyType), elem.Elem())
-		return true
-	})
-}
-
-func decodeMapKey(iter *Iterator, keyStr string, keyType reflect.Type) reflect.Value {
-	switch {
-	case keyType.Kind() == reflect.String:
-		return reflect.ValueOf(keyStr)
-	//case reflect.PtrTo(kt).Implements(textUnmarshalerType):
-	//	kv = reflect.New(v.Type().Key())
-	//	d.literalStore(item, kv, true)
-	//	kv = kv.Elem()
-	default:
-		switch keyType.Kind() {
-		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-			n, err := strconv.ParseInt(keyStr, 10, 64)
-			if err != nil || reflect.Zero(keyType).OverflowInt(n) {
-				iter.reportError("read map key as int64", "read int64 failed")
-				return reflect.ValueOf("")
+		keyType := decoder.keyType
+		switch {
+		case keyType.Kind() == reflect.String:
+			realVal.SetMapIndex(reflect.ValueOf(keyStr), elem.Elem())
+			return true
+		case keyType.Implements(textUnmarshalerType):
+			textUnmarshaler := reflect.New(keyType.Elem()).Interface().(encoding.TextUnmarshaler)
+			err := textUnmarshaler.UnmarshalText([]byte(keyStr))
+			if err != nil {
+				iter.reportError("read map key as TextUnmarshaler", err.Error())
+				return false
 			}
-			return reflect.ValueOf(n).Convert(keyType)
-		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
-			n, err := strconv.ParseUint(keyStr, 10, 64)
-			if err != nil || reflect.Zero(keyType).OverflowUint(n) {
-				iter.reportError("read map key as uint64", "read uint64 failed")
-				return reflect.ValueOf("")
+			realVal.SetMapIndex(reflect.ValueOf(textUnmarshaler), elem.Elem())
+			return true
+		default:
+			switch keyType.Kind() {
+			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+				n, err := strconv.ParseInt(keyStr, 10, 64)
+				if err != nil || reflect.Zero(keyType).OverflowInt(n) {
+					iter.reportError("read map key as int64", "read int64 failed")
+					return false
+				}
+				realVal.SetMapIndex(reflect.ValueOf(n).Convert(keyType), elem.Elem())
+				return true
+			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+				n, err := strconv.ParseUint(keyStr, 10, 64)
+				if err != nil || reflect.Zero(keyType).OverflowUint(n) {
+					iter.reportError("read map key as uint64", "read uint64 failed")
+					return false
+				}
+				realVal.SetMapIndex(reflect.ValueOf(n).Convert(keyType), elem.Elem())
+				return true
 			}
-			return reflect.ValueOf(n).Convert(keyType)
 		}
-	}
-	iter.reportError("read map key", "json: Unexpected key type")
-	return reflect.ValueOf("")
+		iter.reportError("read map key", "unexpected map key type "+keyType.String())
+		return true
+	})
 }
 
 type mapEncoder struct {
@@ -131,4 +135,4 @@ func (encoder *mapEncoder) isEmpty(ptr unsafe.Pointer) bool {
 	realInterface := (*interface{})(unsafe.Pointer(&mapInterface))
 	realVal := reflect.ValueOf(*realInterface)
 	return realVal.Len() == 0
-}
+}

+ 19 - 0
jsoniter_map_test.go

@@ -3,6 +3,7 @@ package jsoniter
 import (
 	"testing"
 	"github.com/json-iterator/go/require"
+	"math/big"
 )
 
 func Test_read_map(t *testing.T) {
@@ -81,4 +82,22 @@ func Test_decode_int_key_map(t *testing.T) {
 	var val map[int]string
 	should.Nil(UnmarshalFromString(`{"1":"2"}`, &val))
 	should.Equal(map[int]string{1: "2"}, val)
+}
+
+func Test_encode_TextMarshaler_key_map(t *testing.T) {
+	should := require.New(t)
+	f, _, _  := big.ParseFloat("1", 10, 64, big.ToZero)
+	val := map[*big.Float]string{f: "2"}
+	str, err := MarshalToString(val)
+	should.Nil(err)
+	should.Equal(`{"1":"2"}`, str)
+}
+
+func Test_decode_TextMarshaler_key_map(t *testing.T) {
+	should := require.New(t)
+	var val map[*big.Float]string
+	should.Nil(UnmarshalFromString(`{"1":"2"}`, &val))
+	str, err := MarshalToString(val)
+	should.Nil(err)
+	should.Equal(`{"1":"2"}`, str)
 }