Pārlūkot izejas kodu

reflect/protoreflect: add alternative message reflection API

Added API:
	Message.Len
	Message.Range
	Message.Has
	Message.Clear
	Message.Get
	Message.Set
	Message.Mutable
	Message.NewMessage
	Message.WhichOneof
	Message.GetUnknown
	Message.SetUnknown

Deprecated API (to be removed in subsequent CL):
	Message.KnownFields
	Message.UnknownFields

The primary difference with the new API is that the top-level
Message methods are keyed by FieldDescriptor rather than FieldNumber
with the following semantics:
* For known fields, the FieldDescriptor must exactly match the
field descriptor known by the message.
* For extension fields, the FieldDescriptor must implement ExtensionType,
where ContainingMessage.FullName matches the message name, and
the field number is within the message's extension range.
When setting an extension field, it automatically stores
the extension type information.
* Extension fields are always considered nullable,
implying that repeated extension fields are nullable.
That is, you can distinguish between a unpopulated list and an empty list.
* Message.Get always returns a valid Value even if unpopulated.
The behavior is already well-defined for scalars, but for unpopulated
composite types, it now returns an empty read-only version of it.

Change-Id: Ia120630b4db221aeaaf743d0f64160e1a61a0f61
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175458
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 gadi atpakaļ
vecāks
revīzija
378c1329de
39 mainītis faili ar 1803 papildinājumiem un 1692 dzēšanām
  1. 8 11
      encoding/bench_test.go
  2. 16 48
      encoding/protojson/decode.go
  3. 0 13
      encoding/protojson/decode_test.go
  4. 24 28
      encoding/protojson/encode.go
  5. 2 17
      encoding/protojson/encode_test.go
  6. 78 74
      encoding/protojson/well_known_types.go
  7. 26 56
      encoding/prototext/decode.go
  8. 29 33
      encoding/prototext/encode.go
  9. 4 26
      encoding/prototext/encode_test.go
  10. 16 16
      internal/cmd/generate-types/proto.go
  11. 4 7
      internal/fileinit/fileinit_test.go
  12. 80 249
      internal/impl/legacy_test.go
  13. 222 127
      internal/impl/message.go
  14. 148 0
      internal/impl/message_deprecated.go
  15. 171 45
      internal/impl/message_field.go
  16. 7 3
      internal/impl/message_field_extension.go
  17. 5 1
      internal/impl/message_field_unknown.go
  18. 207 220
      internal/impl/message_test.go
  19. 33 55
      internal/testprotos/irregular/irregular.go
  20. 95 0
      internal/testprotos/irregular/irregular_deprecated.go
  21. 10 8
      internal/value/list.go
  22. 11 9
      internal/value/map.go
  23. 28 39
      proto/decode.go
  24. 10 10
      proto/decode_gen.go
  25. 6 14
      proto/decode_test.go
  26. 26 41
      proto/encode.go
  27. 6 6
      proto/encode_gen.go
  28. 3 15
      proto/encode_test.go
  29. 72 96
      proto/equal.go
  30. 4 43
      proto/equal_test.go
  31. 18 31
      proto/isinit.go
  32. 24 0
      proto/reset.go
  33. 4 18
      proto/size.go
  34. 186 0
      reflect/protoreflect/deprecated.go
  35. 2 10
      reflect/protoreflect/type.go
  36. 125 194
      reflect/protoreflect/value.go
  37. 4 11
      reflect/protoreflect/value_union.go
  38. 5 8
      runtime/protoimpl/impl.go
  39. 84 110
      testing/prototest/prototest.go

+ 8 - 11
encoding/bench_test.go

@@ -44,31 +44,28 @@ func fillMessage(m pref.Message, level int) {
 		return
 	}
 
-	knownFields := m.KnownFields()
 	fieldDescs := m.Descriptor().Fields()
 	for i := 0; i < fieldDescs.Len(); i++ {
 		fd := fieldDescs.Get(i)
-		num := fd.Number()
 		switch {
 		case fd.IsList():
-			setList(knownFields.Get(num).List(), fd, level)
+			setList(m.Mutable(fd).List(), fd, level)
 		case fd.IsMap():
-			setMap(knownFields.Get(num).Map(), fd, level)
+			setMap(m.Mutable(fd).Map(), fd, level)
 		default:
-			setScalarField(knownFields, fd, level)
+			setScalarField(m, fd, level)
 		}
 	}
 }
 
