Browse Source

proto: use XXX_unrecognized exclusively for unknown fields

The protobuf data model makes no distinction between unknown fields
that are within the extension field ranges or not. Now that we eagerly
unmarshal extensions, there is even less need for storing unknown
fields in the extension map. Instead, use the XXX_unrecognized field
exclusively for this purpose.

To support this logic, we fork the v2 internal/encoding/wire package.
This is a temporary measure since the v1 code will be completely
re-written in terms of the v2 API in the near future.

Change-Id: I3dadd04ec2314e6d245d46f6329383bb9e0d00f7
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175580
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 years ago
parent
commit
911a20d792
12 changed files with 759 additions and 180 deletions
  1. 1 1
      go.mod
  2. 3 1
      go.sum
  3. 516 0
      internal/wire/wire.go
  4. 0 4
      proto/clone.go
  5. 3 40
      proto/equal.go
  6. 84 15
      proto/extensions.go
  7. 0 12
      proto/extensions_test.go
  8. 7 17
      proto/message_set.go
  9. 8 2
      proto/message_set_test.go
  10. 70 38
      proto/table_marshal.go
  11. 67 45
      proto/table_unmarshal.go
  12. 0 5
      proto/text.go

+ 1 - 1
go.mod

@@ -1,3 +1,3 @@
 module github.com/golang/protobuf
 
-require github.com/golang/protobuf/v2 v2.0.0-20190420063524-d24bc72368a2
+require github.com/golang/protobuf/v2 v2.0.0-20190509012650-00a323deed55

+ 3 - 1
go.sum

@@ -1,9 +1,11 @@
 github.com/golang/protobuf v1.2.1-0.20190322195920-d94fb84e04b7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 github.com/golang/protobuf v1.2.1-0.20190326022002-be03c15fcaa2/go.mod h1:rZ4veVXHB1S2+o7TKqD9Isxml062IeDutnCDtFPUlCc=
 github.com/golang/protobuf v1.2.1-0.20190416233244-13cf6e79fd39/go.mod h1:RgnTNLHWo9HXezTFX5MTeuXnXx9eeQX8y3Cukv+9HaE=
+github.com/golang/protobuf v1.2.1-0.20190420064300-2b4f3c98b458/go.mod h1:hPB+itxf2EbA0J6prVtJg+ohMeLFLEhlSXXPS2qxTZE=
 github.com/golang/protobuf/v2 v2.0.0-20190322201422-f503c300f70e/go.mod h1:25ZALhydMFaBRgPH58a8zpFe9YXMAMjOYWtB6pNPcoo=
 github.com/golang/protobuf/v2 v2.0.0-20190416222953-ab61d41ec93f/go.mod h1:baUT2weUsA1MR7ocRtLXLmi2B1s4VrUT3S6tO8AYzMw=
-github.com/golang/protobuf/v2 v2.0.0-20190420063524-d24bc72368a2 h1:Tp4FhirEYFiZdhirylriHTC/4tGUz3j1r96XDMpYaAQ=
 github.com/golang/protobuf/v2 v2.0.0-20190420063524-d24bc72368a2/go.mod h1:wcEMLTNPNYxBFS3yY7kunR0QKUgP/f+wzZaPeTbHi0g=
+github.com/golang/protobuf/v2 v2.0.0-20190509012650-00a323deed55 h1:zd+Y1Z1XtROfzk20h4yfXgURDkZ13m1dCC1aoaA/T6c=
+github.com/golang/protobuf/v2 v2.0.0-20190509012650-00a323deed55/go.mod h1:pWnbrfE+N2TBYiklDHixM32oa26kuZCiJwxIu0DAl7Y=
 github.com/google/go-cmp v0.2.1-0.20190312032427-6f77996f0c42 h1:q3pnF5JFBNRz8sRD+IRj7Y6DMyYGTNqnZ9axTbSfoNI=
 github.com/google/go-cmp v0.2.1-0.20190312032427-6f77996f0c42/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=

+ 516 - 0
internal/wire/wire.go

