Bläddra i källkod

proto: store extension values according to protobuf data model (#746)

The current API represents scalar extension fields as *T and
repeated extension fields as []T.
However, this is not an accurate reflection of the protobuf data model.

For scalars, pointers are usually used to represent nullability. However,
in the case of extension scalars, there is no need to do so since presence
information is captured by checking whether the field is in the extension map.
For this reason, presence on extension scalars is not determined by checking
whether the returned pointer is nil, but whether HasExtension reports true.

For repeated fields, using []T means that the returned value is only a partially
mutable value. You can swap out elements, but you cannot change the length of
the original field value. On the other hand, the reflective API provides methods
on repeated field values that permit operations that do change the length.
Thus, using *[]T is a closer match to the protobuf data model.

This CL changes the implementation of extension fields to always store T
for scalars, and *[]T for repeated fields. However, for backwards compatibility,
the API continues to provide *T and []T. In theory, this could break anyone
that relies on memory aliasing for *T. However, this is unlikely since:
* use of extensions themselves are relatively rare
* if extensions are used, it is recommended practice to use a message as the
field extension and not a scalar
* relying on memory aliasing is generally not a good idiom to follow.
The expected pattern is to call SetExtension to make it explicit that a mutation
is happening on the message.
* analysis within Google demonstrates that no one is relying on this behavior.

Change-Id: Ieec4c3ba8176aabe7c3139d12741c7fcc46ba83d
Cherry-Pick: github.com/golang/protobuf@52132540909e117f2b98b0694383dc0ab1e1deca
Original-Author: Joe Tsai <joetsai@digital-static.net>
Reviewed-on: https://go-review.googlesource.com/c/151436
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 7 år sedan
förälder
incheckning
d6f7ca2ac0
6 ändrade filer med 141 tillägg och 42 borttagningar
  1. 33 17
      proto/clone_test.go
  2. 2 1
      proto/equal.go
  3. 70 6
      proto/extensions.go
  4. 4 1
      proto/pointer_reflect.go
  5. 10 5
      proto/pointer_unsafe.go
  6. 22 12
      proto/table_marshal.go

+ 33 - 17
proto/clone_test.go

@@ -67,34 +67,50 @@ func init() {
 	if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_More, ext); err != nil {
 		panic("SetExtension: " + err.Error())
 	}
+	if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_Text, proto.String("hello")); err != nil {
+		panic("SetExtension: " + err.Error())
+	}
+	if err := proto.SetExtension(cloneTestMessage, pb.E_Greeting, []string{"one", "two"}); err != nil {
+		panic("SetExtension: " + err.Error())
+	}
 }
 
 func TestClone(t *testing.T) {
+	// Create a clone using a marshal/unmarshal roundtrip.
+	vanilla := new(pb.MyMessage)
+	b, err := proto.Marshal(cloneTestMessage)
+	if err != nil {
+		t.Errorf("unexpected Marshal error: %v", err)
+	}
+	if err := proto.Unmarshal(b, vanilla); err != nil {
+		t.Errorf("unexpected Unarshal error: %v", err)
+	}
+
+	// Create a clone using Clone and verify that it is equal to the original.
 	m := proto.Clone(cloneTestMessage).(*pb.MyMessage)
 	if !proto.Equal(m, cloneTestMessage) {
 		t.Fatalf("Clone(%v) = %v", cloneTestMessage, m)
 	}
 
-	// Verify it was a deep copy.
-	*m.Inner.Port++
-	if proto.Equal(m, cloneTestMessage) {
-		t.Error("Mutating clone changed the original")
-	}
-	// Byte fields and repeated fields should be copied.
-	if &m.Pet[0] == &cloneTestMessage.Pet[0] {
-		t.Error("Pet: repeated field not copied")
+	// Mutate the clone, which should not affect the original.
+	x1, err := proto.GetExtension(m, pb.E_Ext_More)
+	if err != nil {
+		t.Errorf("unexpected GetExtension(%v) error: %v", pb.E_Ext_More.Name, err)
 	}
-	if &m.Others[0] == &cloneTestMessage.Others[0] {
-		t.Error("Others: repeated field not copied")
+	x2, err := proto.GetExtension(m, pb.E_Ext_Text)
+	if err != nil {
+		t.Errorf("unexpected GetExtension(%v) error: %v", pb.E_Ext_Text.Name, err)
 	}
-	if &m.Others[0].Value[0] == &cloneTestMessage.Others[0].Value[0] {
-		t.Error("Others[0].Value: bytes field not copied")
+	x3, err := proto.GetExtension(m, pb.E_Greeting)
+	if err != nil {
+		t.Errorf("unexpected GetExtension(%v) error: %v", pb.E_Greeting.Name, err)
 	}
-	if &m.RepBytes[0] == &cloneTestMessage.RepBytes[0] {
-		t.Error("RepBytes: repeated field not copied")
-	}
-	if &m.RepBytes[0][0] == &cloneTestMessage.RepBytes[0][0] {
-		t.Error("RepBytes[0]: bytes field not copied")
+	*m.Inner.Port++
+	*(x1.(*pb.Ext)).Data = "blah blah"
+	*(x2.(*string)) = "goodbye"
+	x3.([]string)[0] = "zero"
+	if !proto.Equal(cloneTestMessage, vanilla) {
+		t.Fatalf("mutation on original detected:\ngot  %v\nwant %v", cloneTestMessage, vanilla)
 	}
 }
 

+ 2 - 1
proto/equal.go

@@ -246,7 +246,8 @@ func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
 			return false
 		}
 
-		m1, m2 := e1.value, e2.value
+		m1 := extensionAsLegacyType(e1.value)
+		m2 := extensionAsLegacyType(e2.value)
 
 		if m1 == nil && m2 == nil {
 			// Both have only encoded form.

+ 70 - 6
proto/extensions.go

@@ -185,9 +185,25 @@ type Extension struct {
 	// extension will have only enc set. When such an extension is
 	// accessed using GetExtension (or GetExtensions) desc and value
 	// will be set.
-	desc  *ExtensionDesc
+	desc *ExtensionDesc
+
+	// value is a concrete value for the extension field. Let the type of
+	// desc.ExtensionType be the "API type" and the type of Extension.value
+	// be the "storage type". The API type and storage type are the same except:
+	//	* For scalars (except []byte), the API type uses *T,
+	//	while the storage type uses T.
+	//	* For repeated fields, the API type uses []T, while the storage type
+	//	uses *[]T.
+	//
+	// The reason for the divergence is so that the storage type more naturally
+	// matches what is expected of when retrieving the values through the
+	// protobuf reflection APIs.
+	//
+	// The value may only be populated if desc is also populated.
 	value interface{}
-	enc   []byte
+
+	// enc is the raw bytes for the extension field.
+	enc []byte
 }
 
 // SetRawExtension is for testing only.
@@ -334,7 +350,7 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 			// descriptors with the same field number.
 			return nil, errors.New("proto: descriptor conflict")
 		}
-		return e.value, nil
+		return extensionAsLegacyType(e.value), nil
 	}
 
 	if extension.ExtensionType == nil {
@@ -349,11 +365,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 
 	// Remember the decoded version and drop the encoded version.
 	// That way it is safe to mutate what we return.
-	e.value = v
+	e.value = extensionAsStorageType(v)
 	e.desc = extension
 	e.enc = nil
 	emap[extension.Field] = e
-	return e.value, nil
+	return extensionAsLegacyType(e.value), nil
 }
 
 // defaultExtensionValue returns the default value for extension.
@@ -500,7 +516,7 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
 	}
 
 	extmap := epb.extensionsWrite()
-	extmap[extension.Field] = Extension{desc: extension, value: value}
+	extmap[extension.Field] = Extension{desc: extension, value: extensionAsStorageType(value)}
 	return nil
 }
 
