Bläddra i källkod

all: make handling of zero-value composites more consistent

We occasionally need to work with immutable, empty lists, maps, and
messages. Notably, Message.Get on an empty repeated field will return a
"frozen" empty value.

Move handling of these immutable, zero-length composites into Converter,
to unify the behavior of regular and extension fields.

Add a Zero method to Converter, MessageType, and ExtensionType, to
provide a consistent way to get an empty, frozen value of a composite
type. Adding this method to the public {Message,Extension}Type
interfaces does increase our API surface, but lets us (for example)
cleanly represent an empty map as a nil map rather than a non-nil
one wrapped in a frozenMap type.

Drop the frozen{List,Map,Message} types as no longer necessary.
(These types did have support for creating a read-only view of a
non-empty value, but we are not currently using that feature.)

Change-Id: Ia76f149d591da07b40ce75b7404a7ab8a60cb9d8
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189339
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Damien Neil 6 år sedan
förälder
incheckning
d4f0800c42

+ 2 - 1
internal/filetype/build.go

@@ -346,7 +346,8 @@ type Extension struct {
 	conv   pimpl.Converter
 	conv   pimpl.Converter
 }
 }
 
 
-func (t *Extension) New() pref.Value { return t.lazyInit().New() }
+func (t *Extension) New() pref.Value  { return t.lazyInit().New() }
+func (t *Extension) Zero() pref.Value { return t.lazyInit().Zero() }
 func (t *Extension) ValueOf(v interface{}) pref.Value {
 func (t *Extension) ValueOf(v interface{}) pref.Value {
 	return t.lazyInit().PBValueOf(reflect.ValueOf(v))
 	return t.lazyInit().PBValueOf(reflect.ValueOf(v))
 }
 }

+ 24 - 0
internal/impl/convert.go

@@ -19,9 +19,21 @@ type Unwrapper interface {
 
 
 // A Converter coverts to/from Go reflect.Value types and protobuf protoreflect.Value types.
 // A Converter coverts to/from Go reflect.Value types and protobuf protoreflect.Value types.
 type Converter interface {
 type Converter interface {
+	// PBValueOf converts a reflect.Value to a protoreflect.Value.
 	PBValueOf(reflect.Value) pref.Value
 	PBValueOf(reflect.Value) pref.Value
+
+	// GoValueOf converts a protoreflect.Value to a reflect.Value.
 	GoValueOf(pref.Value) reflect.Value
 	GoValueOf(pref.Value) reflect.Value
+
+	// New returns a new field value.
+	// For scalars, it returns the default value of the field.
+	// For composite types, it returns a new mutable value.
 	New() pref.Value
 	New() pref.Value
+
+	// Zero returns a new field value.
+	// For scalars, it returns the default value of the field.
+	// For composite types, it returns an immutable, empty value.
+	Zero() pref.Value
 }
 }
 
 
 // NewConverter matches a Go type with a protobuf field and returns a Converter
 // NewConverter matches a Go type with a protobuf field and returns a Converter
@@ -158,6 +170,10 @@ func (c *scalarConverter) New() pref.Value {
 	return c.def
 	return c.def
 }
 }
 
 