@@ -0,0 +1,516 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package wire parses and formats the protobuf wire encoding.
+//
+// See https://developers.google.com/protocol-buffers/docs/encoding.
+package wire
+
+import (
+	"errors"
+	"io"
+	"math"
+	"math/bits"
+
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// Number represents the field number.
+type Number = protoreflect.FieldNumber
+
+const (
+	MinValidNumber      Number = 1
+	FirstReservedNumber Number = 19000
+	LastReservedNumber  Number = 19999
+	MaxValidNumber      Number = 1<<29 - 1
+)
+
+// Type represents the wire type.
+type Type int8
+
+const (
+	VarintType     Type = 0
+	Fixed32Type    Type = 5
+	Fixed64Type    Type = 1
+	BytesType      Type = 2
+	StartGroupType Type = 3
+	EndGroupType   Type = 4
+)
+
+const (
+	_ = -iota
+	errCodeTruncated
+	errCodeFieldNumber
+	errCodeOverflow
+	errCodeReserved
+	errCodeEndGroup
+)
+
+var (
+	errFieldNumber = errors.New("invalid field number")
+	errOverflow    = errors.New("variable length integer overflow")
+	errReserved    = errors.New("cannot parse reserved wire type")
+	errEndGroup    = errors.New("mismatching end group marker")
+	errParse       = errors.New("parse error")
+)
+
+// ParseError converts an error code into an error value.
+// This returns nil if n is a non-negative number.
+func ParseError(n int) error {
+	if n >= 0 {
+		return nil
+	}
+	switch n {
+	case errCodeTruncated:
+		return io.ErrUnexpectedEOF
+	case errCodeFieldNumber:
+		return errFieldNumber
+	case errCodeOverflow:
+		return errOverflow
+	case errCodeReserved:
+		return errReserved
+	case errCodeEndGroup:
+		return errEndGroup
+	default:
+		return errParse
+	}
+}
+
+// ConsumeField parses an entire field record (both tag and value) and returns
+// the field number, the wire type, and the total length.
+// This returns a negative length upon an error (see ParseError).
+//
+// The total length includes the tag header and the end group marker (if the
+// field is a group).
+func ConsumeField(b []byte) (Number, Type, int) {
+	num, typ, n := ConsumeTag(b)
+	if n < 0 {
+		return 0, 0, n // forward error code
+	}
+	m := ConsumeFieldValue(num, typ, b[n:])
+	if m < 0 {
+		return 0, 0, m // forward error code
+	}
+	return num, typ, n + m
+}
+
+// ConsumeFieldValue parses a field value and returns its length.
+// This assumes that the field Number and wire Type have already been parsed.
+// This returns a negative length upon an error (see ParseError).
+//
+// When parsing a group, the length includes the end group marker and
+// the end group is verified to match the starting field number.
+func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
+	switch typ {
+	case VarintType:
+		_, n = ConsumeVarint(b)
+		return n
+	case Fixed32Type:
+		_, n = ConsumeFixed32(b)
+		return n
+	case Fixed64Type:
+		_, n = ConsumeFixed64(b)
+		return n
+	case BytesType:
+		_, n = ConsumeBytes(b)
+		return n
+	case StartGroupType:
+		n0 := len(b)
+		for {
+			num2, typ2, n := ConsumeTag(b)
+			if n < 0 {
+				return n // forward error code
+			}
+			b = b[n:]
+			if typ2 == EndGroupType {
+				if num != num2 {
+					return errCodeEndGroup
+				}
+				return n0 - len(b)
+			}
+
+			n = ConsumeFieldValue(num2, typ2, b)
+			if n < 0 {
+				return n // forward error code
+			}
+			b = b[n:]
+		}
+	case EndGroupType:
+		return errCodeEndGroup
+	default:
+		return errCodeReserved
+	}
+}
+
+// AppendTag encodes num and typ as a varint-encoded tag and appends it to b.
+func AppendTag(b []byte, num Number, typ Type) []byte {
+	return AppendVarint(b, EncodeTag(num, typ))
+}
+
+// ConsumeTag parses b as a varint-encoded tag, reporting its length.
+// This returns a negative length upon an error (see ParseError).
+func ConsumeTag(b []byte) (Number, Type, int) {
+	v, n := ConsumeVarint(b)
+	if n < 0 {
+		return 0, 0, n // forward error code
+	}
+	num, typ := DecodeTag(v)
+	if num < MinValidNumber {
+		return 0, 0, errCodeFieldNumber
+	}
+	return num, typ, n
+}
+
+func SizeTag(num Number) int {
+	return SizeVarint(EncodeTag(num, 0)) // wire type has no effect on size
+}
+
+// AppendVarint appends v to b as a varint-encoded uint64.
+func AppendVarint(b []byte, v uint64) []byte {
+	// TODO: Specialize for sizes 1 and 2 with mid-stack inlining.
+	switch {
+	case v < 1<<7:
+		b = append(b, byte(v))
+	case v < 1<<14:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte(v>>7))
+	case v < 1<<21:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte((v>>7)&0x7f|0x80),
+			byte(v>>14))
+	case v < 1<<28:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte((v>>7)&0x7f|0x80),
+			byte((v>>14)&0x7f|0x80),
+			byte(v>>21))
+	case v < 1<<35:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte((v>>7)&0x7f|0x80),
+			byte((v>>14)&0x7f|0x80),
+			byte((v>>21)&0x7f|0x80),
+			byte(v>>28))
+	case v < 1<<42:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte((v>>7)&0x7f|0x80),
+			byte((v>>14)&0x7f|0x80),
+			byte((v>>21)&0x7f|0x80),
+			byte((v>>28)&0x7f|0x80),
+			byte(v>>35))
+	case v < 1<<49:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte((v>>7)&0x7f|0x80),
+			byte((v>>14)&0x7f|0x80),
+			byte((v>>21)&0x7f|0x80),
+			byte((v>>28)&0x7f|0x80),
+			byte((v>>35)&0x7f|0x80),
+			byte(v>>42))
+	case v < 1<<56:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte((v>>7)&0x7f|0x80),
+			byte((v>>14)&0x7f|0x80),
+			byte((v>>21)&0x7f|0x80),
+			byte((v>>28)&0x7f|0x80),
+			byte((v>>35)&0x7f|0x80),
+			byte((v>>42)&0x7f|0x80),
+			byte(v>>49))
+	case v < 1<<63:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte((v>>7)&0x7f|0x80),
+			byte((v>>14)&0x7f|0x80),
+			byte((v>>21)&0x7f|0x80),
+			byte((v>>28)&0x7f|0x80),
+			byte((v>>35)&0x7f|0x80),
+			byte((v>>42)&0x7f|0x80),
+			byte((v>>49)&0x7f|0x80),
+			byte(v>>56))
+	default:
+		b = append(b,
+			byte((v>>0)&0x7f|0x80),
+			byte((v>>7)&0x7f|0x80),
+			byte((v>>14)&0x7f|0x80),
+			byte((v>>21)&0x7f|0x80),
+			byte((v>>28)&0x7f|0x80),
+			byte((v>>35)&0x7f|0x80),
+			byte((v>>42)&0x7f|0x80),
+			byte((v>>49)&0x7f|0x80),
+			byte((v>>56)&0x7f|0x80),
+			1)
+	}
+	return b
+}
+
+// ConsumeVarint parses b as a varint-encoded uint64, reporting its length.
+// This returns a negative length upon an error (see ParseError).
+func ConsumeVarint(b []byte) (v uint64, n int) {
+	// TODO: Specialize for sizes 1 and 2 with mid-stack inlining.
+	var y uint64
+	if len(b) <= 0 {
+		return 0, errCodeTruncated
+	}
+	v = uint64(b[0])
+	if v < 0x80 {
+		return v, 1
+	}
+	v -= 0x80
+
+	if len(b) <= 1 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[1])
+	v += y << 7
+	if y < 0x80 {
+		return v, 2
+	}
+	v -= 0x80 << 7
+
+	if len(b) <= 2 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[2])
+	v += y << 14
+	if y < 0x80 {
+		return v, 3
+	}
+	v -= 0x80 << 14
+
+	if len(b) <= 3 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[3])
+	v += y << 21
+	if y < 0x80 {
+		return v, 4
+	}
+	v -= 0x80 << 21
+
+	if len(b) <= 4 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[4])
+	v += y << 28
+	if y < 0x80 {
+		return v, 5
+	}
+	v -= 0x80 << 28
+
+	if len(b) <= 5 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[5])
+	v += y << 35
+	if y < 0x80 {
+		return v, 6
+	}
+	v -= 0x80 << 35
+
+	if len(b) <= 6 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[6])
+	v += y << 42
+	if y < 0x80 {
+		return v, 7
+	}
+	v -= 0x80 << 42
+
+	if len(b) <= 7 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[7])
+	v += y << 49
+	if y < 0x80 {
+		return v, 8
+	}
+	v -= 0x80 << 49
+
+	if len(b) <= 8 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[8])
+	v += y << 56
+	if y < 0x80 {
+		return v, 9
+	}
+	v -= 0x80 << 56
+
+	if len(b) <= 9 {
+		return 0, errCodeTruncated
+	}
+	y = uint64(b[9])
+	v += y << 63
+	if y < 2 {
+		return v, 10
+	}
+	return 0, errCodeOverflow
+}
+
+// SizeVarint returns the encoded size of a varint.
+// The size is guaranteed to be within 1 and 10, inclusive.
+func SizeVarint(v uint64) int {
+	return 1 + (bits.Len64(v)-1)/7
+}
+
+// AppendFixed32 appends v to b as a little-endian uint32.
+func AppendFixed32(b []byte, v uint32) []byte {
+	return append(b,
+		byte(v>>0),
+		byte(v>>8),
+		byte(v>>16),
+		byte(v>>24))
+}
+
+// ConsumeFixed32 parses b as a little-endian uint32, reporting its length.
+// This returns a negative length upon an error (see ParseError).
+func ConsumeFixed32(b []byte) (v uint32, n int) {
+	if len(b) < 4 {
+		return 0, errCodeTruncated
+	}
+	v = uint32(b[0])<<0 | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+	return v, 4
+}
+
+// SizeFixed32 returns the encoded size of a fixed32; which is always 4.
+func SizeFixed32() int {
+	return 4
+}
+
+// AppendFixed64 appends v to b as a little-endian uint64.
+func AppendFixed64(b []byte, v uint64) []byte {
+	return append(b,
+		byte(v>>0),
+		byte(v>>8),
+		byte(v>>16),
+		byte(v>>24),
+		byte(v>>32),
+		byte(v>>40),
+		byte(v>>48),
+		byte(v>>56))
+}
+
+// ConsumeFixed64 parses b as a little-endian uint64, reporting its length.
+// This returns a negative length upon an error (see ParseError).
+func ConsumeFixed64(b []byte) (v uint64, n int) {
+	if len(b) < 8 {
+		return 0, errCodeTruncated
+	}
+	v = uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
+	return v, 8
+}
+
+// SizeFixed64 returns the encoded size of a fixed64; which is always 8.
+func SizeFixed64() int {
+	return 8
+}
+
+// AppendBytes appends v to b as a length-prefixed bytes value.
+func AppendBytes(b []byte, v []byte) []byte {
+	return append(AppendVarint(b, uint64(len(v))), v...)
+}
+
+// ConsumeBytes parses b as a length-prefixed bytes value, reporting its length.
+// This returns a negative length upon an error (see ParseError).
+func ConsumeBytes(b []byte) (v []byte, n int) {
+	m, n := ConsumeVarint(b)
+	if n < 0 {
+		return nil, n // forward error code
+	}
+	if m > uint64(len(b[n:])) {
+		return nil, errCodeTruncated
+	}
+	return b[n:][:m], n + int(m)
+}
+
+// SizeBytes returns the encoded size of a length-prefixed bytes value,
+// given only the length.
+func SizeBytes(n int) int {
+	return SizeVarint(uint64(n)) + n
+}
+
+// AppendGroup appends v to b as group value, with a trailing end group marker.
+// The value v must not contain the end marker.
+func AppendGroup(b []byte, num Number, v []byte) []byte {
+	return AppendVarint(append(b, v...), EncodeTag(num, EndGroupType))
+}
+
+// ConsumeGroup parses b as a group value until the trailing end group marker,
+// and verifies that the end marker matches the provided num. The value v
+// does not contain the end marker, while the length does contain the end marker.
+// This returns a negative length upon an error (see ParseError).
+func ConsumeGroup(num Number, b []byte) (v []byte, n int) {
+	n = ConsumeFieldValue(num, StartGroupType, b)
+	if n < 0 {
+		return nil, n // forward error code
+	}
+	b = b[:n]
+
+	// Truncate off end group marker, but need to handle denormalized varints.
+	// Assuming end marker is never 0 (which is always the case since
+	// EndGroupType is non-zero), we can truncate all trailing bytes where the
+	// lower 7 bits are all zero (implying that the varint is denormalized).
+	for len(b) > 0 && b[len(b)-1]&0x7f == 0 {
+		b = b[:len(b)-1]
+	}
+	b = b[:len(b)-SizeTag(num)]
+	return b, n
+}
+
+// SizeGroup returns the encoded size of a group, given only the length.
+func SizeGroup(num Number, n int) int {
+	return n + SizeTag(num)
+}
+
+// DecodeTag decodes the field Number and wire Type from its unified form.
+// The Number is -1 if the decoded field number overflows.
+// Other than overflow, this does not check for field number validity.
+func DecodeTag(x uint64) (Number, Type) {
+	// NOTE: MessageSet allows for larger field numbers than normal.
+	if x>>3 > uint64(math.MaxInt32) {
+		return -1, 0
+	}
+	return Number(x >> 3), Type(x & 7)
+}
+
+// EncodeTag encodes the field Number and wire Type into its unified form.
+func EncodeTag(num Number, typ Type) uint64 {
+	return uint64(num)<<3 | uint64(typ&7)
+}
+
+// DecodeZigZag decodes a zig-zag-encoded uint64 as an int64.
+//	Input:  {…,  5,  3,  1,  0,  2,  4,  6, …}
+//	Output: {…, -3, -2, -1,  0, +1, +2, +3, …}
+func DecodeZigZag(x uint64) int64 {
+	return int64(x>>1) ^ int64(x)<<63>>63
+}
+
+// EncodeZigZag encodes an int64 as a zig-zag-encoded uint64.
+//	Input:  {…, -3, -2, -1,  0, +1, +2, +3, …}
+//	Output: {…,  5,  3,  1,  0,  2,  4,  6, …}
+func EncodeZigZag(x int64) uint64 {
+	return uint64(x<<1) ^ uint64(x>>63)
+}
+
+// DecodeBool decodes a uint64 as a bool.
+//	Input:  {    0,    1,    2, …}
+//	Output: {false, true, true, …}
+func DecodeBool(x uint64) bool {
+	return x != 0
+}
+
+// EncodeBool encodes a bool as a uint64.
+//	Input:  {false, true}
+//	Output: {    0,    1}
+func EncodeBool(x bool) uint64 {
+	if x {
+		return 1
+	}
+	return 0
+}

