extension_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package test
  2. import (
  3. "github.com/json-iterator/go"
  4. "github.com/modern-go/reflect2"
  5. "github.com/stretchr/testify/require"
  6. "reflect"
  7. "strconv"
  8. "testing"
  9. "unsafe"
  10. )
  11. type TestObject1 struct {
  12. Field1 string
  13. }
  14. type testExtension struct {
  15. jsoniter.DummyExtension
  16. }
  17. func (extension *testExtension) UpdateStructDescriptor(structDescriptor *jsoniter.StructDescriptor) {
  18. if structDescriptor.Type.String() != "test.TestObject1" {
  19. return
  20. }
  21. binding := structDescriptor.GetField("Field1")
  22. binding.Encoder = &funcEncoder{fun: func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
  23. str := *((*string)(ptr))
  24. val, _ := strconv.Atoi(str)
  25. stream.WriteInt(val)
  26. }}
  27. binding.Decoder = &funcDecoder{func(ptr unsafe.Pointer, iter *jsoniter.Iterator) {
  28. *((*string)(ptr)) = strconv.Itoa(iter.ReadInt())
  29. }}
  30. binding.ToNames = []string{"field-1"}
  31. binding.FromNames = []string{"field-1"}
  32. }
  33. func Test_customize_field_by_extension(t *testing.T) {
  34. should := require.New(t)
  35. cfg := jsoniter.Config{}.Froze()
  36. cfg.RegisterExtension(&testExtension{})
  37. obj := TestObject1{}
  38. err := cfg.UnmarshalFromString(`{"field-1": 100}`, &obj)
  39. should.Nil(err)
  40. should.Equal("100", obj.Field1)
  41. str, err := cfg.MarshalToString(obj)
  42. should.Nil(err)
  43. should.Equal(`{"field-1":100}`, str)
  44. }
  45. func Test_customize_map_key_encoder(t *testing.T) {
  46. should := require.New(t)
  47. cfg := jsoniter.Config{}.Froze()
  48. cfg.RegisterExtension(&testMapKeyExtension{})
  49. m := map[int]int{1: 2}
  50. output, err := cfg.MarshalToString(m)
  51. should.NoError(err)
  52. should.Equal(`{"2":2}`, output)
  53. m = map[int]int{}
  54. should.NoError(cfg.UnmarshalFromString(output, &m))
  55. should.Equal(map[int]int{1: 2}, m)
  56. }
  57. type testMapKeyExtension struct {
  58. jsoniter.DummyExtension
  59. }
  60. func (extension *testMapKeyExtension) CreateMapKeyEncoder(typ reflect2.Type) jsoniter.ValEncoder {
  61. if typ.Kind() == reflect.Int {
  62. return &funcEncoder{
  63. fun: func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
  64. stream.WriteRaw(`"`)
  65. stream.WriteInt(*(*int)(ptr) + 1)
  66. stream.WriteRaw(`"`)
  67. },
  68. }
  69. }
  70. return nil
  71. }
  72. func (extension *testMapKeyExtension) CreateMapKeyDecoder(typ reflect2.Type) jsoniter.ValDecoder {
  73. if typ.Kind() == reflect.Int {
  74. return &funcDecoder{
  75. fun: func(ptr unsafe.Pointer, iter *jsoniter.Iterator) {
  76. i, err := strconv.Atoi(iter.ReadString())
  77. if err != nil {
  78. iter.ReportError("read map key", err.Error())
  79. return
  80. }
  81. i--
  82. *(*int)(ptr) = i
  83. },
  84. }
  85. }
  86. return nil
  87. }
  88. type funcDecoder struct {
  89. fun jsoniter.DecoderFunc
  90. }
  91. func (decoder *funcDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) {
  92. decoder.fun(ptr, iter)
  93. }
  94. type funcEncoder struct {
  95. fun jsoniter.EncoderFunc
  96. isEmptyFunc func(ptr unsafe.Pointer) bool
  97. }
  98. func (encoder *funcEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) {
  99. encoder.fun(ptr, stream)
  100. }
  101. func (encoder *funcEncoder) IsEmpty(ptr unsafe.Pointer) bool {
  102. if encoder.isEmptyFunc == nil {
  103. return false
  104. }
  105. return encoder.isEmptyFunc(ptr)
  106. }