@@ -541,3 +557,51 @@ func RegisterExtension(desc *ExtensionDesc) {
 func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
 	return extensionMaps[reflect.TypeOf(pb).Elem()]
 }
+
+// extensionAsLegacyType converts an value in the storage type as the API type.
+// See Extension.value.
+func extensionAsLegacyType(v interface{}) interface{} {
+	switch rv := reflect.ValueOf(v); rv.Kind() {
+	case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
+		// Represent primitive types as a pointer to the value.
+		rv2 := reflect.New(rv.Type())
+		rv2.Elem().Set(rv)
+		v = rv2.Interface()
+	case reflect.Ptr:
+		// Represent slice types as the value itself.
+		switch rv.Type().Elem().Kind() {
+		case reflect.Slice:
+			if rv.IsNil() {
+				v = reflect.Zero(rv.Type().Elem()).Interface()
+			} else {
+				v = rv.Elem().Interface()
+			}
+		}
+	}
+	return v
+}
+
+// extensionAsStorageType converts an value in the API type as the storage type.
+// See Extension.value.
+func extensionAsStorageType(v interface{}) interface{} {
+	switch rv := reflect.ValueOf(v); rv.Kind() {
+	case reflect.Ptr:
+		// Represent slice types as the value itself.
+		switch rv.Type().Elem().Kind() {
+		case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
+			if rv.IsNil() {
+				v = reflect.Zero(rv.Type().Elem()).Interface()
+			} else {
+				v = rv.Elem().Interface()
+			}
+		}
+	case reflect.Slice:
+		// Represent slice types as a pointer to the value.
+		if rv.Type().Elem().Kind() != reflect.Uint8 {
+			rv2 := reflect.New(rv.Type())
+			rv2.Elem().Set(rv)
+			v = rv2.Interface()
+		}
+	}
+	return v
+}

