Просмотр исходного кода

proto: fix merge semantics for oneof message

The proper semantics for a message field within a oneof
when unmarshaling is to merge into an existing message,
rather than replacing it.

Change-Id: I7c08f6e4fa958c6ee6241e9083f7311515a97e15
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185957
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 лет назад
Родитель
Сommit
6c28674cea
4 измененных файлов с 14 добавлено и 21 удалено
  1. 9 3
      internal/impl/codec_field.go
  2. 1 13
      proto/decode.go
  3. 3 1
      proto/decode_test.go
  4. 1 4
      proto/merge_test.go

+ 9 - 3
internal/impl/codec_field.go

@@ -52,12 +52,18 @@ func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFun
 			return funcs.marshal(b, v, wiretag, opts)
 		},
 		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
-			v := reflect.New(ot)
-			n, err := funcs.unmarshal(b, pointerOfValue(v).Apply(zeroOffset), wtyp, opts)
+			var vw reflect.Value         // pointer to wrapper type
+			vi := p.AsValueOf(ft).Elem() // oneof field value of interface kind
+			if !vi.IsNil() && !vi.Elem().IsNil() && vi.Elem().Elem().Type() == ot {
+				vw = vi.Elem()
+			} else {
+				vw = reflect.New(ot)
+			}
+			n, err := funcs.unmarshal(b, pointerOfValue(vw).Apply(zeroOffset), wtyp, opts)
 			if err != nil {
 				return 0, err
 			}
-			p.AsValueOf(ft).Elem().Set(v)
+			vi.Set(vw)
 			return n, nil
 		},
 	}

+ 1 - 13
proto/decode.go

@@ -122,19 +122,7 @@ func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp wire.Type, m protoref
 	}
 	switch fd.Kind() {
 	case protoreflect.GroupKind, protoreflect.MessageKind:
-		// Messages are merged with any existing message value,
-		// unless the message is part of a oneof.
-		//
-		// TODO: C++ merges into oneofs, while v1 does not.
-		// Evaluate which behavior to pick.
-		var m2 protoreflect.Message
-		if m.Has(fd) && fd.ContainingOneof() == nil {
-			m2 = m.Mutable(fd).Message()
-		} else {
-			m2 = m.NewMessage(fd)
-			m.Set(fd, protoreflect.ValueOf(m2))
-		}
-		// Pass up errors (fatal and otherwise).
+		m2 := m.Mutable(fd).Message()
 		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
 			return n, err
 		}

+ 3 - 1
proto/decode_test.go

@@ -881,16 +881,18 @@ var testProtos = []testProto{
 		wire: pack.Message{pack.Tag{112, pack.BytesType}, pack.LengthPrefix(pack.Message{})}.Marshal(),
 	},
 	{
-		desc: "oneof (overridden message)",
+		desc: "oneof (merged message)",
 		decodeTo: []proto.Message{
 			&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofNestedMessage{
 				&testpb.TestAllTypes_NestedMessage{
+					A: scalar.Int32(1),
 					Corecursive: &testpb.TestAllTypes{
 						OptionalInt32: scalar.Int32(43),
 					},
 				},
 			}}, &test3pb.TestAllTypes{OneofField: &test3pb.TestAllTypes_OneofNestedMessage{
 				&test3pb.TestAllTypes_NestedMessage{
+					A: 1,
 					Corecursive: &test3pb.TestAllTypes{
 						OptionalInt32: 43,
 					},

+ 1 - 4
proto/merge_test.go

@@ -27,8 +27,6 @@ func TestMerge(t *testing.T) {
 		src     proto.Message
 		want    proto.Message
 		mutator func(proto.Message) // if provided, is run on src after merging
-
-		skipMarshalUnmarshal bool // TODO: Remove this when proto.Unmarshal is fixed for messages in oneofs
 	}{{
 		desc: "merge from nil message",
 		dst:  new(testpb.TestAllTypes),
@@ -258,7 +256,6 @@ func TestMerge(t *testing.T) {
 			m := mi.(*testpb.TestAllTypes)
 			*m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.Corecursive.OptionalInt64++
 		},
-		skipMarshalUnmarshal: true,
 	}, {
 		desc: "merge oneof scalar fields",
 		dst: &testpb.TestAllTypes{
@@ -382,7 +379,7 @@ func TestMerge(t *testing.T) {
 			if err != nil {
 				t.Fatalf("Unmarshal() error: %v", err)
 			}
-			if !proto.Equal(dst, tt.want) && !tt.skipMarshalUnmarshal {
+			if !proto.Equal(dst, tt.want) {
 				t.Fatalf("Unmarshal(Marshal(dst)+Marshal(src)) mismatch: got %v, want %v", dst, tt.want)
 			}