+func (c *scalarConverter) Zero() pref.Value {
+	return c.New()
+}
+
 type enumConverter struct {
 type enumConverter struct {
 	goType reflect.Type
 	goType reflect.Type
 	def    pref.Value
 	def    pref.Value
@@ -188,6 +204,10 @@ func (c *enumConverter) New() pref.Value {
 	return c.def
 	return c.def
 }
 }
 
 
+func (c *enumConverter) Zero() pref.Value {
+	return c.def
+}
+
 type messageConverter struct {
 type messageConverter struct {
 	goType reflect.Type
 	goType reflect.Type
 }
 }
@@ -223,3 +243,7 @@ func (c *messageConverter) GoValueOf(v pref.Value) reflect.Value {
 func (c *messageConverter) New() pref.Value {
 func (c *messageConverter) New() pref.Value {
 	return c.PBValueOf(reflect.New(c.goType.Elem()))
 	return c.PBValueOf(reflect.New(c.goType.Elem()))
 }
 }
+
+func (c *messageConverter) Zero() pref.Value {
+	return c.PBValueOf(reflect.Zero(c.goType))
+}

+ 4 - 0
internal/impl/convert_list.go

@@ -38,6 +38,10 @@ func (c *listConverter) New() pref.Value {
 	return c.PBValueOf(reflect.New(c.goType.Elem()))
 	return c.PBValueOf(reflect.New(c.goType.Elem()))
 }
 }
 
 
+func (c *listConverter) Zero() pref.Value {
+	return c.PBValueOf(reflect.Zero(c.goType))
+}
+
 type listReflect struct {
 type listReflect struct {
 	v    reflect.Value // *[]T
 	v    reflect.Value // *[]T
 	conv Converter
 	conv Converter

+ 4 - 0
internal/impl/convert_map.go

@@ -42,6 +42,10 @@ func (c *mapConverter) New() pref.Value {
 	return c.PBValueOf(reflect.MakeMap(c.goType))
 	return c.PBValueOf(reflect.MakeMap(c.goType))
 }
 }
 
 
+func (c *mapConverter) Zero() pref.Value {
+	return c.PBValueOf(reflect.Zero(c.goType))
+}
+
 type mapReflect struct {
 type mapReflect struct {
 	v       reflect.Value // map[K]V
 	v       reflect.Value // map[K]V
 	keyConv Converter
 	keyConv Converter

+ 1 - 0
internal/impl/legacy_extension.go

@@ -232,6 +232,7 @@ type legacyExtensionType struct {
 
 
 func (x *legacyExtensionType) GoType() reflect.Type { return x.typ }
 func (x *legacyExtensionType) GoType() reflect.Type { return x.typ }
 func (x *legacyExtensionType) New() pref.Value      { return x.conv.New() }
 func (x *legacyExtensionType) New() pref.Value      { return x.conv.New() }
+func (x *legacyExtensionType) Zero() pref.Value     { return x.conv.Zero() }
 func (x *legacyExtensionType) ValueOf(v interface{}) pref.Value {
 func (x *legacyExtensionType) ValueOf(v interface{}) pref.Value {
 	return x.conv.PBValueOf(reflect.ValueOf(v))
 	return x.conv.PBValueOf(reflect.ValueOf(v))
 }
 }

+ 2 - 2
internal/impl/legacy_test.go

@@ -491,8 +491,8 @@ func TestExtensionConvert(t *testing.T) {
 							switch name {
 							switch name {
 							case "ParentFile", "Parent":
 							case "ParentFile", "Parent":
 							// Ignore parents to avoid recursive cycle.
 							// Ignore parents to avoid recursive cycle.
-							case "New":
-								// Ignore New since it a constructor.
+							case "New", "Zero":
+								// Ignore constructors.
 							case "Options":
 							case "Options":
 								// Ignore descriptor options since protos are not cmperable.
 								// Ignore descriptor options since protos are not cmperable.
 							case "ContainingOneof", "ContainingMessage", "Enum", "Message":
 							case "ContainingOneof", "ContainingMessage", "Enum", "Message":

+ 9 - 106
internal/impl/message_field.go

@@ -42,10 +42,6 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x export
 	}
 	}
 	conv := NewConverter(ot.Field(0).Type, fd)
 	conv := NewConverter(ot.Field(0).Type, fd)
 	isMessage := fd.Message() != nil
 	isMessage := fd.Message() != nil
-	var frozenEmpty pref.Value
-	if isMessage {
-		frozenEmpty = pref.ValueOf(frozenMessage{conv.New().Message()})
-	}
 
 
 	// TODO: Implement unsafe fast path?
 	// TODO: Implement unsafe fast path?
 	fieldOffset := offsetOf(fs, x)
 	fieldOffset := offsetOf(fs, x)
@@ -74,17 +70,11 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x export
 		},
 		},
 		get: func(p pointer) pref.Value {
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
 			if p.IsNil() {
-				if frozenEmpty.IsValid() {
-					return frozenEmpty
-				}
-				return defaultValueOf(fd)
+				return conv.Zero()
 			}
 			}
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
 			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
-				if frozenEmpty.IsValid() {
-					return frozenEmpty
-				}
-				return defaultValueOf(fd)
+				return conv.Zero()
 			}
 			}
 			rv = rv.Elem().Elem().Field(0)
 			rv = rv.Elem().Elem().Field(0)
 			return conv.PBValueOf(rv)
 			return conv.PBValueOf(rv)
@@ -126,7 +116,6 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField, x exporter
 		panic(fmt.Sprintf("invalid type: got %v, want map kind", ft))
 		panic(fmt.Sprintf("invalid type: got %v, want map kind", ft))
 	}
 	}
 	conv := NewConverter(ft, fd)
 	conv := NewConverter(ft, fd)