+ 4 - 1
proto/pointer_reflect.go

@@ -79,10 +79,13 @@ func toPointer(i *Message) pointer {
 
 // toAddrPointer converts an interface to a pointer that points to
 // the interface data.
-func toAddrPointer(i *interface{}, isptr bool) pointer {
+func toAddrPointer(i *interface{}, isptr, deref bool) pointer {
 	v := reflect.ValueOf(*i)
 	u := reflect.New(v.Type())
 	u.Elem().Set(v)
+	if deref {
+		u = u.Elem()
+	}
 	return pointer{v: u}
 }
 

+ 10 - 5
proto/pointer_unsafe.go

@@ -85,16 +85,21 @@ func toPointer(i *Message) pointer {
 
 // toAddrPointer converts an interface to a pointer that points to
 // the interface data.
-func toAddrPointer(i *interface{}, isptr bool) pointer {
+func toAddrPointer(i *interface{}, isptr, deref bool) (p pointer) {
 	// Super-tricky - read or get the address of data word of interface value.
 	if isptr {
 		// The interface is of pointer type, thus it is a direct interface.
 		// The data word is the pointer data itself. We take its address.
-		return pointer{p: unsafe.Pointer(uintptr(unsafe.Pointer(i)) + ptrSize)}
+		p = pointer{p: unsafe.Pointer(uintptr(unsafe.Pointer(i)) + ptrSize)}
+	} else {
+		// The interface is not of pointer type. The data word is the pointer
+		// to the data.
+		p = pointer{p: (*[2]unsafe.Pointer)(unsafe.Pointer(i))[1]}
 	}
-	// The interface is not of pointer type. The data word is the pointer
-	// to the data.
-	return pointer{p: (*[2]unsafe.Pointer)(unsafe.Pointer(i))[1]}
+	if deref {
+		p.p = *(*unsafe.Pointer)(p.p)
+	}
+	return p
 }
 
 // valToPointer converts v to a pointer. v must be of pointer type.

+ 22 - 12
proto/table_marshal.go

@@ -87,6 +87,7 @@ type marshalElemInfo struct {
 	sizer     sizer
 	marshaler marshaler
 	isptr     bool // elem is pointer typed, thus interface of this type is a direct interface (extension only)
+	deref     bool // dereference the pointer before operating on it; implies isptr
 }
 
 var (
@@ -407,13 +408,22 @@ func (u *marshalInfo) getExtElemInfo(desc *ExtensionDesc) *marshalElemInfo {
 		panic("tag is not an integer")
 	}
 	wt := wiretype(tags[0])
+	if t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct {
+		t = t.Elem()
+	}
 	sizer, marshaler := typeMarshaler(t, tags, false, false)
+	var deref bool
+	if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
+		t = reflect.PtrTo(t)
+		deref = true
+	}
 	e = &marshalElemInfo{
 		wiretag:   uint64(tag)<<3 | wt,
 		tagsize:   SizeVarint(uint64(tag) << 3),
 		sizer:     sizer,
 		marshaler: marshaler,
 		isptr:     t.Kind() == reflect.Ptr,
+		deref:     deref,
 	}
 
 	// update cache
@@ -2310,8 +2320,8 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
 			for _, k := range m.MapKeys() {
 				ki := k.Interface()
 				vi := m.MapIndex(k).Interface()
-				kaddr := toAddrPointer(&ki, false)             // pointer to key
-				vaddr := toAddrPointer(&vi, valIsPtr)          // pointer to value
+				kaddr := toAddrPointer(&ki, false, false)      // pointer to key
+				vaddr := toAddrPointer(&vi, valIsPtr, false)   // pointer to value
 				siz := keySizer(kaddr, 1) + valSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
 				n += siz + SizeVarint(uint64(siz)) + tagsize
 			}
@@ -2329,8 +2339,8 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
 			for _, k := range keys {
 				ki := k.Interface()
 				vi := m.MapIndex(k).Interface()
-				kaddr := toAddrPointer(&ki, false)    // pointer to key
-				vaddr := toAddrPointer(&vi, valIsPtr) // pointer to value
+				kaddr := toAddrPointer(&ki, false, false)    // pointer to key
+				vaddr := toAddrPointer(&vi, valIsPtr, false) // pointer to value
 				b = appendVarint(b, tag)
 				siz := keySizer(kaddr, 1) + valCachedSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
 				b = appendVarint(b, uint64(siz))
@@ -2399,7 +2409,7 @@ func (u *marshalInfo) sizeExtensions(ext *XXX_InternalExtensions) int {
 		// the last time this function was called.
 		ei := u.getExtElemInfo(e.desc)
 		v := e.value
-		p := toAddrPointer(&v, ei.isptr)
+		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		n += ei.sizer(p, ei.tagsize)
 	}
 	mu.Unlock()
@@ -2434,7 +2444,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 
 			ei := u.getExtElemInfo(e.desc)
 			v := e.value
-			p := toAddrPointer(&v, ei.isptr)
+			p := toAddrPointer(&v, ei.isptr, ei.deref)
 			b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 			if !nerr.Merge(err) {
 				return b, err
@@ -2465,7 +2475,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 
 		ei := u.getExtElemInfo(e.desc)
 		v := e.value
-		p := toAddrPointer(&v, ei.isptr)
+		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 		if !nerr.Merge(err) {
 			return b, err
@@ -2510,7 +2520,7 @@ func (u *marshalInfo) sizeMessageSet(ext *XXX_InternalExtensions) int {
 
 		ei := u.getExtElemInfo(e.desc)
 		v := e.value
-		p := toAddrPointer(&v, ei.isptr)
+		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		n += ei.sizer(p, 1) // message, tag = 3 (size=1)
 	}
 	mu.Unlock()
@@ -2553,7 +2563,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 
 			ei := u.getExtElemInfo(e.desc)
 			v := e.value
-			p := toAddrPointer(&v, ei.isptr)
+			p := toAddrPointer(&v, ei.isptr, ei.deref)
 			b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
 			if !nerr.Merge(err) {
 				return b, err
@@ -2591,7 +2601,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 
 		ei := u.getExtElemInfo(e.desc)
 		v := e.value
-		p := toAddrPointer(&v, ei.isptr)
+		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
 		b = append(b, 1<<3|WireEndGroup)
 		if !nerr.Merge(err) {
@@ -2621,7 +2631,7 @@ func (u *marshalInfo) sizeV1Extensions(m map[int32]Extension) int {
 
 		ei := u.getExtElemInfo(e.desc)
 		v := e.value
-		p := toAddrPointer(&v, ei.isptr)
+		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		n += ei.sizer(p, ei.tagsize)
 	}
 	return n
@@ -2656,7 +2666,7 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
 
 		ei := u.getExtElemInfo(e.desc)
 		v := e.value
-		p := toAddrPointer(&v, ei.isptr)
+		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 		if !nerr.Merge(err) {
 			return b, err