message_reflect.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "google.golang.org/protobuf/internal/pragma"
  9. pvalue "google.golang.org/protobuf/internal/value"
  10. pref "google.golang.org/protobuf/reflect/protoreflect"
  11. piface "google.golang.org/protobuf/runtime/protoiface"
  12. )
  13. // MessageState is a data structure that is nested as the first field in a
  14. // concrete message. It provides a way to implement the ProtoReflect method
  15. // in an allocation-free way without needing to have a shadow Go type generated
  16. // for every message type. This technique only works using unsafe.
  17. //
  18. //
  19. // Example generated code:
  20. //
  21. // type M struct {
  22. // state protoimpl.MessageState
  23. //
  24. // Field1 int32
  25. // Field2 string
  26. // Field3 *BarMessage
  27. // ...
  28. // }
  29. //
  30. // func (m *M) ProtoReflect() protoreflect.Message {
  31. // mi := &file_fizz_buzz_proto_msgInfos[5]
  32. // if protoimpl.UnsafeEnabled && m != nil {
  33. // ms := protoimpl.X.MessageStateOf(Pointer(m))
  34. // if ms.LoadMessageInfo() == nil {
  35. // ms.StoreMessageInfo(mi)
  36. // }
  37. // return ms
  38. // }
  39. // return mi.MessageOf(m)
  40. // }
  41. //
  42. // The MessageState type holds a *MessageInfo, which must be atomically set to
  43. // the message info associated with a given message instance.
  44. // By unsafely converting a *M into a *MessageState, the MessageState object
  45. // has access to all the information needed to implement protobuf reflection.
  46. // It has access to the message info as its first field, and a pointer to the
  47. // MessageState is identical to a pointer to the concrete message value.
  48. //
  49. //
  50. // Requirements:
  51. // • The type M must implement protoreflect.ProtoMessage.
  52. // • The address of m must not be nil.
  53. // • The address of m and the address of m.state must be equal,
  54. // even though they are different Go types.
  55. type MessageState struct {
  56. pragma.NoUnkeyedLiterals
  57. pragma.DoNotCompare
  58. pragma.DoNotCopy
  59. mi *MessageInfo
  60. }
  61. type messageState MessageState
  62. var (
  63. _ pref.Message = (*messageState)(nil)
  64. _ pvalue.Unwrapper = (*messageState)(nil)
  65. )
  66. // messageDataType is a tuple of a pointer to the message data and
  67. // a pointer to the message type. It is a generalized way of providing a
  68. // reflective view over a message instance. The disadvantage of this approach
  69. // is the need to allocate this tuple of 16B.
  70. type messageDataType struct {
  71. p pointer
  72. mi *MessageInfo
  73. }
  74. type (
  75. messageIfaceWrapper messageDataType
  76. messageReflectWrapper messageDataType
  77. )
  78. var (
  79. _ pref.Message = (*messageReflectWrapper)(nil)
  80. _ pvalue.Unwrapper = (*messageReflectWrapper)(nil)
  81. _ pref.ProtoMessage = (*messageIfaceWrapper)(nil)
  82. _ pvalue.Unwrapper = (*messageIfaceWrapper)(nil)
  83. )
  84. // MessageOf returns a reflective view over a message. The input must be a
  85. // pointer to a named Go struct. If the provided type has a ProtoReflect method,
  86. // it must be implemented by calling this method.
  87. func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
  88. // TODO: Switch the input to be an opaque Pointer.
  89. if reflect.TypeOf(m) != mi.GoType {
  90. panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoType))
  91. }
  92. p := pointerOfIface(m)
  93. if p.IsNil() {
  94. return mi.nilMessage.Init(mi)
  95. }
  96. return &messageReflectWrapper{p, mi}
  97. }
  98. func (m *messageReflectWrapper) pointer() pointer { return m.p }
  99. func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
  100. return (*messageReflectWrapper)(m)
  101. }
  102. func (m *messageIfaceWrapper) XXX_Methods() *piface.Methods {
  103. // TODO: Consider not recreating this on every call.
  104. m.mi.init()
  105. return &piface.Methods{
  106. Flags: piface.MethodFlagDeterministicMarshal,
  107. MarshalAppend: m.marshalAppend,
  108. Unmarshal: m.unmarshal,
  109. Size: m.size,
  110. IsInitialized: m.isInitialized,
  111. }
  112. }
  113. func (m *messageIfaceWrapper) ProtoUnwrap() interface{} {
  114. return m.p.AsIfaceOf(m.mi.GoType.Elem())
  115. }
  116. func (m *messageIfaceWrapper) marshalAppend(b []byte, _ pref.ProtoMessage, opts piface.MarshalOptions) ([]byte, error) {
  117. return m.mi.marshalAppendPointer(b, m.p, newMarshalOptions(opts))
  118. }
  119. func (m *messageIfaceWrapper) unmarshal(b []byte, _ pref.ProtoMessage, opts piface.UnmarshalOptions) error {
  120. _, err := m.mi.unmarshalPointer(b, m.p, 0, newUnmarshalOptions(opts))
  121. return err
  122. }
  123. func (m *messageIfaceWrapper) size(msg pref.ProtoMessage) (size int) {
  124. return m.mi.sizePointer(m.p, 0)
  125. }
  126. func (m *messageIfaceWrapper) isInitialized(_ pref.ProtoMessage) error {
  127. return m.mi.isInitializedPointer(m.p)
  128. }
  129. type extensionMap map[int32]ExtensionField
  130. func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
  131. if m != nil {
  132. for _, x := range *m {
  133. xt := x.GetType()
  134. if !f(xt, xt.ValueOf(x.GetValue())) {
  135. return
  136. }
  137. }
  138. }
  139. }
  140. func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
  141. if m != nil {
  142. _, ok = (*m)[int32(xt.Number())]
  143. }
  144. return ok
  145. }
  146. func (m *extensionMap) Clear(xt pref.ExtensionType) {
  147. delete(*m, int32(xt.Number()))
  148. }
  149. func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
  150. if m != nil {
  151. if x, ok := (*m)[int32(xt.Number())]; ok {
  152. return xt.ValueOf(x.GetValue())
  153. }
  154. }
  155. if !isComposite(xt) {
  156. return defaultValueOf(xt)
  157. }
  158. return frozenValueOf(xt.New())
  159. }
  160. func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
  161. if *m == nil {
  162. *m = make(map[int32]ExtensionField)
  163. }
  164. var x ExtensionField
  165. x.SetType(xt)
  166. x.SetEagerValue(xt.InterfaceOf(v))
  167. (*m)[int32(xt.Number())] = x
  168. }
  169. func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
  170. if !isComposite(xt) {
  171. panic("invalid Mutable on field with non-composite type")
  172. }
  173. if x, ok := (*m)[int32(xt.Number())]; ok {
  174. return xt.ValueOf(x.GetValue())
  175. }
  176. v := xt.New()
  177. m.Set(xt, v)
  178. return v
  179. }
  180. func isComposite(fd pref.FieldDescriptor) bool {
  181. return fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind || fd.IsList() || fd.IsMap()
  182. }
  183. // checkField verifies that the provided field descriptor is valid.
  184. // Exactly one of the returned values is populated.
  185. func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
  186. if fi := mi.fields[fd.Number()]; fi != nil {
  187. if fi.fieldDesc != fd {
  188. panic("mismatching field descriptor")
  189. }
  190. return fi, nil
  191. }
  192. if fd.IsExtension() {
  193. if fd.ContainingMessage().FullName() != mi.PBType.FullName() {
  194. // TODO: Should this be exact containing message descriptor match?
  195. panic("mismatching containing message")
  196. }
  197. if !mi.PBType.ExtensionRanges().Has(fd.Number()) {
  198. panic("invalid extension field")
  199. }
  200. return nil, fd.(pref.ExtensionType)
  201. }
  202. panic("invalid field descriptor")
  203. }