-	frozenEmpty := pref.ValueOf(frozenMap{conv.New().Map()})
 
 
 	// TODO: Implement unsafe fast path?
 	// TODO: Implement unsafe fast path?
 	fieldOffset := offsetOf(fs, x)
 	fieldOffset := offsetOf(fs, x)
@@ -145,12 +134,9 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField, x exporter
 		},
 		},
 		get: func(p pointer) pref.Value {
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
 			if p.IsNil() {
-				return frozenEmpty
+				return conv.Zero()
 			}
 			}
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
-			if rv.Len() == 0 {
-				return frozenEmpty
-			}
 			return conv.PBValueOf(rv)
 			return conv.PBValueOf(rv)
 		},
 		},
 		set: func(p pointer, v pref.Value) {
 		set: func(p pointer, v pref.Value) {
@@ -176,7 +162,6 @@ func fieldInfoForList(fd pref.FieldDescriptor, fs reflect.StructField, x exporte
 		panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
 		panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
 	}
 	}
 	conv := NewConverter(reflect.PtrTo(ft), fd)
 	conv := NewConverter(reflect.PtrTo(ft), fd)
-	frozenEmpty := pref.ValueOf(frozenList{conv.New().List()})
 
 
 	// TODO: Implement unsafe fast path?
 	// TODO: Implement unsafe fast path?
 	fieldOffset := offsetOf(fs, x)
 	fieldOffset := offsetOf(fs, x)
@@ -195,12 +180,9 @@ func fieldInfoForList(fd pref.FieldDescriptor, fs reflect.StructField, x exporte
 		},
 		},
 		get: func(p pointer) pref.Value {
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
 			if p.IsNil() {
-				return frozenEmpty
+				return conv.Zero()
 			}
 			}
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type)
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type)
-			if rv.Elem().Len() == 0 {
-				return frozenEmpty
-			}
 			return conv.PBValueOf(rv)
 			return conv.PBValueOf(rv)
 		},
 		},
 		set: func(p pointer, v pref.Value) {
 		set: func(p pointer, v pref.Value) {
@@ -269,12 +251,12 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField, x expor
 		},
 		},
 		get: func(p pointer) pref.Value {
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
 			if p.IsNil() {
-				return defaultValueOf(fd)
+				return conv.Zero()
 			}
 			}
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			if nullable {
 			if nullable {
 				if rv.IsNil() {
 				if rv.IsNil() {
-					return defaultValueOf(fd)
+					return conv.Zero()
 				}
 				}
 				if rv.Kind() == reflect.Ptr {
 				if rv.Kind() == reflect.Ptr {
 					rv = rv.Elem()
 					rv = rv.Elem()
@@ -312,7 +294,6 @@ func fieldInfoForWeakMessage(fd pref.FieldDescriptor, weakOffset offset) fieldIn
 
 
 	var once sync.Once
 	var once sync.Once
 	var messageType pref.MessageType
 	var messageType pref.MessageType
-	var frozenEmpty pref.Value
 	lazyInit := func() {
 	lazyInit := func() {
 		once.Do(func() {
 		once.Do(func() {
 			messageName := fd.Message().FullName()
 			messageName := fd.Message().FullName()
@@ -320,7 +301,6 @@ func fieldInfoForWeakMessage(fd pref.FieldDescriptor, weakOffset offset) fieldIn
 			if messageType == nil {
 			if messageType == nil {
 				panic(fmt.Sprintf("weak message %v is not linked in", messageName))
 				panic(fmt.Sprintf("weak message %v is not linked in", messageName))
 			}
 			}
-			frozenEmpty = pref.ValueOf(frozenMessage{messageType.New()})
 		})
 		})
 	}
 	}
 
 
@@ -342,12 +322,12 @@ func fieldInfoForWeakMessage(fd pref.FieldDescriptor, weakOffset offset) fieldIn
 		get: func(p pointer) pref.Value {
 		get: func(p pointer) pref.Value {
 			lazyInit()
 			lazyInit()
 			if p.IsNil() {
 			if p.IsNil() {
-				return frozenEmpty
+				return pref.ValueOf(messageType.Zero())
 			}
 			}
 			fs := p.Apply(weakOffset).WeakFields()
 			fs := p.Apply(weakOffset).WeakFields()
 			m, ok := (*fs)[num]
 			m, ok := (*fs)[num]
 			if !ok {
 			if !ok {
-				return frozenEmpty
+				return pref.ValueOf(messageType.Zero())
 			}
 			}
 			return pref.ValueOf(m.(pref.ProtoMessage).ProtoReflect())
 			return pref.ValueOf(m.(pref.ProtoMessage).ProtoReflect())
 		},
 		},
@@ -390,7 +370,6 @@ func fieldInfoForWeakMessage(fd pref.FieldDescriptor, weakOffset offset) fieldIn
 func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
 func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
 	ft := fs.Type
 	ft := fs.Type
 	conv := NewConverter(ft, fd)
 	conv := NewConverter(ft, fd)
-	frozenEmpty := pref.ValueOf(frozenMessage{conv.New().Message()})
 
 
 	// TODO: Implement unsafe fast path?
 	// TODO: Implement unsafe fast path?
 	fieldOffset := offsetOf(fs, x)
 	fieldOffset := offsetOf(fs, x)
@@ -409,12 +388,9 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField, x expo
 		},
 		},
 		get: func(p pointer) pref.Value {
 		get: func(p pointer) pref.Value {
 			if p.IsNil() {
 			if p.IsNil() {
-				return frozenEmpty
+				return conv.Zero()
 			}
 			}
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
 			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
-			if rv.IsNil() {
-				return frozenEmpty
-			}
 			return conv.PBValueOf(rv)
 			return conv.PBValueOf(rv)
 		},
 		},
 		set: func(p pointer, v pref.Value) {
 		set: func(p pointer, v pref.Value) {
@@ -461,76 +437,3 @@ func makeOneofInfo(od pref.OneofDescriptor, fs reflect.StructField, x exporter,
 		},
 		},
 	}
 	}
 }
 }