-func setScalarField(knownFields pref.KnownFields, fd pref.FieldDescriptor, level int) {
-	num := fd.Number()
+func setScalarField(m pref.Message, fd pref.FieldDescriptor, level int) {
 	switch fd.Kind() {
 	case pref.MessageKind, pref.GroupKind:
-		m := knownFields.NewMessage(num)
-		fillMessage(m, level+1)
-		knownFields.Set(num, pref.ValueOf(m))
+		m2 := m.NewMessage(fd)
+		fillMessage(m2, level+1)
+		m.Set(fd, pref.ValueOf(m2))
 	default:
-		knownFields.Set(num, scalarField(fd.Kind()))
+		m.Set(fd, scalarField(fd.Kind()))
 	}
 }
 

+ 16 - 48
encoding/protojson/decode.go

@@ -52,11 +52,10 @@ type UnmarshalOptions struct {
 // setting the fields. If it returns an error, the given message may be
 // partially set.
 func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
-	mr := m.ProtoReflect()
 	// TODO: Determine if we would like to have an option for merging or only
-	// have merging behavior.  We should at least be consistent with textproto
+	// have merging behavior. We should at least be consistent with textproto
 	// marshaling.
-	resetMessage(mr)
+	proto.Reset(m)
 
 	if o.Resolver == nil {
 		o.Resolver = protoregistry.GlobalTypes
@@ -64,7 +63,7 @@ func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
 	o.decoder = json.NewDecoder(b)
 
 	var nerr errors.NonFatal
-	if err := o.unmarshalMessage(mr, false); !nerr.Merge(err) {
+	if err := o.unmarshalMessage(m.ProtoReflect(), false); !nerr.Merge(err) {
 		return err
 	}
 
@@ -83,25 +82,6 @@ func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
 	return nerr.E
 }
 
-// resetMessage clears all fields of given protoreflect.Message.
-func resetMessage(m pref.Message) {
-	knownFields := m.KnownFields()
-	knownFields.Range(func(num pref.FieldNumber, _ pref.Value) bool {
-		knownFields.Clear(num)
-		return true
-	})
-	unknownFields := m.UnknownFields()
-	unknownFields.Range(func(num pref.FieldNumber, _ pref.RawFields) bool {
-		unknownFields.Set(num, nil)
-		return true
-	})
-	extTypes := knownFields.ExtensionTypes()
-	extTypes.Range(func(xt pref.ExtensionType) bool {
-		extTypes.Remove(xt)
-		return true
-	})
-}
-
 // unexpectedJSONError is an error that contains the unexpected json.Value. This
 // is returned by methods to provide callers the read json.Value that it did not
 // expect.
@@ -164,9 +144,7 @@ func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) erro
 	var seenOneofs set.Ints
 
 	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
 	fieldDescs := messageDesc.Fields()
-	xtTypes := knownFields.ExtensionTypes()
 
 Loop:
 	for {
@@ -200,20 +178,12 @@ Loop:
 		var fd pref.FieldDescriptor
 		if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
 			// Only extension names are in [name] format.
-			xtName := pref.FullName(name[1 : len(name)-1])
-			xt := xtTypes.ByName(xtName)
-			if xt == nil {
-				xt, err = o.findExtension(xtName)
-				if err != nil && err != protoregistry.NotFound {
-					return errors.New("unable to resolve [%v]: %v", xtName, err)
-				}
-				if xt != nil {
-					xtTypes.Register(xt)
-				}
-			}
-			if xt != nil {
-				fd = xt.Descriptor()
+			extName := pref.FullName(name[1 : len(name)-1])
+			extType, err := o.findExtension(extName)
+			if err != nil && err != protoregistry.NotFound {
+				return errors.New("unable to resolve [%v]: %v", extName, err)
 			}
+			fd = extType
 		} else {
 			// The name can either be the JSON name or the proto field name.
 			fd = fieldDescs.ByJSONName(name)
@@ -249,12 +219,12 @@ Loop:
 
 		switch {
 		case fd.IsList():
-			list := knownFields.Get(fd.Number()).List()
+			list := m.Mutable(fd).List()
 			if err := o.unmarshalList(list, fd); !nerr.Merge(err) {
 				return errors.New("%v|%q: %v", fd.FullName(), name, err)
 			}
 		case fd.IsMap():
-			mmap := knownFields.Get(fd.Number()).Map()
+			mmap := m.Mutable(fd).Map()
 			if err := o.unmarshalMap(mmap, fd); !nerr.Merge(err) {
 				return errors.New("%v|%q: %v", fd.FullName(), name, err)
 			}
@@ -269,7 +239,7 @@ Loop:
 			}
 
 			// Required or optional fields.
-			if err := o.unmarshalSingular(knownFields, fd); !nerr.Merge(err) {
+			if err := o.unmarshalSingular(m, fd); !nerr.Merge(err) {
 				return errors.New("%v|%q: %v", fd.FullName(), name, err)
 			}
 		}
@@ -305,16 +275,14 @@ func isNullValue(fd pref.FieldDescriptor) bool {
 
 // unmarshalSingular unmarshals to the non-repeated field specified by the given
 // FieldDescriptor.
-func (o UnmarshalOptions) unmarshalSingular(knownFields pref.KnownFields, fd pref.FieldDescriptor) error {
+func (o UnmarshalOptions) unmarshalSingular(m pref.Message, fd pref.FieldDescriptor) error {
 	var val pref.Value
 	var err error
-	num := fd.Number()
-
 	switch fd.Kind() {
 	case pref.MessageKind, pref.GroupKind:
-		m := knownFields.NewMessage(num)
-		err = o.unmarshalMessage(m, false)
-		val = pref.ValueOf(m)
+		m2 := m.NewMessage(fd)
+		err = o.unmarshalMessage(m2, false)
+		val = pref.ValueOf(m2)
 	default:
 		val, err = o.unmarshalScalar(fd)
 	}
@@ -323,7 +291,7 @@ func (o UnmarshalOptions) unmarshalSingular(knownFields pref.KnownFields, fd pre
 	if !nerr.Merge(err) {
 		return err
 	}
-	knownFields.Set(num, val)
+	m.Set(fd, val)
 	return nerr.E
 }
 

+ 0 - 13
encoding/protojson/decode_test.go

@@ -1370,19 +1370,6 @@ func TestUnmarshal(t *testing.T) {
 			})
 			return m
 		}(),
-	}, {
-		desc:         "extension field set to null",
-		inputMessage: &pb2.Extensions{},
-		inputText: `{
-  "[pb2.ExtensionsContainer.opt_ext_bool]": null,
-  "[pb2.ExtensionsContainer.opt_ext_nested]": null
-}`,
-		wantMessage: func() proto.Message {
-			m := &pb2.Extensions{}
-			setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, nil)
-			setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, nil)
-			return m
-		}(),
 	}, {
 		desc:         "extensions of repeated field contains null",
 		inputMessage: &pb2.Extensions{},

+ 24 - 28
encoding/protojson/encode.go

@@ -88,20 +88,17 @@ func (o MarshalOptions) marshalMessage(m pref.Message) error {
 // marshalFields marshals the fields in the given protoreflect.Message.
 func (o MarshalOptions) marshalFields(m pref.Message) error {
 	var nerr errors.NonFatal
-	fieldDescs := m.Descriptor().Fields()
-	knownFields := m.KnownFields()
 
 	// Marshal out known fields.
+	fieldDescs := m.Descriptor().Fields()
 	for i := 0; i < fieldDescs.Len(); i++ {
 		fd := fieldDescs.Get(i)
-		num := fd.Number()
-
-		if !knownFields.Has(num) {
+		if !m.Has(fd) {
 			continue
 		}
 
 		name := fd.JSONName()
-		val := knownFields.Get(num)
+		val := m.Get(fd)
 		if err := o.encoder.WriteName(name); !nerr.Merge(err) {
 			return err
 		}
@@ -111,7 +108,7 @@ func (o MarshalOptions) marshalFields(m pref.Message) error {
 	}
 
 	// Marshal out extensions.
-	if err := o.marshalExtensions(knownFields); !nerr.Merge(err) {
+	if err := o.marshalExtensions(m); !nerr.Merge(err) {
 		return err
 	}
 	return nerr.E
@@ -254,34 +251,33 @@ func sortMap(keyKind pref.Kind, values []mapEntry) {
 }
 
 // marshalExtensions marshals extension fields.
-func (o MarshalOptions) marshalExtensions(knownFields pref.KnownFields) error {
-	type xtEntry struct {
-		key    string
-		value  pref.Value
-		xtType pref.ExtensionType
+func (o MarshalOptions) marshalExtensions(m pref.Message) error {
+	type entry struct {
+		key   string
+		value pref.Value
+		desc  pref.FieldDescriptor
 	}
 
-	xtTypes := knownFields.ExtensionTypes()
-
 	// Get a sorted list based on field key first.
-	entries := make([]xtEntry, 0, xtTypes.Len())
-	xtTypes.Range(func(xt pref.ExtensionType) bool {
-		name := xt.Descriptor().FullName()
+	var entries []entry
+	m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+		if !fd.IsExtension() {
+			return true
+		}
+		xt := fd.(pref.ExtensionType)
+
 		// If extended type is a MessageSet, set field name to be the message type name.
+		name := xt.Descriptor().FullName()
 		if isMessageSetExtension(xt) {
 			name = xt.Descriptor().Message().FullName()
 		}
 
-		num := xt.Descriptor().Number()
-		if knownFields.Has(num) {
-			// Use [name] format for JSON field name.
-			pval := knownFields.Get(num)
-			entries = append(entries, xtEntry{
-				key:    string(name),
-				value:  pval,
-				xtType: xt,
-			})
-		}
+		// Use [name] format for JSON field name.
+		entries = append(entries, entry{
+			key:   string(name),
+			value: v,
+			desc:  fd,
+		})
 		return true
 	})
 
@@ -299,7 +295,7 @@ func (o MarshalOptions) marshalExtensions(knownFields pref.KnownFields) error {
 		if err := o.encoder.WriteName("[" + entry.key + "]"); !nerr.Merge(err) {
 			return err
 		}
-		if err := o.marshalValue(entry.value, entry.xtType.Descriptor()); !nerr.Merge(err) {
+		if err := o.marshalValue(entry.value, entry.desc); !nerr.Merge(err) {
 			return err
 		}
 	}

+ 2 - 17
encoding/protojson/encode_test.go

@@ -15,7 +15,6 @@ import (
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"google.golang.org/protobuf/encoding/protojson"
 	"google.golang.org/protobuf/internal/encoding/pack"
-	"google.golang.org/protobuf/internal/encoding/wire"
 	pimpl "google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/internal/scalar"
 	"google.golang.org/protobuf/proto"
@@ -50,15 +49,9 @@ func pb2Enums_NestedEnum(i int32) *pb2.Enums_NestedEnum {
 	return p
 }
 
+// TODO: Replace this with proto.SetExtension.
 func setExtension(m proto.Message, xd *protoiface.ExtensionDescV1, val interface{}) {
-	knownFields := m.ProtoReflect().KnownFields()
-	extTypes := knownFields.ExtensionTypes()
-	extTypes.Register(xd.Type)
-	if val == nil {
-		return
-	}
-	pval := xd.Type.ValueOf(val)
-	knownFields.Set(wire.Number(xd.Field), pval)
+	m.ProtoReflect().Set(xd.Type, xd.Type.ValueOf(val))
 }
 
 // dhex decodes a hex-string and returns the bytes and panics if s is invalid.
@@ -944,14 +937,6 @@ func TestMarshal(t *testing.T) {
   },
   "[pb2.opt_ext_string]": "extension field"
 }`,
-	}, {
-		desc: "extension message field set to nil",
-		input: func() proto.Message {
-			m := &pb2.Extensions{}
-			setExtension(m, pb2.E_OptExtNested, nil)
-			return m
-		}(),
-		want: "{}",
 	}, {
 		desc: "extensions of repeated fields",
 		input: func() proto.Message {

+ 78 - 74
encoding/protojson/well_known_types.go

@@ -142,25 +142,26 @@ func (o UnmarshalOptions) unmarshalCustomType(m pref.Message) error {
 // field `value` which holds the custom JSON in addition to the `@type` field.
 
 func (o MarshalOptions) marshalAny(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
+	fds := m.Descriptor().Fields()
+	fdType := fds.ByNumber(fieldnum.Any_TypeUrl)
+	fdValue := fds.ByNumber(fieldnum.Any_Value)
 
 	// Start writing the JSON object.
 	o.encoder.StartObject()
 	defer o.encoder.EndObject()
 
-	if !knownFields.Has(fieldnum.Any_TypeUrl) {
-		if !knownFields.Has(fieldnum.Any_Value) {
+	if !m.Has(fdType) {
+		if !m.Has(fdValue) {
 			// If message is empty, marshal out empty JSON object.
 			return nil
 		} else {
 			// Return error if type_url field is not set, but value is set.
-			return errors.New("%s: type_url is not set", messageDesc.FullName())
+			return errors.New("%s: type_url is not set", m.Descriptor().FullName())
 		}
 	}
 
-	typeVal := knownFields.Get(fieldnum.Any_TypeUrl)
-	valueVal := knownFields.Get(fieldnum.Any_Value)
+	typeVal := m.Get(fdType)
+	valueVal := m.Get(fdValue)
 
 	// Marshal out @type field.
 	typeURL := typeVal.String()
@@ -173,7 +174,7 @@ func (o MarshalOptions) marshalAny(m pref.Message) error {
 	// Resolve the type in order to unmarshal value field.
 	emt, err := o.Resolver.FindMessageByURL(typeURL)
 	if !nerr.Merge(err) {
-		return errors.New("%s: unable to resolve %q: %v", messageDesc.FullName(), typeURL, err)
+		return errors.New("%s: unable to resolve %q: %v", m.Descriptor().FullName(), typeURL, err)
 	}
 
 	em := emt.New()
@@ -185,7 +186,7 @@ func (o MarshalOptions) marshalAny(m pref.Message) error {
 		AllowPartial: o.AllowPartial,
 	}.Unmarshal(valueVal.Bytes(), em.Interface())
 	if !nerr.Merge(err) {
-		return errors.New("%s: unable to unmarshal %q: %v", messageDesc.FullName(), typeURL, err)
+		return errors.New("%s: unable to unmarshal %q: %v", m.Descriptor().FullName(), typeURL, err)
 	}
 
 	// If type of value has custom JSON encoding, marshal out a field "value"
@@ -263,9 +264,12 @@ func (o UnmarshalOptions) unmarshalAny(m pref.Message) error {
 		return errors.New("google.protobuf.Any: %v", err)
 	}
 
-	knownFields := m.KnownFields()
-	knownFields.Set(fieldnum.Any_TypeUrl, pref.ValueOf(typeURL))
-	knownFields.Set(fieldnum.Any_Value, pref.ValueOf(b))
+	fds := m.Descriptor().Fields()
+	fdType := fds.ByNumber(fieldnum.Any_TypeUrl)
+	fdValue := fds.ByNumber(fieldnum.Any_Value)
+
+	m.Set(fdType, pref.ValueOf(typeURL))
+	m.Set(fdValue, pref.ValueOf(b))
 	return nerr.E
 }
 
@@ -446,7 +450,7 @@ const wrapperFieldNumber = fieldnum.BoolValue_Value
 
 func (o MarshalOptions) marshalWrapperType(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(wrapperFieldNumber)
-	val := m.KnownFields().Get(wrapperFieldNumber)
+	val := m.Get(fd)
 	return o.marshalSingular(val, fd)
 }
 
@@ -457,7 +461,7 @@ func (o UnmarshalOptions) unmarshalWrapperType(m pref.Message) error {
 	if !nerr.Merge(err) {
 		return err
 	}
-	m.KnownFields().Set(wrapperFieldNumber, val)
+	m.Set(fd, val)
 	return nerr.E
 }
 
@@ -509,14 +513,12 @@ func (o UnmarshalOptions) unmarshalEmpty(pref.Message) error {
 
 func (o MarshalOptions) marshalStruct(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(fieldnum.Struct_Fields)
-	val := m.KnownFields().Get(fieldnum.Struct_Fields)
-	return o.marshalMap(val.Map(), fd)
+	return o.marshalMap(m.Get(fd).Map(), fd)
 }
 
 func (o UnmarshalOptions) unmarshalStruct(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(fieldnum.Struct_Fields)
-	val := m.KnownFields().Get(fieldnum.Struct_Fields)
-	return o.unmarshalMap(val.Map(), fd)
+	return o.unmarshalMap(m.Mutable(fd).Map(), fd)
 }
 
 // The JSON representation for ListValue is JSON array that contains the encoded
@@ -525,14 +527,12 @@ func (o UnmarshalOptions) unmarshalStruct(m pref.Message) error {
 
 func (o MarshalOptions) marshalListValue(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(fieldnum.ListValue_Values)
-	val := m.KnownFields().Get(fieldnum.ListValue_Values)
-	return o.marshalList(val.List(), fd)
+	return o.marshalList(m.Get(fd).List(), fd)
 }
 
 func (o UnmarshalOptions) unmarshalListValue(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(fieldnum.ListValue_Values)
-	val := m.KnownFields().Get(fieldnum.ListValue_Values)
-	return o.unmarshalList(val.List(), fd)
+	return o.unmarshalList(m.Mutable(fd).List(), fd)
 }
 
 // The JSON representation for a Value is dependent on the oneof field that is
@@ -540,27 +540,21 @@ func (o UnmarshalOptions) unmarshalListValue(m pref.Message) error {
 // Value message needs to be a oneof field set, else it is an error.
 
 func (o MarshalOptions) marshalKnownValue(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
-	num := knownFields.WhichOneof("kind")
-	if num == 0 {
-		// Return error if none of the fields is set.
-		return errors.New("%s: none of the oneof fields is set", messageDesc.FullName())
+	od := m.Descriptor().Oneofs().ByName("kind")
+	fd := m.WhichOneof(od)
+	if fd == nil {
+		return errors.New("%s: none of the oneof fields is set", m.Descriptor().FullName())
 	}
-
-	fd := messageDesc.Fields().ByNumber(num)
-	val := knownFields.Get(num)
-	return o.marshalSingular(val, fd)
+	return o.marshalSingular(m.Get(fd), fd)
 }
 
 func (o UnmarshalOptions) unmarshalKnownValue(m pref.Message) error {
 	var nerr errors.NonFatal
-	knownFields := m.KnownFields()
-
 	switch o.decoder.Peek() {
 	case json.Null:
 		o.decoder.Read()
-		knownFields.Set(fieldnum.Value_NullValue, pref.ValueOf(pref.EnumNumber(0)))
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_NullValue)
+		m.Set(fd, pref.ValueOf(pref.EnumNumber(0)))
 
 	case json.Bool:
 		jval, err := o.decoder.Read()
@@ -571,7 +565,8 @@ func (o UnmarshalOptions) unmarshalKnownValue(m pref.Message) error {
 		if err != nil {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_BoolValue, val)
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_BoolValue)
+		m.Set(fd, val)
 
 	case json.Number:
 		jval, err := o.decoder.Read()
@@ -582,7 +577,8 @@ func (o UnmarshalOptions) unmarshalKnownValue(m pref.Message) error {
 		if err != nil {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_NumberValue, val)
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_NumberValue)
+		m.Set(fd, val)
 
 	case json.String:
 		// A JSON string may have been encoded from the number_value field,
@@ -599,21 +595,24 @@ func (o UnmarshalOptions) unmarshalKnownValue(m pref.Message) error {
 		if !nerr.Merge(err) {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_StringValue, val)
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_StringValue)
+		m.Set(fd, val)
 
 	case json.StartObject:
-		m := knownFields.NewMessage(fieldnum.Value_StructValue)
-		if err := o.unmarshalStruct(m); !nerr.Merge(err) {
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_StructValue)
+		m2 := m.NewMessage(fd)
+		if err := o.unmarshalStruct(m2); !nerr.Merge(err) {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_StructValue, pref.ValueOf(m))
+		m.Set(fd, pref.ValueOf(m2))
 
 	case json.StartArray:
-		m := knownFields.NewMessage(fieldnum.Value_ListValue)
-		if err := o.unmarshalListValue(m); !nerr.Merge(err) {
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_ListValue)
+		m2 := m.NewMessage(fd)
+		if err := o.unmarshalListValue(m2); !nerr.Merge(err) {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_ListValue, pref.ValueOf(m))
+		m.Set(fd, pref.ValueOf(m2))
 
 	default:
 		jval, err := o.decoder.Read()
@@ -622,7 +621,6 @@ func (o UnmarshalOptions) unmarshalKnownValue(m pref.Message) error {
 		}
 		return unexpectedJSONError{jval}
 	}
-
 	return nerr.E
 }
 
@@ -644,21 +642,22 @@ const (
 )
 
 func (o MarshalOptions) marshalDuration(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
+	fds := m.Descriptor().Fields()
+	fdSeconds := fds.ByNumber(fieldnum.Duration_Seconds)
+	fdNanos := fds.ByNumber(fieldnum.Duration_Nanos)
 
-	secsVal := knownFields.Get(fieldnum.Duration_Seconds)
-	nanosVal := knownFields.Get(fieldnum.Duration_Nanos)
+	secsVal := m.Get(fdSeconds)
+	nanosVal := m.Get(fdNanos)
 	secs := secsVal.Int()
 	nanos := nanosVal.Int()
 	if secs < -maxSecondsInDuration || secs > maxSecondsInDuration {
-		return errors.New("%s: seconds out of range %v", messageDesc.FullName(), secs)
+		return errors.New("%s: seconds out of range %v", m.Descriptor().FullName(), secs)
 	}
 	if nanos < -secondsInNanos || nanos > secondsInNanos {
-		return errors.New("%s: nanos out of range %v", messageDesc.FullName(), nanos)
+		return errors.New("%s: nanos out of range %v", m.Descriptor().FullName(), nanos)
 	}
 	if (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) {
-		return errors.New("%s: signs of seconds and nanos do not match", messageDesc.FullName())
+		return errors.New("%s: signs of seconds and nanos do not match", m.Descriptor().FullName())
 	}
 	// Generated output always contains 0, 3, 6, or 9 fractional digits,
 	// depending on required precision, followed by the suffix "s".
@@ -687,21 +686,23 @@ func (o UnmarshalOptions) unmarshalDuration(m pref.Message) error {
 		return unexpectedJSONError{jval}
 	}
 
-	messageDesc := m.Descriptor()
 	input := jval.String()
 	secs, nanos, ok := parseDuration(input)
 	if !ok {
-		return errors.New("%s: invalid duration value %q", messageDesc.FullName(), input)
+		return errors.New("%s: invalid duration value %q", m.Descriptor().FullName(), input)
 	}
 	// Validate seconds. No need to validate nanos because parseDuration would
 	// have covered that already.
 	if secs < -maxSecondsInDuration || secs > maxSecondsInDuration {
-		return errors.New("%s: out of range %q", messageDesc.FullName(), input)
+		return errors.New("%s: out of range %q", m.Descriptor().FullName(), input)
 	}
 
-	knownFields := m.KnownFields()
-	knownFields.Set(fieldnum.Duration_Seconds, pref.ValueOf(secs))
-	knownFields.Set(fieldnum.Duration_Nanos, pref.ValueOf(nanos))
+	fds := m.Descriptor().Fields()
+	fdSeconds := fds.ByNumber(fieldnum.Duration_Seconds)
+	fdNanos := fds.ByNumber(fieldnum.Duration_Nanos)
+
+	m.Set(fdSeconds, pref.ValueOf(secs))
+	m.Set(fdNanos, pref.ValueOf(nanos))
 	return nerr.E
 }
 
@@ -834,18 +835,19 @@ const (
 )
 
 func (o MarshalOptions) marshalTimestamp(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
+	fds := m.Descriptor().Fields()
+	fdSeconds := fds.ByNumber(fieldnum.Timestamp_Seconds)
+	fdNanos := fds.ByNumber(fieldnum.Timestamp_Nanos)
 
-	secsVal := knownFields.Get(fieldnum.Timestamp_Seconds)
-	nanosVal := knownFields.Get(fieldnum.Timestamp_Nanos)
+	secsVal := m.Get(fdSeconds)
+	nanosVal := m.Get(fdNanos)
 	secs := secsVal.Int()
 	nanos := nanosVal.Int()
 	if secs < minTimestampSeconds || secs > maxTimestampSeconds {
-		return errors.New("%s: seconds out of range %v", messageDesc.FullName(), secs)
+		return errors.New("%s: seconds out of range %v", m.Descriptor().FullName(), secs)
 	}
 	if nanos < 0 || nanos > secondsInNanos {
-		return errors.New("%s: nanos out of range %v", messageDesc.FullName(), nanos)
+		return errors.New("%s: nanos out of range %v", m.Descriptor().FullName(), nanos)
 	}
 	// Uses RFC 3339, where generated output will be Z-normalized and uses 0, 3,
 	// 6 or 9 fractional digits.
@@ -868,22 +870,24 @@ func (o UnmarshalOptions) unmarshalTimestamp(m pref.Message) error {
 		return unexpectedJSONError{jval}
 	}
 
-	messageDesc := m.Descriptor()
 	input := jval.String()
 	t, err := time.Parse(time.RFC3339Nano, input)
 	if err != nil {
-		return errors.New("%s: invalid timestamp value %q", messageDesc.FullName(), input)
+		return errors.New("%s: invalid timestamp value %q", m.Descriptor().FullName(), input)
 	}
 	// Validate seconds. No need to validate nanos because time.Parse would have
 	// covered that already.
 	secs := t.Unix()
 	if secs < minTimestampSeconds || secs > maxTimestampSeconds {
-		return errors.New("%s: out of range %q", messageDesc.FullName(), input)
+		return errors.New("%s: out of range %q", m.Descriptor().FullName(), input)
 	}
 
-	knownFields := m.KnownFields()
-	knownFields.Set(fieldnum.Timestamp_Seconds, pref.ValueOf(secs))
-	knownFields.Set(fieldnum.Timestamp_Nanos, pref.ValueOf(int32(t.Nanosecond())))
+	fds := m.Descriptor().Fields()
+	fdSeconds := fds.ByNumber(fieldnum.Timestamp_Seconds)
+	fdNanos := fds.ByNumber(fieldnum.Timestamp_Nanos)
+
+	m.Set(fdSeconds, pref.ValueOf(secs))
+	m.Set(fdNanos, pref.ValueOf(int32(t.Nanosecond())))
 	return nerr.E
 }
 
@@ -893,8 +897,8 @@ func (o UnmarshalOptions) unmarshalTimestamp(m pref.Message) error {
 // end up differently after a round-trip.
 
 func (o MarshalOptions) marshalFieldMask(m pref.Message) error {
-	val := m.KnownFields().Get(fieldnum.FieldMask_Paths)
-	list := val.List()
+	fd := m.Descriptor().Fields().ByNumber(fieldnum.FieldMask_Paths)
+	list := m.Get(fd).List()
 	paths := make([]string, 0, list.Len())
 
 	for i := 0; i < list.Len(); i++ {
@@ -926,8 +930,8 @@ func (o UnmarshalOptions) unmarshalFieldMask(m pref.Message) error {
 	}
 	paths := strings.Split(str, ",")
 
-	val := m.KnownFields().Get(fieldnum.FieldMask_Paths)
-	list := val.List()
+	fd := m.Descriptor().Fields().ByNumber(fieldnum.FieldMask_Paths)
+	list := m.Mutable(fd).List()
 
 	for _, s := range paths {
 		s = strings.TrimSpace(s)

+ 26 - 56
encoding/prototext/decode.go

@@ -47,12 +47,11 @@ type UnmarshalOptions struct {
 func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
 	var nerr errors.NonFatal
 
-	mr := m.ProtoReflect()
 	// Clear all fields before populating it.
 	// TODO: Determine if this needs to be consistent with protojson and binary unmarshal where
 	// behavior is to merge values into existing message. If decision is to not clear the fields
 	// ahead, code will need to be updated properly when merging nested messages.
-	resetMessage(mr)
+	proto.Reset(m)
 
 	// Parse into text.Value of message type.
 	val, err := text.Unmarshal(b)
@@ -63,7 +62,7 @@ func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
 	if o.Resolver == nil {
 		o.Resolver = protoregistry.GlobalTypes
 	}
-	err = o.unmarshalMessage(val.Message(), mr)
+	err = o.unmarshalMessage(val.Message(), m.ProtoReflect())
 	if !nerr.Merge(err) {
 		return err
 	}
@@ -75,41 +74,19 @@ func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
 	return nerr.E
 }
 
-// resetMessage clears all fields of given protoreflect.Message.
-// TODO: This should go into the proto package.
-func resetMessage(m pref.Message) {
-	knownFields := m.KnownFields()
-	knownFields.Range(func(num pref.FieldNumber, _ pref.Value) bool {
-		knownFields.Clear(num)
-		return true
-	})
-	unknownFields := m.UnknownFields()
-	unknownFields.Range(func(num pref.FieldNumber, _ pref.RawFields) bool {
-		unknownFields.Set(num, nil)
-		return true
-	})
-	extTypes := knownFields.ExtensionTypes()
-	extTypes.Range(func(xt pref.ExtensionType) bool {
-		extTypes.Remove(xt)
-		return true
-	})
-}
-
 // unmarshalMessage unmarshals a [][2]text.Value message into the given protoreflect.Message.
 func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message) error {
 	var nerr errors.NonFatal
 
 	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
 
 	// Handle expanded Any message.
 	if messageDesc.FullName() == "google.protobuf.Any" && isExpandedAny(tmsg) {
-		return o.unmarshalAny(tmsg[0], knownFields)
+		return o.unmarshalAny(tmsg[0], m)
 	}
 
 	fieldDescs := messageDesc.Fields()
 	reservedNames := messageDesc.ReservedNames()
-	xtTypes := knownFields.ExtensionTypes()
 	var seenNums set.Ints
 	var seenOneofs set.Ints
 
@@ -134,23 +111,14 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
 			}
 			// Extensions have to be registered first in the message's
 			// ExtensionTypes before setting a value to it.
-			xtName := pref.FullName(tkey.String())
+			extName := pref.FullName(tkey.String())
 			// Check first if it is already registered. This is the case for
 			// repeated fields.
-			xt := xtTypes.ByName(xtName)
-			if xt == nil {
-				var err error
-				xt, err = o.findExtension(xtName)
-				if err != nil && err != protoregistry.NotFound {
-					return errors.New("unable to resolve [%v]: %v", xtName, err)
-				}
-				if xt != nil {
-					xtTypes.Register(xt)
-				}
-			}
-			if xt != nil {
-				fd = xt.Descriptor()
+			xt, err := o.findExtension(extName)
+			if err != nil && err != protoregistry.NotFound {
+				return errors.New("unable to resolve [%v]: %v", extName, err)
 			}
+			fd = xt
 		}
 
 		if fd == nil {
@@ -172,7 +140,7 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
 				items = tval.List()
 			}
 
-			list := knownFields.Get(fd.Number()).List()
+			list := m.Mutable(fd).List()
 			if err := o.unmarshalList(items, fd, list); !nerr.Merge(err) {
 				return err
 			}
@@ -185,7 +153,7 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
 				items = tval.List()
 			}
 
-			mmap := knownFields.Get(fd.Number()).Map()
+			mmap := m.Mutable(fd).Map()
 			if err := o.unmarshalMap(items, fd, mmap); !nerr.Merge(err) {
 				return err
 			}
@@ -204,7 +172,7 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
 			if seenNums.Has(num) {
 				return errors.New("non-repeated field %v is repeated", fd.FullName())
 			}
-			if err := o.unmarshalSingular(tval, fd, knownFields); !nerr.Merge(err) {
+			if err := o.unmarshalSingular(tval, fd, m); !nerr.Merge(err) {
 				return err
 			}
 			seenNums.Set(num)
@@ -230,9 +198,7 @@ func (o UnmarshalOptions) findExtension(xtName pref.FullName) (pref.ExtensionTyp
 }
 
 // unmarshalSingular unmarshals given text.Value into the non-repeated field.
-func (o UnmarshalOptions) unmarshalSingular(input text.Value, fd pref.FieldDescriptor, knownFields pref.KnownFields) error {
-	num := fd.Number()
-
+func (o UnmarshalOptions) unmarshalSingular(input text.Value, fd pref.FieldDescriptor, m pref.Message) error {
 	var nerr errors.NonFatal
 	var val pref.Value
 	switch fd.Kind() {
@@ -240,11 +206,11 @@ func (o UnmarshalOptions) unmarshalSingular(input text.Value, fd pref.FieldDescr
 		if input.Type() != text.Message {
 			return errors.New("%v contains invalid message/group value: %v", fd.FullName(), input)
 		}
-		m := knownFields.NewMessage(num)
-		if err := o.unmarshalMessage(input.Message(), m); !nerr.Merge(err) {
+		m2 := m.NewMessage(fd)
+		if err := o.unmarshalMessage(input.Message(), m2); !nerr.Merge(err) {
 			return err
 		}
-		val = pref.ValueOf(m)
+		val = pref.ValueOf(m2)
 	default:
 		var err error
 		val, err = unmarshalScalar(input, fd)
@@ -252,7 +218,7 @@ func (o UnmarshalOptions) unmarshalSingular(input text.Value, fd pref.FieldDescr
 			return err
 		}
 	}
-	knownFields.Set(num, val)
+	m.Set(fd, val)
 
 	return nerr.E
 }
@@ -480,7 +446,7 @@ func isExpandedAny(tmsg [][2]text.Value) bool {
 
 // unmarshalAny unmarshals an expanded Any textproto. This method assumes that the given
 // tfield has key type of text.String and value type of text.Message.
-func (o UnmarshalOptions) unmarshalAny(tfield [2]text.Value, knownFields pref.KnownFields) error {
+func (o UnmarshalOptions) unmarshalAny(tfield [2]text.Value, m pref.Message) error {
 	var nerr errors.NonFatal
 
 	typeURL := tfield[0].String()
@@ -492,8 +458,8 @@ func (o UnmarshalOptions) unmarshalAny(tfield [2]text.Value, knownFields pref.Kn
 	}
 	// Create new message for the embedded message type and unmarshal the
 	// value into it.
-	m := mt.New()
-	if err := o.unmarshalMessage(value, m); !nerr.Merge(err) {
+	m2 := mt.New()
+	if err := o.unmarshalMessage(value, m2); !nerr.Merge(err) {
 		return err
 	}
 	// Serialize the embedded message and assign the resulting bytes to the value field.
@@ -503,13 +469,17 @@ func (o UnmarshalOptions) unmarshalAny(tfield [2]text.Value, knownFields pref.Kn
 	b, err := proto.MarshalOptions{
 		AllowPartial:  o.AllowPartial,
 		Deterministic: true,
-	}.Marshal(m.Interface())
+	}.Marshal(m2.Interface())
 	if !nerr.Merge(err) {
 		return err
 	}
 
-	knownFields.Set(fieldnum.Any_TypeUrl, pref.ValueOf(typeURL))
-	knownFields.Set(fieldnum.Any_Value, pref.ValueOf(b))
+	fds := m.Descriptor().Fields()
+	fdType := fds.ByNumber(fieldnum.Any_TypeUrl)
+	fdValue := fds.ByNumber(fieldnum.Any_Value)
+
+	m.Set(fdType, pref.ValueOf(typeURL))
+	m.Set(fdValue, pref.ValueOf(b))
 
 	return nerr.E
 }

+ 29 - 33
encoding/prototext/encode.go

@@ -88,13 +88,10 @@ func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
 
 	// Handle known fields.
 	fieldDescs := messageDesc.Fields()
-	knownFields := m.KnownFields()
 	size := fieldDescs.Len()
 	for i := 0; i < size; i++ {
 		fd := fieldDescs.Get(i)
-		num := fd.Number()
-
-		if !knownFields.Has(num) {
+		if !m.Has(fd) {
 			continue
 		}
 
@@ -103,7 +100,7 @@ func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
 		if fd.Kind() == pref.GroupKind {
 			name = text.ValueOf(fd.Message().Name())
 		}
-		pval := knownFields.Get(num)
+		pval := m.Get(fd)
 		var err error
 		msgFields, err = o.appendField(msgFields, name, pval, fd)
 		if !nerr.Merge(err) {
@@ -113,17 +110,14 @@ func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
 
 	// Handle extensions.
 	var err error
-	msgFields, err = o.appendExtensions(msgFields, knownFields)
+	msgFields, err = o.appendExtensions(msgFields, m)
 	if !nerr.Merge(err) {
 		return text.Value{}, err
 	}
 
 	// Handle unknown fields.
 	// TODO: Provide option to exclude or include unknown fields.
-	m.UnknownFields().Range(func(_ pref.FieldNumber, raw pref.RawFields) bool {
-		msgFields = appendUnknown(msgFields, raw)
-		return true
-	})
+	msgFields = appendUnknown(msgFields, m.GetUnknown())
 
 	return text.ValueOf(msgFields), nerr.E
 }
@@ -259,30 +253,29 @@ func (o MarshalOptions) marshalMap(mmap pref.Map, fd pref.FieldDescriptor) ([]te
 }
 
 // appendExtensions marshals extension fields and appends them to the given [][2]text.Value.
-func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, knownFields pref.KnownFields) ([][2]text.Value, error) {
-	xtTypes := knownFields.ExtensionTypes()
-	xtFields := make([][2]text.Value, 0, xtTypes.Len())
-
+func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, m pref.Message) ([][2]text.Value, error) {
 	var nerr errors.NonFatal
 	var err error
-	xtTypes.Range(func(xt pref.ExtensionType) bool {
-		name := xt.Descriptor().FullName()
+	var entries [][2]text.Value
+	m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+		if !fd.IsExtension() {
+			return true
+		}
+		xt := fd.(pref.ExtensionType)
+
 		// If extended type is a MessageSet, set field name to be the message type name.
+		name := xt.Descriptor().FullName()
 		if isMessageSetExtension(xt) {
 			name = xt.Descriptor().Message().FullName()
 		}
 
-		num := xt.Descriptor().Number()
-		if knownFields.Has(num) {
-			// Use string type to produce [name] format.
-			tname := text.ValueOf(string(name))
-			pval := knownFields.Get(num)
-			xtFields, err = o.appendField(xtFields, tname, pval, xt.Descriptor())
-			if !nerr.Merge(err) {
-				return false
-			}
-			err = nil
+		// Use string type to produce [name] format.
+		tname := text.ValueOf(string(name))
+		entries, err = o.appendField(entries, tname, v, xt)
+		if !nerr.Merge(err) {
+			return false
 		}
+		err = nil
 		return true
 	})
 	if err != nil {
@@ -290,10 +283,10 @@ func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, knownFields
 	}
 
 	// Sort extensions lexicographically and append to output.
-	sort.SliceStable(xtFields, func(i, j int) bool {
-		return xtFields[i][0].String() < xtFields[j][0].String()
+	sort.SliceStable(entries, func(i, j int) bool {
+		return entries[i][0].String() < entries[j][0].String()
 	})
-	return append(msgFields, xtFields...), nerr.E
+	return append(msgFields, entries...), nerr.E
 }
 
 // isMessageSetExtension reports whether extension extends a message set.
@@ -347,11 +340,14 @@ func appendUnknown(fields [][2]text.Value, b []byte) [][2]text.Value {
 
 // marshalAny converts a google.protobuf.Any protoreflect.Message to a text.Value.
 func (o MarshalOptions) marshalAny(m pref.Message) (text.Value, error) {
-	var nerr errors.NonFatal
-	knownFields := m.KnownFields()
-	typeURL := knownFields.Get(fieldnum.Any_TypeUrl).String()
-	value := knownFields.Get(fieldnum.Any_Value)
+	fds := m.Descriptor().Fields()
+	fdType := fds.ByNumber(fieldnum.Any_TypeUrl)
+	fdValue := fds.ByNumber(fieldnum.Any_Value)
+
+	typeURL := m.Get(fdType).String()
+	value := m.Get(fdValue)
 
+	var nerr errors.NonFatal
 	emt, err := o.Resolver.FindMessageByURL(typeURL)
 	if !nerr.Merge(err) {
 		return text.Value{}, err

+ 4 - 26
encoding/prototext/encode_test.go

@@ -8,15 +8,12 @@ import (
 	"bytes"
 	"encoding/hex"
 	"math"
-	"strings"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
-	"github.com/google/go-cmp/cmp/cmpopts"
 	"google.golang.org/protobuf/encoding/prototext"
 	"google.golang.org/protobuf/internal/detrand"
 	"google.golang.org/protobuf/internal/encoding/pack"
-	"google.golang.org/protobuf/internal/encoding/wire"
 	pimpl "google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/internal/scalar"
 	"google.golang.org/protobuf/proto"
@@ -33,11 +30,6 @@ func init() {
 	detrand.Disable()
 }
 
-// splitLines is a cmpopts.Option for comparing strings with line breaks.
-var splitLines = cmpopts.AcyclicTransformer("SplitLines", func(s string) []string {
-	return strings.Split(s, "\n")
-})
-
 func pb2Enum(i int32) *pb2.Enum {
 	p := new(pb2.Enum)
 	*p = pb2.Enum(i)
@@ -50,15 +42,9 @@ func pb2Enums_NestedEnum(i int32) *pb2.Enums_NestedEnum {
 	return p
 }
 
+// TODO: Use proto.SetExtension when available.
 func setExtension(m proto.Message, xd *protoiface.ExtensionDescV1, val interface{}) {
-	knownFields := m.ProtoReflect().KnownFields()
-	extTypes := knownFields.ExtensionTypes()
-	extTypes.Register(xd.Type)
-	if val == nil {
-		return
-	}
-	pval := xd.Type.ValueOf(val)
-	knownFields.Set(wire.Number(xd.Field), pval)
+	m.ProtoReflect().Set(xd.Type, xd.Type.ValueOf(val))
 }
 
 // dhex decodes a hex-string and returns the bytes and panics if s is invalid.
@@ -938,8 +924,8 @@ req_nested: {}
 			}.Marshal(),
 		},
 		want: `101: "\x01\x00\x01"
-101: 1
 102: "hello"
+101: 1
 102: "世界"
 `,
 	}, {
@@ -1018,14 +1004,6 @@ opt_int32: 42
   opt_string: "partial1"
 }
 `,
-	}, {
-		desc: "extension message field set to nil",
-		input: func() proto.Message {
-			m := &pb2.Extensions{}
-			setExtension(m, pb2.E_OptExtNested, nil)
-			return m
-		}(),
-		want: "\n",
 	}, {
 		desc: "extensions of repeated fields",
 		input: func() proto.Message {
@@ -1295,7 +1273,7 @@ value: "\x80"
 			got := string(b)
 			if tt.want != "" && got != tt.want {
 				t.Errorf("Marshal()\n<got>\n%v\n<want>\n%v\n", got, tt.want)
-				if diff := cmp.Diff(tt.want, got, splitLines); diff != "" {
+				if diff := cmp.Diff(tt.want, got); diff != "" {
 					t.Errorf("Marshal() diff -want +got\n%v\n", diff)
 				}
 			}

+ 16 - 16
internal/cmd/generate-types/proto.go

@@ -244,15 +244,15 @@ var protoDecodeTemplate = template.Must(template.New("").Parse(`
 // unmarshalScalar decodes a value of the given kind.
 //
 // Message values are decoded into a []byte which aliases the input data.
-func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
-	switch field.Kind() {
+func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
+	switch fd.Kind() {
 	{{- range .}}
 	case {{.Expr}}:
 		if wtyp != {{.WireType.Expr}} {
 			return val, 0, errUnknown
 		}
 		{{if (eq .WireType "Group") -}}
-		v, n := wire.ConsumeGroup(num, b)
+		v, n := wire.ConsumeGroup(fd.Number(), b)
 		{{- else -}}
 		v, n := wire.Consume{{.WireType}}(b)
 		{{- end}}
@@ -260,9 +260,9 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
 			return val, 0, wire.ParseError(n)
 		}
 		{{if (eq .Name "String") -}}
-		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+		if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
 			var nerr errors.NonFatal
-			nerr.AppendInvalidUTF8(string(field.FullName()))
+			nerr.AppendInvalidUTF8(string(fd.FullName()))
 			return protoreflect.ValueOf(string(v)), n, nerr.E
 		}
 		{{end -}}
@@ -273,9 +273,9 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
 	}
 }
 
-func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) {
+func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protoreflect.List, fd protoreflect.FieldDescriptor) (n int, err error) {
 	var nerr errors.NonFatal
-	switch field.Kind() {
+	switch fd.Kind() {
 	{{- range .}}
 	case {{.Expr}}:
 		{{- if .WireType.Packable}}
@@ -299,7 +299,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Numbe
 			return 0, errUnknown
 		}
 		{{if (eq .WireType "Group") -}}
-		v, n := wire.ConsumeGroup(num, b)
+		v, n := wire.ConsumeGroup(fd.Number(), b)
 		{{- else -}}
 		v, n := wire.Consume{{.WireType}}(b)
 		{{- end}}
@@ -307,8 +307,8 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Numbe
 			return 0, wire.ParseError(n)
 		}
 		{{if (eq .Name "String") -}}
-		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
-			nerr.AppendInvalidUTF8(string(field.FullName()))
+		if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+			nerr.AppendInvalidUTF8(string(fd.FullName()))
 		}
 		{{end -}}
 		{{if or (eq .Name "Message") (eq .Name "Group") -}}
@@ -339,14 +339,14 @@ var wireTypes = map[protoreflect.Kind]wire.Type{
 {{- end}}
 }
 
-func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
+func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
 	var nerr errors.NonFatal
-	switch field.Kind() {
+	switch fd.Kind() {
 	{{- range .}}
 	case {{.Expr}}:
 		{{- if (eq .Name "String") }}
-		if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
-			nerr.AppendInvalidUTF8(string(field.FullName()))
+		if fd.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
+			nerr.AppendInvalidUTF8(string(fd.FullName()))
 		}
 		{{end -}}
 		{{- if (eq .Name "Message") -}}
@@ -364,13 +364,13 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protore
 		if !nerr.Merge(err) {
 			return b, err
 		}
-		b = wire.AppendVarint(b, wire.EncodeTag(num, wire.EndGroupType))
+		b = wire.AppendVarint(b, wire.EncodeTag(fd.Number(), wire.EndGroupType))
 		{{- else -}}
 		b = wire.Append{{.WireType}}(b, {{.FromValue}})
 		{{- end}}
 	{{- end}}
 	default:
-		return b, errors.New("invalid kind %v", field.Kind())
+		return b, errors.New("invalid kind %v", fd.Kind())
 	}
 	return b, nerr.E
 }

+ 4 - 7
internal/fileinit/fileinit_test.go

@@ -82,14 +82,11 @@ func TestInit(t *testing.T) {
 
 // visitFields calls f for every field set in m and its children.
 func visitFields(m protoreflect.Message, f func(protoreflect.FieldDescriptor)) {
-	fieldDescs := m.Descriptor().Fields()
-	k := m.KnownFields()
-	k.Range(func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
-		field := fieldDescs.ByNumber(num)
-		f(field)
-		switch field.Kind() {
+	m.Range(func(fd protoreflect.FieldDescriptor, value protoreflect.Value) bool {
+		f(fd)
+		switch fd.Kind() {
 		case protoreflect.MessageKind, protoreflect.GroupKind:
-			if field.IsList() {
+			if fd.IsList() {
 				for i, list := 0, value.List(); i < list.Len(); i++ {
 					visitFields(list.Get(i).Message(), f)
 				}

+ 80 - 249
internal/impl/legacy_test.go

@@ -5,19 +5,16 @@
 package impl_test
 
 import (
-	"bytes"
-	"math"
 	"reflect"
 	"sync"
 	"testing"
 
-	cmp "github.com/google/go-cmp/cmp"
-	cmpopts "github.com/google/go-cmp/cmp/cmpopts"
-	pack "google.golang.org/protobuf/internal/encoding/pack"
+	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
 	pimpl "google.golang.org/protobuf/internal/impl"
 	pragma "google.golang.org/protobuf/internal/pragma"
 	ptype "google.golang.org/protobuf/internal/prototype"
-	scalar "google.golang.org/protobuf/internal/scalar"
+	"google.golang.org/protobuf/internal/scalar"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	preg "google.golang.org/protobuf/reflect/protoregistry"
 	piface "google.golang.org/protobuf/runtime/protoiface"
@@ -42,136 +39,6 @@ func init() {
 	preg.GlobalTypes.Register(mt)
 }
 
-func TestLegacyUnknown(t *testing.T) {
-	rawOf := func(toks ...pack.Token) pref.RawFields {
-		return pref.RawFields(pack.Message(toks).Marshal())
-	}
-	raw1 := rawOf(pack.Tag{1, pack.BytesType}, pack.Bytes("1"))                      // 0a0131
-	raw1a := rawOf(pack.Tag{1, pack.VarintType}, pack.Svarint(-4321))                // 08c143
-	raw1b := rawOf(pack.Tag{1, pack.Fixed32Type}, pack.Uint32(0xdeadbeef))           // 0defbeadde
-	raw1c := rawOf(pack.Tag{1, pack.Fixed64Type}, pack.Float64(math.Pi))             // 09182d4454fb210940
-	raw2a := rawOf(pack.Tag{2, pack.BytesType}, pack.String("hello, world!"))        // 120d68656c6c6f2c20776f726c6421
-	raw2b := rawOf(pack.Tag{2, pack.VarintType}, pack.Uvarint(1234))                 // 10d209
-	raw3a := rawOf(pack.Tag{3, pack.StartGroupType}, pack.Tag{3, pack.EndGroupType}) // 1b1c
-	raw3b := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("\xde\xad\xbe\xef"))      // 1a04deadbeef
-
-	joinRaw := func(bs ...pref.RawFields) (out []byte) {
-		for _, b := range bs {
-			out = append(out, b...)
-		}
-		return out
-	}
-
-	m := new(legacyTestMessage)
-	fs := pimpl.Export{}.MessageOf(m).UnknownFields()
-
-	if got, want := fs.Len(), 0; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-	if got, want := m.XXX_unrecognized, joinRaw(); !bytes.Equal(got, want) {
-		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
-	}
-
-	fs.Set(1, raw1a)
-	fs.Set(1, append(fs.Get(1), raw1b...))
-	fs.Set(1, append(fs.Get(1), raw1c...))
-	if got, want := fs.Len(), 1; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-	if got, want := m.XXX_unrecognized, joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
-		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
-	}
-
-	fs.Set(2, raw2a)
-	if got, want := fs.Len(), 2; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-	if got, want := m.XXX_unrecognized, joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) {
-		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
-	}
-
-	if got, want := fs.Get(1), joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) {
-		t.Errorf("Get(%d) = %x, want %x", 1, got, want)
-	}
-	if got, want := fs.Get(2), joinRaw(raw2a); !bytes.Equal(got, want) {
-		t.Errorf("Get(%d) = %x, want %x", 2, got, want)
-	}
-	if got, want := fs.Get(3), joinRaw(); !bytes.Equal(got, want) {
-		t.Errorf("Get(%d) = %x, want %x", 3, got, want)
-	}
-
-	fs.Set(1, nil) // remove field 1
-	if got, want := fs.Len(), 1; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-	if got, want := m.XXX_unrecognized, joinRaw(raw2a); !bytes.Equal(got, want) {
-		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
-	}
-
-	// Simulate manual appending of raw field data.
-	m.XXX_unrecognized = append(m.XXX_unrecognized, joinRaw(raw3a, raw1a, raw1b, raw2b, raw3b, raw1c)...)
-	if got, want := fs.Len(), 3; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-
-	// Verify range iteration order.
-	var i int
-	want := []struct {
-		num pref.FieldNumber
-		raw pref.RawFields
-	}{
-		{2, joinRaw(raw2a, raw2b)},
-		{3, joinRaw(raw3a, raw3b)},
-		{1, joinRaw(raw1a, raw1b, raw1c)},
-	}
-	fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool {
-		if i < len(want) {
-			if num != want[i].num || !bytes.Equal(raw, want[i].raw) {
-				t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw)
-			}
-		} else {
-			t.Errorf("unexpected Range iteration: %d", i)
-		}
-		i++
-		return true
-	})
-
-	fs.Set(2, fs.Get(2)) // moves field 2 to the end
-	if got, want := fs.Len(), 3; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-	if got, want := m.XXX_unrecognized, joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) {
-		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
-	}
-	fs.Set(1, nil) // remove field 1
-	if got, want := fs.Len(), 2; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-	if got, want := m.XXX_unrecognized, joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) {
-		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
-	}
-
-	// Remove all fields.
-	fs.Range(func(n pref.FieldNumber, b pref.RawFields) bool {
-		fs.Set(n, nil)
-		return true
-	})
-	if got, want := fs.Len(), 0; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-	if got, want := m.XXX_unrecognized, joinRaw(); !bytes.Equal(got, want) {
-		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
-	}
-
-	fs.Set(1, raw1)
-	if got, want := fs.Len(), 1; got != want {
-		t.Errorf("Len() = %d, want %d", got, want)
-	}
-	if got, want := m.XXX_unrecognized, joinRaw(raw1); !bytes.Equal(got, want) {
-		t.Errorf("data mismatch:\ngot:  %x\nwant: %x", got, want)
-	}
-}
-
 var (
 	testParentDesc    = pimpl.Export{}.MessageDescriptorOf((*legacyTestMessage)(nil))
 	testEnumV1Desc    = pimpl.Export{}.EnumDescriptorOf(proto2_20180125.Message_ChildEnum(0))
@@ -466,63 +333,41 @@ func TestLegacyExtensions(t *testing.T) {
 		return x == y // pointer compare messages for object identity
 	})}
 
-	m := new(legacyTestMessage)
-	fs := pimpl.Export{}.MessageOf(m).KnownFields()
-	ts := fs.ExtensionTypes()
+	m := pimpl.Export{}.MessageOf(new(legacyTestMessage))
 
-	if n := fs.Len(); n != 0 {
+	if n := m.Len(); n != 0 {
 		t.Errorf("KnownFields.Len() = %v, want 0", n)
 	}
-	if n := ts.Len(); n != 0 {
-		t.Errorf("ExtensionFieldTypes.Len() = %v, want 0", n)
-	}
-
-	// Register all the extension types.
-	for _, xt := range extensionTypes {
-		ts.Register(xt)
-	}
 
 	// Check that getting the zero value returns the default value for scalars,
 	// nil for singular messages, and an empty list for repeated fields.
-	defaultValues := []interface{}{
-		bool(true),
-		int32(-12345),
-		uint32(3200),
-		float32(3.14159),
-		string("hello, \"world!\"\n"),
-		[]byte("dead\xde\xad\xbe\xefbeef"),
-		proto2_20180125.Message_ALPHA,
-		nil,
-		EnumProto2(0xdead),
-		nil,
-		new([]bool),
-		new([]int32),
-		new([]uint32),
-		new([]float32),
-		new([]string),
-		new([][]byte),
-		new([]proto2_20180125.Message_ChildEnum),
-		new([]*proto2_20180125.Message_ChildMessage),
-		new([]EnumProto2),
-		new([]*EnumMessages),
+	defaultValues := map[int]interface{}{
+		0: bool(true),
+		1: int32(-12345),
+		2: uint32(3200),
+		3: float32(3.14159),
+		4: string("hello, \"world!\"\n"),
+		5: []byte("dead\xde\xad\xbe\xefbeef"),
+		6: proto2_20180125.Message_ALPHA,
+		7: nil,
+		8: EnumProto2(0xdead),
+		9: nil,
 	}
 	for i, xt := range extensionTypes {
 		var got interface{}
-		num := xt.Descriptor().Number()
-		if v := fs.Get(num); v.IsValid() {
-			got = xt.InterfaceOf(v)
+		if !(xt.IsList() || xt.IsMap() || xt.Message() != nil) {
+			got = xt.InterfaceOf(m.Get(xt))
 		}
 		want := defaultValues[i]
 		if diff := cmp.Diff(want, got, opts); diff != "" {
-			t.Errorf("KnownFields.Get(%d) mismatch (-want +got):\n%v", num, diff)
+			t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
 		}
 	}
 
 	// All fields should be unpopulated.
 	for _, xt := range extensionTypes {
-		num := xt.Descriptor().Number()
-		if fs.Has(num) {
-			t.Errorf("KnownFields.Has(%d) = true, want false", num)
+		if m.Has(xt) {
+			t.Errorf("Message.Has(%d) = true, want false", xt.Number())
 		}
 	}
 
@@ -531,102 +376,88 @@ func TestLegacyExtensions(t *testing.T) {
 	m1b := &proto2_20180125.Message_ChildMessage{F1: scalar.String("m2b")}
 	m2a := &EnumMessages{EnumP2: EnumProto2(0x1b).Enum()}
 	m2b := &EnumMessages{EnumP2: EnumProto2(0x2b).Enum()}
-	setValues := []interface{}{
-		bool(false),
-		int32(-54321),
-		uint32(6400),
-		float32(2.71828),
-		string("goodbye, \"world!\"\n"),
-		[]byte("live\xde\xad\xbe\xefchicken"),
-		proto2_20180125.Message_CHARLIE,
-		m1a,
-		EnumProto2(0xbeef),
-		m2a,
-		&[]bool{true},
-		&[]int32{-1000},
-		&[]uint32{1280},
-		&[]float32{1.6180},
-		&[]string{"zero"},
-		&[][]byte{[]byte("zero")},
-		&[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO},
-		&[]*proto2_20180125.Message_ChildMessage{m1b},
-		&[]EnumProto2{0xdead},
-		&[]*EnumMessages{m2b},
+	setValues := map[int]interface{}{
+		0:  bool(false),
+		1:  int32(-54321),
+		2:  uint32(6400),
+		3:  float32(2.71828),
+		4:  string("goodbye, \"world!\"\n"),
+		5:  []byte("live\xde\xad\xbe\xefchicken"),
+		6:  proto2_20180125.Message_CHARLIE,
+		7:  m1a,
+		8:  EnumProto2(0xbeef),
+		9:  m2a,
+		10: &[]bool{true},
+		11: &[]int32{-1000},
+		12: &[]uint32{1280},
+		13: &[]float32{1.6180},
+		14: &[]string{"zero"},
+		15: &[][]byte{[]byte("zero")},
+		16: &[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO},
+		17: &[]*proto2_20180125.Message_ChildMessage{m1b},
+		18: &[]EnumProto2{0xdead},
+		19: &[]*EnumMessages{m2b},
 	}
 	for i, xt := range extensionTypes {
-		fs.Set(xt.Descriptor().Number(), xt.ValueOf(setValues[i]))
+		m.Set(xt, xt.ValueOf(setValues[i]))
 	}
 	for i, xt := range extensionTypes[len(extensionTypes)/2:] {
 		v := extensionTypes[i].ValueOf(setValues[i])
-		fs.Get(xt.Descriptor().Number()).List().Append(v)
+		m.Get(xt).List().Append(v)
 	}
 
 	// Get the values and check for equality.
-	getValues := []interface{}{
-		bool(false),
-		int32(-54321),
-		uint32(6400),
-		float32(2.71828),
-		string("goodbye, \"world!\"\n"),
-		[]byte("live\xde\xad\xbe\xefchicken"),
-		proto2_20180125.Message_ChildEnum(proto2_20180125.Message_CHARLIE),
-		m1a,
-		EnumProto2(0xbeef),
-		m2a,
-		&[]bool{true, false},
-		&[]int32{-1000, -54321},
-		&[]uint32{1280, 6400},
-		&[]float32{1.6180, 2.71828},
-		&[]string{"zero", "goodbye, \"world!\"\n"},
-		&[][]byte{[]byte("zero"), []byte("live\xde\xad\xbe\xefchicken")},
-		&[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO, proto2_20180125.Message_CHARLIE},
-		&[]*proto2_20180125.Message_ChildMessage{m1b, m1a},
-		&[]EnumProto2{0xdead, 0xbeef},
-		&[]*EnumMessages{m2b, m2a},
+	getValues := map[int]interface{}{
+		0:  bool(false),
+		1:  int32(-54321),
+		2:  uint32(6400),
+		3:  float32(2.71828),
+		4:  string("goodbye, \"world!\"\n"),
+		5:  []byte("live\xde\xad\xbe\xefchicken"),
+		6:  proto2_20180125.Message_ChildEnum(proto2_20180125.Message_CHARLIE),
+		7:  m1a,
+		8:  EnumProto2(0xbeef),
+		9:  m2a,
+		10: &[]bool{true, false},
+		11: &[]int32{-1000, -54321},
+		12: &[]uint32{1280, 6400},
+		13: &[]float32{1.6180, 2.71828},
+		14: &[]string{"zero", "goodbye, \"world!\"\n"},
+		15: &[][]byte{[]byte("zero"), []byte("live\xde\xad\xbe\xefchicken")},
+		16: &[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO, proto2_20180125.Message_CHARLIE},
+		17: &[]*proto2_20180125.Message_ChildMessage{m1b, m1a},
+		18: &[]EnumProto2{0xdead, 0xbeef},
+		19: &[]*EnumMessages{m2b, m2a},
 	}
 	for i, xt := range extensionTypes {
-		num := xt.Descriptor().Number()
-		got := xt.InterfaceOf(fs.Get(num))
+		got := xt.InterfaceOf(m.Get(xt))
 		want := getValues[i]
 		if diff := cmp.Diff(want, got, opts); diff != "" {
-			t.Errorf("KnownFields.Get(%d) mismatch (-want +got):\n%v", num, diff)
+			t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
 		}
 	}
 
-	if n := fs.Len(); n != 20 {
-		t.Errorf("KnownFields.Len() = %v, want 0", n)
-	}
-	if n := ts.Len(); n != 20 {
-		t.Errorf("ExtensionFieldTypes.Len() = %v, want 20", n)
+	if n := m.Len(); n != 20 {
+		t.Errorf("Message.Len() = %v, want 0", n)
 	}
 
-	// Clear the field for all extension types.
+	// Clear all singular fields and truncate all repeated fields.
 	for _, xt := range extensionTypes[:len(extensionTypes)/2] {
-		fs.Clear(xt.Descriptor().Number())
+		m.Clear(xt)
 	}
-	for i, xt := range extensionTypes[len(extensionTypes)/2:] {
-		if i%2 == 0 {
-			fs.Clear(xt.Descriptor().Number())
-		} else {
-			fs.Get(xt.Descriptor().Number()).List().Truncate(0)
-		}
-	}
-	if n := fs.Len(); n != 0 {
-		t.Errorf("KnownFields.Len() = %v, want 0", n)
+	for _, xt := range extensionTypes[len(extensionTypes)/2:] {
+		m.Get(xt).List().Truncate(0)
 	}
-	if n := ts.Len(); n != 20 {
-		t.Errorf("ExtensionFieldTypes.Len() = %v, want 20", n)
+	if n := m.Len(); n != 10 {
+		t.Errorf("Message.Len() = %v, want 10", n)
 	}
 
-	// De-register all extension types.
-	for _, xt := range extensionTypes {
-		ts.Remove(xt)
-	}
-	if n := fs.Len(); n != 0 {
-		t.Errorf("KnownFields.Len() = %v, want 0", n)
+	// Clear all repeated fields.
+	for _, xt := range extensionTypes[len(extensionTypes)/2:] {
+		m.Clear(xt)
 	}
-	if n := ts.Len(); n != 0 {
-		t.Errorf("ExtensionFieldTypes.Len() = %v, want 0", n)
+	if n := m.Len(); n != 0 {
+		t.Errorf("Message.Len() = %v, want 0", n)
 	}
 }
 

+ 222 - 127
internal/impl/message.go

@@ -40,12 +40,17 @@ type MessageInfo struct {
 
 	oneofs map[pref.Name]*oneofInfo
 
+	getUnknown func(pointer) pref.RawFields
+	setUnknown func(pointer, pref.RawFields)
+
+	extensionMap func(pointer) *extensionMap
+
 	unknownFields   func(*messageDataType) pref.UnknownFields
 	extensionFields func(*messageDataType) pref.KnownFields
 	methods         piface.Methods
 
-	extensionOffset       offset
 	sizecacheOffset       offset
+	extensionOffset       offset
 	unknownOffset         offset
 	extensionFieldInfosMu sync.RWMutex
 	extensionFieldInfos   map[pref.ExtensionType]*extensionFieldInfo
@@ -106,23 +111,33 @@ func (mi *MessageInfo) initOnce() {
 	atomic.StoreUint32(&mi.initDone, 1)
 }
 
-var sizecacheType = reflect.TypeOf(int32(0))
+type (
+	SizeCache       = int32
+	UnknownFields   = []byte
+	ExtensionFields = map[int32]ExtensionField
+)
+
+var (
+	sizecacheType       = reflect.TypeOf(SizeCache(0))
+	unknownFieldsType   = reflect.TypeOf(UnknownFields(nil))
+	extensionFieldsType = reflect.TypeOf(ExtensionFields(nil))
+)
 
 func (mi *MessageInfo) makeMethods(t reflect.Type) {
-	mi.extensionOffset = invalidOffset
-	if fx, _ := t.FieldByName("XXX_InternalExtensions"); fx.Type == extType {
-		mi.extensionOffset = offsetOf(fx)
-	} else if fx, _ = t.FieldByName("XXX_extensions"); fx.Type == extType {
-		mi.extensionOffset = offsetOf(fx)
-	}
 	mi.sizecacheOffset = invalidOffset
 	if fx, _ := t.FieldByName("XXX_sizecache"); fx.Type == sizecacheType {
 		mi.sizecacheOffset = offsetOf(fx)
 	}
 	mi.unknownOffset = invalidOffset
-	if fx, _ := t.FieldByName("XXX_unrecognized"); fx.Type == bytesType {
+	if fx, _ := t.FieldByName("XXX_unrecognized"); fx.Type == unknownFieldsType {
 		mi.unknownOffset = offsetOf(fx)
 	}
+	mi.extensionOffset = invalidOffset
+	if fx, _ := t.FieldByName("XXX_InternalExtensions"); fx.Type == extensionFieldsType {
+		mi.extensionOffset = offsetOf(fx)
+	} else if fx, _ = t.FieldByName("XXX_extensions"); fx.Type == extensionFieldsType {
+		mi.extensionOffset = offsetOf(fx)
+	}
 	mi.methods.Flags = piface.MethodFlagDeterministicMarshal
 	mi.methods.MarshalAppend = mi.marshalAppend
 	mi.methods.Size = mi.size
@@ -231,22 +246,56 @@ func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
 }
 
 func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type) {
-	if f := makeLegacyUnknownFieldsFunc(t); f != nil {
-		mi.unknownFields = f
-		return
-	}
-	mi.unknownFields = func(*messageDataType) pref.UnknownFields {
-		return emptyUnknownFields{}
+	mi.unknownFields = makeLegacyUnknownFieldsFunc(t)
+
+	mi.getUnknown = func(pointer) pref.RawFields { return nil }
+	mi.setUnknown = func(pointer, pref.RawFields) { return }
+	fu, _ := t.FieldByName("XXX_unrecognized")
+	if fu.Type == unknownFieldsType {
+		fieldOffset := offsetOf(fu)
+		mi.getUnknown = func(p pointer) pref.RawFields {
+			if p.IsNil() {
+				return nil
+			}
+			rv := p.Apply(fieldOffset).AsValueOf(unknownFieldsType)
+			return pref.RawFields(*rv.Interface().(*[]byte))
+		}
+		mi.setUnknown = func(p pointer, b pref.RawFields) {
+			if p.IsNil() {
+				panic("invalid SetUnknown on nil Message")
+			}
+			rv := p.Apply(fieldOffset).AsValueOf(unknownFieldsType)
+			*rv.Interface().(*[]byte) = []byte(b)
+		}
+	} else {
+		mi.getUnknown = func(pointer) pref.RawFields {
+			return nil
+		}
+		mi.setUnknown = func(p pointer, _ pref.RawFields) {
+			if p.IsNil() {
+				panic("invalid SetUnknown on nil Message")
+			}
+		}
 	}
 }
 
 func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type) {
-	if f := makeLegacyExtensionFieldsFunc(t); f != nil {
-		mi.extensionFields = f
-		return
+	mi.extensionFields = makeLegacyExtensionFieldsFunc(t)
+
+	fx, _ := t.FieldByName("XXX_extensions")
+	if fx.Type != extensionFieldsType {
+		fx, _ = t.FieldByName("XXX_InternalExtensions")
 	}
-	mi.extensionFields = func(*messageDataType) pref.KnownFields {
-		return emptyExtensionFields{}
+	if fx.Type == extensionFieldsType {
+		fieldOffset := offsetOf(fx)
+		mi.extensionMap = func(p pointer) *extensionMap {
+			v := p.Apply(fieldOffset).AsValueOf(extensionFieldsType)
+			return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
+		}
+	} else {
+		mi.extensionMap = func(pointer) *extensionMap {
+			return (*extensionMap)(nil)
+		}
 	}
 }
 
@@ -295,21 +344,9 @@ type messageDataType struct {
 
 type messageReflectWrapper messageDataType
 
-// TODO: Remove this.
-func (m *messageReflectWrapper) Type() pref.MessageType {
-	return m.mi.PBType
-}
 func (m *messageReflectWrapper) Descriptor() pref.MessageDescriptor {
 	return m.mi.PBType.Descriptor()
 }
-func (m *messageReflectWrapper) KnownFields() pref.KnownFields {
-	m.mi.init()
-	return (*knownFields)(m)
-}
-func (m *messageReflectWrapper) UnknownFields() pref.UnknownFields {
-	m.mi.init()
-	return m.mi.unknownFields((*messageDataType)(m))
-}
 func (m *messageReflectWrapper) New() pref.Message {
 	return m.mi.PBType.New()
 }
@@ -323,134 +360,192 @@ func (m *messageReflectWrapper) ProtoUnwrap() interface{} {
 	return m.p.AsIfaceOf(m.mi.GoType.Elem())
 }
 
-var _ pvalue.Unwrapper = (*messageReflectWrapper)(nil)
-
-type messageIfaceWrapper messageDataType
-
-func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
-	return (*messageReflectWrapper)(m)
-}
-func (m *messageIfaceWrapper) XXX_Methods() *piface.Methods {
-	// TODO: Consider not recreating this on every call.
+func (m *messageReflectWrapper) Len() (cnt int) {
 	m.mi.init()
-	return &piface.Methods{
-		Flags:         piface.MethodFlagDeterministicMarshal,
-		MarshalAppend: m.marshalAppend,
-		Size:          m.size,
+	for _, fi := range m.mi.fields {
+		if fi.has(m.p) {
+			cnt++
+		}
 	}
+	return cnt + m.mi.extensionMap(m.p).Len()
 }
-func (m *messageIfaceWrapper) ProtoUnwrap() interface{} {
-	return m.p.AsIfaceOf(m.mi.GoType.Elem())
+func (m *messageReflectWrapper) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
+	m.mi.init()
+	for _, fi := range m.mi.fields {
+		if fi.has(m.p) {
+			if !f(fi.fieldDesc, fi.get(m.p)) {
+				return
+			}
+		}
+	}
+	m.mi.extensionMap(m.p).Range(f)
 }
-func (m *messageIfaceWrapper) marshalAppend(b []byte, _ pref.ProtoMessage, opts piface.MarshalOptions) ([]byte, error) {
-	return m.mi.marshalAppendPointer(b, m.p, newMarshalOptions(opts))
+func (m *messageReflectWrapper) Has(fd pref.FieldDescriptor) bool {
+	if fi, xt := m.checkField(fd); fi != nil {
+		return fi.has(m.p)
+	} else {
+		return m.mi.extensionMap(m.p).Has(xt)
+	}
 }
-func (m *messageIfaceWrapper) size(msg pref.ProtoMessage) (size int) {
-	return m.mi.sizePointer(m.p, 0)
+func (m *messageReflectWrapper) Clear(fd pref.FieldDescriptor) {
+	if fi, xt := m.checkField(fd); fi != nil {
+		fi.clear(m.p)
+	} else {
+		m.mi.extensionMap(m.p).Clear(xt)
+	}
 }
-
-type knownFields messageDataType
-
-func (fs *knownFields) Len() (cnt int) {
-	for _, fi := range fs.mi.fields {
-		if fi.has(fs.p) {
-			cnt++
-		}
+func (m *messageReflectWrapper) Get(fd pref.FieldDescriptor) pref.Value {
+	if fi, xt := m.checkField(fd); fi != nil {
+		return fi.get(m.p)
+	} else {
+		return m.mi.extensionMap(m.p).Get(xt)
 	}
-	return cnt + fs.extensionFields().Len()
 }
-func (fs *knownFields) Has(n pref.FieldNumber) bool {
-	if fi := fs.mi.fields[n]; fi != nil {
-		return fi.has(fs.p)
+func (m *messageReflectWrapper) Set(fd pref.FieldDescriptor, v pref.Value) {
+	if fi, xt := m.checkField(fd); fi != nil {
+		fi.set(m.p, v)
+	} else {
+		m.mi.extensionMap(m.p).Set(xt, v)
 	}
-	return fs.extensionFields().Has(n)
 }
-func (fs *knownFields) Get(n pref.FieldNumber) pref.Value {
-	if fi := fs.mi.fields[n]; fi != nil {
-		return fi.get(fs.p)
+func (m *messageReflectWrapper) Mutable(fd pref.FieldDescriptor) pref.Value {
+	if fi, xt := m.checkField(fd); fi != nil {
+		return fi.mutable(m.p)
+	} else {
+		return m.mi.extensionMap(m.p).Mutable(xt)
 	}
-	return fs.extensionFields().Get(n)
 }
-func (fs *knownFields) Set(n pref.FieldNumber, v pref.Value) {
-	if fi := fs.mi.fields[n]; fi != nil {
-		fi.set(fs.p, v)
-		return
+func (m *messageReflectWrapper) NewMessage(fd pref.FieldDescriptor) pref.Message {
+	if fi, xt := m.checkField(fd); fi != nil {
+		return fi.newMessage()
+	} else {
+		return xt.New().Message()
 	}
-	if fs.mi.PBType.Descriptor().ExtensionRanges().Has(n) {
-		fs.extensionFields().Set(n, v)
-		return
+}
+func (m *messageReflectWrapper) WhichOneof(od pref.OneofDescriptor) pref.FieldDescriptor {
+	m.mi.init()
+	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
+		return od.Fields().ByNumber(oi.which(m.p))
 	}
-	panic(fmt.Sprintf("invalid field: %d", n))
+	panic("invalid oneof descriptor")
 }
-func (fs *knownFields) Clear(n pref.FieldNumber) {
-	if fi := fs.mi.fields[n]; fi != nil {
-		fi.clear(fs.p)
-		return
+func (m *messageReflectWrapper) GetUnknown() pref.RawFields {
+	m.mi.init()
+	return m.mi.getUnknown(m.p)
+}
+func (m *messageReflectWrapper) SetUnknown(b pref.RawFields) {
+	m.mi.init()
+	m.mi.setUnknown(m.p, b)
+}
+
+// checkField verifies that the provided field descriptor is valid.
+// Exactly one of the returned values is populated.
+func (m *messageReflectWrapper) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
+	m.mi.init()
+	if fi := m.mi.fields[fd.Number()]; fi != nil {
+		if fi.fieldDesc != fd {
+			panic("mismatching field descriptor")
+		}
+		return fi, nil
 	}
-	if fs.mi.PBType.Descriptor().ExtensionRanges().Has(n) {
-		fs.extensionFields().Clear(n)
-		return
+	if fd.IsExtension() {
+		if fd.ContainingMessage().FullName() != m.mi.PBType.FullName() {
+			// TODO: Should this be exact containing message descriptor match?
+			panic("mismatching containing message")
+		}
+		if !m.mi.PBType.ExtensionRanges().Has(fd.Number()) {
+			panic("invalid extension field")
+		}
+		return nil, fd.(pref.ExtensionType)
 	}
+	panic("invalid field descriptor")
 }
-func (fs *knownFields) WhichOneof(s pref.Name) pref.FieldNumber {
-	if oi := fs.mi.oneofs[s]; oi != nil {
-		return oi.which(fs.p)
+
+type extensionMap map[int32]ExtensionField
+
+func (m *extensionMap) Len() int {
+	if m != nil {
+		return len(*m)
 	}
 	return 0
 }
-func (fs *knownFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
-	for n, fi := range fs.mi.fields {
-		if fi.has(fs.p) {
-			if !f(n, fi.get(fs.p)) {
+func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
+	if m != nil {
+		for _, x := range *m {
+			xt := x.GetType()
+			if !f(xt, xt.ValueOf(x.GetValue())) {
 				return
 			}
 		}
 	}
-	fs.extensionFields().Range(f)
 }
-func (fs *knownFields) NewMessage(n pref.FieldNumber) pref.Message {
-	if fi := fs.mi.fields[n]; fi != nil {
-		return fi.newMessage()
+func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
+	if m != nil {
+		_, ok = (*m)[int32(xt.Number())]
+	}
+	return ok
+}
+func (m *extensionMap) Clear(xt pref.ExtensionType) {
+	delete(*m, int32(xt.Number()))
+}
+func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
+	if m != nil {
+		if x, ok := (*m)[int32(xt.Number())]; ok {
+			return xt.ValueOf(x.GetValue())
+		}
 	}
-	if fs.mi.PBType.Descriptor().ExtensionRanges().Has(n) {
-		return fs.extensionFields().NewMessage(n)
+	if !isComposite(xt) {
+		return defaultValueOf(xt)
 	}
-	panic(fmt.Sprintf("invalid field: %d", n))
+	return frozenValueOf(xt.New())
 }
-func (fs *knownFields) ExtensionTypes() pref.ExtensionFieldTypes {
-	return fs.extensionFields().ExtensionTypes()
+func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
+	if *m == nil {
+		*m = make(map[int32]ExtensionField)
+	}
+	var x ExtensionField
+	x.SetType(xt)
+	x.SetEagerValue(xt.InterfaceOf(v))
+	(*m)[int32(xt.Number())] = x
 }
-func (fs *knownFields) extensionFields() pref.KnownFields {
-	return fs.mi.extensionFields((*messageDataType)(fs))
+func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
+	if !isComposite(xt) {
+		panic("invalid Mutable on field with non-composite type")
+	}
+	if x, ok := (*m)[int32(xt.Number())]; ok {
+		return xt.ValueOf(x.GetValue())
+	}
+	v := xt.New()
+	m.Set(xt, v)
+	return v
 }
 
-type emptyUnknownFields struct{}
+func isComposite(fd pref.FieldDescriptor) bool {
+	return fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind || fd.IsList() || fd.IsMap()
+}
 
-func (emptyUnknownFields) Len() int                                          { return 0 }
-func (emptyUnknownFields) Get(pref.FieldNumber) pref.RawFields               { return nil }
-func (emptyUnknownFields) Set(pref.FieldNumber, pref.RawFields)              { return } // noop
-func (emptyUnknownFields) Range(func(pref.FieldNumber, pref.RawFields) bool) { return }
-func (emptyUnknownFields) IsSupported() bool                                 { return false }
+var _ pvalue.Unwrapper = (*messageReflectWrapper)(nil)
 
-type emptyExtensionFields struct{}
+type messageIfaceWrapper messageDataType
 
-func (emptyExtensionFields) Len() int                                      { return 0 }
-func (emptyExtensionFields) Has(pref.FieldNumber) bool                     { return false }
-func (emptyExtensionFields) Get(pref.FieldNumber) pref.Value               { return pref.Value{} }
-func (emptyExtensionFields) Set(pref.FieldNumber, pref.Value)              { panic("extensions not supported") }
-func (emptyExtensionFields) Clear(pref.FieldNumber)                        { return } // noop
-func (emptyExtensionFields) WhichOneof(pref.Name) pref.FieldNumber         { return 0 }
-func (emptyExtensionFields) Range(func(pref.FieldNumber, pref.Value) bool) { return }
-func (emptyExtensionFields) NewMessage(pref.FieldNumber) pref.Message {
-	panic("extensions not supported")
+func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
+	return (*messageReflectWrapper)(m)
+}
+func (m *messageIfaceWrapper) XXX_Methods() *piface.Methods {
+	// TODO: Consider not recreating this on every call.
+	m.mi.init()
+	return &piface.Methods{
+		Flags:         piface.MethodFlagDeterministicMarshal,
+		MarshalAppend: m.marshalAppend,
+		Size:          m.size,
+	}
+}
+func (m *messageIfaceWrapper) ProtoUnwrap() interface{} {
+	return m.p.AsIfaceOf(m.mi.GoType.Elem())
+}
+func (m *messageIfaceWrapper) marshalAppend(b []byte, _ pref.ProtoMessage, opts piface.MarshalOptions) ([]byte, error) {
+	return m.mi.marshalAppendPointer(b, m.p, newMarshalOptions(opts))
+}
+func (m *messageIfaceWrapper) size(msg pref.ProtoMessage) (size int) {
+	return m.mi.sizePointer(m.p, 0)
 }
-func (emptyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes { return emptyExtensionTypes{} }
-
-type emptyExtensionTypes struct{}
-
-func (emptyExtensionTypes) Len() int                                     { return 0 }
-func (emptyExtensionTypes) Register(pref.ExtensionType)                  { panic("extensions not supported") }
-func (emptyExtensionTypes) Remove(pref.ExtensionType)                    { return } // noop
-func (emptyExtensionTypes) ByNumber(pref.FieldNumber) pref.ExtensionType { return nil }
-func (emptyExtensionTypes) ByName(pref.FullName) pref.ExtensionType      { return nil }
-func (emptyExtensionTypes) Range(func(pref.ExtensionType) bool)          { return }

+ 148 - 0
internal/impl/message_deprecated.go

@@ -0,0 +1,148 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"fmt"
+
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+// TODO: Remove this file.
+
+// TODO: Remove this.
+func (m *messageReflectWrapper) Type() pref.MessageType {
+	return m.mi.PBType
+}
+
+// TODO: Remove this.
+func (m *messageReflectWrapper) KnownFields() pref.KnownFields {
+	m.mi.init()
+	return (*knownFields)(m)
+}
+
+// TODO: Remove this.
+func (m *messageReflectWrapper) UnknownFields() pref.UnknownFields {
+	m.mi.init()
+	return m.mi.unknownFields((*messageDataType)(m))
+}
+
+// TODO: Remove this.
+type knownFields messageDataType
+
+func (fs *knownFields) Len() (cnt int) {
+	for _, fi := range fs.mi.fields {
+		if fi.has(fs.p) {
+			cnt++
+		}
+	}
+	return cnt + fs.extensionFields().Len()
+}
+func (fs *knownFields) Has(n pref.FieldNumber) bool {
+	if fi := fs.mi.fields[n]; fi != nil {
+		return fi.has(fs.p)
+	}
+	return fs.extensionFields().Has(n)
+}
+func (fs *knownFields) Get(n pref.FieldNumber) pref.Value {
+	if fi := fs.mi.fields[n]; fi != nil {
+		if !fi.has(fs.p) && isComposite(fi.fieldDesc) {
+			if fi.newMessage != nil {
+				return pref.Value{}
+			}
+			if !fs.p.IsNil() {
+				return fi.mutable(fs.p)
+			}
+		}
+		return fi.get(fs.p)
+	}
+	return fs.extensionFields().Get(n)
+}
+func (fs *knownFields) Set(n pref.FieldNumber, v pref.Value) {
+	if fi := fs.mi.fields[n]; fi != nil {
+		fi.set(fs.p, v)
+		return
+	}
+	if fs.mi.PBType.Descriptor().ExtensionRanges().Has(n) {
+		fs.extensionFields().Set(n, v)
+		return
+	}
+	panic(fmt.Sprintf("invalid field: %d", n))
+}
+func (fs *knownFields) Clear(n pref.FieldNumber) {
+	if fi := fs.mi.fields[n]; fi != nil {
+		fi.clear(fs.p)
+		return
+	}
+	if fs.mi.PBType.Descriptor().ExtensionRanges().Has(n) {
+		fs.extensionFields().Clear(n)
+		return
+	}
+}
+func (fs *knownFields) WhichOneof(s pref.Name) pref.FieldNumber {
+	if oi := fs.mi.oneofs[s]; oi != nil {
+		return oi.which(fs.p)
+	}
+	return 0
+}
+func (fs *knownFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
+	for n, fi := range fs.mi.fields {
+		if fi.has(fs.p) {
+			if !f(n, fi.get(fs.p)) {
+				return
+			}
+		}
+	}
+	fs.extensionFields().Range(f)
+}
+func (fs *knownFields) NewMessage(n pref.FieldNumber) pref.Message {
+	if fi := fs.mi.fields[n]; fi != nil {
+		return fi.newMessage()
+	}
+	if fs.mi.PBType.Descriptor().ExtensionRanges().Has(n) {
+		return fs.extensionFields().NewMessage(n)
+	}
+	panic(fmt.Sprintf("invalid field: %d", n))
+}
+func (fs *knownFields) ExtensionTypes() pref.ExtensionFieldTypes {
+	return fs.extensionFields().ExtensionTypes()
+}
+func (fs *knownFields) extensionFields() pref.KnownFields {
+	return fs.mi.extensionFields((*messageDataType)(fs))
+}
+
+// TODO: Remove this.
+type emptyUnknownFields struct{}
+
+func (emptyUnknownFields) Len() int                                          { return 0 }
+func (emptyUnknownFields) Get(pref.FieldNumber) pref.RawFields               { return nil }
+func (emptyUnknownFields) Set(pref.FieldNumber, pref.RawFields)              { return } // noop
+func (emptyUnknownFields) Range(func(pref.FieldNumber, pref.RawFields) bool) { return }
+func (emptyUnknownFields) IsSupported() bool                                 { return false }
+
+// TODO: Remove this.
+type emptyExtensionFields struct{}
+
+func (emptyExtensionFields) Len() int                                      { return 0 }
+func (emptyExtensionFields) Has(pref.FieldNumber) bool                     { return false }
+func (emptyExtensionFields) Get(pref.FieldNumber) pref.Value               { return pref.Value{} }
+func (emptyExtensionFields) Set(pref.FieldNumber, pref.Value)              { panic("extensions not supported") }
+func (emptyExtensionFields) Clear(pref.FieldNumber)                        { return } // noop
+func (emptyExtensionFields) WhichOneof(pref.Name) pref.FieldNumber         { return 0 }
+func (emptyExtensionFields) Range(func(pref.FieldNumber, pref.Value) bool) { return }
+func (emptyExtensionFields) NewMessage(pref.FieldNumber) pref.Message {
+	panic("extensions not supported")
+}
+func (emptyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes { return emptyExtensionTypes{} }
+
+// TODO: Remove this.
+type emptyExtensionTypes struct{}
+
+func (emptyExtensionTypes) Len() int                                     { return 0 }
+func (emptyExtensionTypes) Register(pref.ExtensionType)                  { panic("extensions not supported") }
+func (emptyExtensionTypes) Remove(pref.ExtensionType)                    { return } // noop
+func (emptyExtensionTypes) ByNumber(pref.FieldNumber) pref.ExtensionType { return nil }
+func (emptyExtensionTypes) ByName(pref.FullName) pref.ExtensionType      { return nil }
+func (emptyExtensionTypes) Range(func(pref.ExtensionType) bool)          { return }

+ 171 - 45
internal/impl/message_field.go

@@ -16,11 +16,14 @@ import (
 )
 
 type fieldInfo struct {
+	fieldDesc pref.FieldDescriptor
+
 	// These fields are used for protobuf reflection support.
 	has        func(pointer) bool
+	clear      func(pointer)
 	get        func(pointer) pref.Value
 	set        func(pointer, pref.Value)
-	clear      func(pointer)
+	mutable    func(pointer) pref.Value
 	newMessage func() pref.Message
 
 	// These fields are used for fast-path functions.
@@ -44,13 +47,19 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, ot refle
 		panic(fmt.Sprintf("invalid type: %v does not implement %v", ot, ft))
 	}
 	conv, _ := newConverter(ot.Field(0).Type, fd.Kind())
-	fieldOffset := offsetOf(fs)
+	var frozenEmpty pref.Value
+	if conv.NewMessage != nil {
+		frozenEmpty = pref.ValueOf(frozenMessage{conv.NewMessage()})
+	}
+
 	// TODO: Implement unsafe fast path?
+	fieldOffset := offsetOf(fs)
 	return fieldInfo{
 		// NOTE: The logic below intentionally assumes that oneof fields are
 		// well-formatted. That is, the oneof interface never contains a
 		// typed nil pointer to one of the wrapper structs.
 
+		fieldDesc: fd,
 		has: func(p pointer) bool {
 			if p.IsNil() {
 				return false
@@ -61,12 +70,25 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, ot refle
 			}
 			return true
 		},
+		clear: func(p pointer) {
+			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
+			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
+				return
+			}
+			rv.Set(reflect.Zero(rv.Type()))
+		},
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
+				if frozenEmpty.IsValid() {
+					return frozenEmpty
+				}
 				return defaultValueOf(fd)
 			}
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
+				if frozenEmpty.IsValid() {
+					return frozenEmpty
+				}
 				return defaultValueOf(fd)
 			}
 			rv = rv.Elem().Elem().Field(0)
@@ -80,12 +102,19 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, ot refle
 			rv = rv.Elem().Elem().Field(0)
 			rv.Set(conv.GoValueOf(v))
 		},
-		clear: func(p pointer) {
+		mutable: func(p pointer) pref.Value {
+			if conv.NewMessage == nil {
+				panic("invalid Mutable on field with non-composite type")
+			}
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
-				return
+				rv.Set(reflect.New(ot))
 			}
-			rv.Set(reflect.Zero(rv.Type()))
+			rv = rv.Elem().Elem().Field(0)
+			if rv.IsNil() {
+				rv.Set(conv.GoValueOf(pref.ValueOf(conv.NewMessage())))
+			}
+			return conv.PBValueOf(rv)
 		},
 		newMessage: conv.NewMessage,
 		offset:     fieldOffset,
@@ -101,9 +130,14 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo
 	keyConv, _ := newConverter(ft.Key(), fd.MapKey().Kind())
 	valConv, _ := newConverter(ft.Elem(), fd.MapValue().Kind())
 	wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
-	fieldOffset := offsetOf(fs)
+	frozenEmpty := pref.ValueOf(frozenMap{
+		pvalue.MapOf(reflect.Zero(reflect.PtrTo(fs.Type)).Interface(), keyConv, valConv),
+	})
+
 	// TODO: Implement unsafe fast path?
+	fieldOffset := offsetOf(fs)
 	return fieldInfo{
+		fieldDesc: fd,
 		has: func(p pointer) bool {
 			if p.IsNil() {
 				return false
@@ -111,21 +145,27 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			return rv.Len() > 0
 		},
+		clear: func(p pointer) {
+			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
+			rv.Set(reflect.Zero(rv.Type()))
+		},
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
-				v := reflect.Zero(reflect.PtrTo(fs.Type)).Interface()
-				return pref.ValueOf(pvalue.MapOf(v, keyConv, valConv))
+				return frozenEmpty
 			}
-			v := p.Apply(fieldOffset).AsIfaceOf(fs.Type)
-			return pref.ValueOf(pvalue.MapOf(v, keyConv, valConv))
+			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
+			if rv.IsNil() {
+				return frozenEmpty
+			}
+			return pref.ValueOf(pvalue.MapOf(rv.Addr().Interface(), keyConv, valConv))
 		},
 		set: func(p pointer, v pref.Value) {
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			rv.Set(reflect.ValueOf(v.Map().(pvalue.Unwrapper).ProtoUnwrap()).Elem())
 		},
-		clear: func(p pointer) {
-			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
-			rv.Set(reflect.Zero(rv.Type()))
+		mutable: func(p pointer) pref.Value {
+			v := p.Apply(fieldOffset).AsIfaceOf(fs.Type)
+			return pref.ValueOf(pvalue.MapOf(v, keyConv, valConv))
 		},
 		funcs:     encoderFuncsForMap(fd, ft),
 		offset:    fieldOffset,
@@ -147,9 +187,14 @@ func fieldInfoForList(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo
 	} else {
 		wiretag = wire.EncodeTag(fd.Number(), wire.BytesType)
 	}
-	fieldOffset := offsetOf(fs)
+	frozenEmpty := pref.ValueOf(frozenList{
+		pvalue.ListOf(reflect.Zero(reflect.PtrTo(fs.Type)).Interface(), conv),
+	})
+
 	// TODO: Implement unsafe fast path?
+	fieldOffset := offsetOf(fs)
 	return fieldInfo{
+		fieldDesc: fd,
 		has: func(p pointer) bool {
 			if p.IsNil() {
 				return false
@@ -157,21 +202,27 @@ func fieldInfoForList(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			return rv.Len() > 0
 		},
+		clear: func(p pointer) {
+			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
+			rv.Set(reflect.Zero(rv.Type()))
+		},
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
-				v := reflect.Zero(reflect.PtrTo(fs.Type)).Interface()
-				return pref.ValueOf(pvalue.ListOf(v, conv))
+				return frozenEmpty
 			}
-			v := p.Apply(fieldOffset).AsIfaceOf(fs.Type)
-			return pref.ValueOf(pvalue.ListOf(v, conv))
+			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
+			if rv.Len() == 0 {
+				return frozenEmpty
+			}
+			return pref.ValueOf(pvalue.ListOf(rv.Addr().Interface(), conv))
 		},
 		set: func(p pointer, v pref.Value) {
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			rv.Set(reflect.ValueOf(v.List().(pvalue.Unwrapper).ProtoUnwrap()).Elem())
 		},
-		clear: func(p pointer) {
-			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
-			rv.Set(reflect.Zero(rv.Type()))
+		mutable: func(p pointer) pref.Value {
+			v := p.Apply(fieldOffset).AsIfaceOf(fs.Type)
+			return pref.ValueOf(pvalue.ListOf(v, conv))
 		},
 		funcs:     fieldCoder(fd, ft),
 		offset:    fieldOffset,
@@ -196,10 +247,12 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
 		}
 	}
 	conv, _ := newConverter(ft, fd.Kind())
-	fieldOffset := offsetOf(fs)
 	wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
+
 	// TODO: Implement unsafe fast path?
+	fieldOffset := offsetOf(fs)
 	return fieldInfo{
+		fieldDesc: fd,
 		has: func(p pointer) bool {
 			if p.IsNil() {
 				return false
@@ -223,6 +276,10 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
 				panic(fmt.Sprintf("invalid type: %v", rv.Type())) // should never happen
 			}
 		},
+		clear: func(p pointer) {
+			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
+			rv.Set(reflect.Zero(rv.Type()))
+		},
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
 				return defaultValueOf(fd)
@@ -251,10 +308,6 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
 				rv.Set(emptyBytes)
 			}
 		},
-		clear: func(p pointer) {
-			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
-			rv.Set(reflect.Zero(rv.Type()))
-		},
 		funcs:     funcs,
 		offset:    fieldOffset,
 		isPointer: nullable,
@@ -266,10 +319,13 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
 func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo {
 	ft := fs.Type
 	conv, _ := newConverter(ft, fd.Kind())
-	fieldOffset := offsetOf(fs)
-	// TODO: Implement unsafe fast path?
 	wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
+	frozenEmpty := pref.ValueOf(frozenMessage{conv.NewMessage()})
+
+	// TODO: Implement unsafe fast path?
+	fieldOffset := offsetOf(fs)
 	return fieldInfo{
+		fieldDesc: fd,
 		has: func(p pointer) bool {
 			if p.IsNil() {
 				return false
@@ -277,13 +333,17 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldI
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			return !rv.IsNil()
 		},
+		clear: func(p pointer) {
+			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
+			rv.Set(reflect.Zero(rv.Type()))
+		},
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
-				return pref.Value{}
+				return frozenEmpty
 			}
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			if rv.IsNil() {
-				return pref.Value{}
+				return frozenEmpty
 			}
 			return conv.PBValueOf(rv)
 		},
@@ -294,9 +354,12 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldI
 				panic("invalid nil pointer")
 			}
 		},
-		clear: func(p pointer) {
+		mutable: func(p pointer) pref.Value {
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
-			rv.Set(reflect.Zero(rv.Type()))
+			if rv.IsNil() {
+				rv.Set(conv.GoValueOf(pref.ValueOf(conv.NewMessage())))
+			}
+			return conv.PBValueOf(rv)
 		},
 		newMessage: conv.NewMessage,
 		funcs:      fieldCoder(fd, ft),
@@ -307,25 +370,15 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldI
 	}
 }
 
-// defaultValueOf returns the default value for the field.
-func defaultValueOf(fd pref.FieldDescriptor) pref.Value {
-	if fd == nil {
-		return pref.Value{}
-	}
-	pv := fd.Default() // invalid Value for messages and repeated fields
-	if fd.Kind() == pref.BytesKind && pv.IsValid() && len(pv.Bytes()) > 0 {
-		return pref.ValueOf(append([]byte(nil), pv.Bytes()...)) // copy default bytes for safety
-	}
-	return pv
-}
-
 type oneofInfo struct {
-	which func(pointer) pref.FieldNumber
+	oneofDesc pref.OneofDescriptor
+	which     func(pointer) pref.FieldNumber
 }
 
 func makeOneofInfo(od pref.OneofDescriptor, fs reflect.StructField, wrappersByType map[reflect.Type]pref.FieldNumber) *oneofInfo {
 	fieldOffset := offsetOf(fs)
 	return &oneofInfo{
+		oneofDesc: od,
 		which: func(p pointer) pref.FieldNumber {
 			if p.IsNil() {
 				return 0
@@ -388,3 +441,76 @@ func newConverter(t reflect.Type, k pref.Kind) (conv pvalue.Converter, isLegacy
 	}
 	return pvalue.NewConverter(t, k), false
 }
+
+// defaultValueOf returns the default value for the field.
+func defaultValueOf(fd pref.FieldDescriptor) pref.Value {
+	if fd == nil {
+		return pref.Value{}
+	}
+	pv := fd.Default() // invalid Value for messages and repeated fields
+	if fd.Kind() == pref.BytesKind && pv.IsValid() && len(pv.Bytes()) > 0 {
+		return pref.ValueOf(append([]byte(nil), pv.Bytes()...)) // copy default bytes for safety
+	}
+	return pv
+}
+
+// frozenValueOf returns a frozen version of any composite value.
+func frozenValueOf(v pref.Value) pref.Value {
+	switch v := v.Interface().(type) {
+	case pref.Message:
+		if _, ok := v.(frozenMessage); !ok {
+			return pref.ValueOf(frozenMessage{v})
+		}
+	case pref.List:
+		if _, ok := v.(frozenList); !ok {
+			return pref.ValueOf(frozenList{v})
+		}
+	case pref.Map:
+		if _, ok := v.(frozenMap); !ok {
+			return pref.ValueOf(frozenMap{v})
+		}
+	}
+	return v
+}
+
+type frozenMessage struct{ pref.Message }
+
+func (m frozenMessage) ProtoReflect() pref.Message   { return m }
+func (m frozenMessage) Interface() pref.ProtoMessage { return m }
+func (m frozenMessage) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
+	m.Message.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+		return f(fd, frozenValueOf(v))
+	})
+}
+func (m frozenMessage) Get(fd pref.FieldDescriptor) pref.Value {
+	v := m.Message.Get(fd)
+	return frozenValueOf(v)
+}
+func (frozenMessage) Clear(pref.FieldDescriptor)              { panic("invalid on read-only Message") }
+func (frozenMessage) Set(pref.FieldDescriptor, pref.Value)    { panic("invalid on read-only Message") }
+func (frozenMessage) Mutable(pref.FieldDescriptor) pref.Value { panic("invalid on read-only Message") }
+func (frozenMessage) SetUnknown(pref.RawFields)               { panic("invalid on read-only Message") }
+
+type frozenList struct{ pref.List }
+
+func (ls frozenList) Get(i int) pref.Value {
+	v := ls.List.Get(i)
+	return frozenValueOf(v)
+}
+func (frozenList) Set(i int, v pref.Value) { panic("invalid on read-only List") }
+func (frozenList) Append(v pref.Value)     { panic("invalid on read-only List") }
+func (frozenList) Truncate(i int)          { panic("invalid on read-only List") }
+
+type frozenMap struct{ pref.Map }
+
+func (ms frozenMap) Get(k pref.MapKey) pref.Value {
+	v := ms.Map.Get(k)
+	return frozenValueOf(v)
+}
+func (ms frozenMap) Range(f func(pref.MapKey, pref.Value) bool) {
+	ms.Map.Range(func(k pref.MapKey, v pref.Value) bool {
+		return f(k, frozenValueOf(v))
+	})
+}
+func (frozenMap) Set(k pref.MapKey, v pref.Value) { panic("invalid n read-only Map") }
+func (frozenMap) Clear(k pref.MapKey)             { panic("invalid on read-only Map") }

+ 7 - 3
internal/impl/message_field_extension.go

@@ -10,10 +10,16 @@ import (
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
+// TODO: Remove this file.
+
+var extType = reflect.TypeOf(map[int32]ExtensionField(nil))
+
 func makeLegacyExtensionFieldsFunc(t reflect.Type) func(p *messageDataType) pref.KnownFields {
 	f := makeLegacyExtensionMapFunc(t)
 	if f == nil {
-		return nil
+		return func(*messageDataType) pref.KnownFields {
+			return emptyExtensionFields{}
+		}
 	}
 	return func(p *messageDataType) pref.KnownFields {
 		if p.p.IsNil() {
@@ -23,8 +29,6 @@ func makeLegacyExtensionFieldsFunc(t reflect.Type) func(p *messageDataType) pref
 	}
 }
 
-var extType = reflect.TypeOf(map[int32]ExtensionField{})
-
 func makeLegacyExtensionMapFunc(t reflect.Type) func(*messageDataType) *legacyExtensionMap {
 	fx, _ := t.FieldByName("XXX_extensions")
 	if fx.Type != extType {

+ 5 - 1
internal/impl/message_field_unknown.go

@@ -12,12 +12,16 @@ import (
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
+// TODO: Remove this file.
+
 var bytesType = reflect.TypeOf([]byte(nil))
 
 func makeLegacyUnknownFieldsFunc(t reflect.Type) func(p *messageDataType) pref.UnknownFields {
 	fu, ok := t.FieldByName("XXX_unrecognized")
 	if !ok || fu.Type != bytesType {
-		return nil
+		return func(*messageDataType) pref.UnknownFields {
+			return emptyUnknownFields{}
+		}
 	}
 	fieldOffset := offsetOf(fu)
 	return func(p *messageDataType) pref.UnknownFields {

+ 207 - 220
internal/impl/message_test.go

@@ -17,7 +17,6 @@ import (
 	pimpl "google.golang.org/protobuf/internal/impl"
 	ptype "google.golang.org/protobuf/internal/prototype"
 	scalar "google.golang.org/protobuf/internal/scalar"
-	pvalue "google.golang.org/protobuf/internal/value"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/reflect/prototype"
 
@@ -52,25 +51,31 @@ type (
 	// check for the presence of specific oneof member fields.
 	whichOneofs map[pref.Name]pref.FieldNumber
 	// apply messageOps on each specified message field
-	messageFields map[pref.FieldNumber]messageOps
+	messageFields        map[pref.FieldNumber]messageOps
+	messageFieldsMutable map[pref.FieldNumber]messageOps
 	// apply listOps on each specified list field
-	listFields map[pref.FieldNumber]listOps
+	listFields        map[pref.FieldNumber]listOps
+	listFieldsMutable map[pref.FieldNumber]listOps
 	// apply mapOps on each specified map fields
-	mapFields map[pref.FieldNumber]mapOps
+	mapFields        map[pref.FieldNumber]mapOps
+	mapFieldsMutable map[pref.FieldNumber]mapOps
 	// range through all fields and check that they match
 	rangeFields map[pref.FieldNumber]pref.Value
 )
 
-func (equalMessage) isMessageOp()  {}
-func (hasFields) isMessageOp()     {}
-func (getFields) isMessageOp()     {}
-func (setFields) isMessageOp()     {}
-func (clearFields) isMessageOp()   {}
-func (whichOneofs) isMessageOp()   {}
-func (messageFields) isMessageOp() {}
-func (listFields) isMessageOp()    {}
-func (mapFields) isMessageOp()     {}
-func (rangeFields) isMessageOp()   {}
+func (equalMessage) isMessageOp()         {}
+func (hasFields) isMessageOp()            {}
+func (getFields) isMessageOp()            {}
+func (setFields) isMessageOp()            {}
+func (clearFields) isMessageOp()          {}
+func (whichOneofs) isMessageOp()          {}
+func (messageFields) isMessageOp()        {}
+func (messageFieldsMutable) isMessageOp() {}
+func (listFields) isMessageOp()           {}
+func (listFieldsMutable) isMessageOp()    {}
+func (mapFields) isMessageOp()            {}
+func (mapFieldsMutable) isMessageOp()     {}
+func (rangeFields) isMessageOp()          {}
 
 // Test operations performed on a list.
 type (
@@ -221,27 +226,14 @@ var scalarProto2Type = pimpl.MessageInfo{GoType: reflect.TypeOf(new(ScalarProto2
 		},
 	}),
 	NewMessage: func() pref.Message {
-		return new(ScalarProto2)
+		return pref.ProtoMessage(new(ScalarProto2)).ProtoReflect()
 	},
 }}
 
-// TODO: Remove this.
-func (m *ScalarProto2) Type() pref.MessageType { return scalarProto2Type.PBType }
-func (m *ScalarProto2) Descriptor() pref.MessageDescriptor {
-	return scalarProto2Type.PBType.Descriptor()
-}
-func (m *ScalarProto2) KnownFields() pref.KnownFields {
-	return scalarProto2Type.MessageOf(m).KnownFields()
-}
-func (m *ScalarProto2) UnknownFields() pref.UnknownFields {
-	return scalarProto2Type.MessageOf(m).UnknownFields()
-}
-func (m *ScalarProto2) New() pref.Message            { return new(ScalarProto2) }
-func (m *ScalarProto2) Interface() pref.ProtoMessage { return m }
-func (m *ScalarProto2) ProtoReflect() pref.Message   { return m }
+func (m *ScalarProto2) ProtoReflect() pref.Message { return scalarProto2Type.MessageOf(m) }
 
 func TestScalarProto2(t *testing.T) {
-	testMessage(t, nil, &ScalarProto2{}, messageOps{
+	testMessage(t, nil, new(ScalarProto2).ProtoReflect(), messageOps{
 		hasFields{
 			1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false,
 			12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false,
@@ -258,16 +250,16 @@ func TestScalarProto2(t *testing.T) {
 			1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true,
 			12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true,
 		},
-		equalMessage{&ScalarProto2{
+		equalMessage{(&ScalarProto2{
 			new(bool), new(int32), new(int64), new(uint32), new(uint64), new(float32), new(float64), new(string), []byte{}, []byte{}, new(string),
 			new(MyBool), new(MyInt32), new(MyInt64), new(MyUint32), new(MyUint64), new(MyFloat32), new(MyFloat64), new(MyString), MyBytes{}, MyBytes{}, new(MyString),
-		}},
+		}).ProtoReflect()},
 		clearFields{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22},
-		equalMessage{&ScalarProto2{}},
+		equalMessage{new(ScalarProto2).ProtoReflect()},
 	})
 
 	// Test read-only operations on nil message.
-	testMessage(t, nil, (*ScalarProto2)(nil), messageOps{
+	testMessage(t, nil, (*ScalarProto2)(nil).ProtoReflect(), messageOps{
 		hasFields{
 			1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false,
 			12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false,
@@ -336,27 +328,14 @@ var scalarProto3Type = pimpl.MessageInfo{GoType: reflect.TypeOf(new(ScalarProto3
 		},
 	}),
 	NewMessage: func() pref.Message {
-		return new(ScalarProto3)
+		return pref.ProtoMessage(new(ScalarProto3)).ProtoReflect()
 	},
 }}
 
-// TODO: Remove this.
-func (m *ScalarProto3) Type() pref.MessageType { return scalarProto3Type.PBType }
-func (m *ScalarProto3) Descriptor() pref.MessageDescriptor {
-	return scalarProto3Type.PBType.Descriptor()
-}
-func (m *ScalarProto3) KnownFields() pref.KnownFields {
-	return scalarProto3Type.MessageOf(m).KnownFields()
-}
-func (m *ScalarProto3) UnknownFields() pref.UnknownFields {
-	return scalarProto3Type.MessageOf(m).UnknownFields()
-}
-func (m *ScalarProto3) New() pref.Message            { return new(ScalarProto3) }
-func (m *ScalarProto3) Interface() pref.ProtoMessage { return m }
-func (m *ScalarProto3) ProtoReflect() pref.Message   { return m }
+func (m *ScalarProto3) ProtoReflect() pref.Message { return scalarProto3Type.MessageOf(m) }
 
 func TestScalarProto3(t *testing.T) {
-	testMessage(t, nil, &ScalarProto3{}, messageOps{
+	testMessage(t, nil, new(ScalarProto3).ProtoReflect(), messageOps{
 		hasFields{
 			1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false,
 			12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false,
@@ -373,7 +352,7 @@ func TestScalarProto3(t *testing.T) {
 			1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false,
 			12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false,
 		},
-		equalMessage{&ScalarProto3{}},
+		equalMessage{new(ScalarProto3).ProtoReflect()},
 		setFields{
 			1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V([]byte("10")), 11: V([]byte("11")),
 			12: V(bool(true)), 13: V(int32(13)), 14: V(int64(14)), 15: V(uint32(15)), 16: V(uint64(16)), 17: V(float32(17)), 18: V(float64(18)), 19: V(string("19")), 20: V(string("20")), 21: V([]byte("21")), 22: V([]byte("22")),
@@ -382,10 +361,10 @@ func TestScalarProto3(t *testing.T) {
 			1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true,
 			12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true,
 		},
-		equalMessage{&ScalarProto3{
+		equalMessage{(&ScalarProto3{
 			true, 2, 3, 4, 5, 6, 7, "8", []byte("9"), []byte("10"), "11",
 			true, 13, 14, 15, 16, 17, 18, "19", []byte("20"), []byte("21"), "22",
-		}},
+		}).ProtoReflect()},
 		setFields{
 			2: V(int32(-2)), 3: V(int64(-3)), 6: V(float32(math.Inf(-1))), 7: V(float64(math.NaN())),
 		},
@@ -393,7 +372,7 @@ func TestScalarProto3(t *testing.T) {
 			2: true, 3: true, 6: true, 7: true,
 		},
 		clearFields{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22},
-		equalMessage{&ScalarProto3{}},
+		equalMessage{new(ScalarProto3).ProtoReflect()},
 
 		// Verify that -0 triggers proper Has behavior.
 		hasFields{6: false, 7: false},
@@ -402,7 +381,7 @@ func TestScalarProto3(t *testing.T) {
 	})
 
 	// Test read-only operations on nil message.
-	testMessage(t, nil, (*ScalarProto3)(nil), messageOps{
+	testMessage(t, nil, (*ScalarProto3)(nil).ProtoReflect(), messageOps{
 		hasFields{
 			1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false,
 			12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false,
@@ -467,28 +446,15 @@ var listScalarsType = pimpl.MessageInfo{GoType: reflect.TypeOf(new(ListScalars))
 		},
 	}),
 	NewMessage: func() pref.Message {
-		return new(ListScalars)
+		return pref.ProtoMessage(new(ListScalars)).ProtoReflect()
 	},
 }}
 
-// TODO: Remove this.
-func (m *ListScalars) Type() pref.MessageType             { return listScalarsType.PBType }
-func (m *ListScalars) Descriptor() pref.MessageDescriptor { return listScalarsType.PBType.Descriptor() }
-func (m *ListScalars) KnownFields() pref.KnownFields {
-	return listScalarsType.MessageOf(m).KnownFields()
-}
-func (m *ListScalars) UnknownFields() pref.UnknownFields {
-	return listScalarsType.MessageOf(m).UnknownFields()
-}
-func (m *ListScalars) New() pref.Message            { return new(ListScalars) }
-func (m *ListScalars) Interface() pref.ProtoMessage { return m }
-func (m *ListScalars) ProtoReflect() pref.Message   { return m }
+func (m *ListScalars) ProtoReflect() pref.Message { return listScalarsType.MessageOf(m) }
 
 func TestListScalars(t *testing.T) {
-	empty := &ListScalars{}
-	emptyFS := empty.KnownFields()
-
-	want := &ListScalars{
+	empty := new(ListScalars).ProtoReflect()
+	want := (&ListScalars{
 		Bools:    []bool{true, false, true},
 		Int32s:   []int32{2, math.MinInt32, math.MaxInt32},
 		Int64s:   []int64{3, math.MinInt64, math.MaxInt64},
@@ -510,19 +476,18 @@ func TestListScalars(t *testing.T) {
 		MyStrings4: ListBytes{[]byte("17"), nil, []byte("seventeen")},
 		MyBytes3:   ListBytes{[]byte("18"), nil, []byte("eighteen")},
 		MyBytes4:   ListStrings{"19", "", "nineteen"},
-	}
-	wantFS := want.KnownFields()
+	}).ProtoReflect()
 
-	testMessage(t, nil, &ListScalars{}, messageOps{
+	testMessage(t, nil, new(ListScalars).ProtoReflect(), messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false},
-		getFields{1: emptyFS.Get(1), 3: emptyFS.Get(3), 5: emptyFS.Get(5), 7: emptyFS.Get(7), 9: emptyFS.Get(9), 11: emptyFS.Get(11), 13: emptyFS.Get(13), 15: emptyFS.Get(15), 17: emptyFS.Get(17), 19: emptyFS.Get(19)},
-		setFields{1: wantFS.Get(1), 3: wantFS.Get(3), 5: wantFS.Get(5), 7: wantFS.Get(7), 9: wantFS.Get(9), 11: wantFS.Get(11), 13: wantFS.Get(13), 15: wantFS.Get(15), 17: wantFS.Get(17), 19: wantFS.Get(19)},
-		listFields{
+		getFields{1: getField(empty, 1), 3: getField(empty, 3), 5: getField(empty, 5), 7: getField(empty, 7), 9: getField(empty, 9), 11: getField(empty, 11), 13: getField(empty, 13), 15: getField(empty, 15), 17: getField(empty, 17), 19: getField(empty, 19)},
+		setFields{1: getField(want, 1), 3: getField(want, 3), 5: getField(want, 5), 7: getField(want, 7), 9: getField(want, 9), 11: getField(want, 11), 13: getField(want, 13), 15: getField(want, 15), 17: getField(want, 17), 19: getField(want, 19)},
+		listFieldsMutable{
 			2: {
 				lenList(0),
 				appendList{V(int32(2)), V(int32(math.MinInt32)), V(int32(math.MaxInt32))},
 				getList{0: V(int32(2)), 1: V(int32(math.MinInt32)), 2: V(int32(math.MaxInt32))},
-				equalList{wantFS.Get(2).List()},
+				equalList{getField(want, 2).List()},
 			},
 			4: {
 				appendList{V(uint32(0)), V(uint32(0)), V(uint32(0))},
@@ -531,39 +496,39 @@ func TestListScalars(t *testing.T) {
 			},
 			6: {
 				appendList{V(float32(6)), V(float32(math.SmallestNonzeroFloat32)), V(float32(math.NaN())), V(float32(math.MaxFloat32))},
-				equalList{wantFS.Get(6).List()},
+				equalList{getField(want, 6).List()},
 			},
 			8: {
 				appendList{V(""), V(""), V(""), V(""), V(""), V("")},
 				lenList(6),
 				setList{0: V("8"), 2: V("eight")},
 				truncList(3),
-				equalList{wantFS.Get(8).List()},
+				equalList{getField(want, 8).List()},
 			},
 			10: {
 				appendList{V([]byte(nil)), V([]byte(nil))},
 				setList{0: V([]byte("10"))},
 				appendList{V([]byte("wrong"))},
 				setList{2: V([]byte("ten"))},
-				equalList{wantFS.Get(10).List()},
+				equalList{getField(want, 10).List()},
 			},
 			12: {
 				appendList{V("12"), V("wrong"), V("twelve")},
 				setList{1: V("")},
-				equalList{wantFS.Get(12).List()},
+				equalList{getField(want, 12).List()},
 			},
 			14: {
 				appendList{V([]byte("14")), V([]byte(nil)), V([]byte("fourteen"))},
-				equalList{wantFS.Get(14).List()},
+				equalList{getField(want, 14).List()},
 			},
 			16: {
 				appendList{V("16"), V(""), V("sixteen"), V("extra")},
 				truncList(3),
-				equalList{wantFS.Get(16).List()},
+				equalList{getField(want, 16).List()},
 			},
 			18: {
 				appendList{V([]byte("18")), V([]byte(nil)), V([]byte("eighteen"))},
-				equalList{wantFS.Get(18).List()},
+				equalList{getField(want, 18).List()},
 			},
 		},
 		hasFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true},
@@ -573,7 +538,7 @@ func TestListScalars(t *testing.T) {
 	})
 
 	// Test read-only operations on nil message.
-	testMessage(t, nil, (*ListScalars)(nil), messageOps{
+	testMessage(t, nil, (*ListScalars)(nil).ProtoReflect(), messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false},
 		listFields{2: {lenList(0)}, 4: {lenList(0)}, 6: {lenList(0)}, 8: {lenList(0)}, 10: {lenList(0)}, 12: {lenList(0)}, 14: {lenList(0)}, 16: {lenList(0)}, 18: {lenList(0)}},
 	})
@@ -665,28 +630,15 @@ var mapScalarsType = pimpl.MessageInfo{GoType: reflect.TypeOf(new(MapScalars)),
 		},
 	}),
 	NewMessage: func() pref.Message {
-		return new(MapScalars)
+		return pref.ProtoMessage(new(MapScalars)).ProtoReflect()
 	},
 }}
 
-// TODO: Remove this.
-func (m *MapScalars) Type() pref.MessageType             { return mapScalarsType.PBType }
-func (m *MapScalars) Descriptor() pref.MessageDescriptor { return mapScalarsType.PBType.Descriptor() }
-func (m *MapScalars) KnownFields() pref.KnownFields {
-	return mapScalarsType.MessageOf(m).KnownFields()
-}
-func (m *MapScalars) UnknownFields() pref.UnknownFields {
-	return mapScalarsType.MessageOf(m).UnknownFields()
-}
-func (m *MapScalars) New() pref.Message            { return new(MapScalars) }
-func (m *MapScalars) Interface() pref.ProtoMessage { return m }
-func (m *MapScalars) ProtoReflect() pref.Message   { return m }
+func (m *MapScalars) ProtoReflect() pref.Message { return mapScalarsType.MessageOf(m) }
 
 func TestMapScalars(t *testing.T) {
-	empty := &MapScalars{}
-	emptyFS := empty.KnownFields()
-
-	want := &MapScalars{
+	empty := new(MapScalars).ProtoReflect()
+	want := (&MapScalars{
 		KeyBools:   map[bool]string{true: "true", false: "false"},
 		KeyInt32s:  map[int32]string{0: "zero", -1: "one", 2: "two"},
 		KeyInt64s:  map[int64]string{0: "zero", -10: "ten", 20: "twenty"},
@@ -715,14 +667,13 @@ func TestMapScalars(t *testing.T) {
 		MyStrings4: MapBytes{"s1": []byte("s1"), "s2": []byte("s2")},
 		MyBytes3:   MapBytes{"s1": []byte("s1"), "s2": []byte("s2")},
 		MyBytes4:   MapStrings{"s1": "s1", "s2": "s2"},
-	}
-	wantFS := want.KnownFields()
+	}).ProtoReflect()
 
-	testMessage(t, nil, &MapScalars{}, messageOps{
+	testMessage(t, nil, new(MapScalars).ProtoReflect(), messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, 23: false, 24: false, 25: false},
-		getFields{1: emptyFS.Get(1), 3: emptyFS.Get(3), 5: emptyFS.Get(5), 7: emptyFS.Get(7), 9: emptyFS.Get(9), 11: emptyFS.Get(11), 13: emptyFS.Get(13), 15: emptyFS.Get(15), 17: emptyFS.Get(17), 19: emptyFS.Get(19), 21: emptyFS.Get(21), 23: emptyFS.Get(23), 25: emptyFS.Get(25)},
-		setFields{1: wantFS.Get(1), 3: wantFS.Get(3), 5: wantFS.Get(5), 7: wantFS.Get(7), 9: wantFS.Get(9), 11: wantFS.Get(11), 13: wantFS.Get(13), 15: wantFS.Get(15), 17: wantFS.Get(17), 19: wantFS.Get(19), 21: wantFS.Get(21), 23: wantFS.Get(23), 25: wantFS.Get(25)},
-		mapFields{
+		getFields{1: getField(empty, 1), 3: getField(empty, 3), 5: getField(empty, 5), 7: getField(empty, 7), 9: getField(empty, 9), 11: getField(empty, 11), 13: getField(empty, 13), 15: getField(empty, 15), 17: getField(empty, 17), 19: getField(empty, 19), 21: getField(empty, 21), 23: getField(empty, 23), 25: getField(empty, 25)},
+		setFields{1: getField(want, 1), 3: getField(want, 3), 5: getField(want, 5), 7: getField(want, 7), 9: getField(want, 9), 11: getField(want, 11), 13: getField(want, 13), 15: getField(want, 15), 17: getField(want, 17), 19: getField(want, 19), 21: getField(want, 21), 23: getField(want, 23), 25: getField(want, 25)},
+		mapFieldsMutable{
 			2: {
 				lenMap(0),
 				hasMap{int32(0): false, int32(-1): false, int32(2): false},
@@ -738,7 +689,7 @@ func TestMapScalars(t *testing.T) {
 			},
 			4: {
 				setMap{uint32(0): V("zero"), uint32(1): V("one"), uint32(2): V("two")},
-				equalMap{wantFS.Get(4).Map()},
+				equalMap{getField(want, 4).Map()},
 			},
 			6: {
 				clearMap{"noexist"},
@@ -749,13 +700,13 @@ func TestMapScalars(t *testing.T) {
 				clearMap{"extra", "noexist"},
 			},
 			8: {
-				equalMap{emptyFS.Get(8).Map()},
+				equalMap{getField(empty, 8).Map()},
 				setMap{"one": V(int32(1)), "two": V(int32(2)), "three": V(int32(3))},
 			},
 			10: {
 				setMap{"0x00": V(uint32(0x00)), "0xff": V(uint32(0xff)), "0xdead": V(uint32(0xdead))},
 				lenMap(3),
-				equalMap{wantFS.Get(10).Map()},
+				equalMap{getField(want, 10).Map()},
 				getMap{"0x00": V(uint32(0x00)), "0xff": V(uint32(0xff)), "0xdead": V(uint32(0xdead)), "0xdeadbeef": V(nil)},
 			},
 			12: {
@@ -764,12 +715,12 @@ func TestMapScalars(t *testing.T) {
 				rangeMap{"nan": V(float32(math.NaN())), "pi": V(float32(math.Pi))},
 			},
 			14: {
-				equalMap{emptyFS.Get(14).Map()},
+				equalMap{getField(empty, 14).Map()},
 				setMap{"s1": V("s1"), "s2": V("s2")},
 			},
 			16: {
 				setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
-				equalMap{wantFS.Get(16).Map()},
+				equalMap{getField(want, 16).Map()},
 			},
 			18: {
 				hasMap{"s1": false, "s2": false, "s3": false},
@@ -777,7 +728,7 @@ func TestMapScalars(t *testing.T) {
 				hasMap{"s1": true, "s2": true, "s3": false},
 			},
 			20: {
-				equalMap{emptyFS.Get(20).Map()},
+				equalMap{getField(empty, 20).Map()},
 				setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
 			},
 			22: {
@@ -788,7 +739,7 @@ func TestMapScalars(t *testing.T) {
 			},
 			24: {
 				setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
-				equalMap{wantFS.Get(24).Map()},
+				equalMap{getField(want, 24).Map()},
 			},
 		},
 		hasFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, 23: true, 24: true, 25: true},
@@ -798,7 +749,7 @@ func TestMapScalars(t *testing.T) {
 	})
 
 	// Test read-only operations on nil message.
-	testMessage(t, nil, (*MapScalars)(nil), messageOps{
+	testMessage(t, nil, (*MapScalars)(nil).ProtoReflect(), messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, 23: false, 24: false, 25: false},
 		mapFields{2: {lenMap(0)}, 4: {lenMap(0)}, 6: {lenMap(0)}, 8: {lenMap(0)}, 10: {lenMap(0)}, 12: {lenMap(0)}, 14: {lenMap(0)}, 16: {lenMap(0)}, 18: {lenMap(0)}, 20: {lenMap(0)}, 22: {lenMap(0)}, 24: {lenMap(0)}},
 	})
@@ -830,24 +781,11 @@ var oneofScalarsType = pimpl.MessageInfo{GoType: reflect.TypeOf(new(OneofScalars
 		Oneofs: []ptype.Oneof{{Name: "union"}},
 	}),
 	NewMessage: func() pref.Message {
-		return new(OneofScalars)
+		return pref.ProtoMessage(new(OneofScalars)).ProtoReflect()
 	},
 }}
 
-// TODO: Remove this.
-func (m *OneofScalars) Type() pref.MessageType { return oneofScalarsType.PBType }
-func (m *OneofScalars) Descriptor() pref.MessageDescriptor {
-	return oneofScalarsType.PBType.Descriptor()
-}
-func (m *OneofScalars) KnownFields() pref.KnownFields {
-	return oneofScalarsType.MessageOf(m).KnownFields()
-}
-func (m *OneofScalars) UnknownFields() pref.UnknownFields {
-	return oneofScalarsType.MessageOf(m).UnknownFields()
-}
-func (m *OneofScalars) New() pref.Message            { return new(OneofScalars) }
-func (m *OneofScalars) Interface() pref.ProtoMessage { return m }
-func (m *OneofScalars) ProtoReflect() pref.Message   { return m }
+func (m *OneofScalars) ProtoReflect() pref.Message { return oneofScalarsType.MessageOf(m) }
 
 func (*OneofScalars) XXX_OneofWrappers() []interface{} {
 	return []interface{}{
@@ -942,41 +880,41 @@ func TestOneofs(t *testing.T) {
 	want12 := &OneofScalars{Union: &OneofScalars_BytesA{string("120")}}
 	want13 := &OneofScalars{Union: &OneofScalars_BytesB{MyBytes("130")}}
 
-	testMessage(t, nil, &OneofScalars{}, messageOps{
+	testMessage(t, nil, new(OneofScalars).ProtoReflect(), messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false},
 		getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("13"))},
-		whichOneofs{"union": 0, "Union": 0},
+		whichOneofs{"union": 0},
 
-		setFields{1: V(bool(true))}, hasFields{1: true}, equalMessage{want1},
-		setFields{2: V(int32(20))}, hasFields{2: true}, equalMessage{want2},
-		setFields{3: V(int64(30))}, hasFields{3: true}, equalMessage{want3},
-		setFields{4: V(uint32(40))}, hasFields{4: true}, equalMessage{want4},
-		setFields{5: V(uint64(50))}, hasFields{5: true}, equalMessage{want5},
-		setFields{6: V(float32(60))}, hasFields{6: true}, equalMessage{want6},
-		setFields{7: V(float64(70))}, hasFields{7: true}, equalMessage{want7},
+		setFields{1: V(bool(true))}, hasFields{1: true}, equalMessage{want1.ProtoReflect()},
+		setFields{2: V(int32(20))}, hasFields{2: true}, equalMessage{want2.ProtoReflect()},
+		setFields{3: V(int64(30))}, hasFields{3: true}, equalMessage{want3.ProtoReflect()},
+		setFields{4: V(uint32(40))}, hasFields{4: true}, equalMessage{want4.ProtoReflect()},
+		setFields{5: V(uint64(50))}, hasFields{5: true}, equalMessage{want5.ProtoReflect()},
+		setFields{6: V(float32(60))}, hasFields{6: true}, equalMessage{want6.ProtoReflect()},
+		setFields{7: V(float64(70))}, hasFields{7: true}, equalMessage{want7.ProtoReflect()},
 
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: true, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false},
-		whichOneofs{"union": 7, "Union": 0},
+		whichOneofs{"union": 7},
 
-		setFields{8: V(string("80"))}, hasFields{8: true}, equalMessage{want8},
-		setFields{9: V(string("90"))}, hasFields{9: true}, equalMessage{want9},
-		setFields{10: V(string("100"))}, hasFields{10: true}, equalMessage{want10},
-		setFields{11: V([]byte("110"))}, hasFields{11: true}, equalMessage{want11},
-		setFields{12: V([]byte("120"))}, hasFields{12: true}, equalMessage{want12},
-		setFields{13: V([]byte("130"))}, hasFields{13: true}, equalMessage{want13},
+		setFields{8: V(string("80"))}, hasFields{8: true}, equalMessage{want8.ProtoReflect()},
+		setFields{9: V(string("90"))}, hasFields{9: true}, equalMessage{want9.ProtoReflect()},
+		setFields{10: V(string("100"))}, hasFields{10: true}, equalMessage{want10.ProtoReflect()},
+		setFields{11: V([]byte("110"))}, hasFields{11: true}, equalMessage{want11.ProtoReflect()},
+		setFields{12: V([]byte("120"))}, hasFields{12: true}, equalMessage{want12.ProtoReflect()},
+		setFields{13: V([]byte("130"))}, hasFields{13: true}, equalMessage{want13.ProtoReflect()},
 
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: true},
 		getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("130"))},
 		clearFields{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
-		whichOneofs{"union": 13, "Union": 0},
-		equalMessage{want13},
+		whichOneofs{"union": 13},
+		equalMessage{want13.ProtoReflect()},
 		clearFields{13},
-		whichOneofs{"union": 0, "Union": 0},
-		equalMessage{empty},
+		whichOneofs{"union": 0},
+		equalMessage{empty.ProtoReflect()},
 	})
 
 	// Test read-only operations on nil message.
-	testMessage(t, nil, (*OneofScalars)(nil), messageOps{
+	testMessage(t, nil, (*OneofScalars)(nil).ProtoReflect(), messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false},
 		getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("13"))},
 	})
@@ -1053,7 +991,7 @@ var enumMessagesType = pimpl.MessageInfo{GoType: reflect.TypeOf(new(EnumMessages
 		Oneofs: []ptype.Oneof{{Name: "union"}},
 	}),
 	NewMessage: func() pref.Message {
-		return new(EnumMessages)
+		return pref.ProtoMessage(new(EnumMessages)).ProtoReflect()
 	},
 }}
 
@@ -1079,20 +1017,7 @@ var messageMapDesc = mustMakeMessageDesc(ptype.StandaloneMessage{
 	IsMapEntry: true,
 })
 
-// TODO: Remove this.
-func (m *EnumMessages) Type() pref.MessageType { return enumMessagesType.PBType }
-func (m *EnumMessages) Descriptor() pref.MessageDescriptor {
-	return enumMessagesType.PBType.Descriptor()
-}
-func (m *EnumMessages) KnownFields() pref.KnownFields {
-	return enumMessagesType.MessageOf(m).KnownFields()
-}
-func (m *EnumMessages) UnknownFields() pref.UnknownFields {
-	return enumMessagesType.MessageOf(m).UnknownFields()
-}
-func (m *EnumMessages) New() pref.Message            { return new(EnumMessages) }
-func (m *EnumMessages) Interface() pref.ProtoMessage { return m }
-func (m *EnumMessages) ProtoReflect() pref.Message   { return m }
+func (m *EnumMessages) ProtoReflect() pref.Message { return enumMessagesType.MessageOf(m) }
 
 func (*EnumMessages) XXX_OneofWrappers() []interface{} {
 	return []interface{}{
@@ -1127,22 +1052,27 @@ func (*EnumMessages_OneofM2) isEnumMessages_Union() {}
 func (*EnumMessages_OneofM3) isEnumMessages_Union() {}
 
 func TestEnumMessages(t *testing.T) {
+	emptyL := pimpl.Export{}.MessageOf(new(proto2_20180125.Message))
+	emptyM := new(EnumMessages).ProtoReflect()
+	emptyM2 := new(ScalarProto2).ProtoReflect()
+	emptyM3 := new(ScalarProto3).ProtoReflect()
+
 	wantL := pimpl.Export{}.MessageOf(&proto2_20180125.Message{OptionalFloat: scalar.Float32(math.E)})
-	wantM := &EnumMessages{EnumP2: EnumProto2(1234).Enum()}
+	wantM := (&EnumMessages{EnumP2: EnumProto2(1234).Enum()}).ProtoReflect()
 	wantM2a := &ScalarProto2{Float32: scalar.Float32(math.Pi)}
 	wantM2b := &ScalarProto2{Float32: scalar.Float32(math.Phi)}
 	wantM3a := &ScalarProto3{Float32: math.Pi}
 	wantM3b := &ScalarProto3{Float32: math.Ln2}
 
-	wantList5 := (&EnumMessages{EnumList: []EnumProto2{333, 222}}).KnownFields().Get(5)
-	wantList6 := (&EnumMessages{MessageList: []*ScalarProto2{wantM2a, wantM2b}}).KnownFields().Get(6)
+	wantList5 := getField((&EnumMessages{EnumList: []EnumProto2{333, 222}}).ProtoReflect(), 5)
+	wantList6 := getField((&EnumMessages{MessageList: []*ScalarProto2{wantM2a, wantM2b}}).ProtoReflect(), 6)
 
-	wantMap7 := (&EnumMessages{EnumMap: map[string]EnumProto3{"one": 1, "two": 2}}).KnownFields().Get(7)
-	wantMap8 := (&EnumMessages{MessageMap: map[string]*ScalarProto3{"pi": wantM3a, "ln2": wantM3b}}).KnownFields().Get(8)
+	wantMap7 := getField((&EnumMessages{EnumMap: map[string]EnumProto3{"one": 1, "two": 2}}).ProtoReflect(), 7)
+	wantMap8 := getField((&EnumMessages{MessageMap: map[string]*ScalarProto3{"pi": wantM3a, "ln2": wantM3b}}).ProtoReflect(), 8)
 
-	testMessage(t, nil, &EnumMessages{}, messageOps{
+	testMessage(t, nil, new(EnumMessages).ProtoReflect(), messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false},
-		getFields{1: VE(0xbeef), 2: VE(1), 3: V(nil), 4: V(nil), 9: VE(0xbeef), 10: VE(1)},
+		getFields{1: VE(0xbeef), 2: VE(1), 3: V(emptyL), 4: V(emptyM), 9: VE(0xbeef), 10: VE(1)},
 
 		// Test singular enums.
 		setFields{1: VE(0xdead), 2: VE(0)},
@@ -1150,8 +1080,8 @@ func TestEnumMessages(t *testing.T) {
 		hasFields{1: true, 2: true},
 
 		// Test singular messages.
-		messageFields{3: messageOps{setFields{109: V(float32(math.E))}}},
-		messageFields{4: messageOps{setFields{1: VE(1234)}}},
+		messageFieldsMutable{3: messageOps{setFields{109: V(float32(math.E))}}},
+		messageFieldsMutable{4: messageOps{setFields{1: VE(1234)}}},
 		getFields{3: V(wantL), 4: V(wantM)},
 		clearFields{3, 4},
 		hasFields{3: false, 4: false},
@@ -1159,7 +1089,7 @@ func TestEnumMessages(t *testing.T) {
 		hasFields{3: true, 4: true},
 
 		// Test list of enums and messages.
-		listFields{
+		listFieldsMutable{
 			5: listOps{
 				appendList{VE(111), VE(222)},
 				setList{0: VE(333)},
@@ -1169,8 +1099,8 @@ func TestEnumMessages(t *testing.T) {
 			6: listOps{
 				appendMessageList{setFields{4: V(uint32(1e6))}},
 				appendMessageList{setFields{6: V(float32(math.Phi))}},
-				setList{0: V(wantM2a)},
-				getList{0: V(wantM2a), 1: V(wantM2b)},
+				setList{0: V(wantM2a.ProtoReflect())},
+				getList{0: V(wantM2a.ProtoReflect()), 1: V(wantM2b.ProtoReflect())},
 			},
 		},
 		getFields{5: wantList5, 6: wantList6},
@@ -1179,7 +1109,7 @@ func TestEnumMessages(t *testing.T) {
 		hasFields{5: false, 6: true},
 
 		// Test maps of enums and messages.
-		mapFields{
+		mapFieldsMutable{
 			7: mapOps{
 				setMap{"one": VE(1), "two": VE(2)},
 				hasMap{"one": true, "two": true, "three": false},
@@ -1187,8 +1117,8 @@ func TestEnumMessages(t *testing.T) {
 			},
 			8: mapOps{
 				messageMap{"pi": messageOps{setFields{6: V(float32(math.Pi))}}},
-				setMap{"ln2": V(wantM3b)},
-				getMap{"pi": V(wantM3a), "ln2": V(wantM3b), "none": V(nil)},
+				setMap{"ln2": V(wantM3b.ProtoReflect())},
+				getMap{"pi": V(wantM3a.ProtoReflect()), "ln2": V(wantM3b.ProtoReflect()), "none": V(nil)},
 				lenMap(2),
 			},
 		},
@@ -1202,32 +1132,32 @@ func TestEnumMessages(t *testing.T) {
 		hasFields{1: true, 2: true, 9: true, 10: false, 11: false, 12: false},
 		setFields{10: VE(0)},
 		hasFields{1: true, 2: true, 9: false, 10: true, 11: false, 12: false},
-		messageFields{11: messageOps{setFields{6: V(float32(math.Pi))}}},
-		getFields{11: V(wantM2a)},
+		messageFieldsMutable{11: messageOps{setFields{6: V(float32(math.Pi))}}},
+		getFields{11: V(wantM2a.ProtoReflect())},
 		hasFields{1: true, 2: true, 9: false, 10: false, 11: true, 12: false},
-		messageFields{12: messageOps{setFields{6: V(float32(math.Pi))}}},
-		getFields{12: V(wantM3a)},
+		messageFieldsMutable{12: messageOps{setFields{6: V(float32(math.Pi))}}},
+		getFields{12: V(wantM3a.ProtoReflect())},
 		hasFields{1: true, 2: true, 9: false, 10: false, 11: false, 12: true},
 
 		// Check entire message.
-		rangeFields{1: VE(0xdead), 2: VE(0), 3: V(wantL), 4: V(wantM), 6: wantList6, 7: wantMap7, 12: V(wantM3a)},
-		equalMessage{&EnumMessages{
+		rangeFields{1: VE(0xdead), 2: VE(0), 3: V(wantL), 4: V(wantM), 6: wantList6, 7: wantMap7, 12: V(wantM3a.ProtoReflect())},
+		equalMessage{(&EnumMessages{
 			EnumP2:        EnumProto2(0xdead).Enum(),
 			EnumP3:        EnumProto3(0).Enum(),
 			MessageLegacy: &proto2_20180125.Message{OptionalFloat: scalar.Float32(math.E)},
-			MessageCycle:  wantM,
+			MessageCycle:  wantM.Interface().(*EnumMessages),
 			MessageList:   []*ScalarProto2{wantM2a, wantM2b},
 			EnumMap:       map[string]EnumProto3{"one": 1, "two": 2},
 			Union:         &EnumMessages_OneofM3{wantM3a},
-		}},
+		}).ProtoReflect()},
 		clearFields{1, 2, 3, 4, 6, 7, 12},
-		equalMessage{&EnumMessages{}},
+		equalMessage{new(EnumMessages).ProtoReflect()},
 	})
 
 	// Test read-only operations on nil message.
-	testMessage(t, nil, (*EnumMessages)(nil), messageOps{
+	testMessage(t, nil, (*EnumMessages)(nil).ProtoReflect(), messageOps{
 		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false},
-		getFields{1: VE(0xbeef), 2: VE(1), 3: V(nil), 4: V(nil), 9: VE(0xbeef), 10: VE(1), 11: V(nil), 12: V(nil)},
+		getFields{1: VE(0xbeef), 2: VE(1), 3: V(emptyL), 4: V(emptyM), 9: VE(0xbeef), 10: VE(1), 11: V(emptyM2), 12: V(emptyM3)},
 		listFields{5: {lenList(0)}, 6: {lenList(0)}},
 		mapFields{7: {lenMap(0)}, 8: {lenMap(0)}},
 	})
@@ -1238,89 +1168,141 @@ var cmpOpts = cmp.Options{
 		return protoV1.Equal(x, y)
 	}),
 	cmp.Transformer("UnwrapValue", func(pv pref.Value) interface{} {
-		return pv.Interface()
-	}),
-	cmp.Transformer("UnwrapGeneric", func(x pvalue.Unwrapper) interface{} {
-		return x.ProtoUnwrap()
+		switch v := pv.Interface().(type) {
+		case pref.Message:
+			out := make(map[pref.FieldNumber]pref.Value)
+			v.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+				out[fd.Number()] = v
+				return true
+			})
+			return out
+		case pref.List:
+			var out []pref.Value
+			for i := 0; i < v.Len(); i++ {
+				out = append(out, v.Get(i))
+			}
+			return out
+		case pref.Map:
+			out := make(map[interface{}]pref.Value)
+			v.Range(func(k pref.MapKey, v pref.Value) bool {
+				out[k.Interface()] = v
+				return true
+			})
+			return out
+		default:
+			return v
+		}
 	}),
 	cmpopts.EquateNaNs(),
 	cmpopts.EquateEmpty(),
 }
 
 func testMessage(t *testing.T, p path, m pref.Message, tt messageOps) {
-	fs := m.KnownFields()
+	fieldDescs := m.Descriptor().Fields()
+	oneofDescs := m.Descriptor().Oneofs()
 	for i, op := range tt {
 		p.Push(i)
 		switch op := op.(type) {
 		case equalMessage:
-			if diff := cmp.Diff(op.Message, m, cmpOpts); diff != "" {
+			if diff := cmp.Diff(V(op.Message), V(m), cmpOpts); diff != "" {
 				t.Errorf("operation %v, message mismatch (-want, +got):\n%s", p, diff)
 			}
 		case hasFields:
 			got := map[pref.FieldNumber]bool{}
 			want := map[pref.FieldNumber]bool(op)
 			for n := range want {
-				got[n] = fs.Has(n)
+				fd := fieldDescs.ByNumber(n)
+				got[n] = m.Has(fd)
 			}
 			if diff := cmp.Diff(want, got); diff != "" {
-				t.Errorf("operation %v, KnownFields.Has mismatch (-want, +got):\n%s", p, diff)
+				t.Errorf("operation %v, Message.Has mismatch (-want, +got):\n%s", p, diff)
 			}
 		case getFields:
 			got := map[pref.FieldNumber]pref.Value{}
 			want := map[pref.FieldNumber]pref.Value(op)
 			for n := range want {
-				got[n] = fs.Get(n)
+				fd := fieldDescs.ByNumber(n)
+				got[n] = m.Get(fd)
 			}
 			if diff := cmp.Diff(want, got, cmpOpts); diff != "" {
-				t.Errorf("operation %v, KnownFields.Get mismatch (-want, +got):\n%s", p, diff)
+				t.Errorf("operation %v, Message.Get mismatch (-want, +got):\n%s", p, diff)
 			}
 		case setFields:
 			for n, v := range op {
-				fs.Set(n, v)
+				fd := fieldDescs.ByNumber(n)
+				m.Set(fd, v)
 			}
 		case clearFields:
 			for _, n := range op {
-				fs.Clear(n)
+				fd := fieldDescs.ByNumber(n)
+				m.Clear(fd)
 			}
 		case whichOneofs:
 			got := map[pref.Name]pref.FieldNumber{}
 			want := map[pref.Name]pref.FieldNumber(op)
 			for s := range want {
-				got[s] = fs.WhichOneof(s)
+				od := oneofDescs.ByName(s)
+				fd := m.WhichOneof(od)
+				if fd == nil {
+					got[s] = 0
+				} else {
+					got[s] = fd.Number()
+				}
 			}
 			if diff := cmp.Diff(want, got); diff != "" {
-				t.Errorf("operation %v, KnownFields.WhichOneof mismatch (-want, +got):\n%s", p, diff)
+				t.Errorf("operation %v, Message.WhichOneof mismatch (-want, +got):\n%s", p, diff)
 			}
 		case messageFields:
 			for n, tt := range op {
 				p.Push(int(n))
-				if !fs.Has(n) {
-					fs.Set(n, V(fs.NewMessage(n)))
-				}
-				testMessage(t, p, fs.Get(n).Message(), tt)
+				fd := fieldDescs.ByNumber(n)
+				testMessage(t, p, m.Get(fd).Message(), tt)
+				p.Pop()
+			}
+		case messageFieldsMutable:
+			for n, tt := range op {
+				p.Push(int(n))
+				fd := fieldDescs.ByNumber(n)
+				testMessage(t, p, m.Mutable(fd).Message(), tt)
 				p.Pop()
 			}
 		case listFields:
 			for n, tt := range op {
 				p.Push(int(n))
-				testLists(t, p, fs.Get(n).List(), tt)
+				fd := fieldDescs.ByNumber(n)
+				testLists(t, p, m.Get(fd).List(), tt)
+				p.Pop()
+			}
+		case listFieldsMutable:
+			for n, tt := range op {
+				p.Push(int(n))
+				fd := fieldDescs.ByNumber(n)
+				testLists(t, p, m.Mutable(fd).List(), tt)
 				p.Pop()
 			}
 		case mapFields:
 			for n, tt := range op {
 				p.Push(int(n))
-				testMaps(t, p, fs.Get(n).Map(), tt)
+				fd := fieldDescs.ByNumber(n)
+				testMaps(t, p, m.Get(fd).Map(), tt)
+				p.Pop()
+			}
+		case mapFieldsMutable:
+			for n, tt := range op {
+				p.Push(int(n))
+				fd := fieldDescs.ByNumber(n)
+				testMaps(t, p, m.Mutable(fd).Map(), tt)
 				p.Pop()
 			}
 		case rangeFields:
 			got := map[pref.FieldNumber]pref.Value{}
 			want := map[pref.FieldNumber]pref.Value(op)
-			fs.Range(func(n pref.FieldNumber, v pref.Value) bool {
-				got[n] = v
+			m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+				got[fd.Number()] = v
 				return true
 			})
 			if diff := cmp.Diff(want, got, cmpOpts); diff != "" {
-				t.Errorf("operation %v, KnownFields.Range mismatch (-want, +got):\n%s", p, diff)
+				t.Errorf("operation %v, Message.Range mismatch (-want, +got):\n%s", p, diff)
 			}
 		default:
 			t.Fatalf("operation %v, invalid operation: %T", p, op)
@@ -1334,7 +1316,7 @@ func testLists(t *testing.T, p path, v pref.List, tt listOps) {
 		p.Push(i)
 		switch op := op.(type) {
 		case equalList:
-			if diff := cmp.Diff(op.List, v, cmpOpts); diff != "" {
+			if diff := cmp.Diff(V(op.List), V(v), cmpOpts); diff != "" {
 				t.Errorf("operation %v, list mismatch (-want, +got):\n%s", p, diff)
 			}
 		case lenList:
@@ -1376,7 +1358,7 @@ func testMaps(t *testing.T, p path, m pref.Map, tt mapOps) {
 		p.Push(i)
 		switch op := op.(type) {
 		case equalMap:
-			if diff := cmp.Diff(op.Map, m, cmpOpts); diff != "" {
+			if diff := cmp.Diff(V(op.Map), V(m), cmpOpts); diff != "" {
 				t.Errorf("operation %v, map mismatch (-want, +got):\n%s", p, diff)
 			}
 		case lenMap:
@@ -1434,6 +1416,11 @@ func testMaps(t *testing.T, p path, m pref.Map, tt mapOps) {
 	}
 }
 
+func getField(m pref.Message, n pref.FieldNumber) pref.Value {
+	fd := m.Descriptor().Fields().ByNumber(n)
+	return m.Get(fd)
+}
+
 type path []int
 
 func (p *path) Push(i int) { *p = append(*p, i) }

+ 33 - 55
internal/testprotos/irregular/irregular.go

@@ -21,92 +21,70 @@ func (m *IrregularMessage) ProtoReflect() pref.Message { return (*message)(m) }
 type message IrregularMessage
 
 func (m *message) Descriptor() pref.MessageDescriptor { return descriptor.Messages().Get(0) }
-func (m *message) Type() pref.MessageType             { return nil }
-func (m *message) KnownFields() pref.KnownFields      { return (*known)(m) }
-func (m *message) UnknownFields() pref.UnknownFields  { return (*unknown)(m) }
 func (m *message) New() pref.Message                  { return &message{} }
 func (m *message) Interface() pref.ProtoMessage       { return (*IrregularMessage)(m) }
 
-type known IrregularMessage
+var fieldDescS = descriptor.Messages().Get(0).Fields().Get(0)
 
-func (m *known) Len() int {
+func (m *message) Len() int {
 	if m.set {
 		return 1
 	}
 	return 0
 }
 
-func (m *known) Has(num pref.FieldNumber) bool {
-	switch num {
-	case fieldS:
-		return m.set
-	}
-	return false
-}
-
-func (m *known) Get(num pref.FieldNumber) pref.Value {
-	switch num {
-	case fieldS:
-		return pref.ValueOf(m.value)
+func (m *message) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
+	if m.set {
+		f(fieldDescS, pref.ValueOf(m.value))
 	}
-	return pref.Value{}
 }
 
-func (m *known) Set(num pref.FieldNumber, v pref.Value) {
-	switch num {
-	case fieldS:
-		m.value = v.String()
-	default:
-		panic("unknown field")
+func (m *message) Has(fd pref.FieldDescriptor) bool {
+	if fd == fieldDescS {
+		return m.set
 	}
+	panic("invalid field descriptor")
 }
 
-func (m *known) Clear(num pref.FieldNumber) {
-	switch num {
-	case fieldS:
+func (m *message) Clear(fd pref.FieldDescriptor) {
+	if fd == fieldDescS {
 		m.value = ""
 		m.set = false
-	default:
-		panic("unknown field")
+		return
 	}
+	panic("invalid field descriptor")
 }
 
-func (m *known) WhichOneof(name pref.Name) pref.FieldNumber {
-	return 0
+func (m *message) Get(fd pref.FieldDescriptor) pref.Value {
+	if fd == fieldDescS {
+		return pref.ValueOf(m.value)
+	}
+	panic("invalid field descriptor")
 }
 
-func (m *known) Range(f func(pref.FieldNumber, pref.Value) bool) {
-	if m.set {
-		f(fieldS, pref.ValueOf(m.value))
+func (m *message) Set(fd pref.FieldDescriptor, v pref.Value) {
+	if fd == fieldDescS {
+		m.value = v.String()
+		m.set = true
+		return
 	}
+	panic("invalid field descriptor")
 }
 
-func (m *known) NewMessage(num pref.FieldNumber) pref.Message {
-	panic("not a message field")
+func (m *message) Mutable(pref.FieldDescriptor) pref.Value {
+	panic("invalid field descriptor")
 }
 
-func (m *known) ExtensionTypes() pref.ExtensionFieldTypes {
-	return (*exttypes)(m)
+func (m *message) NewMessage(pref.FieldDescriptor) pref.Message {
+	panic("invalid field descriptor")
 }
 
-type unknown IrregularMessage
-
-func (m *unknown) Len() int                                          { return 0 }
-func (m *unknown) Get(pref.FieldNumber) pref.RawFields               { return nil }
-func (m *unknown) Set(pref.FieldNumber, pref.RawFields)              {}
-func (m *unknown) Range(func(pref.FieldNumber, pref.RawFields) bool) {}
-func (m *unknown) IsSupported() bool                                 { return false }
-
-type exttypes IrregularMessage
-
-func (m *exttypes) Len() int                                     { return 0 }
-func (m *exttypes) Register(pref.ExtensionType)                  { panic("not extendable") }
-func (m *exttypes) Remove(pref.ExtensionType)                    {}
-func (m *exttypes) ByNumber(pref.FieldNumber) pref.ExtensionType { return nil }
-func (m *exttypes) ByName(pref.FullName) pref.ExtensionType      { return nil }
-func (m *exttypes) Range(func(pref.ExtensionType) bool)          {}
+func (m *message) WhichOneof(pref.OneofDescriptor) pref.FieldDescriptor {
+	panic("invalid oneof descriptor")
+}
 
-const fieldS = pref.FieldNumber(1)
+func (m *message) GetUnknown() pref.RawFields { return nil }
+func (m *message) SetUnknown(pref.RawFields)  { return }
 
 var descriptor = func() pref.FileDescriptor {
 	p := &descriptorpb.FileDescriptorProto{}

+ 95 - 0
internal/testprotos/irregular/irregular_deprecated.go

@@ -0,0 +1,95 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package irregular
+
+import (
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+// TODO: Remove this.
+func (m *message) Type() pref.MessageType            { return nil }
+func (m *message) KnownFields() pref.KnownFields     { return (*known)(m) }
+func (m *message) UnknownFields() pref.UnknownFields { return (*unknown)(m) }
+
+type known IrregularMessage
+
+func (m *known) Len() int {
+	if m.set {
+		return 1
+	}
+	return 0
+}
+
+func (m *known) Has(num pref.FieldNumber) bool {
+	switch num {
+	case fieldS:
+		return m.set
+	}
+	return false
+}
+
+func (m *known) Get(num pref.FieldNumber) pref.Value {
+	switch num {
+	case fieldS:
+		return pref.ValueOf(m.value)
+	}
+	return pref.Value{}
+}
+
+func (m *known) Set(num pref.FieldNumber, v pref.Value) {
+	switch num {
+	case fieldS:
+		m.value = v.String()
+	default:
+		panic("unknown field")
+	}
+}
+
+func (m *known) Clear(num pref.FieldNumber) {
+	switch num {
+	case fieldS:
+		m.value = ""
+		m.set = false
+	default:
+		panic("unknown field")
+	}
+}
+
+func (m *known) WhichOneof(name pref.Name) pref.FieldNumber {
+	return 0
+}
+
+func (m *known) Range(f func(pref.FieldNumber, pref.Value) bool) {
+	if m.set {
+		f(fieldS, pref.ValueOf(m.value))
+	}
+}
+
+func (m *known) NewMessage(num pref.FieldNumber) pref.Message {
+	panic("not a message field")
+}
+
+func (m *known) ExtensionTypes() pref.ExtensionFieldTypes {
+	return (*exttypes)(m)
+}
+
+const fieldS = pref.FieldNumber(1)
+
+type unknown IrregularMessage
+
+func (m *unknown) Len() int                                          { return 0 }
+func (m *unknown) Get(pref.FieldNumber) pref.RawFields               { return nil }
+func (m *unknown) Set(pref.FieldNumber, pref.RawFields)              {}
+func (m *unknown) Range(func(pref.FieldNumber, pref.RawFields) bool) {}
+func (m *unknown) IsSupported() bool                                 { return false }
+
+type exttypes IrregularMessage
+
+func (m *exttypes) Len() int                                     { return 0 }
+func (m *exttypes) Register(pref.ExtensionType)                  { panic("not extendable") }
+func (m *exttypes) Remove(pref.ExtensionType)                    {}
+func (m *exttypes) ByNumber(pref.FieldNumber) pref.ExtensionType { return nil }
+func (m *exttypes) ByName(pref.FullName) pref.ExtensionType      { return nil }
+func (m *exttypes) Range(func(pref.ExtensionType) bool)          {}

+ 10 - 8
internal/value/list.go

@@ -10,13 +10,15 @@ import (
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
+// ListOf returns a protoreflect.List view of p, which must be a *[]T.
+// If p is nil, this returns an empty, read-only list.
 func ListOf(p interface{}, c Converter) interface {
 	pref.List
 	Unwrapper
 } {
 	// TODO: Validate that p is a *[]T?
 	rv := reflect.ValueOf(p)
-	return listReflect{rv, c}
+	return &listReflect{rv, c}
 }
 
 type listReflect struct {
@@ -24,27 +26,27 @@ type listReflect struct {
 	conv Converter
 }
 
-func (ls listReflect) Len() int {
+func (ls *listReflect) Len() int {
 	if ls.v.IsNil() {
 		return 0
 	}
 	return ls.v.Elem().Len()
 }
-func (ls listReflect) Get(i int) pref.Value {
+func (ls *listReflect) Get(i int) pref.Value {
 	return ls.conv.PBValueOf(ls.v.Elem().Index(i))
 }
-func (ls listReflect) Set(i int, v pref.Value) {
+func (ls *listReflect) Set(i int, v pref.Value) {
 	ls.v.Elem().Index(i).Set(ls.conv.GoValueOf(v))
 }
-func (ls listReflect) Append(v pref.Value) {
+func (ls *listReflect) Append(v pref.Value) {
 	ls.v.Elem().Set(reflect.Append(ls.v.Elem(), ls.conv.GoValueOf(v)))
 }
-func (ls listReflect) Truncate(i int) {
+func (ls *listReflect) Truncate(i int) {
 	ls.v.Elem().Set(ls.v.Elem().Slice(0, i))
 }
-func (ls listReflect) NewMessage() pref.Message {
+func (ls *listReflect) NewMessage() pref.Message {
 	return ls.conv.NewMessage()
 }
-func (ls listReflect) ProtoUnwrap() interface{} {
+func (ls *listReflect) ProtoUnwrap() interface{} {
 	return ls.v.Interface()
 }

+ 11 - 9
internal/value/map.go

@@ -10,13 +10,15 @@ import (
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
+// MapOf returns a protoreflect.Map view of p, which must a *map[K]V.
+// If p is nil, this returns an empty, read-only map.
 func MapOf(p interface{}, kc, kv Converter) interface {
 	pref.Map
 	Unwrapper
 } {
 	// TODO: Validate that p is a *map[K]V?
 	rv := reflect.ValueOf(p)
-	return mapReflect{rv, kc, kv}
+	return &mapReflect{rv, kc, kv}
 }
 
 type mapReflect struct {
@@ -25,13 +27,13 @@ type mapReflect struct {
 	valConv Converter
 }
 
-func (ms mapReflect) Len() int {
+func (ms *mapReflect) Len() int {
 	if ms.v.IsNil() {
 		return 0
 	}
 	return ms.v.Elem().Len()
 }
-func (ms mapReflect) Has(k pref.MapKey) bool {
+func (ms *mapReflect) Has(k pref.MapKey) bool {
 	if ms.v.IsNil() {
 		return false
 	}
@@ -39,7 +41,7 @@ func (ms mapReflect) Has(k pref.MapKey) bool {
 	rv := ms.v.Elem().MapIndex(rk)
 	return rv.IsValid()
 }
-func (ms mapReflect) Get(k pref.MapKey) pref.Value {
+func (ms *mapReflect) Get(k pref.MapKey) pref.Value {
 	if ms.v.IsNil() {
 		return pref.Value{}
 	}
@@ -50,7 +52,7 @@ func (ms mapReflect) Get(k pref.MapKey) pref.Value {
 	}
 	return ms.valConv.PBValueOf(rv)
 }
-func (ms mapReflect) Set(k pref.MapKey, v pref.Value) {
+func (ms *mapReflect) Set(k pref.MapKey, v pref.Value) {
 	if ms.v.Elem().IsNil() {
 		ms.v.Elem().Set(reflect.MakeMap(ms.v.Elem().Type()))
 	}
@@ -58,11 +60,11 @@ func (ms mapReflect) Set(k pref.MapKey, v pref.Value) {
 	rv := ms.valConv.GoValueOf(v)
 	ms.v.Elem().SetMapIndex(rk, rv)
 }
-func (ms mapReflect) Clear(k pref.MapKey) {
+func (ms *mapReflect) Clear(k pref.MapKey) {
 	rk := ms.keyConv.GoValueOf(k.Value())
 	ms.v.Elem().SetMapIndex(rk, reflect.Value{})
 }
-func (ms mapReflect) Range(f func(pref.MapKey, pref.Value) bool) {
+func (ms *mapReflect) Range(f func(pref.MapKey, pref.Value) bool) {
 	if ms.v.IsNil() {
 		return
 	}
@@ -76,9 +78,9 @@ func (ms mapReflect) Range(f func(pref.MapKey, pref.Value) bool) {
 		}
 	}
 }
-func (ms mapReflect) NewMessage() pref.Message {
+func (ms *mapReflect) NewMessage() pref.Message {
 	return ms.valConv.NewMessage()
 }
-func (ms mapReflect) ProtoUnwrap() interface{} {
+func (ms *mapReflect) ProtoUnwrap() interface{} {
 	return ms.v.Interface()
 }

+ 28 - 39
proto/decode.go

@@ -74,8 +74,6 @@ func (o UnmarshalOptions) unmarshalMessageFast(b []byte, m Message) error {
 func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
 	messageDesc := m.Descriptor()
 	fieldDescs := messageDesc.Fields()
-	knownFields := m.KnownFields()
-	unknownFields := m.UnknownFields()
 	var nerr errors.NonFatal
 	for len(b) > 0 {
 		// Parse the tag (field number and wire type).
@@ -85,41 +83,32 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
 		}
 
 		// Parse the field value.
-		fieldDesc := fieldDescs.ByNumber(num)
-		if fieldDesc == nil {
-			extType := knownFields.ExtensionTypes().ByNumber(num)
-			if extType == nil && messageDesc.ExtensionRanges().Has(num) {
-				var err error
-				extType, err = o.Resolver.FindExtensionByNumber(messageDesc.FullName(), num)
-				if err != nil && err != protoregistry.NotFound {
-					return err
-				}
-				if extType != nil {
-					knownFields.ExtensionTypes().Register(extType)
-				}
-			}
-			if extType != nil {
-				fieldDesc = extType.Descriptor()
+		fd := fieldDescs.ByNumber(num)
+		if fd == nil && messageDesc.ExtensionRanges().Has(num) {
+			extType, err := o.Resolver.FindExtensionByNumber(messageDesc.FullName(), num)
+			if err != nil && err != protoregistry.NotFound {
+				return err
 			}
+			fd = extType
 		}
 		var err error
 		var valLen int
 		switch {
-		case fieldDesc == nil:
+		case fd == nil:
 			err = errUnknown
-		case fieldDesc.IsList():
-			valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldDesc)
-		case fieldDesc.IsMap():
-			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldDesc)
+		case fd.IsList():
+			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
+		case fd.IsMap():
+			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
 		default:
-			valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldDesc)
+			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
 		}
 		if err == errUnknown {
 			valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
 			if valLen < 0 {
 				return wire.ParseError(valLen)
 			}
-			unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
+			m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
 		} else if !nerr.Merge(err) {
 			return err
 		}
@@ -128,38 +117,38 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
 	return nerr.E
 }
 
-func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
+func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp wire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
 	var nerr errors.NonFatal
-	v, n, err := o.unmarshalScalar(b, wtyp, num, field)
+	v, n, err := o.unmarshalScalar(b, wtyp, fd)
 	if !nerr.Merge(err) {
 		return 0, err
 	}
-	switch field.Kind() {
+	switch fd.Kind() {
 	case protoreflect.GroupKind, protoreflect.MessageKind:
 		// Messages are merged with any existing message value,
 		// unless the message is part of a oneof.
 		//
 		// TODO: C++ merges into oneofs, while v1 does not.
 		// Evaluate which behavior to pick.
-		var m protoreflect.Message
-		if knownFields.Has(num) && field.ContainingOneof() == nil {
-			m = knownFields.Get(num).Message()
+		var m2 protoreflect.Message
+		if m.Has(fd) && fd.ContainingOneof() == nil {
+			m2 = m.Mutable(fd).Message()
 		} else {
-			m = knownFields.NewMessage(num)
-			knownFields.Set(num, protoreflect.ValueOf(m))
+			m2 = m.NewMessage(fd)
+			m.Set(fd, protoreflect.ValueOf(m2))
 		}
 		// Pass up errors (fatal and otherwise).
-		if err := o.unmarshalMessage(v.Bytes(), m); !nerr.Merge(err) {
+		if err := o.unmarshalMessage(v.Bytes(), m2); !nerr.Merge(err) {
 			return n, err
 		}
 	default:
 		// Non-message scalars replace the previous value.
-		knownFields.Set(num, v)
+		m.Set(fd, v)
 	}
 	return n, nerr.E
 }
 
-func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
+func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
 	if wtyp != wire.BytesType {
 		return 0, errUnknown
 	}
@@ -168,8 +157,8 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number
 		return 0, wire.ParseError(n)
 	}
 	var (
-		keyField = field.MapKey()
-		valField = field.MapValue()
+		keyField = fd.MapKey()
+		valField = fd.MapValue()
 		key      protoreflect.Value
 		val      protoreflect.Value
 		haveKey  bool
@@ -191,7 +180,7 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number
 		err = errUnknown
 		switch num {
 		case 1:
-			key, n, err = o.unmarshalScalar(b, wtyp, num, keyField)
+			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
 			if !nerr.Merge(err) {
 				break
 			}
@@ -199,7 +188,7 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number
 			haveKey = true
 		case 2:
 			var v protoreflect.Value
-			v, n, err = o.unmarshalScalar(b, wtyp, num, valField)
+			v, n, err = o.unmarshalScalar(b, wtyp, valField)
 			if !nerr.Merge(err) {
 				break
 			}

+ 10 - 10
proto/decode_gen.go

@@ -18,8 +18,8 @@ import (
 // unmarshalScalar decodes a value of the given kind.
 //
 // Message values are decoded into a []byte which aliases the input data.
-func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
-	switch field.Kind() {
+func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
+	switch fd.Kind() {
 	case protoreflect.BoolKind:
 		if wtyp != wire.VarintType {
 			return val, 0, errUnknown
@@ -154,9 +154,9 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
 		if n < 0 {
 			return val, 0, wire.ParseError(n)
 		}
-		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+		if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
 			var nerr errors.NonFatal
-			nerr.AppendInvalidUTF8(string(field.FullName()))
+			nerr.AppendInvalidUTF8(string(fd.FullName()))
 			return protoreflect.ValueOf(string(v)), n, nerr.E
 		}
 		return protoreflect.ValueOf(string(v)), n, nil
@@ -182,7 +182,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
 		if wtyp != wire.StartGroupType {
 			return val, 0, errUnknown
 		}
-		v, n := wire.ConsumeGroup(num, b)
+		v, n := wire.ConsumeGroup(fd.Number(), b)
 		if n < 0 {
 			return val, 0, wire.ParseError(n)
 		}
@@ -192,9 +192,9 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
 	}
 }
 
-func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) {
+func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protoreflect.List, fd protoreflect.FieldDescriptor) (n int, err error) {
 	var nerr errors.NonFatal
-	switch field.Kind() {
+	switch fd.Kind() {
 	case protoreflect.BoolKind:
 		if wtyp == wire.BytesType {
 			buf, n := wire.ConsumeBytes(b)
@@ -553,8 +553,8 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Numbe
 		if n < 0 {
 			return 0, wire.ParseError(n)
 		}
-		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
-			nerr.AppendInvalidUTF8(string(field.FullName()))
+		if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+			nerr.AppendInvalidUTF8(string(fd.FullName()))
 		}
 		list.Append(protoreflect.ValueOf(string(v)))
 		return n, nerr.E
@@ -586,7 +586,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Numbe
 		if wtyp != wire.StartGroupType {
 			return 0, errUnknown
 		}
-		v, n := wire.ConsumeGroup(num, b)
+		v, n := wire.ConsumeGroup(fd.Number(), b)
 		if n < 0 {
 			return 0, wire.ParseError(n)
 		}

+ 6 - 14
proto/decode_test.go

@@ -16,7 +16,6 @@ import (
 	"google.golang.org/protobuf/internal/scalar"
 	"google.golang.org/protobuf/proto"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
-	"google.golang.org/protobuf/runtime/protoimpl"
 
 	legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
 	legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
@@ -910,12 +909,12 @@ var testProtos = []testProto{
 		desc: "unknown fields",
 		decodeTo: []proto.Message{build(
 			&testpb.TestAllTypes{},
-			unknown(100000, pack.Message{
+			unknown(pack.Message{
 				pack.Tag{100000, pack.VarintType}, pack.Varint(1),
 			}.Marshal()),
 		), build(
 			&test3pb.TestAllTypes{},
-			unknown(100000, pack.Message{
+			unknown(pack.Message{
 				pack.Tag{100000, pack.VarintType}, pack.Varint(1),
 			}.Marshal()),
 		)},
@@ -927,12 +926,12 @@ var testProtos = []testProto{
 		desc: "field type mismatch",
 		decodeTo: []proto.Message{build(
 			&testpb.TestAllTypes{},
-			unknown(1, pack.Message{
+			unknown(pack.Message{
 				pack.Tag{1, pack.BytesType}, pack.String("string"),
 			}.Marshal()),
 		), build(
 			&test3pb.TestAllTypes{},
-			unknown(1, pack.Message{
+			unknown(pack.Message{
 				pack.Tag{1, pack.BytesType}, pack.String("string"),
 			}.Marshal()),
 		)},
@@ -1345,16 +1344,9 @@ func build(m proto.Message, opts ...buildOpt) proto.Message {
 
 type buildOpt func(proto.Message)
 
-func unknown(num pref.FieldNumber, raw pref.RawFields) buildOpt {
+func unknown(raw pref.RawFields) buildOpt {
 	return func(m proto.Message) {
-		m.ProtoReflect().UnknownFields().Set(num, raw)
-	}
-}
-
-func registerExtension(desc *protoV1.ExtensionDesc) buildOpt {
-	return func(m proto.Message) {
-		et := protoimpl.X.ExtensionTypeFromDesc(desc)
-		m.ProtoReflect().KnownFields().ExtensionTypes().Register(et)
+		m.ProtoReflect().SetUnknown(raw)
 	}
 }
 

+ 26 - 41
proto/encode.go

@@ -5,7 +5,6 @@
 package proto
 
 import (
-	"fmt"
 	"sort"
 
 	"google.golang.org/protobuf/internal/encoding/wire"
@@ -125,23 +124,14 @@ func (o MarshalOptions) marshalMessageFast(b []byte, m Message) ([]byte, error)
 func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte, error) {
 	// There are many choices for what order we visit fields in. The default one here
 	// is chosen for reasonable efficiency and simplicity given the protoreflect API.
-	// It is not deterministic, since KnownFields.Range does not return fields in any
+	// It is not deterministic, since Message.Range does not return fields in any
 	// defined order.
 	//
 	// When using deterministic serialization, we sort the known fields by field number.
-	fieldDescs := m.Descriptor().Fields()
-	knownFields := m.KnownFields()
 	var err error
 	var nerr errors.NonFatal
-	o.rangeKnown(knownFields, func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
-		field := fieldDescs.ByNumber(num)
-		if field == nil {
-			field = knownFields.ExtensionTypes().ByNumber(num).Descriptor()
-			if field == nil {
-				panic(fmt.Errorf("no descriptor for field %d in %q", num, m.Descriptor().FullName()))
-			}
-		}
-		b, err = o.marshalField(b, field, value)
+	o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		b, err = o.marshalField(b, fd, v)
 		if nerr.Merge(err) {
 			err = nil
 			return true
@@ -151,57 +141,52 @@ func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte
 	if err != nil {
 		return b, err
 	}
-	m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool {
-		b = append(b, raw...)
-		return true
-	})
+	b = append(b, m.GetUnknown()...)
 	return b, nerr.E
 }
 
-// rangeKnown visits known fields in field number order when deterministic
+// rangeFields visits fields in field number order when deterministic
 // serialization is enabled.
-func (o MarshalOptions) rangeKnown(knownFields protoreflect.KnownFields, f func(protoreflect.FieldNumber, protoreflect.Value) bool) {
+func (o MarshalOptions) rangeFields(m protoreflect.Message, f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
 	if !o.Deterministic {
-		knownFields.Range(f)
+		m.Range(f)
 		return
 	}
-	nums := make([]protoreflect.FieldNumber, 0, knownFields.Len())
-	knownFields.Range(func(num protoreflect.FieldNumber, _ protoreflect.Value) bool {
-		nums = append(nums, num)
+	fds := make([]protoreflect.FieldDescriptor, 0, m.Len())
+	m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
+		fds = append(fds, fd)
 		return true
 	})
-	sort.Slice(nums, func(a, b int) bool {
-		return nums[a] < nums[b]
+	sort.Slice(fds, func(a, b int) bool {
+		return fds[a].Number() < fds[b].Number()
 	})
-	for _, num := range nums {
-		if !f(num, knownFields.Get(num)) {
+	for _, fd := range fds {
+		if !f(fd, m.Get(fd)) {
 			break
 		}
 	}
 }
 
 func (o MarshalOptions) marshalField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) {
-	num := fd.Number()
-	kind := fd.Kind()
 	switch {
 	case fd.IsList():
-		return o.marshalList(b, num, fd, value.List())
+		return o.marshalList(b, fd, value.List())
 	case fd.IsMap():
-		return o.marshalMap(b, num, fd, value.Map())
+		return o.marshalMap(b, fd, value.Map())
 	default:
-		b = wire.AppendTag(b, num, wireTypes[kind])
-		return o.marshalSingular(b, num, fd, value)
+		b = wire.AppendTag(b, fd.Number(), wireTypes[fd.Kind()])
+		return o.marshalSingular(b, fd, value)
 	}
 }
 
-func (o MarshalOptions) marshalList(b []byte, num wire.Number, fd protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
-	if fd.IsPacked() {
-		b = wire.AppendTag(b, num, wire.BytesType)
+func (o MarshalOptions) marshalList(b []byte, fd protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
+	if fd.IsPacked() && list.Len() > 0 {
+		b = wire.AppendTag(b, fd.Number(), wire.BytesType)
 		b, pos := appendSpeculativeLength(b)
 		var nerr errors.NonFatal
 		for i, llen := 0, list.Len(); i < llen; i++ {
 			var err error
-			b, err = o.marshalSingular(b, num, fd, list.Get(i))
+			b, err = o.marshalSingular(b, fd, list.Get(i))
 			if !nerr.Merge(err) {
 				return b, err
 			}
@@ -214,8 +199,8 @@ func (o MarshalOptions) marshalList(b []byte, num wire.Number, fd protoreflect.F
 	var nerr errors.NonFatal
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		var err error
-		b = wire.AppendTag(b, num, wireTypes[kind])
-		b, err = o.marshalSingular(b, num, fd, list.Get(i))
+		b = wire.AppendTag(b, fd.Number(), wireTypes[kind])
+		b, err = o.marshalSingular(b, fd, list.Get(i))
 		if !nerr.Merge(err) {
 			return b, err
 		}
@@ -223,13 +208,13 @@ func (o MarshalOptions) marshalList(b []byte, num wire.Number, fd protoreflect.F
 	return b, nerr.E
 }
 
-func (o MarshalOptions) marshalMap(b []byte, num wire.Number, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) {
+func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) {
 	keyf := fd.MapKey()
 	valf := fd.MapValue()
 	var nerr errors.NonFatal
 	var err error
 	o.rangeMap(mapv, keyf.Kind(), func(key protoreflect.MapKey, value protoreflect.Value) bool {
-		b = wire.AppendTag(b, num, wire.BytesType)
+		b = wire.AppendTag(b, fd.Number(), wire.BytesType)
 		var pos int
 		b, pos = appendSpeculativeLength(b)
 

+ 6 - 6
proto/encode_gen.go

@@ -36,9 +36,9 @@ var wireTypes = map[protoreflect.Kind]wire.Type{
 	protoreflect.GroupKind:    wire.StartGroupType,
 }
 
-func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
+func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
 	var nerr errors.NonFatal
-	switch field.Kind() {
+	switch fd.Kind() {
 	case protoreflect.BoolKind:
 		b = wire.AppendVarint(b, wire.EncodeBool(v.Bool()))
 	case protoreflect.EnumKind:
@@ -68,8 +68,8 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protore
 	case protoreflect.DoubleKind:
 		b = wire.AppendFixed64(b, math.Float64bits(v.Float()))
 	case protoreflect.StringKind:
-		if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
-			nerr.AppendInvalidUTF8(string(field.FullName()))
+		if fd.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
+			nerr.AppendInvalidUTF8(string(fd.FullName()))
 		}
 		b = wire.AppendBytes(b, []byte(v.String()))
 	case protoreflect.BytesKind:
@@ -89,9 +89,9 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protore
 		if !nerr.Merge(err) {
 			return b, err
 		}
-		b = wire.AppendVarint(b, wire.EncodeTag(num, wire.EndGroupType))
+		b = wire.AppendVarint(b, wire.EncodeTag(fd.Number(), wire.EndGroupType))
 	default:
-		return b, errors.New("invalid kind %v", field.Kind())
+		return b, errors.New("invalid kind %v", fd.Kind())
 	}
 	return b, nerr.E
 }

+ 3 - 15
proto/encode_test.go

@@ -7,13 +7,11 @@ package proto_test
 import (
 	"bytes"
 	"fmt"
-	"reflect"
 	"testing"
 
 	protoV1 "github.com/golang/protobuf/proto"
 	"github.com/google/go-cmp/cmp"
 	"google.golang.org/protobuf/proto"
-	pref "google.golang.org/protobuf/reflect/protoreflect"
 
 	test3pb "google.golang.org/protobuf/internal/testprotos/test3"
 )
@@ -35,7 +33,7 @@ func TestEncode(t *testing.T) {
 					t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want))
 				}
 
-				got := newMessage(want)
+				got := want.ProtoReflect().New().Interface()
 				uopts := proto.UnmarshalOptions{
 					AllowPartial: test.partial,
 				}
@@ -76,7 +74,7 @@ func TestEncodeDeterministic(t *testing.T) {
 					t.Fatalf("deterministic marshal returned varying results:\n%v", cmp.Diff(wire, wire2))
 				}
 
-				got := newMessage(want)
+				got := want.ProtoReflect().New().Interface()
 				uopts := proto.UnmarshalOptions{
 					AllowPartial: test.partial,
 				}
@@ -105,7 +103,7 @@ func TestEncodeInvalidUTF8(t *testing.T) {
 				if !isErrInvalidUTF8(err) {
 					t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
 				}
-				got := newMessage(want)
+				got := want.ProtoReflect().New().Interface()
 				if err := proto.Unmarshal(wire, got); !isErrInvalidUTF8(err) {
 					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					return
@@ -147,13 +145,3 @@ func TestMarshalAppend(t *testing.T) {
 		t.Fatalf("MarshalAppend modified prefix: got %v, want prefix %v", got, want)
 	}
 }
-
-// newMessage returns a new message with the same type and extension fields as m.
-func newMessage(m proto.Message) proto.Message {
-	n := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
-	m.ProtoReflect().KnownFields().ExtensionTypes().Range(func(xt pref.ExtensionType) bool {
-		n.ProtoReflect().KnownFields().ExtensionTypes().Register(xt)
-		return true
-	})
-	return n
-}

+ 72 - 96
proto/equal.go

@@ -6,147 +6,123 @@ package proto
 
 import (
 	"bytes"
+	"reflect"
 
+	"google.golang.org/protobuf/internal/encoding/wire"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
-// Equal returns true of two messages are equal.
+// Equal reports whether two messages are equal.
 //
-// Two messages are equal if they have identical types and registered extension fields,
-// marshal to the same bytes under deterministic serialization,
-// and contain no floating point NaNs.
-func Equal(a, b Message) bool {
-	return equalMessage(a.ProtoReflect(), b.ProtoReflect())
+// Two messages are equal if they belong to the same message descriptor,
+// have the same set of populated known and extension field values,
+// and the same set of unknown fields values.
+//
+// Scalar values are compared with the equivalent of the == operator in Go,
+// except bytes values which are compared using bytes.Equal.
+// Note that this means that floating point NaNs are considered inequal.
+// Message values are compared by recursively calling Equal.
+// Lists are equal if each element value is also equal.
+// Maps are equal if they have the same set of keys, where the pair of values
+// for each key is also equal.
+func Equal(x, y Message) bool {
+	return equalMessage(x.ProtoReflect(), y.ProtoReflect())
 }
 
 // equalMessage compares two messages.
-func equalMessage(a, b pref.Message) bool {
-	mda, mdb := a.Descriptor(), b.Descriptor()
-	if mda != mdb && mda.FullName() != mdb.FullName() {
+func equalMessage(mx, my pref.Message) bool {
+	if mx.Descriptor() != my.Descriptor() {
 		return false
 	}
 
-	// TODO: The v1 says that a nil message is not equal to an empty one.
-	// Decide what to do about this when v1 wraps v2.
-
-	knowna, knownb := a.KnownFields(), b.KnownFields()
-
-	fields := mda.Fields()
-	for i, flen := 0, fields.Len(); i < flen; i++ {
-		fd := fields.Get(i)
-		num := fd.Number()
-		hasa, hasb := knowna.Has(num), knownb.Has(num)
-		if !hasa && !hasb {
-			continue
-		}
-		if hasa != hasb || !equalFields(fd, knowna.Get(num), knownb.Get(num)) {
-			return false
-		}
-	}
-	equal := true
-
-	unknowna, unknownb := a.UnknownFields(), b.UnknownFields()
-	ulen := unknowna.Len()
-	if ulen != unknownb.Len() {
+	if mx.Len() != my.Len() {
 		return false
 	}
-	unknowna.Range(func(num pref.FieldNumber, ra pref.RawFields) bool {
-		rb := unknownb.Get(num)
-		if !bytes.Equal([]byte(ra), []byte(rb)) {
-			equal = false
-			return false
-		}
-		return true
+	equal := true
+	mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
+		vy := my.Get(fd)
+		equal = my.Has(fd) && equalField(fd, vx, vy)
+		return equal
 	})
 	if !equal {
 		return false
 	}
 
-	// If the set of extension types is not identical for both messages, we report
-	// a inequality.
-	//
-	// This requirement is stringent. Registering an extension type for a message
-	// without setting a value for the extension will cause that message to compare
-	// as inequal to the same message without the registration.
-	//
-	// TODO: Revisit this behavior after eager decoding of extensions is implemented.
-	xtypesa, xtypesb := knowna.ExtensionTypes(), knownb.ExtensionTypes()
-	if la, lb := xtypesa.Len(), xtypesb.Len(); la != lb {
-		return false
-	} else if la == 0 {
-		return true
-	}
-	xtypesa.Range(func(xt pref.ExtensionType) bool {
-		num := xt.Descriptor().Number()
-		if xtypesb.ByNumber(num) != xt {
-			equal = false
-			return false
-		}
-		hasa, hasb := knowna.Has(num), knownb.Has(num)
-		if !hasa && !hasb {
-			return true
-		}
-		if hasa != hasb || !equalFields(xt.Descriptor(), knowna.Get(num), knownb.Get(num)) {
-			equal = false
-			return false
-		}
-		return true
-	})
-	return equal
+	return equalUnknown(mx.GetUnknown(), my.GetUnknown())
 }
 
-// equalFields compares two fields.
-func equalFields(fd pref.FieldDescriptor, a, b pref.Value) bool {
+// equalField compares two fields.
+func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
 	switch {
 	case fd.IsList():
-		return equalList(fd, a.List(), b.List())
+		return equalList(fd, x.List(), y.List())
 	case fd.IsMap():
-		return equalMap(fd, a.Map(), b.Map())
+		return equalMap(fd, x.Map(), y.Map())
 	default:
-		return equalValue(fd, a, b)
+		return equalValue(fd, x, y)
 	}
 }
 
-// equalMap compares a map field.
-func equalMap(fd pref.FieldDescriptor, a, b pref.Map) bool {
-	alen := a.Len()
-	if alen != b.Len() {
+// equalMap compares two maps.
+func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
+	if x.Len() != y.Len() {
 		return false
 	}
 	equal := true
-	a.Range(func(k pref.MapKey, va pref.Value) bool {
-		vb := b.Get(k)
-		if !vb.IsValid() || !equalValue(fd.MapValue(), va, vb) {
-			equal = false
-			return false
-		}
-		return true
+	x.Range(func(k pref.MapKey, vx pref.Value) bool {
+		vy := y.Get(k)
+		equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
+		return equal
 	})
 	return equal
 }
 
-// equalList compares a non-map repeated field.
-func equalList(fd pref.FieldDescriptor, a, b pref.List) bool {
-	alen := a.Len()
-	if alen != b.Len() {
+// equalList compares two lists.
+func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
+	if x.Len() != y.Len() {
 		return false
 	}
-	for i := 0; i < alen; i++ {
-		if !equalValue(fd, a.Get(i), b.Get(i)) {
+	for i := x.Len() - 1; i >= 0; i-- {
+		if !equalValue(fd, x.Get(i), y.Get(i)) {
 			return false
 		}
 	}
 	return true
 }
 
-// equalValue compares the scalar value type of a field.
-func equalValue(fd pref.FieldDescriptor, a, b pref.Value) bool {
+// equalValue compares two singular values.
+func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
 	switch {
 	case fd.Message() != nil:
-		return equalMessage(a.Message(), b.Message())
+		return equalMessage(x.Message(), y.Message())
 	case fd.Kind() == pref.BytesKind:
-		return bytes.Equal(a.Bytes(), b.Bytes())
+		return bytes.Equal(x.Bytes(), y.Bytes())
 	default:
-		return a.Interface() == b.Interface()
+		return x.Interface() == y.Interface()
+	}
+}
+
+// equalUnknown compares unknown fields by direct comparison on the raw bytes
+// of each individual field number.
+func equalUnknown(x, y pref.RawFields) bool {
+	if len(x) != len(y) {
+		return false
+	}
+	if bytes.Equal([]byte(x), []byte(y)) {
+		return true
+	}
+
+	mx := make(map[pref.FieldNumber]pref.RawFields)
+	my := make(map[pref.FieldNumber]pref.RawFields)
+	for len(x) > 0 {
+		fnum, _, n := wire.ConsumeField(x)
+		mx[fnum] = append(mx[fnum], x[:n]...)
+		x = x[n:]
+	}
+	for len(y) > 0 {
+		fnum, _, n := wire.ConsumeField(y)
+		my[fnum] = append(my[fnum], y[:n]...)
+		y = y[n:]
 	}
+	return reflect.DeepEqual(mx, my)
 }

+ 4 - 43
proto/equal_test.go

@@ -371,22 +371,22 @@ var inequalities = []struct{ a, b proto.Message }{
 	},
 	// Unknown fields.
 	{
-		build(&testpb.TestAllTypes{}, unknown(100000, pack.Message{
+		build(&testpb.TestAllTypes{}, unknown(pack.Message{
 			pack.Tag{100000, pack.VarintType}, pack.Varint(1),
 		}.Marshal())),
-		build(&testpb.TestAllTypes{}, unknown(100000, pack.Message{
+		build(&testpb.TestAllTypes{}, unknown(pack.Message{
 			pack.Tag{100000, pack.VarintType}, pack.Varint(2),
 		}.Marshal())),
 	},
 	{
-		build(&testpb.TestAllTypes{}, unknown(100000, pack.Message{
+		build(&testpb.TestAllTypes{}, unknown(pack.Message{
 			pack.Tag{100000, pack.VarintType}, pack.Varint(1),
 		}.Marshal())),
 		&testpb.TestAllTypes{},
 	},
 	{
 		&testpb.TestAllTypes{},
-		build(&testpb.TestAllTypes{}, unknown(100000, pack.Message{
+		build(&testpb.TestAllTypes{}, unknown(pack.Message{
 			pack.Tag{100000, pack.VarintType}, pack.Varint(1),
 		}.Marshal())),
 	},
@@ -399,34 +399,12 @@ var inequalities = []struct{ a, b proto.Message }{
 			extend(testpb.E_OptionalInt32Extension, scalar.Int32(2)),
 		),
 	},
-	{
-		build(&testpb.TestAllExtensions{},
-			registerExtension(testpb.E_OptionalInt32Extension),
-		),
-		build(&testpb.TestAllExtensions{},
-			extend(testpb.E_OptionalInt32Extension, scalar.Int32(2)),
-		),
-	},
-	{
-		build(&testpb.TestAllExtensions{},
-			extend(testpb.E_OptionalInt32Extension, scalar.Int32(1)),
-		),
-		build(&testpb.TestAllExtensions{},
-			registerExtension(testpb.E_OptionalInt32Extension),
-		),
-	},
 	{
 		&testpb.TestAllExtensions{},
 		build(&testpb.TestAllExtensions{},
 			extend(testpb.E_OptionalInt32Extension, scalar.Int32(2)),
 		),
 	},
-	{
-		&testpb.TestAllExtensions{},
-		build(&testpb.TestAllExtensions{},
-			registerExtension(testpb.E_OptionalInt32Extension),
-		),
-	},
 	// Proto2 default values are not considered by Equal, so the following are still unequal.
 	{
 		&testpb.TestAllTypes{DefaultInt32: scalar.Int32(81)},
@@ -496,21 +474,4 @@ var inequalities = []struct{ a, b proto.Message }{
 		&testpb.TestAllTypes{},
 		&testpb.TestAllTypes{DefaultNestedEnum: testpb.TestAllTypes_BAR.Enum()},
 	},
-	// Extension ddefault values are not considered by Equal, so the following are still unequal.
-	{
-		build(&testpb.TestAllExtensions{},
-			registerExtension(testpb.E_DefaultInt32Extension),
-		),
-		build(&testpb.TestAllExtensions{},
-			extend(testpb.E_DefaultInt32Extension, scalar.Int32(81)),
-		),
-	},
-	{
-		build(&testpb.TestAllExtensions{},
-			extend(testpb.E_DefaultInt32Extension, scalar.Int32(81)),
-		),
-		build(&testpb.TestAllExtensions{},
-			registerExtension(testpb.E_DefaultInt32Extension),
-		),
-	},
 }

+ 18 - 31
proto/isinit.go

@@ -28,53 +28,40 @@ func IsInitialized(m Message) error {
 // IsInitialized returns an error if any required fields in m are not set.
 func isInitialized(m pref.Message, stack []interface{}) error {
 	md := m.Descriptor()
-	known := m.KnownFields()
-	fields := md.Fields()
+	fds := md.Fields()
 	for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
-		num := nums.Get(i)
-		if !known.Has(num) {
-			stack = append(stack, fields.ByNumber(num).Name())
+		fd := fds.ByNumber(nums.Get(i))
+		if !m.Has(fd) {
+			stack = append(stack, fd.Name())
 			return newRequiredNotSetError(stack)
 		}
 	}
 	var err error
-	known.Range(func(num pref.FieldNumber, v pref.Value) bool {
-		field := fields.ByNumber(num)
-		if field == nil {
-			field = known.ExtensionTypes().ByNumber(num).Descriptor()
-		}
-		if field == nil {
-			panic(fmt.Errorf("no descriptor for field %d in %q", num, md.FullName()))
-		}
-		// Look for fields containing a message: Messages, groups, and maps
-		// with a message or group value.
-		md := field.Message()
-		if md == nil {
-			return true
-		}
-		if field.IsMap() {
-			if field.MapValue().Message() == nil {
+	m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+		// Recurse into fields containing message values.
+		stack := append(stack, fd.Name())
+		switch {
+		case fd.IsList():
+			if fd.Message() == nil {
 				return true
 			}
-		}
-		// Recurse into the field
-		stack := append(stack, field.Name())
-		switch {
-		case field.IsList():
-			for i, list := 0, v.List(); i < list.Len(); i++ {
+			for i, list := 0, v.List(); i < list.Len() && err == nil; i++ {
 				stack := append(stack, "[", i, "].")
 				err = isInitialized(list.Get(i).Message(), stack)
-				if err != nil {
-					break
-				}
 			}
-		case field.IsMap():
+		case fd.IsMap():
+			if fd.MapValue().Message() == nil {
+				return true
+			}
 			v.Map().Range(func(key pref.MapKey, v pref.Value) bool {
 				stack := append(stack, "[", key, "].")
 				err = isInitialized(v.Message(), stack)
 				return err == nil
 			})
 		default:
+			if fd.Message() == nil {
+				return true
+			}
 			stack := append(stack, ".")
 			err = isInitialized(v.Message(), stack)
 		}

+ 24 - 0
proto/reset.go

@@ -0,0 +1,24 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import "google.golang.org/protobuf/reflect/protoreflect"
+
+// Reset clears every field in the message.
+func Reset(m Message) {
+	// TODO: Document memory aliasing guarantees.
+	// TODO: Add fast-path for reset?
+	resetMessage(m.ProtoReflect())
+}
+
+func resetMessage(m protoreflect.Message) {
+	m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
+		m.Clear(fd)
+		return true
+	})
+	if m.GetUnknown() != nil {
+		m.SetUnknown(nil)
+	}
+}

+ 4 - 18
proto/size.go

@@ -5,8 +5,6 @@
 package proto
 
 import (
-	"fmt"
-
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/reflect/protoreflect"
 )
@@ -34,23 +32,11 @@ func sizeMessageFast(m Message) (int, error) {
 }
 
 func sizeMessage(m protoreflect.Message) (size int) {
-	fieldDescs := m.Descriptor().Fields()
-	knownFields := m.KnownFields()
-	m.KnownFields().Range(func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
-		field := fieldDescs.ByNumber(num)
-		if field == nil {
-			field = knownFields.ExtensionTypes().ByNumber(num).Descriptor()
-			if field == nil {
-				panic(fmt.Errorf("no descriptor for field %d in %q", num, m.Descriptor().FullName()))
-			}
-		}
-		size += sizeField(field, value)
-		return true
-	})
-	m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool {
-		size += len(raw)
+	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		size += sizeField(fd, v)
 		return true
 	})
+	size += len(m.GetUnknown())
 	return size
 }
 
@@ -67,7 +53,7 @@ func sizeField(fd protoreflect.FieldDescriptor, value protoreflect.Value) (size
 }
 
 func sizeList(num wire.Number, fd protoreflect.FieldDescriptor, list protoreflect.List) (size int) {
-	if fd.IsPacked() {
+	if fd.IsPacked() && list.Len() > 0 {
 		content := 0
 		for i, llen := 0, list.Len(); i < llen; i++ {
 			content += sizeSingular(num, fd.Kind(), list.Get(i))

+ 186 - 0
reflect/protoreflect/deprecated.go

@@ -0,0 +1,186 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package protoreflect
+
+// TODO: Remove this file.
+
+type (
+	deprecatedEnum interface {
+		// Deprecated: Use Descriptor instead.
+		Type() EnumType
+	}
+
+	deprecatedMessage interface {
+		// Deprecated: Use Descriptor instead.
+		Type() MessageType
+		// Deprecated: Use methods on Message directly.
+		KnownFields() KnownFields
+		// Deprecated: Use methods on Message directly.
+		UnknownFields() UnknownFields
+	}
+)
+
+// KnownFields provides accessor and mutator methods for known fields.
+//
+// Each field Value can either be a scalar, Message, List, or Map.
+// The field is a List or Map if FieldDescriptor.Cardinality is Repeated and
+// a Map if and only if FieldDescriptor.IsMap is true. The scalar type or
+// underlying repeated element type is determined by the FieldDescriptor.Kind.
+// See Value for a list of Go types associated with each Kind.
+//
+// Field extensions are handled as known fields once the extension type has been
+// registered with KnownFields.ExtensionTypes.
+//
+// Len, Has, Get, Range, and ExtensionTypes are safe for concurrent use.
+//
+// Deprecated: Use direct methods on Message instead.
+type KnownFields interface {
+	// Len reports the number of fields that are populated.
+	Len() int
+
+	// Has reports whether a field is populated.
+	//
+	// Some fields have the property of nullability where it is possible to
+	// distinguish between the default value of a field and whether the field
+	// was explicitly populated with the default value. Only scalars in proto2,
+	// member fields of a oneof, and singular messages are nullable.
+	//
+	// A nullable field is populated only if explicitly set.
+	// A scalar field in proto3 is populated if it contains a non-zero value.
+	// A repeated field is populated only if it is non-empty.
+	Has(FieldNumber) bool
+
+	// Get retrieves the value for a field with the given field number.
+	// If the field is unpopulated, it returns the default value for scalars,
+	// a mutable empty List for empty repeated fields, a mutable empty Map for
+	// empty map fields, and an invalid value for message fields.
+	// If the field is unknown (does not appear in MessageDescriptor.Fields
+	// or ExtensionFieldTypes), it returns an invalid value.
+	Get(FieldNumber) Value
+
+	// Set stores the value for a field with the given field number.
+	// Setting a field belonging to a oneof implicitly clears any other field
+	// that may be currently set by the same oneof.
+	//
+	// When setting a composite type, it is unspecified whether the set
+	// value aliases the source's memory in any way.
+	//
+	// It panics if the field number does not correspond with a known field
+	// in MessageDescriptor.Fields or an extension field in ExtensionTypes.
+	Set(FieldNumber, Value)
+
+	// TODO: Document memory aliasing behavior when a field is cleared?
+	// For example, if Mutable is called later, can it reuse memory?
+
+	// Clear clears the field such that a subsequent call to Has reports false.
+	// The operation does nothing if the field number does not correspond with
+	// a known field or extension field.
+	Clear(FieldNumber)
+
+	// WhichOneof reports which field within the named oneof is populated.
+	// It returns 0 if the oneof does not exist or no fields are populated.
+	WhichOneof(Name) FieldNumber
+
+	// Range iterates over every populated field in an undefined order,
+	// calling f for each field number and value encountered.
+	// Range calls f Len times unless f returns false, which stops iteration.
+	// While iterating, mutating operations through Set, Clear, or Mutable
+	// may only be performed on the current field number.
+	Range(f func(FieldNumber, Value) bool)
+
+	// NewMessage returns a newly allocated empty message assignable to
+	// the field of the given number.
+	// It panics if the field is not a singular message.
+	NewMessage(FieldNumber) Message
+
+	// ExtensionTypes are extension field types that are known by this
+	// specific message instance.
+	ExtensionTypes() ExtensionFieldTypes
+}
+
+// UnknownFields are a list of unknown or unparsed fields and may contain
+// field numbers corresponding with defined fields or extension fields.
+// The ordering of fields is maintained for fields of the same field number.
+// However, the relative ordering of fields with different field numbers
+// is undefined.
+//
+// Len, Get, and Range are safe for concurrent use.
+//
+// Deprecated: Use direct methods on Message instead.
+type UnknownFields interface {
+	// Len reports the number of fields that are populated.
+	Len() int
+
+	// Get retrieves the raw bytes of fields with the given field number.
+	// It returns an empty RawFields if there are no populated fields.
+	//
+	// The caller must not mutate the content of the retrieved RawFields.
+	Get(FieldNumber) RawFields
+
+	// Set stores the raw bytes of fields with the given field number.
+	// The RawFields must be valid and correspond with the given field number;
+	// an implementation may panic if the fields are invalid.
+	// An empty RawFields may be passed to clear the fields.
+	//
+	// The caller must not mutate the content of the RawFields being stored.
+	Set(FieldNumber, RawFields)
+
+	// Range iterates over every populated field in an undefined order,
+	// calling f for each field number and raw field value encountered.
+	// Range calls f Len times unless f returns false, which stops iteration.
+	// While iterating, mutating operations through Set may only be performed
+	// on the current field number.
+	//
+	// While the iteration order is undefined, it is deterministic.
+	// It is recommended, but not required, that fields be presented
+	// in the order that they were encountered in the wire data.
+	Range(f func(FieldNumber, RawFields) bool)
+
+	// TODO: Should IsSupported be renamed as ReadOnly?
+	// TODO: Should IsSupported panic on Set instead of silently ignore?
+
+	// IsSupported reports whether this message supports unknown fields.
+	// If false, UnknownFields ignores all Set operations.
+	IsSupported() bool
+}
+
+// ExtensionFieldTypes are the extension field types that this message instance
+// has been extended with.
+//
+// Len, Get, and Range are safe for concurrent use.
+//
+// Deprecated: Use direct methods on Message instead.
+type ExtensionFieldTypes interface {
+	// Len reports the number of field extensions.
+	Len() int
+
+	// Register stores an ExtensionType.
+	// The ExtensionType.ExtendedType must match the containing message type
+	// and the field number must be within the valid extension ranges
+	// (see MessageDescriptor.ExtensionRanges).
+	// It panics if the extension has already been registered (i.e.,
+	// a conflict by number or by full name).
+	Register(ExtensionType)
+
+	// Remove removes the ExtensionType.
+	// It panics if a value for this extension field is still populated.
+	// The operation does nothing if there is no associated type to remove.
+	Remove(ExtensionType)
+
+	// ByNumber looks up an extension by field number.
+	// It returns nil if not found.
+	ByNumber(FieldNumber) ExtensionType
+
+	// ByName looks up an extension field by full name.
+	// It returns nil if not found.
+	ByName(FullName) ExtensionType
+
+	// Range iterates over every registered field in an undefined order,
+	// calling f for each extension descriptor encountered.
+	// Range calls f Len times unless f returns false, which stops iteration.
+	// While iterating, mutating operations through Remove may only
+	// be performed on the current descriptor.
+	Range(f func(ExtensionType) bool)
+}

+ 2 - 10
reflect/protoreflect/type.go

@@ -6,9 +6,6 @@ package protoreflect
 
 import "reflect"
 
-// TODO: For all ByX methods (e.g., ByName, ByJSONName, ByNumber, etc),
-// should they use a (v, ok) signature for the return value?
-
 // Descriptor provides a set of accessors that are common to every descriptor.
 // Each descriptor type wraps the equivalent google.protobuf.XXXDescriptorProto,
 // but provides efficient lookup and immutability.
@@ -231,8 +228,6 @@ type isMessageDescriptor interface{ ProtoType(MessageDescriptor) }
 
 // MessageType encapsulates a MessageDescriptor with a concrete Go implementation.
 type MessageType interface {
-	// TODO: Remove this.
-	// Deprecated: Do not rely on these methods.
 	MessageDescriptor
 
 	// New returns a newly allocated empty message.
@@ -444,8 +439,6 @@ type ExtensionDescriptors interface {
 // Field "bar_field" is an extension of FooMessage, but its full name is
 // "example.BarMessage.bar_field" instead of "example.FooMessage.bar_field".
 type ExtensionType interface {
-	// TODO: Remove this.
-	// Deprecated: Do not rely on these methods.
 	ExtensionDescriptor
 
 	// New returns a new value for the field.
@@ -475,7 +468,8 @@ type ExtensionType interface {
 
 	// InterfaceOf completely unwraps the Value to the underlying Go type.
 	// InterfaceOf panics if the input is nil or does not represent the
-	// appropriate underlying Go type.
+	// appropriate underlying Go type. For composite types, it panics if the
+	// value is not mutable.
 	//
 	// InterfaceOf is able to unwrap the Value further than Value.Interface
 	// as it has more type information available.
@@ -504,8 +498,6 @@ type isEnumDescriptor interface{ ProtoType(EnumDescriptor) }
 
 // EnumType encapsulates an EnumDescriptor with a concrete Go implementation.
 type EnumType interface {
-	// TODO: Remove this.
-	// Deprecated: Do not rely on these methods.
 	EnumDescriptor
 
 	// New returns an instance of this enum type with its value set to n.

+ 125 - 194
reflect/protoreflect/value.go

@@ -13,294 +13,225 @@ import "google.golang.org/protobuf/internal/encoding/wire"
 type Enum interface {
 	Descriptor() EnumDescriptor
 
-	// TODO: Remove this.
-	// Deprecated: Use Descriptor instead.
-	Type() EnumType
-
 	// Number returns the enum value as an integer.
 	Number() EnumNumber
+
+	deprecatedEnum
 }
 
 // Message is a reflective interface for a concrete message value,
-// which provides type information and getters/setters for individual fields.
+// encapsulating both type and value information for the message.
+//
+// Accessor/mutators for individual fields are keyed by FieldDescriptor.
+// For non-extension fields, the descriptor must exactly match the
+// field known by the parent message.
+// For extension fields, the descriptor must implement ExtensionType,
+// extend the parent message (i.e., have the same message FullName), and
+// be within the parent's extension range.
 //
-// Concrete types may implement interfaces defined in proto/protoiface,
-// which provide specialized, performant implementations of high-level
-// operations such as Marshal and Unmarshal.
+// Each field Value can be a scalar or a composite type (Message, List, or Map).
+// See Value for the Go types associated with a FieldDescriptor.
+// Providing a Value that is invalid or of an incorrect type panics.
 type Message interface {
 	Descriptor() MessageDescriptor
 
-	// TODO: Remove this.
-	// Deprecated: Use Descriptor instead.
-	Type() MessageType
-
-	// KnownFields returns an interface to access/mutate known fields.
-	KnownFields() KnownFields
-
-	// UnknownFields returns an interface to access/mutate unknown fields.
-	UnknownFields() UnknownFields
-
-	// New returns a newly allocated empty message.
+	// New returns a newly allocated and mutable empty message.
 	New() Message
 
 	// Interface unwraps the message reflection interface and
-	// returns the underlying proto.Message interface.
+	// returns the underlying ProtoMessage interface.
 	Interface() ProtoMessage
-}
 
-// KnownFields provides accessor and mutator methods for known fields.
-//
-// Each field Value can either be a scalar, Message, List, or Map.
-// The field is a List or Map if FieldDescriptor.Cardinality is Repeated and
-// a Map if and only if FieldDescriptor.IsMap is true. The scalar type or
-// underlying repeated element type is determined by the FieldDescriptor.Kind.
-// See Value for a list of Go types associated with each Kind.
-//
-// Field extensions are handled as known fields once the extension type has been
-// registered with KnownFields.ExtensionTypes.
-//
-// Len, Has, Get, Range, and ExtensionTypes are safe for concurrent use.
-type KnownFields interface {
-	// Len reports the number of fields that are populated.
+	// Len reports the number of populated fields (i.e., Has reports true).
 	Len() int
 
+	// Range iterates over every populated field in an undefined order,
+	// calling f for each field descriptor and value encountered.
+	// Range calls f Len times unless f returns false, which stops iteration.
+	// While iterating, mutating operations may only be performed
+	// on the current field descriptor.
+	Range(f func(FieldDescriptor, Value) bool)
+
 	// Has reports whether a field is populated.
 	//
 	// Some fields have the property of nullability where it is possible to
 	// distinguish between the default value of a field and whether the field
-	// was explicitly populated with the default value. Only scalars in proto2,
-	// member fields of a oneof, and singular messages are nullable.
+	// was explicitly populated with the default value. Singular message fields,
+	// member fields of a oneof, proto2 scalar fields, and extension fields
+	// are nullable. Such fields are populated only if explicitly set.
 	//
-	// A nullable field is populated only if explicitly set.
-	// A scalar field in proto3 is populated if it contains a non-zero value.
-	// A repeated field is populated only if it is non-empty.
-	Has(FieldNumber) bool
-
-	// Get retrieves the value for a field with the given field number.
-	// If the field is unpopulated, it returns the default value for scalars,
-	// a mutable empty List for empty repeated fields, a mutable empty Map for
-	// empty map fields, and an invalid value for message fields.
-	// If the field is unknown (does not appear in MessageDescriptor.Fields
-	// or ExtensionFieldTypes), it returns an invalid value.
-	Get(FieldNumber) Value
-
-	// Set stores the value for a field with the given field number.
-	// Setting a field belonging to a oneof implicitly clears any other field
-	// that may be currently set by the same oneof.
+	// In other cases (aside from the nullable cases above),
+	// a proto3 scalar field is populated if it contains a non-zero value, and
+	// a repeated field is populated if it is non-empty.
+	Has(FieldDescriptor) bool
+
+	// Clear clears the field such that a subsequent Has call reports false.
 	//
-	// When setting a composite type, it is unspecified whether the set
-	// value aliases the source's memory in any way.
+	// Clearing an extension field clears both the extension type and value
+	// associated with the given field number.
 	//
-	// It panics if the field number does not correspond with a known field
-	// in MessageDescriptor.Fields or an extension field in ExtensionTypes.
-	Set(FieldNumber, Value)
+	// Clear is a mutating operation and unsafe for concurrent use.
+	Clear(FieldDescriptor)
 
-	// TODO: Document memory aliasing behavior when a field is cleared?
-	// For example, if Mutable is called later, can it reuse memory?
+	// Get retrieves the value for a field.
+	//
+	// For unpopulated scalars, it returns the default value, where
+	// the default value of a bytes scalar is guaranteed to be a copy.
+	// For unpopulated composite types, it returns an empty, read-only view
+	// of the value; to obtain a mutable reference, use Mutable.
+	Get(FieldDescriptor) Value
 
-	// Clear clears the field such that a subsequent call to Has reports false.
-	// The operation does nothing if the field number does not correspond with
-	// a known field or extension field.
-	Clear(FieldNumber)
+	// TODO: Should Set of a empty, read-only value be equivalent to Clear?
 
-	// WhichOneof reports which field within the named oneof is populated.
-	// It returns 0 if the oneof does not exist or no fields are populated.
-	WhichOneof(Name) FieldNumber
+	// Set stores the value for a field.
+	//
+	// For a field belonging to a oneof, it implicitly clears any other field
+	// that may be currently set within the same oneof.
+	// For extension fields, it implicitly stores the provided ExtensionType.
+	// When setting a composite type, it is unspecified whether the stored value
+	// aliases the source's memory in any way. If the composite value is an
+	// empty, read-only value, then it panics.
+	//
+	// Set is a mutating operation and unsafe for concurrent use.
+	Set(FieldDescriptor, Value)
 
-	// Range iterates over every populated field in an undefined order,
-	// calling f for each field number and value encountered.
-	// Range calls f Len times unless f returns false, which stops iteration.
-	// While iterating, mutating operations through Set, Clear, or Mutable
-	// may only be performed on the current field number.
-	Range(f func(FieldNumber, Value) bool)
+	// Mutable returns a mutable reference to a composite type.
+	//
+	// If the field is unpopulated, it may allocate a composite value.
+	// For a field belonging to a oneof, it implicitly clears any other field
+	// that may be currently set within the same oneof.
+	// For extension fields, it implicitly stores the provided ExtensionType
+	// if not already stored.
+	// It panics if the field does not contain a composite type.
+	//
+	// Mutable is a mutating operation and unsafe for concurrent use.
+	Mutable(FieldDescriptor) Value
 
 	// NewMessage returns a newly allocated empty message assignable to
-	// the field of the given number.
+	// the field of the given descriptor.
 	// It panics if the field is not a singular message.
-	NewMessage(FieldNumber) Message
-
-	// ExtensionTypes are extension field types that are known by this
-	// specific message instance.
-	ExtensionTypes() ExtensionFieldTypes
-}
-
-// UnknownFields are a list of unknown or unparsed fields and may contain
-// field numbers corresponding with defined fields or extension fields.
-// The ordering of fields is maintained for fields of the same field number.
-// However, the relative ordering of fields with different field numbers
-// is undefined.
-//
-// Len, Get, and Range are safe for concurrent use.
-type UnknownFields interface {
-	// Len reports the number of fields that are populated.
-	Len() int
-
-	// Get retrieves the raw bytes of fields with the given field number.
-	// It returns an empty RawFields if there are no populated fields.
-	//
-	// The caller must not mutate the content of the retrieved RawFields.
-	Get(FieldNumber) RawFields
-
-	// Set stores the raw bytes of fields with the given field number.
-	// The RawFields must be valid and correspond with the given field number;
-	// an implementation may panic if the fields are invalid.
+	NewMessage(FieldDescriptor) Message
+
+	// WhichOneof reports which field within the oneof is populated,
+	// returning nil if none are populated.
+	// It panics if the oneof descriptor does not belong to this message.
+	WhichOneof(OneofDescriptor) FieldDescriptor
+
+	// GetUnknown retrieves the entire list of unknown fields.
+	// The caller may only mutate the contents of the RawFields
+	// if the mutated bytes are stored back into the message with SetUnknown.
+	GetUnknown() RawFields
+
+	// SetUnknown stores an entire list of unknown fields.
+	// The raw fields must be syntactically valid according to the wire format.
+	// An implementation may panic if this is not the case.
+	// Once stored, the caller must not mutate the content of the RawFields.
 	// An empty RawFields may be passed to clear the fields.
 	//
-	// The caller must not mutate the content of the RawFields being stored.
-	Set(FieldNumber, RawFields)
-
-	// Range iterates over every populated field in an undefined order,
-	// calling f for each field number and raw field value encountered.
-	// Range calls f Len times unless f returns false, which stops iteration.
-	// While iterating, mutating operations through Set may only be performed
-	// on the current field number.
-	//
-	// While the iteration order is undefined, it is deterministic.
-	// It is recommended, but not required, that fields be presented
-	// in the order that they were encountered in the wire data.
-	Range(f func(FieldNumber, RawFields) bool)
+	// SetUnknown is a mutating operation and unsafe for concurrent use.
+	SetUnknown(RawFields)
 
-	// TODO: Should IsSupported be renamed as ReadOnly?
-	// TODO: Should IsSupported panic on Set instead of silently ignore?
+	// TODO: Add method to retrieve ExtensionType by FieldNumber?
 
-	// IsSupported reports whether this message supports unknown fields.
-	// If false, UnknownFields ignores all Set operations.
-	IsSupported() bool
+	deprecatedMessage
 }
 
 // RawFields is the raw bytes for an ordered sequence of fields.
 // Each field contains both the tag (representing field number and wire type),
 // and also the wire data itself.
-//
-// Once stored, the content of a RawFields must be treated as immutable.
-// The capacity of RawFields may be treated as mutable only for the use-case of
-// appending additional data to store back into UnknownFields.
 type RawFields []byte
 
-// IsValid reports whether RawFields is syntactically correct wire format.
-// All fields must belong to the same field number.
+// IsValid reports whether b is syntactically correct wire format.
 func (b RawFields) IsValid() bool {
-	var want FieldNumber
 	for len(b) > 0 {
-		got, _, n := wire.ConsumeField(b)
-		if n < 0 || (want > 0 && got != want) {
+		_, _, n := wire.ConsumeField(b)
+		if n < 0 {
 			return false
 		}
-		want = got
 		b = b[n:]
 	}
 	return true
 }
 
-// ExtensionFieldTypes are the extension field types that this message instance
-// has been extended with.
-//
-// Len, Get, and Range are safe for concurrent use.
-type ExtensionFieldTypes interface {
-	// Len reports the number of field extensions.
-	Len() int
-
-	// Register stores an ExtensionType.
-	// The ExtensionType.ExtendedType must match the containing message type
-	// and the field number must be within the valid extension ranges
-	// (see MessageDescriptor.ExtensionRanges).
-	// It panics if the extension has already been registered (i.e.,
-	// a conflict by number or by full name).
-	Register(ExtensionType)
-
-	// Remove removes the ExtensionType.
-	// It panics if a value for this extension field is still populated.
-	// The operation does nothing if there is no associated type to remove.
-	Remove(ExtensionType)
-
-	// ByNumber looks up an extension by field number.
-	// It returns nil if not found.
-	ByNumber(FieldNumber) ExtensionType
-
-	// ByName looks up an extension field by full name.
-	// It returns nil if not found.
-	ByName(FullName) ExtensionType
-
-	// Range iterates over every registered field in an undefined order,
-	// calling f for each extension descriptor encountered.
-	// Range calls f Len times unless f returns false, which stops iteration.
-	// While iterating, mutating operations through Remove may only
-	// be performed on the current descriptor.
-	Range(f func(ExtensionType) bool)
-}
-
-// List is an ordered list. Every element is considered populated
-// (i.e., Get never provides and Set never accepts invalid Values).
-// The element Value type is determined by the associated FieldDescriptor.Kind
-// and cannot be a Map or List.
-//
-// Len and Get are safe for concurrent use.
+// List is a zero-indexed, ordered list.
+// The element Value type is determined by FieldDescriptor.Kind.
+// Providing a Value that is invalid or of an incorrect type panics.
 type List interface {
 	// Len reports the number of entries in the List.
-	// Get, Set, Mutable, and Truncate panic with out of bound indexes.
+	// Get, Set, and Truncate panic with out of bound indexes.
 	Len() int
 
 	// Get retrieves the value at the given index.
+	// It never returns an invalid value.
 	Get(int) Value
 
 	// Set stores a value for the given index.
-	//
 	// When setting a composite type, it is unspecified whether the set
 	// value aliases the source's memory in any way.
+	//
+	// Set is a mutating operation and unsafe for concurrent use.
 	Set(int, Value)
 
 	// Append appends the provided value to the end of the list.
-	//
 	// When appending a composite type, it is unspecified whether the appended
 	// value aliases the source's memory in any way.
+	//
+	// Append is a mutating operation and unsafe for concurrent use.
 	Append(Value)
 
+	// TODO: Should there be a Mutable and MutableAppend method?
+
 	// TODO: Should truncate accept two indexes similar to slicing?
 
 	// Truncate truncates the list to a smaller length.
+	//
+	// Truncate is a mutating operation and unsafe for concurrent use.
 	Truncate(int)
 
-	// NewMessage returns a newly allocated empty message assignable to a list entry.
+	// NewMessage returns a newly allocated empty message assignable as a list entry.
 	// It panics if the list entry type is not a message.
 	NewMessage() Message
 }
 
-// Map is an unordered, associative map. Only elements within the map
-// is considered populated. The entry Value type is determined by the associated
-// FieldDescripto.Kind and cannot be a Map or List.
-//
-// Len, Has, Get, and Range are safe for concurrent use.
+// Map is an unordered, associative map.
+// The entry MapKey type is determined by FieldDescriptor.MapKey.Kind.
+// The entry Value type is determined by FieldDescriptor.MapValue.Kind.
+// Providing a MapKey or Value that is invalid or of an incorrect type panics.
 type Map interface {
 	// Len reports the number of elements in the map.
 	Len() int
 
+	// Range iterates over every map entry in an undefined order,
+	// calling f for each key and value encountered.
+	// Range calls f Len times unless f returns false, which stops iteration.
+	// While iterating, mutating operations may only be performed
+	// on the current map key.
+	Range(f func(MapKey, Value) bool)
+
 	// Has reports whether an entry with the given key is in the map.
 	Has(MapKey) bool
 
+	// Clear clears the entry associated with they given key.
+	// The operation does nothing if there is no entry associated with the key.
+	//
+	// Clear is a mutating operation and unsafe for concurrent use.
+	Clear(MapKey)
+
 	// Get retrieves the value for an entry with the given key.
 	// It returns an invalid value for non-existent entries.
 	Get(MapKey) Value
 
 	// Set stores the value for an entry with the given key.
-	//
+	// It panics when given a key or value that is invalid or the wrong type.
 	// When setting a composite type, it is unspecified whether the set
 	// value aliases the source's memory in any way.
 	//
-	// It panics if either the key or value are invalid.
+	// Set is a mutating operation and unsafe for concurrent use.
 	Set(MapKey, Value)
 
-	// Clear clears the entry associated with they given key.
-	// The operation does nothing if there is no entry associated with the key.
-	Clear(MapKey)
-
-	// Range iterates over every map entry in an undefined order,
-	// calling f for each key and value encountered.
-	// Range calls f Len times unless f returns false, which stops iteration.
-	// While iterating, mutating operations through Set, Clear, or Mutable
-	// may only be performed on the current map key.
-	Range(f func(MapKey, Value) bool)
+	// TODO: Should there be a Mutable method?
 
-	// NewMessage returns a newly allocated empty message assignable to a map value.
+	// NewMessage returns a newly allocated empty message assignable as a map value.
 	// It panics if the map value type is not a message.
 	NewMessage() Message
 }

+ 4 - 11
reflect/protoreflect/value_union.go

@@ -12,7 +12,7 @@ import (
 
 // Value is a union where only one Go type may be set at a time.
 // The Value is used to represent all possible values a field may take.
-// The following shows what Go type is used to represent each proto Kind:
+// The following shows which Go type is used to represent each proto Kind:
 //
 //	+------------+-------------------------------------+
 //	| Go type    | Protobuf kind                       |
@@ -27,10 +27,7 @@ import (
 //	| string     | StringKind                          |
 //	| []byte     | BytesKind                           |
 //	| EnumNumber | EnumKind                            |
-//	+------------+-------------------------------------+
 //	| Message    | MessageKind, GroupKind              |
-//	| List       |                                     |
-//	| Map        |                                     |
 //	+------------+-------------------------------------+
 //
 // Multiple protobuf Kinds may be represented by a single Go type if the type
@@ -38,9 +35,9 @@ import (
 // Int64Kind, Sint64Kind, and Sfixed64Kind are all represented by int64,
 // but use different integer encoding methods.
 //
-// The List or Map types are used if the FieldDescriptor.Cardinality of the
-// corresponding field is Repeated and a Map if and only if
-// FieldDescriptor.IsMap is true.
+// The List or Map types are used if the field cardinality is repeated.
+// A field is a List if FieldDescriptor.IsList reports true.
+// A field is a Map if FieldDescriptor.IsMap reports true.
 //
 // Converting to/from a Value and a concrete Go value panics on type mismatch.
 // For example, ValueOf("hello").Int() panics because this attempts to
@@ -63,8 +60,6 @@ type Value value
 // ValueOf returns a Value initialized with the concrete value stored in v.
 // This panics if the type does not match one of the allowed types in the
 // Value union.
-//
-// After calling ValueOf on a []byte, the slice must no longer be mutated.
 func ValueOf(v interface{}) Value {
 	switch v := v.(type) {
 	case nil:
@@ -108,7 +103,6 @@ func (v Value) IsValid() bool {
 }
 
 // Interface returns v as an interface{}.
-// Returned []byte values must not be mutated.
 //
 // Invariant: v == ValueOf(v).Interface()
 func (v Value) Interface() interface{} {
@@ -192,7 +186,6 @@ func (v Value) String() string {
 }
 
 // Bytes returns v as a []byte and panics if the type is not a []byte.
-// The returned slice must not be mutated.
 func (v Value) Bytes() []byte {
 	switch v.typ {
 	case bytesType:

+ 5 - 8
runtime/protoimpl/impl.go

@@ -58,14 +58,11 @@ type (
 	// being a compilation failure (guaranteed by the Go specification).
 	EnforceVersion uint
 
-	MessageInfo = impl.MessageInfo
-	FileBuilder = fileinit.FileBuilder
-
-	// TODO: Change these to more efficient data structures.
-	ExtensionFields = map[int32]impl.ExtensionField
-	UnknownFields   = []byte
-	SizeCache       = int32
-
+	FileBuilder      = fileinit.FileBuilder
+	MessageInfo      = impl.MessageInfo
+	SizeCache        = impl.SizeCache
+	UnknownFields    = impl.UnknownFields
+	ExtensionFields  = impl.ExtensionFields
 	ExtensionFieldV1 = impl.ExtensionField
 )
 

+ 84 - 110
testing/prototest/prototest.go

@@ -17,68 +17,46 @@ import (
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
-// TestMessage runs the provided message through a series of tests
-// exercising the protobuf reflection API.
-func TestMessage(t testing.TB, message proto.Message) {
-	md := message.ProtoReflect().Descriptor()
+// TODO: Test read-only properties of unpopulated composite values.
+// TODO: Test invalid field descriptors or oneof descriptors.
+// TODO: This should test the functionality that can be provided by fast-paths.
 
-	m := message.ProtoReflect().New()
+// TestMessage runs the provided m through a series of tests
+// exercising the protobuf reflection API.
+func TestMessage(t testing.TB, m proto.Message) {
+	md := m.ProtoReflect().Descriptor()
+	m1 := m.ProtoReflect().New()
 	for i := 0; i < md.Fields().Len(); i++ {
 		fd := md.Fields().Get(i)
 		switch {
 		case fd.IsList():
-			testFieldList(t, m, fd)
+			testFieldList(t, m1, fd)
 		case fd.IsMap():
-			testFieldMap(t, m, fd)
+			testFieldMap(t, m1, fd)
 		case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind:
-			testFieldFloat(t, m, fd)
+			testFieldFloat(t, m1, fd)
 		}
-		testField(t, m, fd)
+		testField(t, m1, fd)
 	}
 	for i := 0; i < md.Oneofs().Len(); i++ {
-		testOneof(t, m, md.Oneofs().Get(i))
-	}
-
-	// Test has/get/clear on a non-existent field.
-	for num := pref.FieldNumber(1); ; num++ {
-		if md.Fields().ByNumber(num) != nil {
-			continue
-		}
-		if md.ExtensionRanges().Has(num) {
-			continue
-		}
-		// Field num does not exist.
-		if m.KnownFields().Has(num) {
-			t.Errorf("non-existent field: Has(%v) = true, want false", num)
-		}
-		if v := m.KnownFields().Get(num); v.IsValid() {
-			t.Errorf("non-existent field: Get(%v) = %v, want invalid", num, formatValue(v))
-		}
-		m.KnownFields().Clear(num) // noop
-		break
-	}
-
-	// Test WhichOneof on a non-existent oneof.
-	const invalidName = "invalid-name"
-	if got, want := m.KnownFields().WhichOneof(invalidName), pref.FieldNumber(0); got != want {
-		t.Errorf("non-existent oneof: WhichOneof(%q) = %v, want %v", invalidName, got, want)
+		testOneof(t, m1, md.Oneofs().Get(i))
 	}
 
 	// TODO: Extensions, unknown fields.
 
 	// Test round-trip marshal/unmarshal.
-	m1 := message.ProtoReflect().New().Interface()
-	populateMessage(m1.ProtoReflect(), 1, nil)
-	b, err := proto.Marshal(m1)
+	m2 := m.ProtoReflect().New().Interface()
+	populateMessage(m2.ProtoReflect(), 1, nil)
+	b, err := proto.Marshal(m2)
 	if err != nil {
-		t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m1))
+		t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
 	}
-	m2 := message.ProtoReflect().New().Interface()
-	if err := proto.Unmarshal(b, m2); err != nil {
-		t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m1))
+	m3 := m.ProtoReflect().New().Interface()
+	if err := proto.Unmarshal(b, m3); err != nil {
+		t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
 	}
-	if !proto.Equal(m1, m2) {
-		t.Errorf("round-trip marshal/unmarshal did not preserve message.\nOriginal:\n%v\nNew:\n%v", marshalText(m1), marshalText(m2))
+	if !proto.Equal(m2, m3) {
+		t.Errorf("round-trip marshal/unmarshal did not preserve message\nOriginal:\n%v\nNew:\n%v", marshalText(m2), marshalText(m3))
 	}
 }
 
@@ -87,16 +65,15 @@ func marshalText(m proto.Message) string {
 	return string(b)
 }
 
-// testField exericises set/get/has/clear of a field.
+// testField exercises set/get/has/clear of a field.
 func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
-	num := fd.Number()
 	name := fd.FullName()
-	known := m.KnownFields()
+	num := fd.Number()
 
 	// Set to a non-zero value, the zero value, different non-zero values.
 	for _, n := range []seed{1, 0, minVal, maxVal} {
 		v := newValue(m, fd, n, nil)
-		known.Set(num, v)
+		m.Set(fd, v)
 		wantHas := true
 		if n == 0 {
 			if fd.Syntax() == pref.Proto3 && fd.Message() == nil {
@@ -109,55 +86,55 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
 				wantHas = true
 			}
 		}
-		if got, want := known.Has(num), wantHas; got != want {
-			t.Errorf("after setting %q to %v:\nHas(%v) = %v, want %v", name, formatValue(v), num, got, want)
+		if got, want := m.Has(fd), wantHas; got != want {
+			t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want)
 		}
-		if got, want := known.Get(num), v; !valueEqual(got, want) {
-			t.Errorf("after setting %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+		if got, want := m.Get(fd), v; !valueEqual(got, want) {
+			t.Errorf("after setting %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
 		}
 	}
 
-	known.Clear(num)
-	if got, want := known.Has(num), false; got != want {
-		t.Errorf("after clearing %q:\nHas(%v) = %v, want %v", name, num, got, want)
+	m.Clear(fd)
+	if got, want := m.Has(fd), false; got != want {
+		t.Errorf("after clearing %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
 	}
 	switch {
 	case fd.IsList():
-		if got := known.Get(num); got.List().Len() != 0 {
-			t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
+		if got := m.Get(fd); got.List().Len() != 0 {
+			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
 		}
 	case fd.IsMap():
-		if got := known.Get(num); got.Map().Len() != 0 {
-			t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
+		if got := m.Get(fd); got.Map().Len() != 0 {
+			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
 		}
-	default:
-		if got, want := known.Get(num), fd.Default(); !valueEqual(got, want) {
-			t.Errorf("after clearing %q:\nGet(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
+	case fd.Message() == nil:
+		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
+			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
 		}
 	}
 }
 
 // testFieldMap tests set/get/has/clear of entries in a map field.
 func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
-	num := fd.Number()
 	name := fd.FullName()
-	known := m.KnownFields()
-	known.Clear(num) // start with an empty map
-	mapv := known.Get(num).Map()
+	num := fd.Number()
+
+	m.Clear(fd) // start with an empty map
+	mapv := m.Mutable(fd).Map()
 
 	// Add values.
 	want := make(testMap)
 	for i, n := range []seed{1, 0, minVal, maxVal} {
-		if got, want := known.Has(num), i > 0; got != want {
-			t.Errorf("after inserting %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
+		if got, want := m.Has(fd), i > 0; got != want {
+			t.Errorf("after inserting %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
 		}
 
 		k := newMapKey(fd, n)
 		v := newMapValue(fd, mapv, n, nil)
 		mapv.Set(k, v)
 		want.Set(k, v)
-		if got, want := known.Get(num), pref.ValueOf(want); !valueEqual(got, want) {
-			t.Errorf("after inserting %d elements to %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
+		if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+			t.Errorf("after inserting %d elements to %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
 		}
 	}
 
@@ -166,8 +143,8 @@ func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
 		nv := newMapValue(fd, mapv, 10, nil)
 		mapv.Set(k, nv)
 		want.Set(k, nv)
-		if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
-			t.Errorf("after setting element %v of %q:\nGet(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
+		if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+			t.Errorf("after setting element %v of %q:\nMessage.Get(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
 		}
 		return true
 	})
@@ -176,11 +153,11 @@ func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
 	want.Range(func(k pref.MapKey, v pref.Value) bool {
 		mapv.Clear(k)
 		want.Clear(k)
-		if got, want := known.Has(num), want.Len() > 0; got != want {
-			t.Errorf("after clearing elements of %q:\nHas(%v) = %v, want %v", name, num, got, want)
+		if got, want := m.Has(fd), want.Len() > 0; got != want {
+			t.Errorf("after clearing elements of %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
 		}
-		if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
-			t.Errorf("after clearing elements of %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+		if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+			t.Errorf("after clearing elements of %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
 		}
 		return true
 	})
@@ -188,10 +165,10 @@ func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
 	// Non-existent map keys.
 	missingKey := newMapKey(fd, 1)
 	if got, want := mapv.Has(missingKey), false; got != want {
-		t.Errorf("non-existent map key in %q: Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
+		t.Errorf("non-existent map key in %q: Map.Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
 	}
 	if got, want := mapv.Get(missingKey).IsValid(), false; got != want {
-		t.Errorf("non-existent map key in %q: Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
+		t.Errorf("non-existent map key in %q: Map.Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
 	}
 	mapv.Clear(missingKey) // noop
 }
@@ -214,24 +191,24 @@ func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) {
 
 // testFieldList exercises set/get/append/truncate of values in a list.
 func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
-	num := fd.Number()
 	name := fd.FullName()
-	known := m.KnownFields()
-	known.Clear(num) // start with an empty list
-	list := known.Get(num).List()
+	num := fd.Number()
+
+	m.Clear(fd) // start with an empty list
+	list := m.Mutable(fd).List()
 
 	// Append values.
 	var want pref.List = &testList{}
 	for i, n := range []seed{1, 0, minVal, maxVal} {
-		if got, want := known.Has(num), i > 0; got != want {
-			t.Errorf("after appending %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
+		if got, want := m.Has(fd), i > 0; got != want {
+			t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
 		}
 		v := newListElement(fd, list, n, nil)
 		want.Append(v)
 		list.Append(v)
 
-		if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
-			t.Errorf("after appending %d elements to %q:\nGet(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
+		if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+			t.Errorf("after appending %d elements to %q:\nMessage.Get(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
 		}
 	}
 
@@ -240,8 +217,8 @@ func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
 		v := newListElement(fd, list, seed(i+10), nil)
 		want.Set(i, v)
 		list.Set(i, v)
-		if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
-			t.Errorf("after setting element %d of %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
+		if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+			t.Errorf("after setting element %d of %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
 		}
 	}
 
@@ -250,11 +227,11 @@ func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
 		n := want.Len() - 1
 		want.Truncate(n)
 		list.Truncate(n)
-		if got, want := known.Has(num), want.Len() > 0; got != want {
-			t.Errorf("after truncating %q to %d:\nHas(%v) = %v, want %v", name, n, num, got, want)
+		if got, want := m.Has(fd), want.Len() > 0; got != want {
+			t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want)
 		}
-		if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
-			t.Errorf("after truncating %q to %d:\nGet(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
+		if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+			t.Errorf("after truncating %q to %d:\nMessage.Get(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
 		}
 	}
 }
@@ -272,9 +249,9 @@ func (l *testList) NewMessage() pref.Message { panic("unimplemented") }
 
 // testFieldFloat exercises some interesting floating-point scalar field values.
 func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
-	num := fd.Number()
 	name := fd.FullName()
-	known := m.KnownFields()
+	num := fd.Number()
+
 	for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} {
 		var val pref.Value
 		if fd.Kind() == pref.FloatKind {
@@ -282,29 +259,28 @@ func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
 		} else {
 			val = pref.ValueOf(v)
 		}
-		known.Set(num, val)
+		m.Set(fd, val)
 		// Note that Has is true for -0.
-		if got, want := known.Has(num), true; got != want {
-			t.Errorf("after setting %v to %v: Get(%v) = %v, want %v", name, v, num, got, want)
+		if got, want := m.Has(fd), true; got != want {
+			t.Errorf("after setting %v to %v: Message.Has(%v) = %v, want %v", name, v, num, got, want)
 		}
-		if got, want := known.Get(num), val; !valueEqual(got, want) {
-			t.Errorf("after setting %v: Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+		if got, want := m.Get(fd), val; !valueEqual(got, want) {
+			t.Errorf("after setting %v: Message.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
 		}
 	}
 }
 
 // testOneof tests the behavior of fields in a oneof.
 func testOneof(t testing.TB, m pref.Message, od pref.OneofDescriptor) {
-	known := m.KnownFields()
 	for i := 0; i < od.Fields().Len(); i++ {
 		fda := od.Fields().Get(i)
-		known.Set(fda.Number(), newValue(m, fda, 1, nil))
-		if got, want := known.WhichOneof(od.Name()), fda.Number(); got != want {
+		m.Set(fda, newValue(m, fda, 1, nil))
+		if got, want := m.WhichOneof(od), fda; got != want {
 			t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want)
 		}
 		for j := 0; j < od.Fields().Len(); j++ {
 			fdb := od.Fields().Get(j)
-			if got, want := known.Has(fdb.Number()), i == j; got != want {
+			if got, want := m.Has(fdb), i == j; got != want {
 				t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want)
 			}
 		}
@@ -422,10 +398,9 @@ const (
 // The stack parameter is used to avoid infinite recursion when populating circular
 // data structures.
 func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
-	num := fd.Number()
 	switch {
 	case fd.IsList():
-		list := m.New().KnownFields().Get(num).List()
+		list := m.New().Mutable(fd).List()
 		if n == 0 {
 			return pref.ValueOf(list)
 		}
@@ -435,7 +410,7 @@ func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.Mess
 		list.Append(newListElement(fd, list, n, stack))
 		return pref.ValueOf(list)
 	case fd.IsMap():
-		mapv := m.New().KnownFields().Get(num).Map()
+		mapv := m.New().Mutable(fd).Map()
 		if n == 0 {
 			return pref.ValueOf(mapv)
 		}
@@ -445,7 +420,7 @@ func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.Mess
 		mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, 10*n, stack))
 		return pref.ValueOf(mapv)
 	case fd.Message() != nil:
-		return populateMessage(m.KnownFields().NewMessage(num), n, stack)
+		return populateMessage(m.Mutable(fd).Message(), n, stack)
 	default:
 		return newScalarValue(fd, n)
 	}
@@ -476,7 +451,7 @@ func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
 	case pref.BoolKind:
 		return pref.ValueOf(n != 0)
 	case pref.EnumKind:
-		// TODO use actual value
+		// TODO: use actual value
 		return pref.ValueOf(pref.EnumNumber(n))
 	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
 		switch n {
@@ -559,13 +534,12 @@ func populateMessage(m pref.Message, n seed, stack []pref.MessageDescriptor) pre
 		}
 	}
 	stack = append(stack, md)
-	known := m.KnownFields()
 	for i := 0; i < md.Fields().Len(); i++ {
 		fd := md.Fields().Get(i)
 		if fd.IsWeak() {
 			continue
 		}
-		known.Set(fd.Number(), newValue(m, fd, 10*n+seed(i), stack))
+		m.Set(fd, newValue(m, fd, 10*n+seed(i), stack))
 	}
 	return pref.ValueOf(m)
 }