浏览代码

support int/string encode

Tao Wen 9 年之前
父节点
当前提交
5b0609f901

+ 2 - 2
feature_adapter.go

@@ -9,7 +9,7 @@ import (
 // Unmarshal adapts to json/encoding APIs
 func Unmarshal(data []byte, v interface{}) error {
 	iter := ParseBytes(data)
-	iter.Read(v)
+	iter.ReadVal(v)
 	if iter.Error == io.EOF {
 		return nil
 	}
@@ -20,7 +20,7 @@ func UnmarshalFromString(str string, v interface{}) error {
 	// safe to do the unsafe cast here, as str is always referenced in this scope
 	data := *(*[]byte)(unsafe.Pointer(&str))
 	iter := ParseBytes(data)
-	iter.Read(v)
+	iter.ReadVal(v)
 	if iter.Error == io.EOF {
 		return nil
 	}

+ 72 - 24
feature_reflect.go

@@ -1,7 +1,6 @@
 package jsoniter
 
 import (
-	"errors"
 	"fmt"
 	"io"
 	"reflect"
@@ -19,10 +18,12 @@ Reflection on value is avoided as we can, as the reflect.Value itself will alloc
 For a simple struct binding, it will be reflect.Value free and allocation free
 */
 
-// Decoder works like a father class for sub-type decoders
 type Decoder interface {
 	decode(ptr unsafe.Pointer, iter *Iterator)
 }
+type Encoder interface {
+	encode(ptr unsafe.Pointer, stream *Stream)
+}
 
 type DecoderFunc func(ptr unsafe.Pointer, iter *Iterator)
 type ExtensionFunc func(typ reflect.Type, field *reflect.StructField) ([]string, DecoderFunc)
@@ -36,6 +37,7 @@ func (decoder *funcDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 }
 
 var DECODERS unsafe.Pointer
+var ENCODERS unsafe.Pointer
 
 var typeDecoders map[string]Decoder
 var fieldDecoders map[string]Decoder
@@ -46,19 +48,34 @@ func init() {
 	fieldDecoders = map[string]Decoder{}
 	extensions = []ExtensionFunc{}
 	atomic.StorePointer(&DECODERS, unsafe.Pointer(&map[string]Decoder{}))
+	atomic.StorePointer(&ENCODERS, unsafe.Pointer(&map[string]Encoder{}))
 }
 
 func addDecoderToCache(cacheKey reflect.Type, decoder Decoder) {
-	retry := true
-	for retry {
+	done := false
+	for !done {
 		ptr := atomic.LoadPointer(&DECODERS)
 		cache := *(*map[reflect.Type]Decoder)(ptr)
-		copy := map[reflect.Type]Decoder{}
+		copied := map[reflect.Type]Decoder{}
 		for k, v := range cache {
-			copy[k] = v
+			copied[k] = v
 		}
-		copy[cacheKey] = decoder
-		retry = !atomic.CompareAndSwapPointer(&DECODERS, ptr, unsafe.Pointer(&copy))
+		copied[cacheKey] = decoder
+		done = atomic.CompareAndSwapPointer(&DECODERS, ptr, unsafe.Pointer(&copied))
+	}
+}
+
+func addEncoderToCache(cacheKey reflect.Type, encoder Encoder) {
+	done := false
+	for !done {
+		ptr := atomic.LoadPointer(&ENCODERS)
+		cache := *(*map[reflect.Type]Encoder)(ptr)
+		copied := map[reflect.Type]Encoder{}
+		for k, v := range cache {
+			copied[k] = v
+		}
+		copied[cacheKey] = encoder
+		done = atomic.CompareAndSwapPointer(&ENCODERS, ptr, unsafe.Pointer(&copied))
 	}
 }
 
@@ -68,6 +85,12 @@ func getDecoderFromCache(cacheKey reflect.Type) Decoder {
 	return cache[cacheKey]
 }
 
+func getEncoderFromCache(cacheKey reflect.Type) Encoder {
+	ptr := atomic.LoadPointer(&ENCODERS)
+	cache := *(*map[reflect.Type]Encoder)(ptr)
+	return cache[cacheKey]
+}
+
 // RegisterTypeDecoder can register a type for json object
 func RegisterTypeDecoder(typ string, fun DecoderFunc) {
 	typeDecoders[typ] = &funcDecoder{fun}
@@ -241,12 +264,12 @@ func (iter *Iterator) readNumber() (ret *Any) {
 }
 
 // Read converts an Iterator instance into go interface, same as json.Unmarshal
-func (iter *Iterator) Read(obj interface{}) {
+func (iter *Iterator) ReadVal(obj interface{}) {
 	typ := reflect.TypeOf(obj)
 	cacheKey := typ.Elem()
 	cachedDecoder := getDecoderFromCache(cacheKey)
 	if cachedDecoder == nil {
-		decoder, err := decoderOfType(typ)
+		decoder, err := decoderOfType(cacheKey)
 		if err != nil {
 			iter.Error = err
 			return
@@ -258,6 +281,27 @@ func (iter *Iterator) Read(obj interface{}) {
 	cachedDecoder.decode(e.word, iter)
 }
 
+
+func (stream *Stream) WriteVal(val interface{}) {
+	typ := reflect.TypeOf(val)
+	cacheKey := typ
+	if typ.Kind() == reflect.Ptr {
+		cacheKey = typ.Elem()
+	}
+	cachedEncoder := getEncoderFromCache(cacheKey)
+	if cachedEncoder == nil {
+		encoder, err := encoderOfType(cacheKey)
+		if err != nil {
+			stream.Error = err
+			return
+		}
+		cachedEncoder = encoder
+		addEncoderToCache(cacheKey, encoder)
+	}
+	e := (*emptyInterface)(unsafe.Pointer(&val))
+	cachedEncoder.encode(e.word, stream)
+}
+
 type prefix string
 
 func (p prefix) addTo(decoder Decoder, err error) (Decoder, error) {
@@ -268,15 +312,6 @@ func (p prefix) addTo(decoder Decoder, err error) (Decoder, error) {
 }
 
 func decoderOfType(typ reflect.Type) (Decoder, error) {
-	switch typ.Kind() {
-	case reflect.Ptr:
-		return prefix("ptr").addTo(decoderOfPtr(typ.Elem()))
-	default:
-		return nil, errors.New("expect ptr")
-	}
-}
-
-func decoderOfPtr(typ reflect.Type) (Decoder, error) {
 	typeName := typ.String()
 	if typeName == "jsoniter.Any" {
 		return &anyDecoder{}, nil
@@ -287,9 +322,9 @@ func decoderOfPtr(typ reflect.Type) (Decoder, error) {
 	}
 	switch typ.Kind() {
 	case reflect.String:
-		return &stringDecoder{}, nil
+		return &stringCodec{}, nil
 	case reflect.Int:
-		return &intDecoder{}, nil
+		return &intCodec{}, nil
 	case reflect.Int8:
 		return &int8Decoder{}, nil
 	case reflect.Int16:
@@ -329,8 +364,21 @@ func decoderOfPtr(typ reflect.Type) (Decoder, error) {
 	}
 }
 
+
+
+func encoderOfType(typ reflect.Type) (Encoder, error) {
+	switch typ.Kind() {
+	case reflect.String:
+		return &stringCodec{}, nil
+	case reflect.Int:
+		return &intCodec{}, nil
+	default:
+		return nil, fmt.Errorf("unsupported type: %v", typ)
+	}
+}
+
 func decoderOfOptional(typ reflect.Type) (Decoder, error) {
-	decoder, err := decoderOfPtr(typ)
+	decoder, err := decoderOfType(typ)
 	if err != nil {
 		return nil, err
 	}
@@ -338,7 +386,7 @@ func decoderOfOptional(typ reflect.Type) (Decoder, error) {
 }
 
 func decoderOfSlice(typ reflect.Type) (Decoder, error) {
-	decoder, err := decoderOfPtr(typ.Elem())
+	decoder, err := decoderOfType(typ.Elem())
 	if err != nil {
 		return nil, err
 	}
@@ -346,7 +394,7 @@ func decoderOfSlice(typ reflect.Type) (Decoder, error) {
 }
 
 func decoderOfMap(typ reflect.Type) (Decoder, error) {
-	decoder, err := decoderOfPtr(typ.Elem())
+	decoder, err := decoderOfType(typ.Elem())
 	if err != nil {
 		return nil, err
 	}

+ 12 - 4
feature_reflect_native.go

@@ -2,20 +2,28 @@ package jsoniter
 
 import "unsafe"
 
-type stringDecoder struct {
+type stringCodec struct {
 }
 
-func (decoder *stringDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
+func (codec *stringCodec) decode(ptr unsafe.Pointer, iter *Iterator) {
 	*((*string)(ptr)) = iter.ReadString()
 }
 
-type intDecoder struct {
+func (codec *stringCodec) encode(ptr unsafe.Pointer, stream *Stream) {
+	stream.WriteString(*((*string)(ptr)))
 }
 
-func (decoder *intDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
+type intCodec struct {
+}
+
+func (codec *intCodec) decode(ptr unsafe.Pointer, iter *Iterator) {
 	*((*int)(ptr)) = iter.ReadInt()
 }
 
+func (codec *intCodec) encode(ptr unsafe.Pointer, stream *Stream) {
+	stream.WriteInt(*((*int)(ptr)))
+}
+
 type int8Decoder struct {
 }
 

+ 1 - 1
feature_reflect_object.go

@@ -39,7 +39,7 @@ func decoderOfStruct(typ reflect.Type) (Decoder, error) {
 		}
 		if decoder == nil {
 			var err error
-			decoder, err = decoderOfPtr(field.Type)
+			decoder, err = decoderOfType(field.Type)
 			if err != nil {
 				return prefix(fmt.Sprintf("{%s}", field.Name)).addTo(decoder, err)
 			}

+ 3 - 3
jsoniter_demo_test.go

@@ -8,7 +8,7 @@ import (
 func Test_bind_api_demo(t *testing.T) {
 	iter := ParseString(`[0,1,2,3]`)
 	val := []int{}
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	fmt.Println(val[3])
 }
 
@@ -34,7 +34,7 @@ type ABC struct {
 func Test_deep_nested_any_api(t *testing.T) {
 	iter := ParseString(`{"a": {"b": {"c": "d"}}}`)
 	abc := &ABC{}
-	iter.Read(&abc)
+	iter.ReadVal(&abc)
 	fmt.Println(abc.a.Get("b", "c"))
 }
 
@@ -50,7 +50,7 @@ func Test_iterator_and_bind_api(t *testing.T) {
 	iter.ReadArray()
 	user.userID = iter.ReadInt()
 	iter.ReadArray()
-	iter.Read(&user)
+	iter.ReadVal(&user)
 	iter.ReadArray() // array end
 	fmt.Println(user)
 }

+ 21 - 0
jsoniter_int_test.go

@@ -260,6 +260,27 @@ func Test_write_int64(t *testing.T) {
 	should.Equal("a4294967295", buf.String())
 }
 
+func Test_write_val_int(t *testing.T) {
+	should := require.New(t)
+	buf := &bytes.Buffer{}
+	stream := NewStream(buf, 4096)
+	stream.WriteVal(1001)
+	stream.Flush()
+	should.Nil(stream.Error)
+	should.Equal("1001", buf.String())
+}
+
+func Test_write_val_int_ptr(t *testing.T) {
+	should := require.New(t)
+	buf := &bytes.Buffer{}
+	stream := NewStream(buf, 4096)
+	val := 1001
+	stream.WriteVal(&val)
+	stream.Flush()
+	should.Nil(stream.Error)
+	should.Equal("1001", buf.String())
+}
+
 func Benchmark_jsoniter_encode_int(b *testing.B) {
 	stream := NewStream(ioutil.Discard, 64)
 	for n := 0; n < b.N; n++ {

+ 3 - 3
jsoniter_map_test.go

@@ -9,7 +9,7 @@ import (
 func Test_read_map(t *testing.T) {
 	iter := ParseString(`{"hello": "world"}`)
 	m := map[string]string{"1": "2"}
-	iter.Read(&m)
+	iter.ReadVal(&m)
 	copy(iter.buf, []byte{0, 0, 0, 0, 0, 0})
 	if !reflect.DeepEqual(map[string]string{"1": "2", "hello": "world"}, m) {
 		fmt.Println(iter.Error)
@@ -20,7 +20,7 @@ func Test_read_map(t *testing.T) {
 func Test_read_map_of_interface(t *testing.T) {
 	iter := ParseString(`{"hello": "world"}`)
 	m := map[string]interface{}{"1": "2"}
-	iter.Read(&m)
+	iter.ReadVal(&m)
 	if !reflect.DeepEqual(map[string]interface{}{"1": "2", "hello": "world"}, m) {
 		fmt.Println(iter.Error)
 		t.Fatal(m)
@@ -30,7 +30,7 @@ func Test_read_map_of_interface(t *testing.T) {
 func Test_read_map_of_any(t *testing.T) {
 	iter := ParseString(`{"hello": "world"}`)
 	m := map[string]Any{"1": *MakeAny("2")}
-	iter.Read(&m)
+	iter.ReadVal(&m)
 	if !reflect.DeepEqual(map[string]Any{"1": *MakeAny("2"), "hello": *MakeAny("world")}, m) {
 		fmt.Println(iter.Error)
 		t.Fatal(m)

+ 16 - 16
jsoniter_reflect_native_test.go

@@ -8,7 +8,7 @@ import (
 func Test_reflect_str(t *testing.T) {
 	iter := ParseString(`"hello"`)
 	str := ""
-	iter.Read(&str)
+	iter.ReadVal(&str)
 	if str != "hello" {
 		fmt.Println(iter.Error)
 		t.Fatal(str)
@@ -18,7 +18,7 @@ func Test_reflect_str(t *testing.T) {
 func Test_reflect_ptr_str(t *testing.T) {
 	iter := ParseString(`"hello"`)
 	var str *string
-	iter.Read(&str)
+	iter.ReadVal(&str)
 	if *str != "hello" {
 		t.Fatal(str)
 	}
@@ -27,7 +27,7 @@ func Test_reflect_ptr_str(t *testing.T) {
 func Test_reflect_int(t *testing.T) {
 	iter := ParseString(`123`)
 	val := int(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -36,7 +36,7 @@ func Test_reflect_int(t *testing.T) {
 func Test_reflect_int8(t *testing.T) {
 	iter := ParseString(`123`)
 	val := int8(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -45,7 +45,7 @@ func Test_reflect_int8(t *testing.T) {
 func Test_reflect_int16(t *testing.T) {
 	iter := ParseString(`123`)
 	val := int16(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -54,7 +54,7 @@ func Test_reflect_int16(t *testing.T) {
 func Test_reflect_int32(t *testing.T) {
 	iter := ParseString(`123`)
 	val := int32(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -63,7 +63,7 @@ func Test_reflect_int32(t *testing.T) {
 func Test_reflect_int64(t *testing.T) {
 	iter := ParseString(`123`)
 	val := int64(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -72,7 +72,7 @@ func Test_reflect_int64(t *testing.T) {
 func Test_reflect_uint(t *testing.T) {
 	iter := ParseString(`123`)
 	val := uint(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -81,7 +81,7 @@ func Test_reflect_uint(t *testing.T) {
 func Test_reflect_uint8(t *testing.T) {
 	iter := ParseString(`123`)
 	val := uint8(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -90,7 +90,7 @@ func Test_reflect_uint8(t *testing.T) {
 func Test_reflect_uint16(t *testing.T) {
 	iter := ParseString(`123`)
 	val := uint16(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -99,7 +99,7 @@ func Test_reflect_uint16(t *testing.T) {
 func Test_reflect_uint32(t *testing.T) {
 	iter := ParseString(`123`)
 	val := uint32(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -108,7 +108,7 @@ func Test_reflect_uint32(t *testing.T) {
 func Test_reflect_uint64(t *testing.T) {
 	iter := ParseString(`123`)
 	val := uint64(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -117,7 +117,7 @@ func Test_reflect_uint64(t *testing.T) {
 func Test_reflect_byte(t *testing.T) {
 	iter := ParseString(`123`)
 	val := byte(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 123 {
 		t.Fatal(val)
 	}
@@ -126,7 +126,7 @@ func Test_reflect_byte(t *testing.T) {
 func Test_reflect_float32(t *testing.T) {
 	iter := ParseString(`1.23`)
 	val := float32(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 1.23 {
 		fmt.Println(iter.Error)
 		t.Fatal(val)
@@ -136,7 +136,7 @@ func Test_reflect_float32(t *testing.T) {
 func Test_reflect_float64(t *testing.T) {
 	iter := ParseString(`1.23`)
 	val := float64(0)
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != 1.23 {
 		fmt.Println(iter.Error)
 		t.Fatal(val)
@@ -146,7 +146,7 @@ func Test_reflect_float64(t *testing.T) {
 func Test_reflect_bool(t *testing.T) {
 	iter := ParseString(`true`)
 	val := false
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if val != true {
 		fmt.Println(iter.Error)
 		t.Fatal(val)

+ 3 - 3
jsoniter_reflect_test.go

@@ -29,7 +29,7 @@ func Test_decode_nested(t *testing.T) {
 	}
 	iter := ParseString(`[{"field1": "hello"}, null, {"field2": "world"}]`)
 	slice := []*StructOfString{}
-	iter.Read(&slice)
+	iter.ReadVal(&slice)
 	if len(slice) != 3 {
 		fmt.Println(iter.Error)
 		t.Fatal(len(slice))
@@ -55,7 +55,7 @@ func Test_decode_base64(t *testing.T) {
 		*((*[]byte)(ptr)) = iter.ReadBase64()
 	})
 	defer CleanDecoders()
-	iter.Read(&val)
+	iter.ReadVal(&val)
 	if "abc" != string(val) {
 		t.Fatal(string(val))
 	}
@@ -77,7 +77,7 @@ func Benchmark_jsoniter_reflect(b *testing.B) {
 	//input := []byte(`null`)
 	for n := 0; n < b.N; n++ {
 		iter.ResetBytes(input)
-		iter.Read(&Struct)
+		iter.ReadVal(&Struct)
 	}
 }
 

+ 11 - 1
jsoniter_string_test.go

@@ -92,7 +92,17 @@ func Test_write_string(t *testing.T) {
 	stream.WriteString("hello")
 	stream.Flush()
 	should.Nil(stream.Error)
-	should.Equal("hello", buf.String())
+	should.Equal(`"hello"`, buf.String())
+}
+
+func Test_write_val_string(t *testing.T) {
+	should := require.New(t)
+	buf := &bytes.Buffer{}
+	stream := NewStream(buf, 4096)
+	stream.WriteVal("hello")
+	stream.Flush()
+	should.Nil(stream.Error)
+	should.Equal(`"hello"`, buf.String())
 }
 
 func Benchmark_jsoniter_unicode(b *testing.B) {

+ 0 - 3
stream.go

@@ -206,7 +206,4 @@ func (stream *Stream) writeIndention(delta int) {
 			stream.Flush()
 		}
 	}
-}
-
-func (stream *Stream) WriteVal(val interface{}) {
 }