+ 0 - 4
proto/clone.go

@@ -214,10 +214,6 @@ func mergeExtension(out, in *extensionMap) {
 			mergeAny(v, reflect.ValueOf(eIn.Value), false, nil)
 			eOut.Value = v.Interface()
 		}
-		if eIn.Raw != nil {
-			eOut.Raw = make([]byte, len(eIn.Raw))
-			copy(eOut.Raw, eIn.Raw)
-		}
 
 		out.Set(extNum, eOut)
 		return true

+ 3 - 40
proto/equal.go

@@ -223,12 +223,7 @@ func equalExtensions(base reflect.Type, em1, em2 *extensionMap) bool {
 		m2 := extensionAsLegacyType(e2.Value)
 
 		if m1 == nil && m2 == nil {
-			// Both have only encoded form.
-			if bytes.Equal(e1.Raw, e2.Raw) {
-				return true
-			}
-			// The bytes are different, but the extensions might still be
-			// equal. We need to decode them to compare.
+			return true
 		}
 
 		if m1 != nil && m2 != nil {
@@ -240,40 +235,8 @@ func equalExtensions(base reflect.Type, em1, em2 *extensionMap) bool {
 			return true
 		}
 
-		// At least one is encoded. To do a semantically correct comparison
-		// we need to unmarshal them first.
-		var desc *ExtensionDesc
-		mz := reflect.Zero(reflect.PtrTo(base)).Interface().(Message)
-		if m := RegisteredExtensions(mz); m != nil {
-			desc = m[int32(extNum)]
-		}
-		if desc == nil {
-			// If both have only encoded form and the bytes are the same,
-			// it is handled above. We get here when the bytes are different.
-			// We don't know how to decode it, so just compare them as byte
-			// slices.
-			log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
-			equal = false
-			return false
-		}
-		var err error
-		if m1 == nil {
-			m1, err = decodeExtension(e1.Raw, desc)
-		}
-		if m2 == nil && err == nil {
-			m2, err = decodeExtension(e2.Raw, desc)
-		}
-		if err != nil {
-			// The encoded form is invalid.
-			log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
-			equal = false
-			return false
-		}
-		if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
-			equal = false
-			return false
-		}
-		return true
+		equal = false
+		return false
 	})
 
 	return equal