-
-// defaultValueOf returns the default value for the field.
-func defaultValueOf(fd pref.FieldDescriptor) pref.Value {
-	if fd == nil {
-		return pref.Value{}
-	}
-	pv := fd.Default() // invalid Value for messages and repeated fields
-	if fd.Kind() == pref.BytesKind && pv.IsValid() && len(pv.Bytes()) > 0 {
-		return pref.ValueOf(append([]byte(nil), pv.Bytes()...)) // copy default bytes for safety
-	}
-	return pv
-}
-
-// frozenValueOf returns a frozen version of any composite value.
-func frozenValueOf(v pref.Value) pref.Value {
-	switch v := v.Interface().(type) {
-	case pref.Message:
-		if _, ok := v.(frozenMessage); !ok {
-			return pref.ValueOf(frozenMessage{v})
-		}
-	case pref.List:
-		if _, ok := v.(frozenList); !ok {
-			return pref.ValueOf(frozenList{v})
-		}
-	case pref.Map:
-		if _, ok := v.(frozenMap); !ok {
-			return pref.ValueOf(frozenMap{v})
-		}
-	}
-	return v
-}
-
-type frozenMessage struct{ pref.Message }
-
-func (m frozenMessage) ProtoReflect() pref.Message   { return m }
-func (m frozenMessage) Interface() pref.ProtoMessage { return m }
-func (m frozenMessage) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
-	m.Message.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
-		return f(fd, frozenValueOf(v))
-	})
-}
-func (m frozenMessage) Get(fd pref.FieldDescriptor) pref.Value {
-	v := m.Message.Get(fd)
-	return frozenValueOf(v)
-}
-func (frozenMessage) Clear(pref.FieldDescriptor)              { panic("invalid on read-only Message") }
-func (frozenMessage) Set(pref.FieldDescriptor, pref.Value)    { panic("invalid on read-only Message") }
-func (frozenMessage) Mutable(pref.FieldDescriptor) pref.Value { panic("invalid on read-only Message") }
-func (frozenMessage) SetUnknown(pref.RawFields)               { panic("invalid on read-only Message") }
-
-type frozenList struct{ pref.List }
-
-func (ls frozenList) Get(i int) pref.Value {
-	v := ls.List.Get(i)
-	return frozenValueOf(v)
-}
-func (frozenList) Set(i int, v pref.Value) { panic("invalid on read-only List") }
-func (frozenList) Append(v pref.Value)     { panic("invalid on read-only List") }
-func (frozenList) Truncate(i int)          { panic("invalid on read-only List") }
-
-type frozenMap struct{ pref.Map }
-
-func (ms frozenMap) Get(k pref.MapKey) pref.Value {
-	v := ms.Map.Get(k)
-	return frozenValueOf(v)
-}
-func (ms frozenMap) Range(f func(pref.MapKey, pref.Value) bool) {
-	ms.Map.Range(func(k pref.MapKey, v pref.Value) bool {
-		return f(k, frozenValueOf(v))
-	})
-}
-func (frozenMap) Set(k pref.MapKey, v pref.Value) { panic("invalid n read-only Map") }
-func (frozenMap) Clear(k pref.MapKey)             { panic("invalid on read-only Map") }

