Jelajahi Sumber

internal/impl: change Go representation of extension lists to []T

Change-Id: Iebcefe0330c8f858c7735f9362abfd87043ee39d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/192458
Reviewed-by: Joe Tsai <joetsai@google.com>
Damien Neil 6 tahun lalu
induk
melakukan
293dc761cb

+ 6 - 6
encoding/protojson/decode_test.go

@@ -1225,9 +1225,9 @@ func TestUnmarshal(t *testing.T) {
 }`,
 		wantMessage: func() proto.Message {
 			m := &pb2.Extensions{}
-			proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
-			proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
-			proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+			proto.SetExtension(m, pb2.E_RptExtEnum, []pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			proto.SetExtension(m, pb2.E_RptExtFixed32, []uint32{42, 47})
+			proto.SetExtension(m, pb2.E_RptExtNested, []*pb2.Nested{
 				&pb2.Nested{OptString: proto.String("one")},
 				&pb2.Nested{OptString: proto.String("two")},
 				&pb2.Nested{OptString: proto.String("three")},
@@ -1282,9 +1282,9 @@ func TestUnmarshal(t *testing.T) {
 				OptBool:   proto.Bool(true),
 				OptInt32:  proto.Int32(42),
 			}
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, []pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, []string{"hello", "world"})
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, []*pb2.Nested{
 				&pb2.Nested{OptString: proto.String("one")},
 				&pb2.Nested{OptString: proto.String("two")},
 				&pb2.Nested{OptString: proto.String("three")},

+ 6 - 6
encoding/protojson/encode_test.go

@@ -909,9 +909,9 @@ func TestMarshal(t *testing.T) {
 		desc: "extensions of repeated fields",
 		input: func() proto.Message {
 			m := &pb2.Extensions{}
-			proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
-			proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
-			proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+			proto.SetExtension(m, pb2.E_RptExtEnum, []pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			proto.SetExtension(m, pb2.E_RptExtFixed32, []uint32{42, 47})
+			proto.SetExtension(m, pb2.E_RptExtNested, []*pb2.Nested{
 				&pb2.Nested{OptString: proto.String("one")},
 				&pb2.Nested{OptString: proto.String("two")},
 				&pb2.Nested{OptString: proto.String("three")},
@@ -974,9 +974,9 @@ func TestMarshal(t *testing.T) {
 				OptBool:   proto.Bool(true),
 				OptInt32:  proto.Int32(42),
 			}
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, []pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, []string{"hello", "world"})
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, []*pb2.Nested{
 				&pb2.Nested{OptString: proto.String("one")},
 				&pb2.Nested{OptString: proto.String("two")},
 				&pb2.Nested{OptString: proto.String("three")},

+ 6 - 6
encoding/prototext/decode_test.go

@@ -1207,9 +1207,9 @@ opt_int32: 42
 `,
 		wantMessage: func() proto.Message {
 			m := &pb2.Extensions{}
-			proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
-			proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
-			proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+			proto.SetExtension(m, pb2.E_RptExtEnum, []pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			proto.SetExtension(m, pb2.E_RptExtFixed32, []uint32{42, 47})
+			proto.SetExtension(m, pb2.E_RptExtNested, []*pb2.Nested{
 				&pb2.Nested{OptString: proto.String("one")},
 				&pb2.Nested{OptString: proto.String("two")},
 				&pb2.Nested{OptString: proto.String("three")},
@@ -1269,9 +1269,9 @@ opt_int32: 42
 				OptBool:   proto.Bool(true),
 				OptInt32:  proto.Int32(42),
 			}
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, []pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, []string{"hello", "world"})
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, []*pb2.Nested{
 				&pb2.Nested{OptString: proto.String("one")},
 				&pb2.Nested{OptString: proto.String("two")},
 				&pb2.Nested{OptString: proto.String("three")},

+ 6 - 6
encoding/prototext/encode_test.go

@@ -969,9 +969,9 @@ opt_int32: 42
 		desc: "extensions of repeated fields",
 		input: func() proto.Message {
 			m := &pb2.Extensions{}
-			proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
-			proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
-			proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+			proto.SetExtension(m, pb2.E_RptExtEnum, []pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			proto.SetExtension(m, pb2.E_RptExtFixed32, []uint32{42, 47})
+			proto.SetExtension(m, pb2.E_RptExtNested, []*pb2.Nested{
 				&pb2.Nested{OptString: proto.String("one")},
 				&pb2.Nested{OptString: proto.String("two")},
 				&pb2.Nested{OptString: proto.String("three")},
@@ -1026,9 +1026,9 @@ opt_int32: 42
 				OptBool:   proto.Bool(true),
 				OptInt32:  proto.Int32(42),
 			}
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
-			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, []pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, []string{"hello", "world"})
+			proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, []*pb2.Nested{
 				&pb2.Nested{OptString: proto.String("one")},
 				&pb2.Nested{OptString: proto.String("two")},
 				&pb2.Nested{OptString: proto.String("three")},

+ 10 - 4
internal/impl/codec_tables.go

@@ -431,10 +431,13 @@ func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
 func encoderFuncsForValue(fd pref.FieldDescriptor, ft reflect.Type) valueCoderFuncs {
 	switch {
 	case fd.Cardinality() == pref.Repeated && !fd.IsPacked():
-		if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
+		if ft.Kind() == reflect.Ptr {
+			ft = ft.Elem()
+		}
+		if ft.Kind() != reflect.Slice {
 			break
 		}
-		ft := ft.Elem().Elem()
+		ft := ft.Elem()
 		switch fd.Kind() {
 		case pref.BoolKind:
 			if ft.Kind() == reflect.Bool {
@@ -512,10 +515,13 @@ func encoderFuncsForValue(fd pref.FieldDescriptor, ft reflect.Type) valueCoderFu
 			return coderGroupSliceValue
 		}
 	case fd.Cardinality() == pref.Repeated && fd.IsPacked():
-		if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
+		if ft.Kind() == reflect.Ptr {
+			ft = ft.Elem()
+		}
+		if ft.Kind() != reflect.Slice {
 			break
 		}
-		ft := ft.Elem().Elem()
+		ft := ft.Elem()
 		switch fd.Kind() {
 		case pref.BoolKind:
 			if ft.Kind() == reflect.Bool {

+ 55 - 10
internal/impl/convert_list.go

@@ -11,30 +11,75 @@ import (
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
+func newListConverter(t reflect.Type, fd pref.FieldDescriptor) Converter {
+	switch {
+	case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Slice:
+		return &listPtrConverter{t, newSingularConverter(t.Elem().Elem(), fd)}
+	case t.Kind() == reflect.Slice:
+		return &listConverter{t, newSingularConverter(t.Elem(), fd)}
+	}
+	panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
+}
+
 type listConverter struct {
 	goType reflect.Type
 	c      Converter
 }
 
-func newListConverter(t reflect.Type, fd pref.FieldDescriptor) Converter {
-	if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Slice {
-		panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
+func (c *listConverter) PBValueOf(v reflect.Value) pref.Value {
+	if v.Type() != c.goType {
+		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 	}
-	return &listConverter{t, newSingularConverter(t.Elem().Elem(), fd)}
+	pv := reflect.New(c.goType)
+	pv.Elem().Set(v)
+	return pref.ValueOf(&listReflect{pv, c.c})
 }
 
-func (c *listConverter) PBValueOf(v reflect.Value) pref.Value {
+func (c *listConverter) GoValueOf(v pref.Value) reflect.Value {
+	rv := v.List().(*listReflect).v
+	if rv.IsNil() {
+		return reflect.Zero(c.goType)
+	}
+	return rv.Elem()
+}
+
+func (c *listConverter) IsValidPB(v pref.Value) bool {
+	list, ok := v.Interface().(*listReflect)
+	if !ok {
+		return false
+	}
+	return list.v.Type().Elem() == c.goType
+}
+
+func (c *listConverter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
+
+func (c *listConverter) New() pref.Value {
+	return pref.ValueOf(&listReflect{reflect.New(c.goType), c.c})
+}
+
+func (c *listConverter) Zero() pref.Value {
+	return pref.ValueOf(&listReflect{reflect.Zero(reflect.PtrTo(c.goType)), c.c})
+}
+
+type listPtrConverter struct {
+	goType reflect.Type
+	c      Converter
+}
+
+func (c *listPtrConverter) PBValueOf(v reflect.Value) pref.Value {
 	if v.Type() != c.goType {
 		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
 	}
 	return pref.ValueOf(&listReflect{v, c.c})
 }
 
-func (c *listConverter) GoValueOf(v pref.Value) reflect.Value {
+func (c *listPtrConverter) GoValueOf(v pref.Value) reflect.Value {
 	return v.List().(*listReflect).v
 }
 
-func (c *listConverter) IsValidPB(v pref.Value) bool {
+func (c *listPtrConverter) IsValidPB(v pref.Value) bool {
 	list, ok := v.Interface().(*listReflect)
 	if !ok {
 		return false
@@ -42,15 +87,15 @@ func (c *listConverter) IsValidPB(v pref.Value) bool {
 	return list.v.Type() == c.goType
 }
 
-func (c *listConverter) IsValidGo(v reflect.Value) bool {
+func (c *listPtrConverter) IsValidGo(v reflect.Value) bool {
 	return v.Type() == c.goType
 }
 
-func (c *listConverter) New() pref.Value {
+func (c *listPtrConverter) New() pref.Value {
 	return c.PBValueOf(reflect.New(c.goType.Elem()))
 }
 
-func (c *listConverter) Zero() pref.Value {
+func (c *listPtrConverter) Zero() pref.Value {
 	return c.PBValueOf(reflect.Zero(c.goType))
 }
 

+ 1 - 1
internal/impl/extension.go

@@ -151,7 +151,7 @@ func (xi *ExtensionInfo) lazyInitSlow() {
 		xi.initFromLegacy()
 	} else if xi.desc.Cardinality() == pref.Repeated {
 		// Cardinality is initialized lazily, so we defer consulting it until here.
-		xi.goType = reflect.PtrTo(reflect.SliceOf(xi.goType))
+		xi.goType = reflect.SliceOf(xi.goType)
 	}
 	xi.conv = NewConverter(xi.goType, xi.desc)
 	xi.tdesc.ExtensionDescriptor = xi.desc

+ 0 - 6
internal/impl/legacy_extension.go

@@ -71,10 +71,6 @@ func (xi *ExtensionInfo) initToLegacy() {
 	switch extType.Kind() {
 	case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
 		extType = reflect.PtrTo(extType) // T -> *T for singular scalar fields
-	case reflect.Ptr:
-		if extType.Elem().Kind() == reflect.Slice {
-			extType = extType.Elem() // *[]T -> []T for repeated fields
-		}
 	}
 
 	// Reconstruct the legacy enum full name.
@@ -154,8 +150,6 @@ func (xi *ExtensionInfo) initFromLegacy() {
 	tt := reflect.TypeOf(xi.ExtensionType)
 	if isOptional {
 		tt = tt.Elem()
-	} else if isRepeated {
-		tt = reflect.PtrTo(tt)
 	}
 	xi.desc = xd
 	xi.goType = tt

+ 20 - 20
internal/impl/legacy_test.go

@@ -392,16 +392,16 @@ func TestLegacyExtensions(t *testing.T) {
 		7:  m1a,
 		8:  EnumProto2(0xbeef),
 		9:  m2a,
-		10: &[]bool{true},
-		11: &[]int32{-1000},
-		12: &[]uint32{1280},
-		13: &[]float32{1.6180},
-		14: &[]string{"zero"},
-		15: &[][]byte{[]byte("zero")},
-		16: &[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO},
-		17: &[]*proto2_20180125.Message_ChildMessage{m1b},
-		18: &[]EnumProto2{0xdead},
-		19: &[]*EnumMessages{m2b},
+		10: []bool{true},
+		11: []int32{-1000},
+		12: []uint32{1280},
+		13: []float32{1.6180},
+		14: []string{"zero"},
+		15: [][]byte{[]byte("zero")},
+		16: []proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO},
+		17: []*proto2_20180125.Message_ChildMessage{m1b},
+		18: []EnumProto2{0xdead},
+		19: []*EnumMessages{m2b},
 	}
 	for i, xt := range extensionTypes {
 		m.Set(xt.TypeDescriptor(), xt.ValueOf(setValues[i]))
@@ -423,16 +423,16 @@ func TestLegacyExtensions(t *testing.T) {
 		7:  m1a,
 		8:  EnumProto2(0xbeef),
 		9:  m2a,
-		10: &[]bool{true, false},
-		11: &[]int32{-1000, -54321},
-		12: &[]uint32{1280, 6400},
-		13: &[]float32{1.6180, 2.71828},
-		14: &[]string{"zero", "goodbye, \"world!\"\n"},
-		15: &[][]byte{[]byte("zero"), []byte("live\xde\xad\xbe\xefchicken")},
-		16: &[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO, proto2_20180125.Message_CHARLIE},
-		17: &[]*proto2_20180125.Message_ChildMessage{m1b, m1a},
-		18: &[]EnumProto2{0xdead, 0xbeef},
-		19: &[]*EnumMessages{m2b, m2a},
+		10: []bool{true, false},
+		11: []int32{-1000, -54321},
+		12: []uint32{1280, 6400},
+		13: []float32{1.6180, 2.71828},
+		14: []string{"zero", "goodbye, \"world!\"\n"},
+		15: [][]byte{[]byte("zero"), []byte("live\xde\xad\xbe\xefchicken")},
+		16: []proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO, proto2_20180125.Message_CHARLIE},
+		17: []*proto2_20180125.Message_ChildMessage{m1b, m1a},
+		18: []EnumProto2{0xdead, 0xbeef},
+		19: []*EnumMessages{m2b, m2a},
 	}
 	for i, xt := range extensionTypes {
 		xd := xt.TypeDescriptor()

+ 0 - 7
proto/decode_test.go

@@ -1742,13 +1742,6 @@ func unknown(raw protoreflect.RawFields) buildOpt {
 }
 
 func extend(desc protoreflect.ExtensionType, value interface{}) buildOpt {
-	// TODO: Should ExtensionType.ValueOf accept []T instead of *[]T?
-	t := reflect.TypeOf(value)
-	if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
-		v := reflect.New(t)
-		v.Elem().Set(reflect.ValueOf(value))
-		value = v.Interface()
-	}
 	return func(m proto.Message) {
 		proto.SetExtension(m, desc, value)
 	}

+ 2 - 2
proto/extension_test.go

@@ -35,8 +35,8 @@ func TestExtensionFuncs(t *testing.T) {
 			ext:     testpb.E_RepeatedStringExtension,
 			// TODO: Represent repeated extension fields as []T.
 			// https://github.com/golang/protobuf/issues/901
-			wantDefault: (*[]string)(nil),
-			value:       &[]string{"a", "b", "c"},
+			wantDefault: ([]string)(nil),
+			value:       []string{"a", "b", "c"},
 		},
 		{
 			message:     protoimpl.X.MessageOf(&legacy1pb.Message{}).Interface(),

+ 4 - 4
proto/merge_test.go

@@ -280,7 +280,7 @@ func TestMerge(t *testing.T) {
 					A: proto.Int32(50),
 				},
 			)
-			proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, &[]uint32{1, 2, 3})
+			proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{1, 2, 3})
 			return m
 		}(),
 		src: func() proto.Message {
@@ -293,7 +293,7 @@ func TestMerge(t *testing.T) {
 					},
 				},
 			)
-			proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, &[]uint32{4, 5, 6})
+			proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{4, 5, 6})
 			return m
 		}(),
 		want: func() proto.Message {
@@ -308,7 +308,7 @@ func TestMerge(t *testing.T) {
 					},
 				},
 			)
-			proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, &[]uint32{1, 2, 3, 4, 5, 6})
+			proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{1, 2, 3, 4, 5, 6})
 			return m
 		}(),
 	}, {
@@ -363,7 +363,7 @@ func TestMerge(t *testing.T) {
 				tt.mutator(tt.src) // should not be observable by dst
 			}
 			if !proto.Equal(tt.dst, tt.want) {
-				t.Fatalf("Merge() mismatch: got %v, want %v", tt.dst, tt.want)
+				t.Fatalf("Merge() mismatch:\n got %v\nwant %v", tt.dst, tt.want)
 			}
 		})
 	}