prototest.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package prototest exercises protobuf reflection.
  5. package prototest
  6. import (
  7. "bytes"
  8. "fmt"
  9. "math"
  10. "sort"
  11. "testing"
  12. textpb "github.com/golang/protobuf/v2/encoding/textpb"
  13. "github.com/golang/protobuf/v2/proto"
  14. pref "github.com/golang/protobuf/v2/reflect/protoreflect"
  15. )
  16. // TestMessage runs the provided message through a series of tests
  17. // exercising the protobuf reflection API.
  18. func TestMessage(t testing.TB, message proto.Message) {
  19. md := message.ProtoReflect().Type()
  20. m := md.New()
  21. for i := 0; i < md.Fields().Len(); i++ {
  22. fd := md.Fields().Get(i)
  23. switch {
  24. case fd.IsMap():
  25. testFieldMap(t, m, fd)
  26. case fd.Cardinality() == pref.Repeated:
  27. testFieldList(t, m, fd)
  28. case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind:
  29. testFieldFloat(t, m, fd)
  30. }
  31. testField(t, m, fd)
  32. }
  33. for i := 0; i < md.Oneofs().Len(); i++ {
  34. testOneof(t, m, md.Oneofs().Get(i))
  35. }
  36. // Test has/get/clear on a non-existent field.
  37. for num := pref.FieldNumber(1); ; num++ {
  38. if md.Fields().ByNumber(num) != nil {
  39. continue
  40. }
  41. if md.ExtensionRanges().Has(num) {
  42. continue
  43. }
  44. // Field num does not exist.
  45. if m.KnownFields().Has(num) {
  46. t.Errorf("non-existent field: Has(%v) = true, want false", num)
  47. }
  48. if v := m.KnownFields().Get(num); v.IsValid() {
  49. t.Errorf("non-existent field: Get(%v) = %v, want invalid", num, formatValue(v))
  50. }
  51. m.KnownFields().Clear(num) // noop
  52. break
  53. }
  54. // Test WhichOneof on a non-existent oneof.
  55. const invalidName = "invalid-name"
  56. if got, want := m.KnownFields().WhichOneof(invalidName), pref.FieldNumber(0); got != want {
  57. t.Errorf("non-existent oneof: WhichOneof(%q) = %v, want %v", invalidName, got, want)
  58. }
  59. // TODO: Extensions, unknown fields.
  60. // Test round-trip marshal/unmarshal.
  61. m1 := md.New().Interface()
  62. populateMessage(m1.ProtoReflect(), 1, nil)
  63. b, err := proto.Marshal(m1)
  64. if err != nil {
  65. t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m1))
  66. }
  67. m2 := md.New().Interface()
  68. if err := proto.Unmarshal(b, m2); err != nil {
  69. t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m1))
  70. }
  71. if !proto.Equal(m1, m2) {
  72. t.Errorf("round-trip marshal/unmarshal did not preserve message.\nOriginal:\n%v\nNew:\n%v", marshalText(m1), marshalText(m2))
  73. }
  74. }
  75. func marshalText(m proto.Message) string {
  76. b, _ := textpb.MarshalOptions{Indent: " "}.Marshal(m)
  77. return string(b)
  78. }
  79. // testField exericises set/get/has/clear of a field.
  80. func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
  81. num := fd.Number()
  82. name := fd.FullName()
  83. known := m.KnownFields()
  84. // Set to a non-zero value, the zero value, different non-zero values.
  85. for _, n := range []seed{1, 0, minVal, maxVal} {
  86. v := newValue(m, fd, n, nil)
  87. known.Set(num, v)
  88. wantHas := true
  89. if n == 0 {
  90. if fd.Syntax() == pref.Proto3 && fd.Message() == nil {
  91. wantHas = false
  92. }
  93. if fd.Cardinality() == pref.Repeated {
  94. wantHas = false
  95. }
  96. if fd.Oneof() != nil {
  97. wantHas = true
  98. }
  99. }
  100. if got, want := known.Has(num), wantHas; got != want {
  101. t.Errorf("after setting %q to %v:\nHas(%v) = %v, want %v", name, formatValue(v), num, got, want)
  102. }
  103. if got, want := known.Get(num), v; !valueEqual(got, want) {
  104. t.Errorf("after setting %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
  105. }
  106. }
  107. known.Clear(num)
  108. if got, want := known.Has(num), false; got != want {
  109. t.Errorf("after clearing %q:\nHas(%v) = %v, want %v", name, num, got, want)
  110. }
  111. switch {
  112. case fd.IsMap():
  113. if got := known.Get(num); got.Map().Len() != 0 {
  114. t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
  115. }
  116. case fd.Cardinality() == pref.Repeated:
  117. if got := known.Get(num); got.List().Len() != 0 {
  118. t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
  119. }
  120. default:
  121. if got, want := known.Get(num), fd.Default(); !valueEqual(got, want) {
  122. t.Errorf("after clearing %q:\nGet(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
  123. }
  124. }
  125. }
  126. // testFieldMap tests set/get/has/clear of entries in a map field.
  127. func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
  128. num := fd.Number()
  129. name := fd.FullName()
  130. known := m.KnownFields()
  131. known.Clear(num) // start with an empty map
  132. mapv := known.Get(num).Map()
  133. // Add values.
  134. want := make(testMap)
  135. for i, n := range []seed{1, 0, minVal, maxVal} {
  136. if got, want := known.Has(num), i > 0; got != want {
  137. t.Errorf("after inserting %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
  138. }
  139. k := newMapKey(fd, n)
  140. v := newMapValue(fd, mapv, n, nil)
  141. mapv.Set(k, v)
  142. want.Set(k, v)
  143. if got, want := known.Get(num), pref.ValueOf(want); !valueEqual(got, want) {
  144. t.Errorf("after inserting %d elements to %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
  145. }
  146. }
  147. // Set values.
  148. want.Range(func(k pref.MapKey, v pref.Value) bool {
  149. nv := newMapValue(fd, mapv, 10, nil)
  150. mapv.Set(k, nv)
  151. want.Set(k, nv)
  152. if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
  153. t.Errorf("after setting element %v of %q:\nGet(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
  154. }
  155. return true
  156. })
  157. // Clear values.
  158. want.Range(func(k pref.MapKey, v pref.Value) bool {
  159. mapv.Clear(k)
  160. want.Clear(k)
  161. if got, want := known.Has(num), want.Len() > 0; got != want {
  162. t.Errorf("after clearing elements of %q:\nHas(%v) = %v, want %v", name, num, got, want)
  163. }
  164. if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
  165. t.Errorf("after clearing elements of %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
  166. }
  167. return true
  168. })
  169. // Non-existent map keys.
  170. missingKey := newMapKey(fd, 1)
  171. if got, want := mapv.Has(missingKey), false; got != want {
  172. t.Errorf("non-existent map key in %q: Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
  173. }
  174. if got, want := mapv.Get(missingKey).IsValid(), false; got != want {
  175. t.Errorf("non-existent map key in %q: Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
  176. }
  177. mapv.Clear(missingKey) // noop
  178. }
  179. type testMap map[interface{}]pref.Value
  180. func (m testMap) Get(k pref.MapKey) pref.Value { return m[k.Interface()] }
  181. func (m testMap) Set(k pref.MapKey, v pref.Value) { m[k.Interface()] = v }
  182. func (m testMap) Has(k pref.MapKey) bool { return m.Get(k).IsValid() }
  183. func (m testMap) Clear(k pref.MapKey) { delete(m, k.Interface()) }
  184. func (m testMap) Len() int { return len(m) }
  185. func (m testMap) NewMessage() pref.Message { panic("unimplemented") }
  186. func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) {
  187. for k, v := range m {
  188. if !f(pref.ValueOf(k).MapKey(), v) {
  189. return
  190. }
  191. }
  192. }
  193. // testFieldList exercises set/get/append/truncate of values in a list.
  194. func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
  195. num := fd.Number()
  196. name := fd.FullName()
  197. known := m.KnownFields()
  198. known.Clear(num) // start with an empty list
  199. list := known.Get(num).List()
  200. // Append values.
  201. var want pref.List = &testList{}
  202. for i, n := range []seed{1, 0, minVal, maxVal} {
  203. if got, want := known.Has(num), i > 0; got != want {
  204. t.Errorf("after appending %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
  205. }
  206. v := newListElement(fd, list, n, nil)
  207. want.Append(v)
  208. list.Append(v)
  209. if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
  210. t.Errorf("after appending %d elements to %q:\nGet(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
  211. }
  212. }
  213. // Set values.
  214. for i := 0; i < want.Len(); i++ {
  215. v := newListElement(fd, list, seed(i+10), nil)
  216. want.Set(i, v)
  217. list.Set(i, v)
  218. if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
  219. t.Errorf("after setting element %d of %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
  220. }
  221. }
  222. // Truncate.
  223. for want.Len() > 0 {
  224. n := want.Len() - 1
  225. want.Truncate(n)
  226. list.Truncate(n)
  227. if got, want := known.Has(num), want.Len() > 0; got != want {
  228. t.Errorf("after truncating %q to %d:\nHas(%v) = %v, want %v", name, n, num, got, want)
  229. }
  230. if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
  231. t.Errorf("after truncating %q to %d:\nGet(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
  232. }
  233. }
  234. }
  235. type testList struct {
  236. a []pref.Value
  237. }
  238. func (l *testList) Append(v pref.Value) { l.a = append(l.a, v) }
  239. func (l *testList) Get(n int) pref.Value { return l.a[n] }
  240. func (l *testList) Len() int { return len(l.a) }
  241. func (l *testList) Set(n int, v pref.Value) { l.a[n] = v }
  242. func (l *testList) Truncate(n int) { l.a = l.a[:n] }
  243. func (l *testList) NewMessage() pref.Message { panic("unimplemented") }
  244. // testFieldFloat exercises some interesting floating-point scalar field values.
  245. func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
  246. num := fd.Number()
  247. name := fd.FullName()
  248. known := m.KnownFields()
  249. for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} {
  250. var val pref.Value
  251. if fd.Kind() == pref.FloatKind {
  252. val = pref.ValueOf(float32(v))
  253. } else {
  254. val = pref.ValueOf(v)
  255. }
  256. known.Set(num, val)
  257. // Note that Has is true for -0.
  258. if got, want := known.Has(num), true; got != want {
  259. t.Errorf("after setting %v to %v: Get(%v) = %v, want %v", name, v, num, got, want)
  260. }
  261. if got, want := known.Get(num), val; !valueEqual(got, want) {
  262. t.Errorf("after setting %v: Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
  263. }
  264. }
  265. }
  266. // testOneof tests the behavior of fields in a oneof.
  267. func testOneof(t testing.TB, m pref.Message, od pref.OneofDescriptor) {
  268. known := m.KnownFields()
  269. for i := 0; i < od.Fields().Len(); i++ {
  270. fda := od.Fields().Get(i)
  271. known.Set(fda.Number(), newValue(m, fda, 1, nil))
  272. if got, want := known.WhichOneof(od.Name()), fda.Number(); got != want {
  273. t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want)
  274. }
  275. for j := 0; j < od.Fields().Len(); j++ {
  276. fdb := od.Fields().Get(j)
  277. if got, want := known.Has(fdb.Number()), i == j; got != want {
  278. t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want)
  279. }
  280. }
  281. }
  282. }
  283. func formatValue(v pref.Value) string {
  284. switch v := v.Interface().(type) {
  285. case pref.List:
  286. var buf bytes.Buffer
  287. buf.WriteString("list[")
  288. for i := 0; i < v.Len(); i++ {
  289. if i > 0 {
  290. buf.WriteString(" ")
  291. }
  292. buf.WriteString(formatValue(v.Get(i)))
  293. }
  294. buf.WriteString("]")
  295. return buf.String()
  296. case pref.Map:
  297. var buf bytes.Buffer
  298. buf.WriteString("map[")
  299. var keys []pref.MapKey
  300. v.Range(func(k pref.MapKey, v pref.Value) bool {
  301. keys = append(keys, k)
  302. return true
  303. })
  304. sort.Slice(keys, func(i, j int) bool {
  305. return keys[i].String() < keys[j].String()
  306. })
  307. for i, k := range keys {
  308. if i > 0 {
  309. buf.WriteString(" ")
  310. }
  311. buf.WriteString(formatValue(k.Value()))
  312. buf.WriteString(":")
  313. buf.WriteString(formatValue(v.Get(k)))
  314. }
  315. buf.WriteString("]")
  316. return buf.String()
  317. case pref.Message:
  318. b, err := textpb.Marshal(v.Interface())
  319. if err != nil {
  320. return fmt.Sprintf("<%v>", err)
  321. }
  322. return fmt.Sprintf("%v{%v}", v.Type().FullName(), string(b))
  323. case string:
  324. return fmt.Sprintf("%q", v)
  325. default:
  326. return fmt.Sprint(v)
  327. }
  328. }
  329. func valueEqual(a, b pref.Value) bool {
  330. ai, bi := a.Interface(), b.Interface()
  331. switch ai.(type) {
  332. case pref.Message:
  333. return proto.Equal(
  334. a.Message().Interface(),
  335. b.Message().Interface(),
  336. )
  337. case pref.List:
  338. lista, listb := a.List(), b.List()
  339. if lista.Len() != listb.Len() {
  340. return false
  341. }
  342. for i := 0; i < lista.Len(); i++ {
  343. if !valueEqual(lista.Get(i), listb.Get(i)) {
  344. return false
  345. }
  346. }
  347. return true
  348. case pref.Map:
  349. mapa, mapb := a.Map(), b.Map()
  350. if mapa.Len() != mapb.Len() {
  351. return false
  352. }
  353. equal := true
  354. mapa.Range(func(k pref.MapKey, v pref.Value) bool {
  355. if !valueEqual(v, mapb.Get(k)) {
  356. equal = false
  357. return false
  358. }
  359. return true
  360. })
  361. return equal
  362. case []byte:
  363. return bytes.Equal(a.Bytes(), b.Bytes())
  364. case float32, float64:
  365. // NaNs are equal, but must be the same NaN.
  366. return math.Float64bits(a.Float()) == math.Float64bits(a.Float())
  367. default:
  368. return ai == bi
  369. }
  370. }
  371. // A seed is used to vary the content of a value.
  372. //
  373. // A seed of 0 is the zero value. Messages do not have a zero-value; a 0-seeded messages
  374. // is unpopulated.
  375. //
  376. // A seed of minVal or maxVal is the least or greatest value of the value type.
  377. type seed int
  378. const (
  379. minVal seed = -1
  380. maxVal seed = -2
  381. )
  382. // newValue returns a new value assignable to a field.
  383. //
  384. // The stack parameter is used to avoid infinite recursion when populating circular
  385. // data structures.
  386. func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageType) pref.Value {
  387. num := fd.Number()
  388. switch {
  389. case fd.IsMap():
  390. mapv := m.Type().New().KnownFields().Get(num).Map()
  391. if n == 0 {
  392. return pref.ValueOf(mapv)
  393. }
  394. mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
  395. mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
  396. mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
  397. mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, 10*n, stack))
  398. return pref.ValueOf(mapv)
  399. case fd.Cardinality() == pref.Repeated:
  400. list := m.Type().New().KnownFields().Get(num).List()
  401. if n == 0 {
  402. return pref.ValueOf(list)
  403. }
  404. list.Append(newListElement(fd, list, 0, stack))
  405. list.Append(newListElement(fd, list, minVal, stack))
  406. list.Append(newListElement(fd, list, maxVal, stack))
  407. list.Append(newListElement(fd, list, n, stack))
  408. return pref.ValueOf(list)
  409. case fd.Message() != nil:
  410. return populateMessage(m.KnownFields().NewMessage(num), n, stack)
  411. default:
  412. return newScalarValue(fd, n)
  413. }
  414. }
  415. func newListElement(fd pref.FieldDescriptor, list pref.List, n seed, stack []pref.MessageType) pref.Value {
  416. if fd.Message() == nil {
  417. return newScalarValue(fd, n)
  418. }
  419. return populateMessage(list.NewMessage(), n, stack)
  420. }
  421. func newMapKey(fd pref.FieldDescriptor, n seed) pref.MapKey {
  422. kd := fd.Message().Fields().ByNumber(1)
  423. return newScalarValue(kd, n).MapKey()
  424. }
  425. func newMapValue(fd pref.FieldDescriptor, mapv pref.Map, n seed, stack []pref.MessageType) pref.Value {
  426. vd := fd.Message().Fields().ByNumber(2)
  427. if vd.Message() == nil {
  428. return newScalarValue(vd, n)
  429. }
  430. return populateMessage(mapv.NewMessage(), n, stack)
  431. }
  432. func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
  433. switch fd.Kind() {
  434. case pref.BoolKind:
  435. return pref.ValueOf(n != 0)
  436. case pref.EnumKind:
  437. // TODO use actual value
  438. return pref.ValueOf(pref.EnumNumber(n))
  439. case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
  440. switch n {
  441. case minVal:
  442. return pref.ValueOf(int32(math.MinInt32))
  443. case maxVal:
  444. return pref.ValueOf(int32(math.MaxInt32))
  445. default:
  446. return pref.ValueOf(int32(n))
  447. }
  448. case pref.Uint32Kind, pref.Fixed32Kind:
  449. switch n {
  450. case minVal:
  451. // Only use 0 for the zero value.
  452. return pref.ValueOf(uint32(1))
  453. case maxVal:
  454. return pref.ValueOf(uint32(math.MaxInt32))
  455. default:
  456. return pref.ValueOf(uint32(n))
  457. }
  458. case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
  459. switch n {
  460. case minVal:
  461. return pref.ValueOf(int64(math.MinInt64))
  462. case maxVal:
  463. return pref.ValueOf(int64(math.MaxInt64))
  464. default:
  465. return pref.ValueOf(int64(n))
  466. }
  467. case pref.Uint64Kind, pref.Fixed64Kind:
  468. switch n {
  469. case minVal:
  470. // Only use 0 for the zero value.
  471. return pref.ValueOf(uint64(1))
  472. case maxVal:
  473. return pref.ValueOf(uint64(math.MaxInt64))
  474. default:
  475. return pref.ValueOf(uint64(n))
  476. }
  477. case pref.FloatKind:
  478. switch n {
  479. case minVal:
  480. return pref.ValueOf(float32(math.SmallestNonzeroFloat32))
  481. case maxVal:
  482. return pref.ValueOf(float32(math.MaxFloat32))
  483. default:
  484. return pref.ValueOf(1.5 * float32(n))
  485. }
  486. case pref.DoubleKind:
  487. switch n {
  488. case minVal:
  489. return pref.ValueOf(float64(math.SmallestNonzeroFloat64))
  490. case maxVal:
  491. return pref.ValueOf(float64(math.MaxFloat64))
  492. default:
  493. return pref.ValueOf(1.5 * float64(n))
  494. }
  495. case pref.StringKind:
  496. if n == 0 {
  497. return pref.ValueOf("")
  498. }
  499. return pref.ValueOf(fmt.Sprintf("%d", n))
  500. case pref.BytesKind:
  501. if n == 0 {
  502. return pref.ValueOf([]byte(nil))
  503. }
  504. return pref.ValueOf([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
  505. }
  506. panic("unhandled kind")
  507. }
  508. func populateMessage(m pref.Message, n seed, stack []pref.MessageType) pref.Value {
  509. if n == 0 {
  510. return pref.ValueOf(m)
  511. }
  512. md := m.Type()
  513. for _, x := range stack {
  514. if md == x {
  515. return pref.ValueOf(m)
  516. }
  517. }
  518. stack = append(stack, md)
  519. known := m.KnownFields()
  520. for i := 0; i < md.Fields().Len(); i++ {
  521. fd := md.Fields().Get(i)
  522. if fd.IsWeak() {
  523. continue
  524. }
  525. known.Set(fd.Number(), newValue(m, fd, 10*n+seed(i), stack))
  526. }
  527. return pref.ValueOf(m)
  528. }