message_reflect.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "google.golang.org/protobuf/internal/pragma"
  9. pref "google.golang.org/protobuf/reflect/protoreflect"
  10. )
  11. // MessageState is a data structure that is nested as the first field in a
  12. // concrete message. It provides a way to implement the ProtoReflect method
  13. // in an allocation-free way without needing to have a shadow Go type generated
  14. // for every message type. This technique only works using unsafe.
  15. //
  16. //
  17. // Example generated code:
  18. //
  19. // type M struct {
  20. // state protoimpl.MessageState
  21. //
  22. // Field1 int32
  23. // Field2 string
  24. // Field3 *BarMessage
  25. // ...
  26. // }
  27. //
  28. // func (m *M) ProtoReflect() protoreflect.Message {
  29. // mi := &file_fizz_buzz_proto_msgInfos[5]
  30. // if protoimpl.UnsafeEnabled && m != nil {
  31. // ms := protoimpl.X.MessageStateOf(Pointer(m))
  32. // if ms.LoadMessageInfo() == nil {
  33. // ms.StoreMessageInfo(mi)
  34. // }
  35. // return ms
  36. // }
  37. // return mi.MessageOf(m)
  38. // }
  39. //
  40. // The MessageState type holds a *MessageInfo, which must be atomically set to
  41. // the message info associated with a given message instance.
  42. // By unsafely converting a *M into a *MessageState, the MessageState object
  43. // has access to all the information needed to implement protobuf reflection.
  44. // It has access to the message info as its first field, and a pointer to the
  45. // MessageState is identical to a pointer to the concrete message value.
  46. //
  47. //
  48. // Requirements:
  49. // • The type M must implement protoreflect.ProtoMessage.
  50. // • The address of m must not be nil.
  51. // • The address of m and the address of m.state must be equal,
  52. // even though they are different Go types.
  53. type MessageState struct {
  54. pragma.NoUnkeyedLiterals
  55. pragma.DoNotCompare
  56. pragma.DoNotCopy
  57. mi *MessageInfo
  58. }
  59. type messageState MessageState
  60. var (
  61. _ pref.Message = (*messageState)(nil)
  62. _ Unwrapper = (*messageState)(nil)
  63. )
  64. // messageDataType is a tuple of a pointer to the message data and
  65. // a pointer to the message type. It is a generalized way of providing a
  66. // reflective view over a message instance. The disadvantage of this approach
  67. // is the need to allocate this tuple of 16B.
  68. type messageDataType struct {
  69. p pointer
  70. mi *MessageInfo
  71. }
  72. type (
  73. messageReflectWrapper messageDataType
  74. messageIfaceWrapper messageDataType
  75. )
  76. var (
  77. _ pref.Message = (*messageReflectWrapper)(nil)
  78. _ Unwrapper = (*messageReflectWrapper)(nil)
  79. _ pref.ProtoMessage = (*messageIfaceWrapper)(nil)
  80. _ Unwrapper = (*messageIfaceWrapper)(nil)
  81. )
  82. // MessageOf returns a reflective view over a message. The input must be a
  83. // pointer to a named Go struct. If the provided type has a ProtoReflect method,
  84. // it must be implemented by calling this method.
  85. func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
  86. // TODO: Switch the input to be an opaque Pointer.
  87. if reflect.TypeOf(m) != mi.GoType {
  88. panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoType))
  89. }
  90. p := pointerOfIface(m)
  91. if p.IsNil() {
  92. return mi.nilMessage.Init(mi)
  93. }
  94. return &messageReflectWrapper{p, mi}
  95. }
  96. func (m *messageReflectWrapper) pointer() pointer { return m.p }
  97. func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
  98. func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
  99. return (*messageReflectWrapper)(m)
  100. }
  101. func (m *messageIfaceWrapper) ProtoUnwrap() interface{} {
  102. return m.p.AsIfaceOf(m.mi.GoType.Elem())
  103. }
  104. type extensionMap map[int32]ExtensionField
  105. func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
  106. if m != nil {
  107. for _, x := range *m {
  108. xt := x.GetType()
  109. if !f(xt.Descriptor(), xt.ValueOf(x.GetValue())) {
  110. return
  111. }
  112. }
  113. }
  114. }
  115. func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
  116. if m != nil {
  117. _, ok = (*m)[int32(xt.Descriptor().Number())]
  118. }
  119. return ok
  120. }
  121. func (m *extensionMap) Clear(xt pref.ExtensionType) {
  122. delete(*m, int32(xt.Descriptor().Number()))
  123. }
  124. func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
  125. xd := xt.Descriptor()
  126. if m != nil {
  127. if x, ok := (*m)[int32(xd.Number())]; ok {
  128. return xt.ValueOf(x.GetValue())
  129. }
  130. }
  131. return xt.Zero()
  132. }
  133. func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
  134. if *m == nil {
  135. *m = make(map[int32]ExtensionField)
  136. }
  137. var x ExtensionField
  138. x.SetType(xt)
  139. x.SetEagerValue(xt.InterfaceOf(v))
  140. (*m)[int32(xt.Descriptor().Number())] = x
  141. }
  142. func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
  143. xd := xt.Descriptor()
  144. if !isComposite(xd) {
  145. panic("invalid Mutable on field with non-composite type")
  146. }
  147. if x, ok := (*m)[int32(xd.Number())]; ok {
  148. return xt.ValueOf(x.GetValue())
  149. }
  150. v := xt.New()
  151. m.Set(xt, v)
  152. return v
  153. }
  154. func isComposite(fd pref.FieldDescriptor) bool {
  155. return fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind || fd.IsList() || fd.IsMap()
  156. }
  157. // checkField verifies that the provided field descriptor is valid.
  158. // Exactly one of the returned values is populated.
  159. func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
  160. if fi := mi.fields[fd.Number()]; fi != nil {
  161. if fi.fieldDesc != fd {
  162. panic("mismatching field descriptor")
  163. }
  164. return fi, nil
  165. }
  166. if fd.IsExtension() {
  167. if fd.ContainingMessage().FullName() != mi.PBType.Descriptor().FullName() {
  168. // TODO: Should this be exact containing message descriptor match?
  169. panic("mismatching containing message")
  170. }
  171. if !mi.PBType.Descriptor().ExtensionRanges().Has(fd.Number()) {
  172. panic("invalid extension field")
  173. }
  174. xtd, ok := fd.(pref.ExtensionTypeDescriptor)
  175. if !ok {
  176. panic("extension descriptor does not implement ExtensionTypeDescriptor")
  177. }
  178. return nil, xtd.Type()
  179. }
  180. panic("invalid field descriptor")
  181. }