Tao Wen преди 8 години
родител
ревизия
8f6a840c63
променени са 4 файла, в които са добавени 169 реда и са изтрити 67 реда
  1. 74 46
      feature_reflect_extension.go
  2. 21 20
      feature_reflect_object.go
  3. 1 1
      jsoniter_customize_test.go
  4. 73 0
      jsoniter_object_test.go

+ 74 - 46
feature_reflect_extension.go

@@ -16,14 +16,22 @@ var extensions = []Extension{}
 
 type StructDescriptor struct {
 	Type   reflect.Type
-	Fields map[string]*Binding
+	Fields []*Binding
+}
+
+func (structDescriptor *StructDescriptor) GetField(fieldName string) *Binding {
+	for _, binding := range structDescriptor.Fields {
+		if binding.Field.Name == fieldName {
+			return binding
+		}
+	}
+	return nil
 }
 
 type Binding struct {
 	Field           *reflect.StructField
 	FromNames       []string
 	ToNames         []string
-	ShouldOmitEmpty bool
 	Encoder         ValEncoder
 	Decoder         ValDecoder
 }
@@ -131,47 +139,75 @@ func getTypeEncoderFromExtension(typ reflect.Type) ValEncoder {
 }
 
 func describeStruct(cfg *frozenConfig, typ reflect.Type) (*StructDescriptor, error) {
-	bindings := map[string]*Binding{}
-	for _, field := range listStructFields(typ) {
-		tagParts := strings.Split(field.Tag.Get("json"), ",")
-		fieldNames := calcFieldNames(field.Name, tagParts[0])
-		fieldCacheKey := fmt.Sprintf("%s/%s", typ.String(), field.Name)
-		decoder := fieldDecoders[fieldCacheKey]
-		if decoder == nil && len(fieldNames) > 0 {
-			var err error
-			decoder, err = decoderOfType(cfg, field.Type)
-			if err != nil {
-				return nil, err
+	bindings := []*Binding{}
+	for i := 0; i < typ.NumField(); i++ {
+		field := typ.Field(i)
+		if field.Anonymous {
+			if field.Type.Kind() == reflect.Struct {
+				structDescriptor, err := describeStruct(cfg, field.Type)
+				if err != nil {
+					return nil, err
+				}
+				for _, binding := range structDescriptor.Fields {
+					bindings = append(bindings, binding)
+				}
+			} else if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
+				structDescriptor, err := describeStruct(cfg, field.Type.Elem())
+				if err != nil {
+					return nil, err
+				}
+				for _, binding := range structDescriptor.Fields {
+					binding.Encoder = &optionalEncoder{binding.Encoder}
+					binding.Encoder = &structFieldEncoder{&field, binding.Encoder, false}
+					binding.Decoder = &optionalDecoder{field.Type, binding.Decoder}
+					binding.Decoder = &structFieldDecoder{&field, binding.Decoder}
+					bindings = append(bindings, binding)
+				}
 			}
-		}
-		encoder := fieldEncoders[fieldCacheKey]
-		if encoder == nil && len(fieldNames) > 0 {
-			var err error
-			encoder, err = encoderOfType(cfg, field.Type)
-			if err != nil {
-				return nil, err
+		} else {
+			tagParts := strings.Split(field.Tag.Get("json"), ",")
+			fieldNames := calcFieldNames(field.Name, tagParts[0])
+			fieldCacheKey := fmt.Sprintf("%s/%s", typ.String(), field.Name)
+			decoder := fieldDecoders[fieldCacheKey]
+			if decoder == nil && len(fieldNames) > 0 {
+				var err error
+				decoder, err = decoderOfType(cfg, field.Type)
+				if err != nil {
+					return nil, err
+				}
 			}
-			// map is stored as pointer in the struct
-			if field.Type.Kind() == reflect.Map {
-				encoder = &optionalEncoder{encoder}
+			encoder := fieldEncoders[fieldCacheKey]
+			if encoder == nil && len(fieldNames) > 0 {
+				var err error
+				encoder, err = encoderOfType(cfg, field.Type)
+				if err != nil {
+					return nil, err
+				}
+				// map is stored as pointer in the struct
+				if field.Type.Kind() == reflect.Map {
+					encoder = &optionalEncoder{encoder}
+				}
 			}
-		}
-		binding := &Binding{
-			Field:     field,
-			FromNames: fieldNames,
-			ToNames:   fieldNames,
-			Decoder:   decoder,
-			Encoder:   encoder,
-		}
-		for _, tagPart := range tagParts[1:] {
-			if tagPart == "omitempty" {
-				binding.ShouldOmitEmpty = true
-			} else if tagPart == "string" {
-				binding.Decoder = &stringModeDecoder{binding.Decoder}
-				binding.Encoder = &stringModeEncoder{binding.Encoder}
+			binding := &Binding{
+				Field:     &field,
+				FromNames: fieldNames,
+				ToNames:   fieldNames,
+				Decoder:   decoder,
+				Encoder:   encoder,
+			}
+			shouldOmitEmpty := false
+			for _, tagPart := range tagParts[1:] {
+				if tagPart == "omitempty" {
+					shouldOmitEmpty = true
+				} else if tagPart == "string" {
+					binding.Decoder = &stringModeDecoder{binding.Decoder}
+					binding.Encoder = &stringModeEncoder{binding.Encoder}
+				}
 			}
+			binding.Decoder = &structFieldDecoder{&field, binding.Decoder}
+			binding.Encoder = &structFieldEncoder{&field, binding.Encoder, shouldOmitEmpty}
+			bindings = append(bindings, binding)
 		}
-		bindings[field.Name] = binding
 	}
 	structDescriptor := &StructDescriptor{
 		Type:   typ,
@@ -185,14 +221,6 @@ func describeStruct(cfg *frozenConfig, typ reflect.Type) (*StructDescriptor, err
 
 func listStructFields(typ reflect.Type) []*reflect.StructField {
 	fields := []*reflect.StructField{}
-	for i := 0; i < typ.NumField(); i++ {
-		field := typ.Field(i)
-		if field.Anonymous {
-			fields = append(fields, listStructFields(field.Type)...)
-		} else {
-			fields = append(fields, &field)
-		}
-	}
 	return fields
 }
 

+ 21 - 20
feature_reflect_object.go

@@ -8,24 +8,20 @@ import (
 )
 
 func encoderOfStruct(cfg *frozenConfig, typ reflect.Type) (ValEncoder, error) {
-	structEncoder_ := &structEncoder{}
 	fields := map[string]*structFieldEncoder{}
 	structDescriptor, err := describeStruct(cfg, typ)
 	if err != nil {
 		return nil, err
 	}
 	for _, binding := range structDescriptor.Fields {
-		for _, fieldName := range binding.ToNames {
-			fields[fieldName] = &structFieldEncoder{binding.Field, fieldName, binding.Encoder, binding.ShouldOmitEmpty}
+		for _, toName := range binding.ToNames {
+			fields[toName] = binding.Encoder.(*structFieldEncoder)
 		}
 	}
 	if len(fields) == 0 {
 		return &emptyStructEncoder{}, nil
 	}
-	for _, field := range fields {
-		structEncoder_.fields = append(structEncoder_.fields, field)
-	}
-	return structEncoder_, nil
+	return &structEncoder{fields}, nil
 }
 
 func decoderOfStruct(cfg *frozenConfig, typ reflect.Type) (ValDecoder, error) {
@@ -35,8 +31,8 @@ func decoderOfStruct(cfg *frozenConfig, typ reflect.Type) (ValDecoder, error) {
 		return nil, err
 	}
 	for _, binding := range structDescriptor.Fields {
-		for _, fieldName := range binding.FromNames {
-			fields[fieldName] = &structFieldDecoder{binding.Field, binding.Decoder}
+		for _, fromName := range binding.FromNames {
+			fields[fromName] = binding.Decoder.(*structFieldDecoder)
 		}
 	}
 	return createStructDecoder(typ, fields)
@@ -959,14 +955,12 @@ func (decoder *structFieldDecoder) decode(ptr unsafe.Pointer, iter *Iterator) {
 
 type structFieldEncoder struct {
 	field        *reflect.StructField
-	fieldName    string
 	fieldEncoder ValEncoder
 	omitempty    bool
 }
 
 func (encoder *structFieldEncoder) encode(ptr unsafe.Pointer, stream *Stream) {
 	fieldPtr := uintptr(ptr) + encoder.field.Offset
-	stream.WriteObjectField(encoder.fieldName)
 	encoder.fieldEncoder.encode(unsafe.Pointer(fieldPtr), stream)
 	if stream.Error != nil && stream.Error != io.EOF {
 		stream.Error = fmt.Errorf("%s: %s", encoder.field.Name, stream.Error.Error())
@@ -983,19 +977,20 @@ func (encoder *structFieldEncoder) isEmpty(ptr unsafe.Pointer) bool {
 }
 
 type structEncoder struct {
-	fields []*structFieldEncoder
+	fields map[string]*structFieldEncoder
 }
 
 func (encoder *structEncoder) encode(ptr unsafe.Pointer, stream *Stream) {
 	stream.WriteObjectStart()
 	isNotFirst := false
-	for _, field := range encoder.fields {
+	for fieldName, field := range encoder.fields {
 		if field.omitempty && field.isEmpty(ptr) {
 			continue
 		}
 		if isNotFirst {
 			stream.WriteMore()
 		}
+		stream.WriteObjectField(fieldName)
 		field.encode(ptr, stream)
 		isNotFirst = true
 	}
@@ -1006,17 +1001,23 @@ func (encoder *structEncoder) encodeInterface(val interface{}, stream *Stream) {
 	var encoderToUse ValEncoder
 	encoderToUse = encoder
 	if len(encoder.fields) == 1 {
-		firstEncoder := encoder.fields[0].fieldEncoder
+		var firstField *structFieldEncoder
+		var firstFieldName string
+		for fieldName, field := range encoder.fields {
+			firstFieldName = fieldName
+			firstField = field
+		}
+		firstEncoder := firstField.fieldEncoder
 		firstEncoderName := reflect.TypeOf(firstEncoder).String()
 		// interface{} has inline optimization for this case
 		if firstEncoderName == "*jsoniter.optionalEncoder" {
 			encoderToUse = &structEncoder{
-				fields: []*structFieldEncoder{{
-					field:        encoder.fields[0].field,
-					fieldName:    encoder.fields[0].fieldName,
-					fieldEncoder: firstEncoder.(*optionalEncoder).valueEncoder,
-					omitempty:    encoder.fields[0].omitempty,
-				}},
+				fields: map[string]*structFieldEncoder{
+					firstFieldName: {
+						field:        firstField.field,
+						fieldEncoder: firstEncoder.(*optionalEncoder).valueEncoder,
+						omitempty:    firstField.omitempty,
+					}},
 			}
 		}
 	}

+ 1 - 1
jsoniter_customize_test.go

@@ -93,7 +93,7 @@ func (extension *testExtension) UpdateStructDescriptor(structDescriptor *StructD
 	if structDescriptor.Type.String() != "jsoniter.TestObject1" {
 		return
 	}
-	binding := structDescriptor.Fields["field1"]
+	binding := structDescriptor.GetField("field1")
 	binding.Encoder = &funcEncoder{fun: func(ptr unsafe.Pointer, stream *Stream) {
 		str := *((*string)(ptr))
 		val, _ := strconv.Atoi(str)

+ 73 - 0
jsoniter_object_test.go

@@ -323,6 +323,79 @@ func Test_decode_anonymous_struct(t *testing.T) {
 	should.Equal("value", outer.Key)
 }
 
+func Test_multiple_level_anonymous_struct(t *testing.T) {
+	type Level1 struct {
+		Field1 string
+	}
+	type Level2 struct {
+		Level1
+		Field2 string
+	}
+	type Level3 struct {
+		Level2
+		Field3 string
+	}
+	should := require.New(t)
+	output, err := MarshalToString(Level3{Level2{Level1{"1"}, "2"}, "3"})
+	should.Nil(err)
+	should.Contains(output, `"Field1":"1"`)
+	should.Contains(output, `"Field2":"2"`)
+	should.Contains(output, `"Field3":"3"`)
+}
+
+func Test_multiple_level_anonymous_struct_with_ptr(t *testing.T) {
+	type Level1 struct {
+		Field1 string
+		Field2 string
+		Field4 string
+	}
+	type Level2 struct {
+		*Level1
+		Field2 string
+		Field3 string
+	}
+	type Level3 struct {
+		*Level2
+		Field3 string
+	}
+	should := require.New(t)
+	output, err := MarshalToString(Level3{&Level2{&Level1{"1", "", "4"}, "2", ""}, "3"})
+	should.Nil(err)
+	should.Contains(output, `"Field1":"1"`)
+	should.Contains(output, `"Field2":"2"`)
+	should.Contains(output, `"Field3":"3"`)
+	should.Contains(output, `"Field4":"4"`)
+}
+
+
+
+func Test_shadow_struct_field(t *testing.T) {
+	should := require.New(t)
+	type omit *struct{}
+	type CacheItem struct {
+		Key    string `json:"key"`
+		MaxAge int    `json:"cacheAge"`
+	}
+	output, err := MarshalToString(struct {
+		*CacheItem
+
+		// Omit bad keys
+		OmitMaxAge omit `json:"cacheAge,omitempty"`
+
+		// Add nice keys
+		MaxAge int    `json:"max_age"`
+	}{
+		CacheItem: &CacheItem{
+			Key:    "value",
+			MaxAge: 100,
+		},
+		MaxAge: 20,
+	})
+	should.Nil(err)
+	should.Contains(output, `"key":"value"`)
+	should.Contains(output, `"max_age":20`)
+}
+
 func Test_decode_nested(t *testing.T) {
 	type StructOfString struct {
 		Field1 string