Pārlūkot izejas kodu

#139 unmarshal non base64 into []byte

Tao Wen 8 gadi atpakaļ
vecāks
revīzija
c15b4d116c
4 mainītis faili ar 42 papildinājumiem un 21 dzēšanām
  1. 6 2
      feature_reflect.go
  2. 23 17
      feature_reflect_native.go
  3. 12 1
      jsoniter_array_test.go
  4. 1 1
      jsoniter_large_file_test.go

+ 6 - 2
feature_reflect.go

@@ -281,7 +281,11 @@ func createDecoderOfType(cfg *frozenConfig, typ reflect.Type) (ValDecoder, error
 		return &jsonNumberCodec{}, nil
 	}
 	if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 {
-		return &base64Codec{}, nil
+		sliceDecoder, err := prefix("[slice]").addToDecoder(decoderOfSlice(cfg, typ))
+		if err != nil {
+			return nil, err
+		}
+		return &base64Codec{sliceDecoder: sliceDecoder}, nil
 	}
 	if typ.Implements(unmarshalerType) {
 		templateInterface := reflect.New(typ).Elem().Interface()
@@ -440,7 +444,7 @@ func createEncoderOfType(cfg *frozenConfig, typ reflect.Type) (ValEncoder, error
 		return &jsonNumberCodec{}, nil
 	}
 	if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 {
-		return &base64Codec{typ}, nil
+		return &base64Codec{}, nil
 	}
 	if typ.Implements(marshalerType) {
 		checkIsEmpty, err := createCheckIsEmpty(typ)

+ 23 - 17
feature_reflect_native.go

@@ -4,7 +4,6 @@ import (
 	"encoding"
 	"encoding/base64"
 	"encoding/json"
-	"reflect"
 	"unsafe"
 )
 
@@ -425,7 +424,7 @@ func (codec *jsoniterRawMessageCodec) IsEmpty(ptr unsafe.Pointer) bool {
 }
 
 type base64Codec struct {
-	actualType reflect.Type
+	sliceDecoder ValDecoder
 }
 
 func (codec *base64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
@@ -436,21 +435,28 @@ func (codec *base64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
 		ptrSlice.Data = nil
 		return
 	}
-	encoding := base64.StdEncoding
-	src := iter.SkipAndReturnBytes()
-	src = src[1 : len(src)-1]
-	decodedLen := encoding.DecodedLen(len(src))
-	dst := make([]byte, decodedLen)
-	len, err := encoding.Decode(dst, src)
-	if err != nil {
-		iter.ReportError("decode base64", err.Error())
-	} else {
-		dst = dst[:len]
-		dstSlice := (*sliceHeader)(unsafe.Pointer(&dst))
-		ptrSlice := (*sliceHeader)(ptr)
-		ptrSlice.Data = dstSlice.Data
-		ptrSlice.Cap = dstSlice.Cap
-		ptrSlice.Len = dstSlice.Len
+	switch iter.WhatIsNext() {
+	case String:
+		encoding := base64.StdEncoding
+		src := iter.SkipAndReturnBytes()
+		src = src[1 : len(src)-1]
+		decodedLen := encoding.DecodedLen(len(src))
+		dst := make([]byte, decodedLen)
+		len, err := encoding.Decode(dst, src)
+		if err != nil {
+			iter.ReportError("decode base64", err.Error())
+		} else {
+			dst = dst[:len]
+			dstSlice := (*sliceHeader)(unsafe.Pointer(&dst))
+			ptrSlice := (*sliceHeader)(ptr)
+			ptrSlice.Data = dstSlice.Data
+			ptrSlice.Cap = dstSlice.Cap
+			ptrSlice.Len = dstSlice.Len
+		}
+	case Array:
+		codec.sliceDecoder.Decode(ptr, iter)
+	default:
+		iter.ReportError("base64Codec", "invalid input")
 	}
 }
 

+ 12 - 1
jsoniter_array_test.go

@@ -156,7 +156,7 @@ func Test_encode_byte_array(t *testing.T) {
 	should.Equal(`"AQID"`, string(bytes))
 }
 
-func Test_decode_byte_array(t *testing.T) {
+func Test_decode_byte_array_from_base64(t *testing.T) {
 	should := require.New(t)
 	data := []byte{}
 	err := json.Unmarshal([]byte(`"AQID"`), &data)
@@ -167,6 +167,17 @@ func Test_decode_byte_array(t *testing.T) {
 	should.Equal([]byte{1, 2, 3}, data)
 }
 
+func Test_decode_byte_array_from_array(t *testing.T) {
+	should := require.New(t)
+	data := []byte{}
+	err := json.Unmarshal([]byte(`[1,2,3]`), &data)
+	should.Nil(err)
+	should.Equal([]byte{1, 2, 3}, data)
+	err = Unmarshal([]byte(`[1,2,3]`), &data)
+	should.Nil(err)
+	should.Equal([]byte{1, 2, 3}, data)
+}
+
 func Test_decode_slice(t *testing.T) {
 	should := require.New(t)
 	slice := make([]string, 0, 5)

+ 1 - 1
jsoniter_large_file_test.go

@@ -122,7 +122,7 @@ func init() {
 /*
 200000	      8886 ns/op	    4336 B/op	       6 allocs/op
 50000	     34244 ns/op	    6744 B/op	      14 allocs/op
- */
+*/
 func Benchmark_jsoniter_large_file(b *testing.B) {
 	b.ReportAllocs()
 	for n := 0; n < b.N; n++ {