+ 1 - 4
internal/impl/message_reflect.go

@@ -142,10 +142,7 @@ func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
 			return xt.ValueOf(x.GetValue())
 			return xt.ValueOf(x.GetValue())
 		}
 		}
 	}
 	}
-	if !isComposite(xt) {
-		return defaultValueOf(xt)
-	}
-	return frozenValueOf(xt.New())
+	return xt.Zero()
 }
 }
 func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
 func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
 	if *m == nil {
 	if *m == nil {

+ 8 - 0
reflect/protoreflect/type.go

@@ -234,6 +234,9 @@ type MessageType interface {
 	// New returns a newly allocated empty message.
 	// New returns a newly allocated empty message.
 	New() Message
 	New() Message
 
 
+	// Zero returns an immutable empty message.
+	Zero() Message
+
 	// GoType returns the Go type of the allocated message.
 	// GoType returns the Go type of the allocated message.
 	//
 	//
 	// Invariant: t.GoType() == reflect.TypeOf(t.New().Interface())
 	// Invariant: t.GoType() == reflect.TypeOf(t.New().Interface())
@@ -439,6 +442,11 @@ type ExtensionType interface {
 	// For scalars, this returns the default value in native Go form.
 	// For scalars, this returns the default value in native Go form.
 	New() Value
 	New() Value
 
 
+	// Zero returns a new value for the field.
+	// For scalars, this returns the default value in native Go form.
+	// For composite types, this returns an empty, immutable message, list, or map.
+	Zero() Value
+
 	// GoType returns the Go type of the field value.
 	// GoType returns the Go type of the field value.
 	//
 	//
 	// Invariants:
 	// Invariants:

+ 4 - 0
reflect/prototype/type.go

@@ -95,6 +95,10 @@ func (t *Message) New() protoreflect.Message {
 	return m
 	return m
 }
 }
 
 
+func (t *Message) Zero() protoreflect.Message {
+	return t.New() // TODO: return a read-only message instead
+}
+
 func (t *Message) GoType() reflect.Type {
 func (t *Message) GoType() reflect.Type {
 	t.New() // initialize t.goType
 	t.New() // initialize t.goType
 	return t.goType
 	return t.goType