Explorar o código

fix #242 add CreateMapKeyEncoder and CreateMapKeyDecoder to extension spi

Tao Wen %!s(int64=7) %!d(string=hai) anos
pai
achega
8d6662b81b
Modificáronse 3 ficheiros con 93 adicións e 0 borrados
  1. 49 0
      extension_tests/extension_test.go
  2. 32 0
      reflect_extension.go
  3. 12 0
      reflect_map.go

+ 49 - 0
extension_tests/extension_test.go

@@ -6,6 +6,8 @@ import (
 	"testing"
 	"github.com/stretchr/testify/require"
 	"github.com/json-iterator/go"
+	"github.com/v2pro/plz/reflect2"
+	"reflect"
 )
 
 type TestObject1 struct {
@@ -46,6 +48,53 @@ func Test_customize_field_by_extension(t *testing.T) {
 	should.Equal(`{"field-1":100}`, str)
 }
 
+func Test_customize_map_key_encoder(t *testing.T) {
+	should := require.New(t)
+	cfg := jsoniter.Config{}.Froze()
+	cfg.RegisterExtension(&testMapKeyExtension{})
+	m := map[int]int{1: 2}
+	output, err := cfg.MarshalToString(m)
+	should.NoError(err)
+	should.Equal(`{"2":2}`, output)
+	m = map[int]int{}
+	should.NoError(cfg.UnmarshalFromString(output, &m))
+	should.Equal(map[int]int{1: 2}, m)
+}
+
+type testMapKeyExtension struct {
+	jsoniter.DummyExtension
+}
+
+func (extension *testMapKeyExtension) CreateMapKeyEncoder(typ reflect2.Type) jsoniter.ValEncoder {
+	if typ.Kind() == reflect.Int {
+		return &funcEncoder{
+			fun: func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
+				stream.WriteRaw(`"`)
+				stream.WriteInt(*(*int)(ptr) + 1)
+				stream.WriteRaw(`"`)
+			},
+		}
+	}
+	return nil
+}
+
+func (extension *testMapKeyExtension) CreateMapKeyDecoder(typ reflect2.Type) jsoniter.ValDecoder {
+	if typ.Kind() == reflect.Int {
+		return &funcDecoder{
+			fun: func(ptr unsafe.Pointer, iter *jsoniter.Iterator) {
+				i, err := strconv.Atoi(iter.ReadString())
+				if err != nil {
+					iter.ReportError("read map key", err.Error())
+					return
+				}
+				i--
+				*(*int)(ptr) = i
+			},
+		}
+	}
+	return nil
+}
+
 type funcDecoder struct {
 	fun jsoniter.DecoderFunc
 }

+ 32 - 0
reflect_extension.go

@@ -47,6 +47,8 @@ type Binding struct {
 // Can also rename fields by UpdateStructDescriptor.
 type Extension interface {
 	UpdateStructDescriptor(structDescriptor *StructDescriptor)
+	CreateMapKeyDecoder(typ reflect2.Type) ValDecoder
+	CreateMapKeyEncoder(typ reflect2.Type) ValEncoder
 	CreateDecoder(typ reflect2.Type) ValDecoder
 	CreateEncoder(typ reflect2.Type) ValEncoder
 	DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder
@@ -61,6 +63,16 @@ type DummyExtension struct {
 func (extension *DummyExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) {
 }
 
+// CreateMapKeyDecoder No-op
+func (extension *DummyExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
+	return nil
+}
+
+// CreateMapKeyEncoder No-op
+func (extension *DummyExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
+	return nil
+}
+
 // CreateDecoder No-op
 func (extension *DummyExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
 	return nil
@@ -97,6 +109,16 @@ func (extension EncoderExtension) CreateEncoder(typ reflect2.Type) ValEncoder {
 	return extension[typ]
 }
 
+// CreateMapKeyDecoder No-op
+func (extension EncoderExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
+	return nil
+}
+
+// CreateMapKeyEncoder No-op
+func (extension EncoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
+	return nil
+}
+
 // DecorateDecoder No-op
 func (extension EncoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
 	return decoder
@@ -113,6 +135,16 @@ type DecoderExtension map[reflect2.Type]ValDecoder
 func (extension DecoderExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) {
 }
 
+// CreateMapKeyDecoder No-op
+func (extension DecoderExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
+	return nil
+}
+
+// CreateMapKeyEncoder No-op
+func (extension DecoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
+	return nil
+}
+
 // CreateDecoder get decoder from map
 func (extension DecoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
 	return extension[typ]

+ 12 - 0
reflect_map.go

@@ -38,6 +38,12 @@ func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder {
 }
 
 func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
+	for _, extension := range ctx.extensions {
+		decoder := extension.CreateMapKeyDecoder(typ)
+		if decoder != nil {
+			return decoder
+		}
+	}
 	switch typ.Kind() {
 	case reflect.String:
 		return decoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String))
@@ -70,6 +76,12 @@ func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
 }
 
 func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder {
+	for _, extension := range ctx.extensions {
+		encoder := extension.CreateMapKeyEncoder(typ)
+		if encoder != nil {
+			return encoder
+		}
+	}
 	switch typ.Kind() {
 	case reflect.String:
 		return encoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String))