legacy_message.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "math"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "unicode"
  13. protoV1 "github.com/golang/protobuf/proto"
  14. descriptorV1 "github.com/golang/protobuf/protoc-gen-go/descriptor"
  15. "github.com/golang/protobuf/v2/internal/encoding/text"
  16. pref "github.com/golang/protobuf/v2/reflect/protoreflect"
  17. ptype "github.com/golang/protobuf/v2/reflect/prototype"
  18. )
  19. var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
  20. // loadMessageDesc returns an MessageDescriptor derived from the Go type,
  21. // which must be an *struct kind and not implement the v2 API already.
  22. func loadMessageDesc(t reflect.Type) pref.MessageDescriptor {
  23. return messageDescSet{}.Load(t)
  24. }
  25. type messageDescSet struct {
  26. visited map[reflect.Type]*ptype.StandaloneMessage
  27. descs []*ptype.StandaloneMessage
  28. types []reflect.Type
  29. }
  30. func (ms messageDescSet) Load(t reflect.Type) pref.MessageDescriptor {
  31. // Fast-path: check if a MessageDescriptor is cached for this concrete type.
  32. if mi, ok := messageDescCache.Load(t); ok {
  33. return mi.(pref.MessageDescriptor)
  34. }
  35. // Slow-path: initialize MessageDescriptor from the Go type.
  36. // Processing t recursively populates descs and types with all sub-messages.
  37. // The descriptor for the first type is guaranteed to be at the front.
  38. ms.processMessage(t)
  39. // Within a proto file it is possible for cyclic dependencies to exist
  40. // between multiple message types. When these cases arise, the set of
  41. // message descriptors must be created together.
  42. mds, err := ptype.NewMessages(ms.descs)
  43. if err != nil {
  44. panic(err)
  45. }
  46. for i, md := range mds {
  47. // Protobuf semantics represents map entries under-the-hood as
  48. // pseudo-messages (has a descriptor, but no generated Go type).
  49. // Avoid caching these fake messages.
  50. if t := ms.types[i]; t.Kind() != reflect.Map {
  51. messageDescCache.Store(t, md)
  52. }
  53. }
  54. return mds[0]
  55. }
  56. func (ms *messageDescSet) processMessage(t reflect.Type) pref.MessageDescriptor {
  57. // Fast-path: Obtain a placeholder if the message is already processed.
  58. if m, ok := ms.visited[t]; ok {
  59. return ptype.PlaceholderMessage(m.FullName)
  60. }
  61. // Slow-path: Walk over the struct fields to derive the message descriptor.
  62. if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct {
  63. panic(fmt.Sprintf("got %v, want *struct kind", t))
  64. }
  65. // Derive name and syntax from the raw descriptor.
  66. m := new(ptype.StandaloneMessage)
  67. mv := reflect.New(t.Elem()).Interface()
  68. if _, ok := mv.(pref.ProtoMessage); ok {
  69. panic(fmt.Sprintf("%v already implements proto.Message", t))
  70. }
  71. if md, ok := mv.(legacyMessage); ok {
  72. b, idxs := md.Descriptor()
  73. fd := loadFileDesc(b)
  74. // Derive syntax.
  75. switch fd.GetSyntax() {
  76. case "proto2", "":
  77. m.Syntax = pref.Proto2
  78. case "proto3":
  79. m.Syntax = pref.Proto3
  80. }
  81. // Derive full name.
  82. md := fd.MessageType[idxs[0]]
  83. m.FullName = pref.FullName(fd.GetPackage()).Append(pref.Name(md.GetName()))
  84. for _, i := range idxs[1:] {
  85. md = md.NestedType[i]
  86. m.FullName = m.FullName.Append(pref.Name(md.GetName()))
  87. }
  88. } else {
  89. // If the type does not implement legacyMessage, then the only way to
  90. // obtain the full name is through the registry. However, this is
  91. // unreliable as some generated messages register with a fork of
  92. // golang/protobuf, so the registry may not have this information.
  93. m.FullName = deriveFullName(t.Elem())
  94. m.Syntax = pref.Proto2
  95. // Try to determine if the message is using proto3 by checking scalars.
  96. for i := 0; i < t.Elem().NumField(); i++ {
  97. f := t.Elem().Field(i)
  98. if tag := f.Tag.Get("protobuf"); tag != "" {
  99. switch f.Type.Kind() {
  100. case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
  101. m.Syntax = pref.Proto3
  102. }
  103. for _, s := range strings.Split(tag, ",") {
  104. if s == "proto3" {
  105. m.Syntax = pref.Proto3
  106. }
  107. }
  108. }
  109. }
  110. }
  111. ms.visit(m, t)
  112. // Obtain a list of oneof wrapper types.
  113. var oneofWrappers []reflect.Type
  114. if fn, ok := t.MethodByName("XXX_OneofFuncs"); ok {
  115. vs := fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[3]
  116. for _, v := range vs.Interface().([]interface{}) {
  117. oneofWrappers = append(oneofWrappers, reflect.TypeOf(v))
  118. }
  119. }
  120. // Obtain a list of the extension ranges.
  121. if fn, ok := t.MethodByName("ExtensionRangeArray"); ok {
  122. vs := fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[0]
  123. for i := 0; i < vs.Len(); i++ {
  124. v := vs.Index(i)
  125. m.ExtensionRanges = append(m.ExtensionRanges, [2]pref.FieldNumber{
  126. pref.FieldNumber(v.FieldByName("Start").Int()),
  127. pref.FieldNumber(v.FieldByName("End").Int() + 1),
  128. })
  129. }
  130. }
  131. // Derive the message fields by inspecting the struct fields.
  132. for i := 0; i < t.Elem().NumField(); i++ {
  133. f := t.Elem().Field(i)
  134. if tag := f.Tag.Get("protobuf"); tag != "" {
  135. tagKey := f.Tag.Get("protobuf_key")
  136. tagVal := f.Tag.Get("protobuf_val")
  137. m.Fields = append(m.Fields, ms.parseField(tag, tagKey, tagVal, f.Type, m))
  138. }
  139. if tag := f.Tag.Get("protobuf_oneof"); tag != "" {
  140. name := pref.Name(tag)
  141. m.Oneofs = append(m.Oneofs, ptype.Oneof{Name: name})
  142. for _, t := range oneofWrappers {
  143. if t.Implements(f.Type) {
  144. f := t.Elem().Field(0)
  145. if tag := f.Tag.Get("protobuf"); tag != "" {
  146. ft := ms.parseField(tag, "", "", f.Type, m)
  147. ft.OneofName = name
  148. m.Fields = append(m.Fields, ft)
  149. }
  150. }
  151. }
  152. }
  153. }
  154. return ptype.PlaceholderMessage(m.FullName)
  155. }
  156. func (ms *messageDescSet) parseField(tag, tagKey, tagVal string, t reflect.Type, parent *ptype.StandaloneMessage) (f ptype.Field) {
  157. isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
  158. isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
  159. if isOptional || isRepeated {
  160. t = t.Elem()
  161. }
  162. f.Options = &descriptorV1.FieldOptions{
  163. Packed: protoV1.Bool(false),
  164. }
  165. for len(tag) > 0 {
  166. i := strings.IndexByte(tag, ',')
  167. if i < 0 {
  168. i = len(tag)
  169. }
  170. switch s := tag[:i]; {
  171. case strings.HasPrefix(s, "name="):
  172. f.Name = pref.Name(s[len("name="):])
  173. case strings.Trim(s, "0123456789") == "":
  174. n, _ := strconv.ParseUint(s, 10, 32)
  175. f.Number = pref.FieldNumber(n)
  176. case s == "opt":
  177. f.Cardinality = pref.Optional
  178. case s == "req":
  179. f.Cardinality = pref.Required
  180. case s == "rep":
  181. f.Cardinality = pref.Repeated
  182. case s == "varint":
  183. switch t.Kind() {
  184. case reflect.Bool:
  185. f.Kind = pref.BoolKind
  186. case reflect.Int32:
  187. f.Kind = pref.Int32Kind
  188. case reflect.Int64:
  189. f.Kind = pref.Int64Kind
  190. case reflect.Uint32:
  191. f.Kind = pref.Uint32Kind
  192. case reflect.Uint64:
  193. f.Kind = pref.Uint64Kind
  194. }
  195. case s == "zigzag32":
  196. if t.Kind() == reflect.Int32 {
  197. f.Kind = pref.Sint32Kind
  198. }
  199. case s == "zigzag64":
  200. if t.Kind() == reflect.Int64 {
  201. f.Kind = pref.Sint64Kind
  202. }
  203. case s == "fixed32":
  204. switch t.Kind() {
  205. case reflect.Int32:
  206. f.Kind = pref.Sfixed32Kind
  207. case reflect.Uint32:
  208. f.Kind = pref.Fixed32Kind
  209. case reflect.Float32:
  210. f.Kind = pref.FloatKind
  211. }
  212. case s == "fixed64":
  213. switch t.Kind() {
  214. case reflect.Int64:
  215. f.Kind = pref.Sfixed64Kind
  216. case reflect.Uint64:
  217. f.Kind = pref.Fixed64Kind
  218. case reflect.Float64:
  219. f.Kind = pref.DoubleKind
  220. }
  221. case s == "bytes":
  222. switch {
  223. case t.Kind() == reflect.String:
  224. f.Kind = pref.StringKind
  225. case t.Kind() == reflect.Slice && t.Elem() == byteType:
  226. f.Kind = pref.BytesKind
  227. default:
  228. f.Kind = pref.MessageKind
  229. }
  230. case s == "group":
  231. f.Kind = pref.GroupKind
  232. case strings.HasPrefix(s, "enum="):
  233. f.Kind = pref.EnumKind
  234. case strings.HasPrefix(s, "json="):
  235. f.JSONName = s[len("json="):]
  236. case s == "packed":
  237. *f.Options.Packed = true
  238. case strings.HasPrefix(s, "weak="):
  239. f.Options.Weak = protoV1.Bool(true)
  240. f.MessageType = ptype.PlaceholderMessage(pref.FullName(s[len("weak="):]))
  241. case strings.HasPrefix(s, "def="):
  242. // The default tag is special in that everything afterwards is the
  243. // default regardless of the presence of commas.
  244. s, i = tag[len("def="):], len(tag)
  245. // Defaults are parsed last, so Kind is populated.
  246. switch f.Kind {
  247. case pref.BoolKind:
  248. switch s {
  249. case "1":
  250. f.Default = pref.ValueOf(true)
  251. case "0":
  252. f.Default = pref.ValueOf(false)
  253. }
  254. case pref.EnumKind:
  255. n, _ := strconv.ParseInt(s, 10, 32)
  256. f.Default = pref.ValueOf(pref.EnumNumber(n))
  257. case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
  258. n, _ := strconv.ParseInt(s, 10, 32)
  259. f.Default = pref.ValueOf(int32(n))
  260. case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
  261. n, _ := strconv.ParseInt(s, 10, 64)
  262. f.Default = pref.ValueOf(int64(n))
  263. case pref.Uint32Kind, pref.Fixed32Kind:
  264. n, _ := strconv.ParseUint(s, 10, 32)
  265. f.Default = pref.ValueOf(uint32(n))
  266. case pref.Uint64Kind, pref.Fixed64Kind:
  267. n, _ := strconv.ParseUint(s, 10, 64)
  268. f.Default = pref.ValueOf(uint64(n))
  269. case pref.FloatKind, pref.DoubleKind:
  270. n, _ := strconv.ParseFloat(s, 64)
  271. switch s {
  272. case "nan":
  273. n = math.NaN()
  274. case "inf":
  275. n = math.Inf(+1)
  276. case "-inf":
  277. n = math.Inf(-1)
  278. }
  279. if f.Kind == pref.FloatKind {
  280. f.Default = pref.ValueOf(float32(n))
  281. } else {
  282. f.Default = pref.ValueOf(float64(n))
  283. }
  284. case pref.StringKind:
  285. f.Default = pref.ValueOf(string(s))
  286. case pref.BytesKind:
  287. // The default value is in escaped form (C-style).
  288. // TODO: Export unmarshalString in the text package to avoid this hack.
  289. v, err := text.Unmarshal([]byte(`["` + s + `"]:0`))
  290. if err == nil && len(v.Message()) == 1 {
  291. s := v.Message()[0][0].String()
  292. f.Default = pref.ValueOf([]byte(s))
  293. }
  294. }
  295. }
  296. tag = strings.TrimPrefix(tag[i:], ",")
  297. }
  298. // The generator uses the group message name instead of the field name.
  299. // We obtain the real field name by lowercasing the group name.
  300. if f.Kind == pref.GroupKind {
  301. f.Name = pref.Name(strings.ToLower(string(f.Name)))
  302. }
  303. // Populate EnumType and MessageType.
  304. if f.EnumType == nil && f.Kind == pref.EnumKind {
  305. if ev, ok := reflect.Zero(t).Interface().(pref.ProtoEnum); ok {
  306. f.EnumType = ev.ProtoReflect().Type()
  307. } else {
  308. f.EnumType = loadEnumDesc(t)
  309. }
  310. }
  311. if f.MessageType == nil && (f.Kind == pref.MessageKind || f.Kind == pref.GroupKind) {
  312. if mv, ok := reflect.Zero(t).Interface().(pref.ProtoMessage); ok {
  313. f.MessageType = mv.ProtoReflect().Type()
  314. } else if t.Kind() == reflect.Map {
  315. m := &ptype.StandaloneMessage{
  316. Syntax: parent.Syntax,
  317. FullName: parent.FullName.Append(mapEntryName(f.Name)),
  318. Options: &descriptorV1.MessageOptions{MapEntry: protoV1.Bool(true)},
  319. Fields: []ptype.Field{
  320. ms.parseField(tagKey, "", "", t.Key(), nil),
  321. ms.parseField(tagVal, "", "", t.Elem(), nil),
  322. },
  323. }
  324. ms.visit(m, t)
  325. f.MessageType = ptype.PlaceholderMessage(m.FullName)
  326. } else if mv, ok := messageDescCache.Load(t); ok {
  327. f.MessageType = mv.(pref.MessageDescriptor)
  328. } else {
  329. f.MessageType = ms.processMessage(t)
  330. }
  331. }
  332. return f
  333. }
  334. func (ms *messageDescSet) visit(m *ptype.StandaloneMessage, t reflect.Type) {
  335. if ms.visited == nil {
  336. ms.visited = make(map[reflect.Type]*ptype.StandaloneMessage)
  337. }
  338. if t.Kind() != reflect.Map {
  339. ms.visited[t] = m
  340. }
  341. ms.descs = append(ms.descs, m)
  342. ms.types = append(ms.types, t)
  343. }
  344. // deriveFullName derives a fully qualified protobuf name for the given Go type
  345. // The provided name is not guaranteed to be stable nor universally unique.
  346. // It should be sufficiently unique within a program.
  347. func deriveFullName(t reflect.Type) pref.FullName {
  348. sanitize := func(r rune) rune {
  349. switch {
  350. case r == '/':
  351. return '.'
  352. case 'a' <= r && r <= 'z', 'A' <= r && r <= 'Z', '0' <= r && r <= '9':
  353. return r
  354. default:
  355. return '_'
  356. }
  357. }
  358. prefix := strings.Map(sanitize, t.PkgPath())
  359. suffix := strings.Map(sanitize, t.Name())
  360. if suffix == "" {
  361. suffix = fmt.Sprintf("UnknownX%X", reflect.ValueOf(t).Pointer())
  362. }
  363. ss := append(strings.Split(prefix, "."), suffix)
  364. for i, s := range ss {
  365. if s == "" || ('0' <= s[0] && s[0] <= '9') {
  366. ss[i] = "x" + s
  367. }
  368. }
  369. return pref.FullName(strings.Join(ss, "."))
  370. }
  371. // mapEntryName derives the message name for a map field of a given name.
  372. // This is identical to MapEntryName from parser.cc in the protoc source.
  373. func mapEntryName(s pref.Name) pref.Name {
  374. var b []byte
  375. nextUpper := true
  376. for i := 0; i < len(s); i++ {
  377. if c := s[i]; c == '_' {
  378. nextUpper = true
  379. } else {
  380. if nextUpper {
  381. c = byte(unicode.ToUpper(rune(c)))
  382. nextUpper = false
  383. }
  384. b = append(b, c)
  385. }
  386. }
  387. return pref.Name(append(b, "Entry"...))
  388. }