registry_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package protoregistry_test
  5. import (
  6. "fmt"
  7. "strings"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. "github.com/google/go-cmp/cmp/cmpopts"
  11. "google.golang.org/protobuf/encoding/prototext"
  12. pimpl "google.golang.org/protobuf/internal/impl"
  13. pdesc "google.golang.org/protobuf/reflect/protodesc"
  14. pref "google.golang.org/protobuf/reflect/protoreflect"
  15. preg "google.golang.org/protobuf/reflect/protoregistry"
  16. testpb "google.golang.org/protobuf/reflect/protoregistry/testprotos"
  17. "google.golang.org/protobuf/types/descriptorpb"
  18. )
  19. func mustMakeFile(s string) pref.FileDescriptor {
  20. pb := new(descriptorpb.FileDescriptorProto)
  21. if err := prototext.Unmarshal([]byte(s), pb); err != nil {
  22. panic(err)
  23. }
  24. fd, err := pdesc.NewFile(pb, nil)
  25. if err != nil {
  26. panic(err)
  27. }
  28. return fd
  29. }
  30. func TestFiles(t *testing.T) {
  31. type (
  32. file struct {
  33. Path string
  34. Pkg pref.FullName
  35. }
  36. testFile struct {
  37. inFile pref.FileDescriptor
  38. wantErr string
  39. }
  40. testFindDesc struct {
  41. inName pref.FullName
  42. wantFound bool
  43. }
  44. testRangePkg struct {
  45. inPkg pref.FullName
  46. wantFiles []file
  47. }
  48. testFindPath struct {
  49. inPath string
  50. wantFiles []file
  51. }
  52. )
  53. tests := []struct {
  54. files []testFile
  55. findDescs []testFindDesc
  56. rangePkgs []testRangePkg
  57. findPaths []testFindPath
  58. }{{
  59. // Test that overlapping packages and files are permitted.
  60. files: []testFile{
  61. {inFile: mustMakeFile(`syntax:"proto2" name:"test1.proto" package:"foo.bar"`)},
  62. {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"my.test"`)},
  63. {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"foo.bar.baz"`), wantErr: "already registered"},
  64. {inFile: mustMakeFile(`syntax:"proto2" name:"test2.proto" package:"my.test.package"`)},
  65. {inFile: mustMakeFile(`syntax:"proto2" name:"weird" package:"foo.bar"`)},
  66. {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/baz/../test.proto" package:"my.test"`)},
  67. },
  68. rangePkgs: []testRangePkg{{
  69. inPkg: "nothing",
  70. }, {
  71. inPkg: "",
  72. }, {
  73. inPkg: ".",
  74. }, {
  75. inPkg: "foo",
  76. }, {
  77. inPkg: "foo.",
  78. }, {
  79. inPkg: "foo..",
  80. }, {
  81. inPkg: "foo.bar",
  82. wantFiles: []file{
  83. {"test1.proto", "foo.bar"},
  84. {"weird", "foo.bar"},
  85. },
  86. }, {
  87. inPkg: "my.test",
  88. wantFiles: []file{
  89. {"foo/bar/baz/../test.proto", "my.test"},
  90. {"foo/bar/test.proto", "my.test"},
  91. },
  92. }, {
  93. inPkg: "fo",
  94. }},
  95. findPaths: []testFindPath{{
  96. inPath: "nothing",
  97. }, {
  98. inPath: "weird",
  99. wantFiles: []file{
  100. {"weird", "foo.bar"},
  101. },
  102. }, {
  103. inPath: "foo/bar/test.proto",
  104. wantFiles: []file{
  105. {"foo/bar/test.proto", "my.test"},
  106. },
  107. }},
  108. }, {
  109. // Test when new enum conflicts with existing package.
  110. files: []testFile{{
  111. inFile: mustMakeFile(`syntax:"proto2" name:"test1a.proto" package:"foo.bar.baz"`),
  112. }, {
  113. inFile: mustMakeFile(`syntax:"proto2" name:"test1b.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`),
  114. wantErr: `file "test1b.proto" has a name conflict over foo`,
  115. }},
  116. }, {
  117. // Test when new package conflicts with existing enum.
  118. files: []testFile{{
  119. inFile: mustMakeFile(`syntax:"proto2" name:"test2a.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`),
  120. }, {
  121. inFile: mustMakeFile(`syntax:"proto2" name:"test2b.proto" package:"foo.bar.baz"`),
  122. wantErr: `file "test2b.proto" has a package name conflict over foo`,
  123. }},
  124. }, {
  125. // Test when new enum conflicts with existing enum in same package.
  126. files: []testFile{{
  127. inFile: mustMakeFile(`syntax:"proto2" name:"test3a.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE" number:0}]}]`),
  128. }, {
  129. inFile: mustMakeFile(`syntax:"proto2" name:"test3b.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE2" number:0}]}]`),
  130. wantErr: `file "test3b.proto" has a name conflict over foo.BAR`,
  131. }},
  132. }, {
  133. files: []testFile{{
  134. inFile: mustMakeFile(`
  135. syntax: "proto2"
  136. name: "test1.proto"
  137. package: "fizz.buzz"
  138. message_type: [{
  139. name: "Message"
  140. field: [
  141. {name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING oneof_index:0}
  142. ]
  143. oneof_decl: [{name:"Oneof"}]
  144. extension_range: [{start:1000 end:2000}]
  145. enum_type: [
  146. {name:"Enum" value:[{name:"EnumValue" number:0}]}
  147. ]
  148. nested_type: [
  149. {name:"Message" field:[{name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}]}
  150. ]
  151. extension: [
  152. {name:"Extension" number:1001 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"}
  153. ]
  154. }]
  155. enum_type: [{
  156. name: "Enum"
  157. value: [{name:"EnumValue" number:0}]
  158. }]
  159. extension: [
  160. {name:"Extension" number:1000 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"}
  161. ]
  162. service: [{
  163. name: "Service"
  164. method: [{
  165. name: "Method"
  166. input_type: ".fizz.buzz.Message"
  167. output_type: ".fizz.buzz.Message"
  168. client_streaming: true
  169. server_streaming: true
  170. }]
  171. }]
  172. `),
  173. }, {
  174. inFile: mustMakeFile(`
  175. syntax: "proto2"
  176. name: "test2.proto"
  177. package: "fizz.buzz.gazz"
  178. enum_type: [{
  179. name: "Enum"
  180. value: [{name:"EnumValue" number:0}]
  181. }]
  182. `),
  183. }, {
  184. inFile: mustMakeFile(`
  185. syntax: "proto2"
  186. name: "test3.proto"
  187. package: "fizz.buzz"
  188. enum_type: [{
  189. name: "Enum1"
  190. value: [{name:"EnumValue1" number:0}]
  191. }, {
  192. name: "Enum2"
  193. value: [{name:"EnumValue2" number:0}]
  194. }]
  195. `),
  196. }, {
  197. // Make sure we can register without package name.
  198. inFile: mustMakeFile(`
  199. name: "weird"
  200. syntax: "proto2"
  201. message_type: [{
  202. name: "Message"
  203. nested_type: [{
  204. name: "Message"
  205. nested_type: [{
  206. name: "Message"
  207. }]
  208. }]
  209. }]
  210. `),
  211. }},
  212. findDescs: []testFindDesc{
  213. {inName: "fizz.buzz.message", wantFound: false},
  214. {inName: "fizz.buzz.Message", wantFound: true},
  215. {inName: "fizz.buzz.Message.X", wantFound: false},
  216. {inName: "fizz.buzz.Field", wantFound: false},
  217. {inName: "fizz.buzz.Oneof", wantFound: false},
  218. {inName: "fizz.buzz.Message.Field", wantFound: true},
  219. {inName: "fizz.buzz.Message.Field.X", wantFound: false},
  220. {inName: "fizz.buzz.Message.Oneof", wantFound: true},
  221. {inName: "fizz.buzz.Message.Oneof.X", wantFound: false},
  222. {inName: "fizz.buzz.Message.Message", wantFound: true},
  223. {inName: "fizz.buzz.Message.Message.X", wantFound: false},
  224. {inName: "fizz.buzz.Message.Enum", wantFound: true},
  225. {inName: "fizz.buzz.Message.Enum.X", wantFound: false},
  226. {inName: "fizz.buzz.Message.EnumValue", wantFound: true},
  227. {inName: "fizz.buzz.Message.EnumValue.X", wantFound: false},
  228. {inName: "fizz.buzz.Message.Extension", wantFound: true},
  229. {inName: "fizz.buzz.Message.Extension.X", wantFound: false},
  230. {inName: "fizz.buzz.enum", wantFound: false},
  231. {inName: "fizz.buzz.Enum", wantFound: true},
  232. {inName: "fizz.buzz.Enum.X", wantFound: false},
  233. {inName: "fizz.buzz.EnumValue", wantFound: true},
  234. {inName: "fizz.buzz.EnumValue.X", wantFound: false},
  235. {inName: "fizz.buzz.Enum.EnumValue", wantFound: false},
  236. {inName: "fizz.buzz.Extension", wantFound: true},
  237. {inName: "fizz.buzz.Extension.X", wantFound: false},
  238. {inName: "fizz.buzz.service", wantFound: false},
  239. {inName: "fizz.buzz.Service", wantFound: true},
  240. {inName: "fizz.buzz.Service.X", wantFound: false},
  241. {inName: "fizz.buzz.Method", wantFound: false},
  242. {inName: "fizz.buzz.Service.Method", wantFound: true},
  243. {inName: "fizz.buzz.Service.Method.X", wantFound: false},
  244. {inName: "fizz.buzz.gazz", wantFound: false},
  245. {inName: "fizz.buzz.gazz.Enum", wantFound: true},
  246. {inName: "fizz.buzz.gazz.EnumValue", wantFound: true},
  247. {inName: "fizz.buzz.gazz.Enum.EnumValue", wantFound: false},
  248. {inName: "fizz.buzz", wantFound: false},
  249. {inName: "fizz.buzz.Enum1", wantFound: true},
  250. {inName: "fizz.buzz.EnumValue1", wantFound: true},
  251. {inName: "fizz.buzz.Enum1.EnumValue1", wantFound: false},
  252. {inName: "fizz.buzz.Enum2", wantFound: true},
  253. {inName: "fizz.buzz.EnumValue2", wantFound: true},
  254. {inName: "fizz.buzz.Enum2.EnumValue2", wantFound: false},
  255. {inName: "fizz.buzz.Enum3", wantFound: false},
  256. {inName: "", wantFound: false},
  257. {inName: "Message", wantFound: true},
  258. {inName: "Message.Message", wantFound: true},
  259. {inName: "Message.Message.Message", wantFound: true},
  260. {inName: "Message.Message.Message.Message", wantFound: false},
  261. },
  262. }}
  263. sortFiles := cmpopts.SortSlices(func(x, y file) bool {
  264. return x.Path < y.Path || (x.Path == y.Path && x.Pkg < y.Pkg)
  265. })
  266. for _, tt := range tests {
  267. t.Run("", func(t *testing.T) {
  268. var files preg.Files
  269. for i, tc := range tt.files {
  270. gotErr := files.Register(tc.inFile)
  271. if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
  272. t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
  273. }
  274. }
  275. for _, tc := range tt.findDescs {
  276. d, _ := files.FindDescriptorByName(tc.inName)
  277. gotFound := d != nil
  278. if gotFound != tc.wantFound {
  279. t.Errorf("FindDescriptorByName(%v) find mismatch: got %v, want %v", tc.inName, gotFound, tc.wantFound)
  280. }
  281. }
  282. for _, tc := range tt.rangePkgs {
  283. var gotFiles []file
  284. var gotCnt int
  285. wantCnt := files.NumFilesByPackage(tc.inPkg)
  286. files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool {
  287. gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
  288. gotCnt++
  289. return true
  290. })
  291. if gotCnt != wantCnt {
  292. t.Errorf("NumFilesByPackage(%v) = %v, want %v", tc.inPkg, gotCnt, wantCnt)
  293. }
  294. if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
  295. t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff)
  296. }
  297. }
  298. for _, tc := range tt.findPaths {
  299. var gotFiles []file
  300. if fd, err := files.FindFileByPath(tc.inPath); err == nil {
  301. gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
  302. }
  303. if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
  304. t.Errorf("FindFileByPath(%v) mismatch (-want +got):\n%v", tc.inPath, diff)
  305. }
  306. }
  307. })
  308. }
  309. }
  310. func TestTypes(t *testing.T) {
  311. mt1 := pimpl.Export{}.MessageTypeOf(&testpb.Message1{})
  312. et1 := pimpl.Export{}.EnumTypeOf(testpb.Enum1_ONE)
  313. xt1 := testpb.E_StringField
  314. xt2 := testpb.E_Message4_MessageField
  315. registry := new(preg.Types)
  316. if err := registry.Register(mt1, et1, xt1, xt2); err != nil {
  317. t.Fatalf("registry.Register() returns unexpected error: %v", err)
  318. }
  319. t.Run("FindMessageByName", func(t *testing.T) {
  320. tests := []struct {
  321. name string
  322. messageType pref.MessageType
  323. wantErr bool
  324. wantNotFound bool
  325. }{{
  326. name: "testprotos.Message1",
  327. messageType: mt1,
  328. }, {
  329. name: "testprotos.NoSuchMessage",
  330. wantErr: true,
  331. wantNotFound: true,
  332. }, {
  333. name: "testprotos.Enum1",
  334. wantErr: true,
  335. }, {
  336. name: "testprotos.Enum2",
  337. wantErr: true,
  338. }, {
  339. name: "testprotos.Enum3",
  340. wantErr: true,
  341. }}
  342. for _, tc := range tests {
  343. got, err := registry.FindMessageByName(pref.FullName(tc.name))
  344. gotErr := err != nil
  345. if gotErr != tc.wantErr {
  346. t.Errorf("FindMessageByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
  347. continue
  348. }
  349. if tc.wantNotFound && err != preg.NotFound {
  350. t.Errorf("FindMessageByName(%v) got error: %v, want NotFound error", tc.name, err)
  351. continue
  352. }
  353. if got != tc.messageType {
  354. t.Errorf("FindMessageByName(%v) got wrong value: %v", tc.name, got)
  355. }
  356. }
  357. })
  358. t.Run("FindMessageByURL", func(t *testing.T) {
  359. tests := []struct {
  360. name string
  361. messageType pref.MessageType
  362. wantErr bool
  363. wantNotFound bool
  364. }{{
  365. name: "testprotos.Message1",
  366. messageType: mt1,
  367. }, {
  368. name: "type.googleapis.com/testprotos.Nada",
  369. wantErr: true,
  370. wantNotFound: true,
  371. }, {
  372. name: "testprotos.Enum1",
  373. wantErr: true,
  374. }}
  375. for _, tc := range tests {
  376. got, err := registry.FindMessageByURL(tc.name)
  377. gotErr := err != nil
  378. if gotErr != tc.wantErr {
  379. t.Errorf("FindMessageByURL(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
  380. continue
  381. }
  382. if tc.wantNotFound && err != preg.NotFound {
  383. t.Errorf("FindMessageByURL(%v) got error: %v, want NotFound error", tc.name, err)
  384. continue
  385. }
  386. if got != tc.messageType {
  387. t.Errorf("FindMessageByURL(%v) got wrong value: %v", tc.name, got)
  388. }
  389. }
  390. })
  391. t.Run("FindEnumByName", func(t *testing.T) {
  392. tests := []struct {
  393. name string
  394. enumType pref.EnumType
  395. wantErr bool
  396. wantNotFound bool
  397. }{{
  398. name: "testprotos.Enum1",
  399. enumType: et1,
  400. }, {
  401. name: "testprotos.None",
  402. wantErr: true,
  403. wantNotFound: true,
  404. }, {
  405. name: "testprotos.Message1",
  406. wantErr: true,
  407. }}
  408. for _, tc := range tests {
  409. got, err := registry.FindEnumByName(pref.FullName(tc.name))
  410. gotErr := err != nil
  411. if gotErr != tc.wantErr {
  412. t.Errorf("FindEnumByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
  413. continue
  414. }
  415. if tc.wantNotFound && err != preg.NotFound {
  416. t.Errorf("FindEnumByName(%v) got error: %v, want NotFound error", tc.name, err)
  417. continue
  418. }
  419. if got != tc.enumType {
  420. t.Errorf("FindEnumByName(%v) got wrong value: %v", tc.name, got)
  421. }
  422. }
  423. })
  424. t.Run("FindExtensionByName", func(t *testing.T) {
  425. tests := []struct {
  426. name string
  427. extensionType pref.ExtensionType
  428. wantErr bool
  429. wantNotFound bool
  430. }{{
  431. name: "testprotos.string_field",
  432. extensionType: xt1,
  433. }, {
  434. name: "testprotos.Message4.message_field",
  435. extensionType: xt2,
  436. }, {
  437. name: "testprotos.None",
  438. wantErr: true,
  439. wantNotFound: true,
  440. }, {
  441. name: "testprotos.Message1",
  442. wantErr: true,
  443. }}
  444. for _, tc := range tests {
  445. got, err := registry.FindExtensionByName(pref.FullName(tc.name))
  446. gotErr := err != nil
  447. if gotErr != tc.wantErr {
  448. t.Errorf("FindExtensionByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
  449. continue
  450. }
  451. if tc.wantNotFound && err != preg.NotFound {
  452. t.Errorf("FindExtensionByName(%v) got error: %v, want NotFound error", tc.name, err)
  453. continue
  454. }
  455. if got != tc.extensionType {
  456. t.Errorf("FindExtensionByName(%v) got wrong value: %v", tc.name, got)
  457. }
  458. }
  459. })
  460. t.Run("FindExtensionByNumber", func(t *testing.T) {
  461. tests := []struct {
  462. parent string
  463. number int32
  464. extensionType pref.ExtensionType
  465. wantErr bool
  466. wantNotFound bool
  467. }{{
  468. parent: "testprotos.Message1",
  469. number: 11,
  470. extensionType: xt1,
  471. }, {
  472. parent: "testprotos.Message1",
  473. number: 13,
  474. wantErr: true,
  475. wantNotFound: true,
  476. }, {
  477. parent: "testprotos.Message1",
  478. number: 21,
  479. extensionType: xt2,
  480. }, {
  481. parent: "testprotos.Message1",
  482. number: 23,
  483. wantErr: true,
  484. wantNotFound: true,
  485. }, {
  486. parent: "testprotos.NoSuchMessage",
  487. number: 11,
  488. wantErr: true,
  489. wantNotFound: true,
  490. }, {
  491. parent: "testprotos.Message1",
  492. number: 30,
  493. wantErr: true,
  494. wantNotFound: true,
  495. }, {
  496. parent: "testprotos.Message1",
  497. number: 99,
  498. wantErr: true,
  499. wantNotFound: true,
  500. }}
  501. for _, tc := range tests {
  502. got, err := registry.FindExtensionByNumber(pref.FullName(tc.parent), pref.FieldNumber(tc.number))
  503. gotErr := err != nil
  504. if gotErr != tc.wantErr {
  505. t.Errorf("FindExtensionByNumber(%v, %d) = (_, %v), want error? %t", tc.parent, tc.number, err, tc.wantErr)
  506. continue
  507. }
  508. if tc.wantNotFound && err != preg.NotFound {
  509. t.Errorf("FindExtensionByNumber(%v, %d) got error %v, want NotFound error", tc.parent, tc.number, err)
  510. continue
  511. }
  512. if got != tc.extensionType {
  513. t.Errorf("FindExtensionByNumber(%v, %d) got wrong value: %v", tc.parent, tc.number, got)
  514. }
  515. }
  516. })
  517. sortTypes := cmp.Options{
  518. cmpopts.SortSlices(func(x, y pref.EnumType) bool {
  519. return x.Descriptor().FullName() < y.Descriptor().FullName()
  520. }),
  521. cmpopts.SortSlices(func(x, y pref.MessageType) bool {
  522. return x.Descriptor().FullName() < y.Descriptor().FullName()
  523. }),
  524. cmpopts.SortSlices(func(x, y pref.ExtensionType) bool {
  525. return x.TypeDescriptor().FullName() < y.TypeDescriptor().FullName()
  526. }),
  527. }
  528. compare := cmp.Options{
  529. cmp.Comparer(func(x, y pref.EnumType) bool {
  530. return x == y
  531. }),
  532. cmp.Comparer(func(x, y pref.ExtensionType) bool {
  533. return x == y
  534. }),
  535. cmp.Comparer(func(x, y pref.MessageType) bool {
  536. return x == y
  537. }),
  538. }
  539. t.Run("RangeEnums", func(t *testing.T) {
  540. want := []pref.EnumType{et1}
  541. var got []pref.EnumType
  542. var gotCnt int
  543. wantCnt := registry.NumEnums()
  544. registry.RangeEnums(func(et pref.EnumType) bool {
  545. got = append(got, et)
  546. gotCnt++
  547. return true
  548. })
  549. if gotCnt != wantCnt {
  550. t.Errorf("NumEnums() = %v, want %v", gotCnt, wantCnt)
  551. }
  552. if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
  553. t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
  554. }
  555. })
  556. t.Run("RangeMessages", func(t *testing.T) {
  557. want := []pref.MessageType{mt1}
  558. var got []pref.MessageType
  559. var gotCnt int
  560. wantCnt := registry.NumMessages()
  561. registry.RangeMessages(func(mt pref.MessageType) bool {
  562. got = append(got, mt)
  563. gotCnt++
  564. return true
  565. })
  566. if gotCnt != wantCnt {
  567. t.Errorf("NumMessages() = %v, want %v", gotCnt, wantCnt)
  568. }
  569. if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
  570. t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
  571. }
  572. })
  573. t.Run("RangeExtensions", func(t *testing.T) {
  574. want := []pref.ExtensionType{xt1, xt2}
  575. var got []pref.ExtensionType
  576. var gotCnt int
  577. wantCnt := registry.NumExtensions()
  578. registry.RangeExtensions(func(xt pref.ExtensionType) bool {
  579. got = append(got, xt)
  580. gotCnt++
  581. return true
  582. })
  583. if gotCnt != wantCnt {
  584. t.Errorf("NumExtensions() = %v, want %v", gotCnt, wantCnt)
  585. }
  586. if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
  587. t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff)
  588. }
  589. })
  590. t.Run("RangeExtensionsByMessage", func(t *testing.T) {
  591. want := []pref.ExtensionType{xt1, xt2}
  592. var got []pref.ExtensionType
  593. var gotCnt int
  594. wantCnt := registry.NumExtensionsByMessage("testprotos.Message1")
  595. registry.RangeExtensionsByMessage("testprotos.Message1", func(xt pref.ExtensionType) bool {
  596. got = append(got, xt)
  597. gotCnt++
  598. return true
  599. })
  600. if gotCnt != wantCnt {
  601. t.Errorf("NumExtensionsByMessage() = %v, want %v", gotCnt, wantCnt)
  602. }
  603. if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
  604. t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff)
  605. }
  606. })
  607. }