Parcourir la source

#87 fix embedded field sorting order

Tao Wen il y a 8 ans
Parent
commit
a3a2d1cd25
2 fichiers modifiés avec 56 ajouts et 27 suppressions
  1. 34 17
      feature_reflect_extension.go
  2. 22 10
      jsoniter_object_test.go

+ 34 - 17
feature_reflect_extension.go

@@ -6,6 +6,7 @@ import (
 	"strings"
 	"unicode"
 	"unsafe"
+	"sort"
 )
 
 var typeDecoders = map[string]ValDecoder{}
@@ -195,8 +196,7 @@ func _getTypeEncoderFromExtension(typ reflect.Type) ValEncoder {
 }
 
 func describeStruct(cfg *frozenConfig, typ reflect.Type) (*StructDescriptor, error) {
-	headAnonymousBindings := []*Binding{}
-	tailAnonymousBindings := []*Binding{}
+	embeddedBindings := []*Binding{}
 	bindings := []*Binding{}
 	for i := 0; i < typ.NumField(); i++ {
 		field := typ.Field(i)
@@ -210,11 +210,7 @@ func describeStruct(cfg *frozenConfig, typ reflect.Type) (*StructDescriptor, err
 					binding.levels = append([]int{i}, binding.levels...)
 					binding.Encoder = &structFieldEncoder{&field, binding.Encoder, false}
 					binding.Decoder = &structFieldDecoder{&field, binding.Decoder}
-					if field.Offset == 0 {
-						headAnonymousBindings = append(headAnonymousBindings, binding)
-					} else {
-						tailAnonymousBindings = append(tailAnonymousBindings, binding)
-					}
+					embeddedBindings = append(embeddedBindings, binding)
 				}
 				continue
 			} else if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
@@ -228,11 +224,7 @@ func describeStruct(cfg *frozenConfig, typ reflect.Type) (*StructDescriptor, err
 					binding.Encoder = &structFieldEncoder{&field, binding.Encoder, false}
 					binding.Decoder = &deferenceDecoder{field.Type.Elem(), binding.Decoder}
 					binding.Decoder = &structFieldDecoder{&field, binding.Decoder}
-					if field.Offset == 0 {
-						headAnonymousBindings = append(headAnonymousBindings, binding)
-					} else {
-						tailAnonymousBindings = append(tailAnonymousBindings, binding)
-					}
+					embeddedBindings = append(embeddedBindings, binding)
 				}
 				continue
 			}
@@ -270,9 +262,9 @@ func describeStruct(cfg *frozenConfig, typ reflect.Type) (*StructDescriptor, err
 		binding.levels = []int{i}
 		bindings = append(bindings, binding)
 	}
-	return createStructDescriptor(cfg, typ, bindings, headAnonymousBindings, tailAnonymousBindings), nil
+	return createStructDescriptor(cfg, typ, bindings, embeddedBindings), nil
 }
-func createStructDescriptor(cfg *frozenConfig, typ reflect.Type, bindings []*Binding, headAnonymousBindings []*Binding, tailAnonymousBindings []*Binding) *StructDescriptor {
+func createStructDescriptor(cfg *frozenConfig, typ reflect.Type, bindings []*Binding, embeddedBindings []*Binding) *StructDescriptor {
 	onePtrEmbedded := false
 	onePtrOptimization := false
 	if typ.NumField() == 1 {
@@ -297,12 +289,37 @@ func createStructDescriptor(cfg *frozenConfig, typ reflect.Type, bindings []*Bin
 		extension.UpdateStructDescriptor(structDescriptor)
 	}
 	processTags(structDescriptor, cfg)
-	// insert anonymous bindings to the head
-	structDescriptor.Fields = append(headAnonymousBindings, structDescriptor.Fields...)
-	structDescriptor.Fields = append(structDescriptor.Fields, tailAnonymousBindings...)
+	// merge normal & embedded bindings & sort with original order
+	allBindings := sortableBindings(append(embeddedBindings, structDescriptor.Fields...))
+	sort.Sort(allBindings)
+	structDescriptor.Fields = allBindings
 	return structDescriptor
 }
 
+type sortableBindings []*Binding
+
+func (bindings sortableBindings) Len() int {
+	return len(bindings)
+}
+
+func (bindings sortableBindings) Less(i, j int) bool {
+	left := bindings[i].levels
+	right := bindings[j].levels
+	k := 0
+	for {
+		if left[k] < right[k] {
+			return true
+		} else if left[k] > right[k] {
+			return false
+		}
+		k++
+	}
+}
+
+func (bindings sortableBindings) Swap(i, j int) {
+	bindings[i], bindings[j] = bindings[j], bindings[i]
+}
+
 func processTags(structDescriptor *StructDescriptor, cfg *frozenConfig) {
 	for _, binding := range structDescriptor.Fields {
 		shouldOmitEmpty := false

+ 22 - 10
jsoniter_object_test.go

@@ -84,7 +84,7 @@ func Test_write_object(t *testing.T) {
 	stream.WriteObjectEnd()
 	stream.Flush()
 	should.Nil(stream.Error)
-	should.Equal("{\n  \"hello\":1,\n  \"world\":2\n}", buf.String())
+	should.Equal("{\n  \"hello\": 1,\n  \"world\": 2\n}", buf.String())
 }
 
 func Test_decode_one_field_struct(t *testing.T) {
@@ -381,21 +381,33 @@ func Test_shadow_struct_field(t *testing.T) {
 	should.Contains(output, `"max_age":20`)
 }
 
-func Test_embed_at_last(t *testing.T) {
-	type Base struct {
-		Type string `json:"type"`
+func Test_embeded_order(t *testing.T) {
+	type A struct {
+		Field2 string
+	}
+
+	type C struct {
+		Field5 string
+	}
+
+	type B struct {
+		Field4 string
+		C
+		Field6 string
 	}
 
-	type Struct struct {
-		Field     string `json:"field"`
-		FieldType string `json:"field_type"`
-		Base
+	type TestObject struct {
+		Field1 string
+		A
+		Field3 string
+		B
+		Field7 string
 	}
 	should := require.New(t)
-	s := Struct{Field: "field", FieldType: "field_type", Base: Base{"type"}}
+	s := TestObject{}
 	output, err := MarshalToString(s)
 	should.Nil(err)
-	should.Equal(`{"field":"field","field_type":"field_type","type":"type"}`, output)
+	should.Equal(`{"Field1":"","Field2":"","Field3":"","Field4":"","Field5":"","Field6":"","Field7":""}`, output)
 }
 
 func Test_decode_nested(t *testing.T) {