Kaynağa Gözat

proto: eagerly unmarshal extensions

CL/172399 switches the v1 code to eagerly unmarshal extensions.
This CL does the equivalent for v2.

For the test, we simply switch from protoV1.Equal to protoV2.Equal,
since the v2 equal does not magically unmarshal raw extensions.

Change-Id: I6f64455b0a75bbc9a9a82108558641a29bd2b982
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175838
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 yıl önce
ebeveyn
işleme
db38ddde7d
3 değiştirilmiş dosya ile 22 ekleme ve 1 silme
  1. 19 0
      proto/decode.go
  2. 1 1
      proto/decode_test.go
  3. 2 0
      runtime/protoiface/methods.go

+ 19 - 0
proto/decode.go

@@ -9,6 +9,7 @@ import (
 	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 	"github.com/golang/protobuf/v2/runtime/protoiface"
 )
 
@@ -25,6 +26,10 @@ type UnmarshalOptions struct {
 	// If DiscardUnknown is set, unknown fields are ignored.
 	DiscardUnknown bool
 
+	// Resolver is used for looking up types when unmarshaling extension fields.
+	// If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver *protoregistry.Types
+
 	pragma.NoUnkeyedLiterals
 }
 
@@ -37,6 +42,10 @@ func Unmarshal(b []byte, m Message) error {
 
 // Unmarshal parses the wire-format message in b and places the result in m.
 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+	if o.Resolver == nil {
+		o.Resolver = protoregistry.GlobalTypes
+	}
+
 	// TODO: Reset m?
 	err := o.unmarshalMessageFast(b, m)
 	if err == errInternalNoFast {
@@ -77,6 +86,16 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
 		fieldType := fieldTypes.ByNumber(num)
 		if fieldType == nil {
 			fieldType = knownFields.ExtensionTypes().ByNumber(num)
+			if fieldType == nil && messageType.ExtensionRanges().Has(num) {
+				extType, err := o.Resolver.FindExtensionByNumber(messageType.FullName(), num)
+				if err != nil && err != protoregistry.NotFound {
+					return err
+				}
+				if extType != nil {
+					knownFields.ExtensionTypes().Register(extType)
+					fieldType = extType
+				}
+			}
 		}
 		var err error
 		var valLen int

+ 1 - 1
proto/decode_test.go

@@ -54,7 +54,7 @@ func TestDecode(t *testing.T) {
 					// Equal doesn't work on messages containing invalid extension data.
 					return
 				}
-				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+				if !proto.Equal(got, want) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 			})

+ 2 - 0
runtime/protoiface/methods.go

@@ -7,6 +7,7 @@ package protoiface
 import (
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 )
 
 // Methoder is an optional interface implemented by generated messages to
@@ -62,6 +63,7 @@ type MarshalOptions struct {
 type UnmarshalOptions struct {
 	AllowPartial   bool
 	DiscardUnknown bool
+	Resolver       *protoregistry.Types
 
 	pragma.NoUnkeyedLiterals
 }