+ 84 - 15
proto/extensions.go

@@ -15,6 +15,7 @@ import (
 	"reflect"
 	"sync"
 
+	"github.com/golang/protobuf/internal/wire"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
 	"github.com/golang/protobuf/v2/runtime/protoiface"
 	"github.com/golang/protobuf/v2/runtime/protoimpl"
@@ -97,11 +98,38 @@ func isRepeatedExtension(ed *ExtensionDesc) bool {
 
 // SetRawExtension is for testing only.
 func SetRawExtension(base Message, id int32, b []byte) {
-	epb, err := extendable(base)
-	if err != nil {
+	v := reflect.ValueOf(base)
+	if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() || v.Elem().Kind() != reflect.Struct {
+		return
+	}
+	v = v.Elem().FieldByName("XXX_unrecognized")
+	if !v.IsValid() {
 		return
 	}
-	epb.Set(protoreflect.FieldNumber(id), Extension{Raw: b})
+
+	// Verify that the raw field is valid.
+	for b0 := b; len(b0) > 0; {
+		fieldNum, _, n := wire.ConsumeField(b0)
+		if int32(fieldNum) != id {
+			panic(fmt.Sprintf("mismatching field number: got %d, want %d", fieldNum, id))
+		}
+		b0 = b0[n:]
+	}
+
+	fnum := protoreflect.FieldNumber(id)
+	v.SetBytes(append(removeRawFields(v.Bytes(), fnum), b...))
+}
+
+func removeRawFields(b []byte, fnum protoreflect.FieldNumber) []byte {
+	out := b[:0]
+	for len(b) > 0 {
+		got, _, n := wire.ConsumeField(b)
+		if got != fnum {
+			out = append(out, b[:n]...)
+		}
+		b = b[n:]
+	}
+	return out
 }
 
 // isExtensionField returns true iff the given field number is in an extension range.
@@ -172,13 +200,24 @@ func extensionProperties(pb Message, ed *ExtensionDesc) *Properties {
 func HasExtension(pb Message, extension *ExtensionDesc) bool {
 	// TODO: Check types, field numbers, etc.?
 	epb, err := extendable(pb)
-	if err != nil {
+	if err != nil || epb == nil {
 		return false
 	}
-	if epb == nil {
-		return false
+	if epb.Has(protoreflect.FieldNumber(extension.Field)) {
+		return true
 	}
-	return epb.Has(protoreflect.FieldNumber(extension.Field))
+
+	// Check whether this field exists in raw form.
+	unrecognized := reflect.ValueOf(pb).Elem().FieldByName("XXX_unrecognized")
+	fnum := protoreflect.FieldNumber(extension.Field)
+	for b := unrecognized.Bytes(); len(b) > 0; {
+		got, _, n := wire.ConsumeField(b)
+		if got == fnum {
+			return true
+		}
+		b = b[n:]
+	}
+	return false
 }
 
 // ClearExtension removes the given extension from pb.
@@ -211,16 +250,24 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 		return nil, err
 	}
 
-	if epb == nil {
-		return defaultExtensionValue(pb, extension)
+	unrecognized := reflect.ValueOf(pb).Elem().FieldByName("XXX_unrecognized")
+	var out []byte
+	fnum := protoreflect.FieldNumber(extension.Field)
+	for b := unrecognized.Bytes(); len(b) > 0; {
+		got, _, n := wire.ConsumeField(b)
+		if got == fnum {
+			out = append(out, b[:n]...)
+		}
+		b = b[n:]
 	}
-	if !epb.Has(protoreflect.FieldNumber(extension.Field)) {
+
+	if !epb.Has(protoreflect.FieldNumber(extension.Field)) && len(out) == 0 {
 		// defaultExtensionValue returns the default value or
 		// ErrMissingExtension if there is no default.
 		return defaultExtensionValue(pb, extension)
 	}
-	e := epb.Get(protoreflect.FieldNumber(extension.Field))
 
+	e := epb.Get(protoreflect.FieldNumber(extension.Field))
 	if e.Value != nil {
 		// Already decoded. Check the descriptor, though.
 		if e.Desc != extension {
@@ -232,12 +279,13 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 		return extensionAsLegacyType(e.Value), nil
 	}
 
+	// Descriptor without type information.
 	if extension.ExtensionType == nil {
-		// incomplete descriptor
-		return e.Raw, nil
+		return out, nil
 	}
 
-	v, err := decodeExtension(e.Raw, extension)
+	// TODO: Remove this logic for automatically unmarshaling the unknown fields.
+	v, err := decodeExtension(out, extension)
 	if err != nil {
 		return nil, err
 	}
@@ -246,7 +294,7 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 	// That way it is safe to mutate what we return.
 	e.Value = extensionAsStorageType(v)
 	e.Desc = extension
-	e.Raw = nil
+	unrecognized.SetBytes(removeRawFields(unrecognized.Bytes(), fnum))
 	epb.Set(protoreflect.FieldNumber(extension.Field), e)
 	return extensionAsLegacyType(e.Value), nil
 }
@@ -367,6 +415,27 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
 		extensions = append(extensions, desc)
 		return true
 	})
+
+	unrecognized := reflect.ValueOf(pb).Elem().FieldByName("XXX_unrecognized")
+	if b := unrecognized.Bytes(); len(b) > 0 {
+		fieldNums := make(map[int32]bool)
+		for len(b) > 0 {
+			fnum, _, n := wire.ConsumeField(b)
+			if isExtensionField(pb, int32(fnum)) {
+				fieldNums[int32(fnum)] = true
+			}
+			b = b[n:]
+		}
+
+		for id := range fieldNums {
+			desc := registeredExtensions[id]
+			if desc == nil {
+				desc = &ExtensionDesc{Field: id}
+			}
+			extensions = append(extensions, desc)
+		}
+	}
+
 	return extensions, nil
 }
 

+ 0 - 12
proto/extensions_test.go

@@ -7,7 +7,6 @@ package proto_test
 import (
 	"bytes"
 	"fmt"
-	"io"
 	"reflect"
 	"sort"
 	"strings"
@@ -39,17 +38,6 @@ func TestGetExtensionsWithMissingExtensions(t *testing.T) {
 	}
 }
 
-func TestGetExtensionWithEmptyBuffer(t *testing.T) {
-	// Make sure that GetExtension returns an error if its
-	// undecoded buffer is empty.
-	msg := &pb.MyMessage{}
-	proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
-	_, err := proto.GetExtension(msg, pb.E_Ext_More)
-	if want := io.ErrUnexpectedEOF; err != want {
-		t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
-	}
-}
-
 func TestGetExtensionForIncompleteDesc(t *testing.T) {
 	msg := &pb.MyMessage{Count: proto.Int32(0)}
 	extdesc1 := &proto.ExtensionDesc{

+ 7 - 17
proto/message_set.go

@@ -10,6 +10,7 @@ package proto
 
 import (
 	"errors"
+	"reflect"
 
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
@@ -116,35 +117,24 @@ func skipVarint(buf []byte) []byte {
 
 // unmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
 // It is called by Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
-func unmarshalMessageSet(buf []byte, exts interface{}) error {
-	m := extensionFieldsOf(exts)
-
+func unmarshalMessageSet(buf []byte, mi Message, exts interface{}) error {
 	ms := new(messageSet)
 	if err := Unmarshal(buf, ms); err != nil {
 		return err
 	}
+	unrecognized := reflect.ValueOf(mi).Elem().FieldByName("XXX_unrecognized").Addr().Interface().(*[]byte)
+
 	for _, item := range ms.Item {
 		id := protoreflect.FieldNumber(*item.TypeId)
 		msg := item.Message
 
 		// Restore wire type and field number varint, plus length varint.
-		// Be careful to preserve duplicate items.
 		b := EncodeVarint(uint64(id)<<3 | WireBytes)
-		if m.Has(id) {
-			ext := m.Get(id)
-
-			// Existing data; rip off the tag and length varint
-			// so we join the new data correctly.
-			// We can assume that ext.Raw is set because we are unmarshaling.
-			o := ext.Raw[len(b):]   // skip wire type and field number
-			_, n := DecodeVarint(o) // calculate length of length varint
-			o = o[n:]               // skip length varint
-			msg = append(o, msg...) // join old data and new data
-		}
 		b = append(b, EncodeVarint(uint64(len(msg)))...)
 		b = append(b, msg...)
 
-		m.Set(id, Extension{Raw: b})
+		*unrecognized = append(*unrecognized, b...)
 	}
-	return nil
+
+	return unmarshalExtensions(mi, unrecognized)
 }

+ 8 - 2
proto/message_set_test.go

@@ -38,13 +38,19 @@ func TestUnmarshalMessageSetWithDuplicate(t *testing.T) {
 			Tag{1, StartGroup},
 			Message{
 				Tag{2, Varint}, Uvarint(12345),
-				Tag{3, Bytes}, Bytes("hoohah"),
+				Tag{3, Bytes}, Bytes("hoo"),
+			},
+			Tag{1, EndGroup},
+			Tag{1, StartGroup},
+			Message{
+				Tag{2, Varint}, Uvarint(12345),
+				Tag{3, Bytes}, Bytes("hah"),
 			},
 			Tag{1, EndGroup},
 		}
 	*/
 	var want []byte
-	fmt.Sscanf("0b10b9601a06686f6f6861680c", "%x", &want)
+	fmt.Sscanf("0b10b9601a03686f6f0c0b10b9601a036861680c", "%x", &want)
 
 	var m MyMessageSet
 	if err := proto.Unmarshal(in, &m); err != nil {

+ 70 - 38
proto/table_marshal.go

@@ -16,6 +16,7 @@ import (
 	"sync/atomic"
 	"unicode/utf8"
 
+	"github.com/golang/protobuf/internal/wire"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
@@ -160,7 +161,7 @@ func (u *marshalInfo) size(ptr pointer) int {
 	if u.extensions.IsValid() {
 		e := ptr.offset(u.extensions).toExtensions()
 		if u.messageset {
-			n += u.sizeMessageSet(e)
+			n += u.sizeMessageSet(e, *ptr.offset(u.unrecognized).toBytes())
 		} else {
 			n += u.sizeExtensions(e)
 		}
@@ -169,7 +170,7 @@ func (u *marshalInfo) size(ptr pointer) int {
 		m := *ptr.offset(u.v1extensions).toOldExtensions()
 		n += u.sizeV1Extensions(m)
 	}
-	if u.unrecognized.IsValid() {
+	if u.unrecognized.IsValid() && !u.messageset {
 		s := *ptr.offset(u.unrecognized).toBytes()
 		n += len(s)
 	}
@@ -212,7 +213,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
 	if u.extensions.IsValid() {
 		e := ptr.offset(u.extensions).toExtensions()
 		if u.messageset {
-			b, err = u.appendMessageSet(b, e, deterministic)
+			b, err = u.appendMessageSet(b, e, *ptr.offset(u.unrecognized).toBytes(), deterministic)
 		} else {
 			b, err = u.appendExtensions(b, e, deterministic)
 		}
@@ -266,7 +267,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
 			return b, err
 		}
 	}
-	if u.unrecognized.IsValid() {
+	if u.unrecognized.IsValid() && !u.messageset {
 		s := *ptr.offset(u.unrecognized).toBytes()
 		b = append(b, s...)
 	}
@@ -2373,9 +2374,7 @@ func (u *marshalInfo) sizeExtensions(ext *XXX_InternalExtensions) int {
 	n := 0
 	m.Range(func(_ protoreflect.FieldNumber, e Extension) bool {
 		if e.Value == nil || e.Desc == nil {
-			// Extension is only in its encoded form.
-			n += len(e.Raw)
-			return true
+			return true // should never happen
 		}
 
 		// We don't skip extensions that have an encoded form set,
@@ -2405,9 +2404,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 	if m.Len() <= 1 {
 		m.Range(func(_ protoreflect.FieldNumber, e Extension) bool {
 			if e.Value == nil || e.Desc == nil {
-				// Extension is only in its encoded form.
-				b = append(b, e.Raw...)
-				return true
+				return true // should never happen
 			}
 
 			// We don't skip extensions that have an encoded form set,
@@ -2439,9 +2436,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 	for _, k := range keys {
 		e := m.Get(protoreflect.FieldNumber(k))
 		if e.Value == nil || e.Desc == nil {
-			// Extension is only in its encoded form.
-			b = append(b, e.Raw...)
-			continue
+			continue // should never happen
 		}
 
 		// We don't skip extensions that have an encoded form set,
@@ -2469,7 +2464,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 
 // sizeMessageSet computes the size of encoded data for a XXX_InternalExtensions field
 // in message set format (above).
-func (u *marshalInfo) sizeMessageSet(ext *XXX_InternalExtensions) int {
+func (u *marshalInfo) sizeMessageSet(ext *XXX_InternalExtensions, unk []byte) int {
 	m := extensionFieldsOf(ext)
 	if m == nil {
 		return 0
@@ -2481,11 +2476,7 @@ func (u *marshalInfo) sizeMessageSet(ext *XXX_InternalExtensions) int {
 		n += SizeVarint(uint64(id)) + 1 // type_id, tag = 2 (size=1)
 
 		if e.Value == nil || e.Desc == nil {
-			// Extension is only in its encoded form.
-			msgWithLen := skipVarint(e.Raw) // skip old tag, but leave the length varint
-			siz := len(msgWithLen)
-			n += siz + 1 // message, tag = 3 (size=1)
-			return true
+			return true // should never happen
 		}
 
 		// We don't skip extensions that have an encoded form set,
@@ -2498,12 +2489,29 @@ func (u *marshalInfo) sizeMessageSet(ext *XXX_InternalExtensions) int {
 		n += ei.sizer(p, 1) // message, tag = 3 (size=1)
 		return true
 	})
+
+	// Extension is only in its encoded form.
+	for len(unk) > 0 {
+		id, _, fieldLen := wire.ConsumeField(unk)
+		if fieldLen < 0 {
+			break
+		}
+
+		msgWithLen := skipVarint(unk[:fieldLen]) // skip old tag, but leave the length varint
+		siz := len(msgWithLen)
+		n += 2                          // start group, end group. tag = 1 (size=1)
+		n += SizeVarint(uint64(id)) + 1 // type_id, tag = 2 (size=1)
+		n += siz + 1                    // message, tag = 3 (size=1)
+
+		unk = unk[fieldLen:]
+	}
+
 	return n
 }
 
 // appendMessageSet marshals a XXX_InternalExtensions field in message set format (above)
 // to the end of byte slice b.
-func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, deterministic bool) ([]byte, error) {
+func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, unk []byte, deterministic bool) ([]byte, error) {
 	m := extensionFieldsOf(ext)
 	if m == nil {
 		return b, nil
@@ -2521,12 +2529,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 			b = appendVarint(b, uint64(id))
 
 			if e.Value == nil || e.Desc == nil {
-				// Extension is only in its encoded form.
-				msgWithLen := skipVarint(e.Raw) // skip old tag, but leave the length varint
-				b = append(b, 3<<3|WireBytes)
-				b = append(b, msgWithLen...)
-				b = append(b, 1<<3|WireEndGroup)
-				return true
+				return true // should never happen
 			}
 
 			// We don't skip extensions that have an encoded form set,
@@ -2544,6 +2547,25 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 			err = nerr.E
 			return true
 		})
+
+		// Extension is only in its encoded form.
+		for len(unk) > 0 {
+			id, _, fieldLen := wire.ConsumeField(unk)
+			if fieldLen < 0 {
+				return b, wire.ParseError(fieldLen)
+			}
+
+			msgWithLen := skipVarint(unk[:fieldLen]) // skip old tag, but leave the length varint
+			b = append(b, 1<<3|WireStartGroup)
+			b = append(b, 2<<3|WireVarint)
+			b = appendVarint(b, uint64(id))
+			b = append(b, 3<<3|WireBytes)
+			b = append(b, msgWithLen...)
+			b = append(b, 1<<3|WireEndGroup)
+
+			unk = unk[fieldLen:]
+		}
+
 		return b, err
 	}
 
@@ -2562,12 +2584,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 		b = appendVarint(b, uint64(id))
 
 		if e.Value == nil || e.Desc == nil {
-			// Extension is only in its encoded form.
-			msgWithLen := skipVarint(e.Raw) // skip old tag, but leave the length varint
-			b = append(b, 3<<3|WireBytes)
-			b = append(b, msgWithLen...)
-			b = append(b, 1<<3|WireEndGroup)
-			continue
+			continue // should never happen
 		}
 
 		// We don't skip extensions that have an encoded form set,
@@ -2583,6 +2600,25 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 			return b, err
 		}
 	}
+
+	// Extension is only in its encoded form.
+	for len(unk) > 0 {
+		id, _, fieldLen := wire.ConsumeField(unk)
+		if fieldLen < 0 {
+			return b, wire.ParseError(fieldLen)
+		}
+
+		msgWithLen := skipVarint(unk[:fieldLen]) // skip old tag, but leave the length varint
+		b = append(b, 1<<3|WireStartGroup)
+		b = append(b, 2<<3|WireVarint)
+		b = appendVarint(b, uint64(id))
+		b = append(b, 3<<3|WireBytes)
+		b = append(b, msgWithLen...)
+		b = append(b, 1<<3|WireEndGroup)
+
+		unk = unk[fieldLen:]
+	}
+
 	return b, nerr.E
 }
 
@@ -2595,9 +2631,7 @@ func (u *marshalInfo) sizeV1Extensions(m map[int32]Extension) int {
 	n := 0
 	for _, e := range m {
 		if e.Value == nil || e.Desc == nil {
-			// Extension is only in its encoded form.
-			n += len(e.Raw)
-			continue
+			continue // should never happen
 		}
 
 		// We don't skip extensions that have an encoded form set,
@@ -2630,9 +2664,7 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
 	for _, k := range keys {
 		e := m[int32(k)]
 		if e.Value == nil || e.Desc == nil {
-			// Extension is only in its encoded form.
-			b = append(b, e.Raw...)
-			continue
+			continue // should never happen
 		}
 
 		// We don't skip extensions that have an encoded form set,

+ 67 - 45
proto/table_unmarshal.go

@@ -16,7 +16,7 @@ import (
 	"sync/atomic"
 	"unicode/utf8"
 
-	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/internal/wire"
 )
 
 // Unmarshal is the entry point from the generated .pb.go files.
@@ -111,7 +111,7 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 		u.computeUnmarshalInfo()
 	}
 	if u.isMessageSet {
-		return unmarshalMessageSet(b, m.offset(u.extensions).toExtensions())
+		return unmarshalMessageSet(b, m.asPointerTo(u.typ).Interface().(Message), m.offset(u.extensions).toExtensions())
 	}
 	var reqMask uint64 // bitmask of required fields we've seen.
 	var errLater error
@@ -187,26 +187,9 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 		// Keep unrecognized data around.
 		// maybe in extensions, maybe in the unrecognized field.
 		z := m.offset(u.unrecognized).toBytes()
-		var emap *extensionMap
-		var e Extension
 		for _, r := range u.extensionRanges {
 			if uint64(r.Start) <= tag && tag <= uint64(r.End) {
 				hasExtensions = true
-				if u.extensions.IsValid() {
-					mp := m.offset(u.extensions).toExtensions()
-					emap = extensionFieldsOf(mp)
-					e = emap.Get(protoreflect.FieldNumber(tag))
-					z = &e.Raw
-					break
-				}
-				if u.oldExtensions.IsValid() {
-					p := m.offset(u.oldExtensions).toOldExtensions()
-					emap = extensionFieldsOf(p)
-					e = emap.Get(protoreflect.FieldNumber(tag))
-					z = &e.Raw
-					break
-				}
-				panic("no extensions field available")
 			}
 		}
 
@@ -219,37 +202,15 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 		}
 		*z = encodeVarint(*z, tag<<3|uint64(wire))
 		*z = append(*z, b0[:len(b0)-len(b)]...)
-
-		if emap != nil {
-			emap.Set(protoreflect.FieldNumber(tag), e)
-		}
 	}
 
 	// If there were unknown extensions, eagerly unmarshal them.
 	if hasExtensions {
+		var nerr nonFatal
 		mi := m.asPointerTo(u.typ).Interface().(Message)
-		ep, _ := extendable(mi)
-		if ep != nil {
-			var errFatal error
-			emap := RegisteredExtensions(mi) // map[int32]*ExtensionDesc
-			ep.Range(func(id protoreflect.FieldNumber, ef Extension) bool {
-				ed := emap[int32(id)]
-				if ed != nil {
-					_, err := GetExtension(mi, ed)
-					var nerr nonFatal
-					if !nerr.Merge(err) {
-						errFatal = err
-						return false
-					}
-					if errLater == nil {
-						errLater = nerr.E
-					}
-				}
-				return true
-			})
-			if errFatal != nil {
-				return errFatal
-			}
+		unrecognized := m.offset(u.unrecognized).toBytes()
+		if err := unmarshalExtensions(mi, unrecognized); !nerr.Merge(err) {
+			return err
 		}
 	}
 
@@ -265,6 +226,67 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 	return errLater
 }
 
+func unmarshalExtensions(mi Message, unrecognized *[]byte) error {
+	extFields, _ := extendable(mi)
+	if extFields == nil {
+		return nil
+	}
+
+	emap := RegisteredExtensions(mi) // map[int32]*ExtensionDesc
+	oldUnknownFields := *unrecognized
+	newUnknownFields := oldUnknownFields[:0]
+
+	for len(oldUnknownFields) > 0 {
+		fieldNum, wireTyp, tagLen := wire.ConsumeTag(oldUnknownFields)
+		if tagLen < 0 {
+			return wire.ParseError(tagLen)
+		}
+		extDesc, ok := emap[int32(fieldNum)]
+		if !ok || extDesc.ExtensionType == nil {
+			valLen := wire.ConsumeFieldValue(fieldNum, wireTyp, oldUnknownFields[tagLen:])
+			if valLen < 0 {
+				return wire.ParseError(valLen)
+			}
+
+			newUnknownFields = append(newUnknownFields, oldUnknownFields[:tagLen+valLen]...)
+			oldUnknownFields = oldUnknownFields[tagLen+valLen:]
+			continue
+		}
+		oldUnknownFields = oldUnknownFields[tagLen:]
+
+		if err := checkExtensionTypeAndRanges(mi, extDesc); err != nil {
+			return err
+		}
+
+		// Create a new value or reuse an existing one.
+		fieldType := reflect.TypeOf(extDesc.ExtensionType)
+		fieldVal := reflect.New(fieldType).Elem() // E.g., *int32, *Message, []T
+		if extField := extFields.Get(fieldNum); extField.Value != nil {
+			fieldVal.Set(reflect.ValueOf(extensionAsLegacyType(extField.Value)))
+		}
+
+		// Unmarshal the value.
+		var err error
+		var nerr nonFatal
+		unmarshal := typeUnmarshaler(fieldType, extDesc.Tag)
+		if oldUnknownFields, err = unmarshal(oldUnknownFields, valToPointer(fieldVal.Addr()), int(wireTyp)); !nerr.Merge(err) {
+			return err
+		}
+
+		// Store the value into the extension field.
+		extFields.Set(fieldNum, Extension{
+			Desc:  extDesc,
+			Value: extensionAsStorageType(fieldVal.Interface()),
+		})
+	}
+
+	if len(newUnknownFields) == 0 {
+		newUnknownFields = nil // NOTE: code actually depends on this...
+	}
+	*unrecognized = newUnknownFields
+	return nil
+}
+
 // computeUnmarshalInfo fills in u with information for use
 // in unmarshaling protocol buffers of type u.typ.
 func (u *unmarshalInfo) computeUnmarshalInfo() {

+ 0 - 5
proto/text.go

@@ -669,16 +669,11 @@ func (tm *textMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error
 	sort.Sort(fieldNumSlice(ids))
 
 	for _, extNum := range ids {
-		ext := ep.Get(extNum)
 		var desc *ExtensionDesc
 		if emap != nil {
 			desc = emap[int32(extNum)]
 		}
 		if desc == nil {
-			// Unknown extension.
-			if err := writeUnknownStruct(w, ext.Raw); err != nil {
-				return err
-			}
 			continue
 		}