Bladeren bron

proto: expose accessors for raw value of extensions (#483)

Modify GetExtension to return the raw bytes if the ExtensionDesc is type-incomplete
(i.e., the ExtensionType is nil).
Joshua Humphries 8 jaren geleden
bovenliggende
commit
5f34c20e59
2 gewijzigde bestanden met toevoegingen van 115 en 6 verwijderingen
  1. 24 4
      proto/extensions.go
  2. 91 2
      proto/extensions_test.go

+ 24 - 4
proto/extensions.go

@@ -291,16 +291,26 @@ func ClearExtension(pb Message, extension *ExtensionDesc) {
 	delete(extmap, extension.Field)
 }
 
-// GetExtension parses and returns the given extension of pb.
-// If the extension is not present and has no default value it returns ErrMissingExtension.
+// GetExtension retrieves a proto2 extended field from pb.
+//
+// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
+// then GetExtension parses the encoded field and returns a Go value of the specified type.
+// If the field is not present, then the default value is returned (if one is specified),
+// otherwise ErrMissingExtension is reported.
+//
+// If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil),
+// then GetExtension returns the raw encoded bytes of the field extension.
 func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 	epb, err := extendable(pb)
 	if err != nil {
 		return nil, err
 	}
 
-	if err := checkExtensionTypes(epb, extension); err != nil {
-		return nil, err
+	if extension.ExtendedType != nil {
+		// can only check type if this is a complete descriptor
+		if err := checkExtensionTypes(epb, extension); err != nil {
+			return nil, err
+		}
 	}
 
 	emap, mu := epb.extensionsRead()
@@ -327,6 +337,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 		return e.value, nil
 	}
 
+	if extension.ExtensionType == nil {
+		// incomplete descriptor
+		return e.enc, nil
+	}
+
 	v, err := decodeExtension(e.enc, extension)
 	if err != nil {
 		return nil, err
@@ -344,6 +359,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 // defaultExtensionValue returns the default value for extension.
 // If no default for an extension is defined ErrMissingExtension is returned.
 func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
+	if extension.ExtensionType == nil {
+		// incomplete descriptor, so no default
+		return nil, ErrMissingExtension
+	}
+
 	t := reflect.TypeOf(extension.ExtensionType)
 	props := extensionProperties(extension)
 

+ 91 - 2
proto/extensions_test.go

@@ -77,7 +77,96 @@ func TestGetExtensionWithEmptyBuffer(t *testing.T) {
 	}
 }
 
-func TestExtensionDescsWithMissingExtensions(t *testing.T) {
+func TestGetExtensionForIncompleteDesc(t *testing.T) {
+	msg := &pb.MyMessage{Count: proto.Int32(0)}
+	extdesc1 := &proto.ExtensionDesc{
+		ExtendedType:  (*pb.MyMessage)(nil),
+		ExtensionType: (*bool)(nil),
+		Field:         123456789,
+		Name:          "a.b",
+		Tag:           "varint,123456789,opt",
+	}
+	ext1 := proto.Bool(true)
+	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
+		t.Fatalf("Could not set ext1: %s", err)
+	}
+	extdesc2 := &proto.ExtensionDesc{
+		ExtendedType:  (*pb.MyMessage)(nil),
+		ExtensionType: ([]byte)(nil),
+		Field:         123456790,
+		Name:          "a.c",
+		Tag:           "bytes,123456790,opt",
+	}
+	ext2 := []byte{0,1,2,3,4,5,6,7}
+	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
+		t.Fatalf("Could not set ext2: %s", err)
+	}
+	extdesc3 := &proto.ExtensionDesc{
+		ExtendedType:  (*pb.MyMessage)(nil),
+		ExtensionType: (*pb.Ext)(nil),
+		Field:         123456791,
+		Name:          "a.d",
+		Tag:           "bytes,123456791,opt",
+	}
+	ext3 := &pb.Ext{Data: proto.String("foo")}
+	if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
+		t.Fatalf("Could not set ext3: %s", err)
+	}
+
+	b, err := proto.Marshal(msg)
+	if err != nil {
+		t.Fatalf("Could not marshal msg: %v", err)
+	}
+	if err := proto.Unmarshal(b, msg); err != nil {
+		t.Fatalf("Could not unmarshal into msg: %v", err)
+	}
+
+	var expected proto.Buffer
+	if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
+		t.Fatalf("failed to compute expected prefix for ext1: %s", err)
+	}
+	if err := expected.EncodeVarint(1 /* bool true */); err != nil {
+		t.Fatalf("failed to compute expected value for ext1: %s", err)
+	}
+
+	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
+		t.Fatalf("Failed to get raw value for ext1: %s", err)
+	} else if !reflect.DeepEqual(b, expected.Bytes()) {
+		t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
+	}
+
+	expected = proto.Buffer{} // reset
+	if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
+		t.Fatalf("failed to compute expected prefix for ext2: %s", err)
+	}
+	if err := expected.EncodeRawBytes(ext2); err != nil {
+		t.Fatalf("failed to compute expected value for ext2: %s", err)
+	}
+
+	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
+		t.Fatalf("Failed to get raw value for ext2: %s", err)
+	} else if !reflect.DeepEqual(b, expected.Bytes()) {
+		t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
+	}
+
+	expected = proto.Buffer{} // reset
+	if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
+		t.Fatalf("failed to compute expected prefix for ext3: %s", err)
+	}
+	if b, err := proto.Marshal(ext3); err != nil {
+		t.Fatalf("failed to compute expected value for ext3: %s", err)
+	} else if err := expected.EncodeRawBytes(b); err != nil {
+		t.Fatalf("failed to compute expected value for ext3: %s", err)
+	}
+
+	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
+		t.Fatalf("Failed to get raw value for ext3: %s", err)
+	} else if !reflect.DeepEqual(b, expected.Bytes()) {
+		t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
+	}
+}
+
+func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
 	msg := &pb.MyMessage{Count: proto.Int32(0)}
 	extdesc1 := pb.E_Ext_More
 	if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
@@ -113,7 +202,7 @@ func TestExtensionDescsWithMissingExtensions(t *testing.T) {
 		t.Fatalf("proto.ExtensionDescs: got error %v", err)
 	}
 	sortExtDescs(descs)
-	wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}}
+	wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
 	if !reflect.DeepEqual(descs, wantDescs) {
 		t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
 	}