Bläddra i källkod

proto: eagerly unmarshal extensions

CL/172399 switches the v1 code to eagerly unmarshal extensions.
This CL does the equivalent for v2.

For the test, we simply switch from protoV1.Equal to protoV2.Equal,
since the v2 equal does not magically unmarshal raw extensions.

Change-Id: I6f64455b0a75bbc9a9a82108558641a29bd2b982
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175838
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 år sedan
förälder
incheckning
db38ddde7d
3 ändrade filer med 22 tillägg och 1 borttagningar
  1. 19 0
      proto/decode.go
  2. 1 1
      proto/decode_test.go
  3. 2 0
      runtime/protoiface/methods.go

+ 19 - 0
proto/decode.go

@@ -9,6 +9,7 @@ import (
 	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 	"github.com/golang/protobuf/v2/runtime/protoiface"
 	"github.com/golang/protobuf/v2/runtime/protoiface"
 )
 )
 
 
@@ -25,6 +26,10 @@ type UnmarshalOptions struct {
 	// If DiscardUnknown is set, unknown fields are ignored.
 	// If DiscardUnknown is set, unknown fields are ignored.
 	DiscardUnknown bool
 	DiscardUnknown bool
 
 
+	// Resolver is used for looking up types when unmarshaling extension fields.
+	// If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver *protoregistry.Types
+
 	pragma.NoUnkeyedLiterals
 	pragma.NoUnkeyedLiterals
 }
 }
 
 
@@ -37,6 +42,10 @@ func Unmarshal(b []byte, m Message) error {
 
 
 // Unmarshal parses the wire-format message in b and places the result in m.
 // Unmarshal parses the wire-format message in b and places the result in m.
 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+	if o.Resolver == nil {
+		o.Resolver = protoregistry.GlobalTypes
+	}
+
 	// TODO: Reset m?
 	// TODO: Reset m?
 	err := o.unmarshalMessageFast(b, m)
 	err := o.unmarshalMessageFast(b, m)
 	if err == errInternalNoFast {
 	if err == errInternalNoFast {
@@ -77,6 +86,16 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
 		fieldType := fieldTypes.ByNumber(num)
 		fieldType := fieldTypes.ByNumber(num)
 		if fieldType == nil {
 		if fieldType == nil {
 			fieldType = knownFields.ExtensionTypes().ByNumber(num)
 			fieldType = knownFields.ExtensionTypes().ByNumber(num)
+			if fieldType == nil && messageType.ExtensionRanges().Has(num) {
+				extType, err := o.Resolver.FindExtensionByNumber(messageType.FullName(), num)
+				if err != nil && err != protoregistry.NotFound {
+					return err
+				}
+				if extType != nil {
+					knownFields.ExtensionTypes().Register(extType)
+					fieldType = extType
+				}
+			}
 		}
 		}
 		var err error
 		var err error
 		var valLen int
 		var valLen int

+ 1 - 1
proto/decode_test.go

@@ -54,7 +54,7 @@ func TestDecode(t *testing.T) {
 					// Equal doesn't work on messages containing invalid extension data.
 					// Equal doesn't work on messages containing invalid extension data.
 					return
 					return
 				}
 				}
-				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+				if !proto.Equal(got, want) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 				}
 			})
 			})

+ 2 - 0
runtime/protoiface/methods.go

@@ -7,6 +7,7 @@ package protoiface
 import (
 import (
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 )
 )
 
 
 // Methoder is an optional interface implemented by generated messages to
 // Methoder is an optional interface implemented by generated messages to
@@ -62,6 +63,7 @@ type MarshalOptions struct {
 type UnmarshalOptions struct {
 type UnmarshalOptions struct {
 	AllowPartial   bool
 	AllowPartial   bool
 	DiscardUnknown bool
 	DiscardUnknown bool
+	Resolver       *protoregistry.Types
 
 
 	pragma.NoUnkeyedLiterals
 	pragma.NoUnkeyedLiterals
 }
 }