فهرست منبع

proto: eagerly unmarshal extensions

CL/172399 switches the v1 code to eagerly unmarshal extensions.
This CL does the equivalent for v2.

For the test, we simply switch from protoV1.Equal to protoV2.Equal,
since the v2 equal does not magically unmarshal raw extensions.

Change-Id: I6f64455b0a75bbc9a9a82108558641a29bd2b982
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175838
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 سال پیش
والد
کامیت
db38ddde7d
3فایلهای تغییر یافته به همراه22 افزوده شده و 1 حذف شده
  1. 19 0
      proto/decode.go
  2. 1 1
      proto/decode_test.go
  3. 2 0
      runtime/protoiface/methods.go

+ 19 - 0
proto/decode.go

@@ -9,6 +9,7 @@ import (
 	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 	"github.com/golang/protobuf/v2/runtime/protoiface"
 )
 
@@ -25,6 +26,10 @@ type UnmarshalOptions struct {
 	// If DiscardUnknown is set, unknown fields are ignored.
 	DiscardUnknown bool
 
+	// Resolver is used for looking up types when unmarshaling extension fields.
+	// If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver *protoregistry.Types
+
 	pragma.NoUnkeyedLiterals
 }
 
@@ -37,6 +42,10 @@ func Unmarshal(b []byte, m Message) error {
 
 // Unmarshal parses the wire-format message in b and places the result in m.
 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+	if o.Resolver == nil {
+		o.Resolver = protoregistry.GlobalTypes
+	}
+
 	// TODO: Reset m?
 	err := o.unmarshalMessageFast(b, m)
 	if err == errInternalNoFast {
@@ -77,6 +86,16 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
 		fieldType := fieldTypes.ByNumber(num)
 		if fieldType == nil {
 			fieldType = knownFields.ExtensionTypes().ByNumber(num)
+			if fieldType == nil && messageType.ExtensionRanges().Has(num) {
+				extType, err := o.Resolver.FindExtensionByNumber(messageType.FullName(), num)
+				if err != nil && err != protoregistry.NotFound {
+					return err
+				}
+				if extType != nil {
+					knownFields.ExtensionTypes().Register(extType)
+					fieldType = extType
+				}
+			}
 		}
 		var err error
 		var valLen int

+ 1 - 1
proto/decode_test.go

@@ -54,7 +54,7 @@ func TestDecode(t *testing.T) {
 					// Equal doesn't work on messages containing invalid extension data.
 					return
 				}
-				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+				if !proto.Equal(got, want) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 			})

+ 2 - 0
runtime/protoiface/methods.go

@@ -7,6 +7,7 @@ package protoiface
 import (
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 )
 
 // Methoder is an optional interface implemented by generated messages to
@@ -62,6 +63,7 @@ type MarshalOptions struct {
 type UnmarshalOptions struct {
 	AllowPartial   bool
 	DiscardUnknown bool
+	Resolver       *protoregistry.Types
 
 	pragma.NoUnkeyedLiterals
 }