merge_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto_test
  5. import (
  6. "sync"
  7. "testing"
  8. "google.golang.org/protobuf/internal/encoding/pack"
  9. "google.golang.org/protobuf/proto"
  10. testpb "google.golang.org/protobuf/internal/testprotos/test"
  11. )
  12. func TestMerge(t *testing.T) {
  13. dst := new(testpb.TestAllTypes)
  14. src := (*testpb.TestAllTypes)(nil)
  15. proto.Merge(dst, src)
  16. // Mutating the source should not affect dst.
  17. tests := []struct {
  18. desc string
  19. dst proto.Message
  20. src proto.Message
  21. want proto.Message
  22. mutator func(proto.Message) // if provided, is run on src after merging
  23. }{{
  24. desc: "merge from nil message",
  25. dst: new(testpb.TestAllTypes),
  26. src: (*testpb.TestAllTypes)(nil),
  27. want: new(testpb.TestAllTypes),
  28. }, {
  29. desc: "clone a large message",
  30. dst: new(testpb.TestAllTypes),
  31. src: &testpb.TestAllTypes{
  32. OptionalInt64: proto.Int64(0),
  33. OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
  34. OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
  35. A: proto.Int32(100),
  36. },
  37. RepeatedSfixed32: []int32{1, 2, 3},
  38. RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
  39. {A: proto.Int32(200)},
  40. {A: proto.Int32(300)},
  41. },
  42. MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
  43. "fizz": 400,
  44. "buzz": 500,
  45. },
  46. MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
  47. "foo": {A: proto.Int32(600)},
  48. "bar": {A: proto.Int32(700)},
  49. },
  50. OneofField: &testpb.TestAllTypes_OneofNestedMessage{
  51. &testpb.TestAllTypes_NestedMessage{
  52. A: proto.Int32(800),
  53. },
  54. },
  55. },
  56. want: &testpb.TestAllTypes{
  57. OptionalInt64: proto.Int64(0),
  58. OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
  59. OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
  60. A: proto.Int32(100),
  61. },
  62. RepeatedSfixed32: []int32{1, 2, 3},
  63. RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
  64. {A: proto.Int32(200)},
  65. {A: proto.Int32(300)},
  66. },
  67. MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
  68. "fizz": 400,
  69. "buzz": 500,
  70. },
  71. MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
  72. "foo": {A: proto.Int32(600)},
  73. "bar": {A: proto.Int32(700)},
  74. },
  75. OneofField: &testpb.TestAllTypes_OneofNestedMessage{
  76. &testpb.TestAllTypes_NestedMessage{
  77. A: proto.Int32(800),
  78. },
  79. },
  80. },
  81. mutator: func(mi proto.Message) {
  82. m := mi.(*testpb.TestAllTypes)
  83. *m.OptionalInt64++
  84. *m.OptionalNestedEnum++
  85. *m.OptionalNestedMessage.A++
  86. m.RepeatedSfixed32[0]++
  87. *m.RepeatedNestedMessage[0].A++
  88. delete(m.MapStringNestedEnum, "fizz")
  89. *m.MapStringNestedMessage["foo"].A++
  90. *m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.A++
  91. },
  92. }, {
  93. desc: "merge bytes",
  94. dst: &testpb.TestAllTypes{
  95. OptionalBytes: []byte{1, 2, 3},
  96. RepeatedBytes: [][]byte{{1, 2}, {3, 4}},
  97. MapStringBytes: map[string][]byte{"alpha": {1, 2, 3}},
  98. },
  99. src: &testpb.TestAllTypes{
  100. OptionalBytes: []byte{4, 5, 6},
  101. RepeatedBytes: [][]byte{{5, 6}, {7, 8}},
  102. MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
  103. },
  104. want: &testpb.TestAllTypes{
  105. OptionalBytes: []byte{4, 5, 6},
  106. RepeatedBytes: [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}},
  107. MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
  108. },
  109. mutator: func(mi proto.Message) {
  110. m := mi.(*testpb.TestAllTypes)
  111. m.OptionalBytes[0]++
  112. m.RepeatedBytes[0][0]++
  113. m.MapStringBytes["alpha"][0]++
  114. },
  115. }, {
  116. desc: "merge singular fields",
  117. dst: &testpb.TestAllTypes{
  118. OptionalInt32: proto.Int32(1),
  119. OptionalInt64: proto.Int64(1),
  120. OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(10).Enum(),
  121. OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
  122. A: proto.Int32(100),
  123. Corecursive: &testpb.TestAllTypes{
  124. OptionalInt64: proto.Int64(1000),
  125. },
  126. },
  127. },
  128. src: &testpb.TestAllTypes{
  129. OptionalInt64: proto.Int64(2),
  130. OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
  131. OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
  132. A: proto.Int32(200),
  133. },
  134. },
  135. want: &testpb.TestAllTypes{
  136. OptionalInt32: proto.Int32(1),
  137. OptionalInt64: proto.Int64(2),
  138. OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
  139. OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
  140. A: proto.Int32(200),
  141. Corecursive: &testpb.TestAllTypes{
  142. OptionalInt64: proto.Int64(1000),
  143. },
  144. },
  145. },
  146. mutator: func(mi proto.Message) {
  147. m := mi.(*testpb.TestAllTypes)
  148. *m.OptionalInt64++
  149. *m.OptionalNestedEnum++
  150. *m.OptionalNestedMessage.A++
  151. },
  152. }, {
  153. desc: "merge list fields",
  154. dst: &testpb.TestAllTypes{
  155. RepeatedSfixed32: []int32{1, 2, 3},
  156. RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
  157. {A: proto.Int32(100)},
  158. {A: proto.Int32(200)},
  159. },
  160. },
  161. src: &testpb.TestAllTypes{
  162. RepeatedSfixed32: []int32{4, 5, 6},
  163. RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
  164. {A: proto.Int32(300)},
  165. {A: proto.Int32(400)},
  166. },
  167. },
  168. want: &testpb.TestAllTypes{
  169. RepeatedSfixed32: []int32{1, 2, 3, 4, 5, 6},
  170. RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
  171. {A: proto.Int32(100)},
  172. {A: proto.Int32(200)},
  173. {A: proto.Int32(300)},
  174. {A: proto.Int32(400)},
  175. },
  176. },
  177. mutator: func(mi proto.Message) {
  178. m := mi.(*testpb.TestAllTypes)
  179. m.RepeatedSfixed32[0]++
  180. *m.RepeatedNestedMessage[0].A++
  181. },
  182. }, {
  183. desc: "merge map fields",
  184. dst: &testpb.TestAllTypes{
  185. MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
  186. "fizz": 100,
  187. "buzz": 200,
  188. "guzz": 300,
  189. },
  190. MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
  191. "foo": {A: proto.Int32(400)},
  192. },
  193. },
  194. src: &testpb.TestAllTypes{
  195. MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
  196. "fizz": 1000,
  197. "buzz": 2000,
  198. },
  199. MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
  200. "foo": {A: proto.Int32(3000)},
  201. "bar": {},
  202. },
  203. },
  204. want: &testpb.TestAllTypes{
  205. MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
  206. "fizz": 1000,
  207. "buzz": 2000,
  208. "guzz": 300,
  209. },
  210. MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
  211. "foo": {A: proto.Int32(3000)},
  212. "bar": {},
  213. },
  214. },
  215. mutator: func(mi proto.Message) {
  216. m := mi.(*testpb.TestAllTypes)
  217. delete(m.MapStringNestedEnum, "fizz")
  218. m.MapStringNestedMessage["bar"].A = proto.Int32(1)
  219. },
  220. }, {
  221. desc: "merge oneof message fields",
  222. dst: &testpb.TestAllTypes{
  223. OneofField: &testpb.TestAllTypes_OneofNestedMessage{
  224. &testpb.TestAllTypes_NestedMessage{
  225. A: proto.Int32(100),
  226. },
  227. },
  228. },
  229. src: &testpb.TestAllTypes{
  230. OneofField: &testpb.TestAllTypes_OneofNestedMessage{
  231. &testpb.TestAllTypes_NestedMessage{
  232. Corecursive: &testpb.TestAllTypes{
  233. OptionalInt64: proto.Int64(1000),
  234. },
  235. },
  236. },
  237. },
  238. want: &testpb.TestAllTypes{
  239. OneofField: &testpb.TestAllTypes_OneofNestedMessage{
  240. &testpb.TestAllTypes_NestedMessage{
  241. A: proto.Int32(100),
  242. Corecursive: &testpb.TestAllTypes{
  243. OptionalInt64: proto.Int64(1000),
  244. },
  245. },
  246. },
  247. },
  248. mutator: func(mi proto.Message) {
  249. m := mi.(*testpb.TestAllTypes)
  250. *m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.Corecursive.OptionalInt64++
  251. },
  252. }, {
  253. desc: "merge oneof scalar fields",
  254. dst: &testpb.TestAllTypes{
  255. OneofField: &testpb.TestAllTypes_OneofUint32{100},
  256. },
  257. src: &testpb.TestAllTypes{
  258. OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
  259. },
  260. want: &testpb.TestAllTypes{
  261. OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
  262. },
  263. mutator: func(mi proto.Message) {
  264. m := mi.(*testpb.TestAllTypes)
  265. m.OneofField.(*testpb.TestAllTypes_OneofFloat).OneofFloat++
  266. },
  267. }, {
  268. desc: "merge extension fields",
  269. dst: func() proto.Message {
  270. m := new(testpb.TestAllExtensions)
  271. proto.SetExtension(m, testpb.E_OptionalInt32Extension, int32(32))
  272. proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension,
  273. &testpb.TestAllTypes_NestedMessage{
  274. A: proto.Int32(50),
  275. },
  276. )
  277. proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{1, 2, 3})
  278. return m
  279. }(),
  280. src: func() proto.Message {
  281. m := new(testpb.TestAllExtensions)
  282. proto.SetExtension(m, testpb.E_OptionalInt64Extension, int64(64))
  283. proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension,
  284. &testpb.TestAllTypes_NestedMessage{
  285. Corecursive: &testpb.TestAllTypes{
  286. OptionalInt64: proto.Int64(1000),
  287. },
  288. },
  289. )
  290. proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{4, 5, 6})
  291. return m
  292. }(),
  293. want: func() proto.Message {
  294. m := new(testpb.TestAllExtensions)
  295. proto.SetExtension(m, testpb.E_OptionalInt32Extension, int32(32))
  296. proto.SetExtension(m, testpb.E_OptionalInt64Extension, int64(64))
  297. proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension,
  298. &testpb.TestAllTypes_NestedMessage{
  299. A: proto.Int32(50),
  300. Corecursive: &testpb.TestAllTypes{
  301. OptionalInt64: proto.Int64(1000),
  302. },
  303. },
  304. )
  305. proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{1, 2, 3, 4, 5, 6})
  306. return m
  307. }(),
  308. }, {
  309. desc: "merge unknown fields",
  310. dst: func() proto.Message {
  311. m := new(testpb.TestAllTypes)
  312. m.ProtoReflect().SetUnknown(pack.Message{
  313. pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
  314. }.Marshal())
  315. return m
  316. }(),
  317. src: func() proto.Message {
  318. m := new(testpb.TestAllTypes)
  319. m.ProtoReflect().SetUnknown(pack.Message{
  320. pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
  321. }.Marshal())
  322. return m
  323. }(),
  324. want: func() proto.Message {
  325. m := new(testpb.TestAllTypes)
  326. m.ProtoReflect().SetUnknown(pack.Message{
  327. pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
  328. pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
  329. }.Marshal())
  330. return m
  331. }(),
  332. }}
  333. for _, tt := range tests {
  334. t.Run(tt.desc, func(t *testing.T) {
  335. // Merge should be semantically equivalent to unmarshaling the
  336. // encoded form of src into the current dst.
  337. b1, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.dst)
  338. if err != nil {
  339. t.Fatalf("Marshal(dst) error: %v", err)
  340. }
  341. b2, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.src)
  342. if err != nil {
  343. t.Fatalf("Marshal(src) error: %v", err)
  344. }
  345. dst := tt.dst.ProtoReflect().New().Interface()
  346. err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(append(b1, b2...), dst)
  347. if err != nil {
  348. t.Fatalf("Unmarshal() error: %v", err)
  349. }
  350. if !proto.Equal(dst, tt.want) {
  351. t.Fatalf("Unmarshal(Marshal(dst)+Marshal(src)) mismatch: got %v, want %v", dst, tt.want)
  352. }
  353. proto.Merge(tt.dst, tt.src)
  354. if tt.mutator != nil {
  355. tt.mutator(tt.src) // should not be observable by dst
  356. }
  357. if !proto.Equal(tt.dst, tt.want) {
  358. t.Fatalf("Merge() mismatch:\n got %v\nwant %v", tt.dst, tt.want)
  359. }
  360. })
  361. }
  362. }
  363. func TestMergeRace(t *testing.T) {
  364. dst := new(testpb.TestAllTypes)
  365. srcs := []*testpb.TestAllTypes{
  366. {OptionalInt32: proto.Int32(1)},
  367. {OptionalString: proto.String("hello")},
  368. {RepeatedInt32: []int32{2, 3, 4}},
  369. {RepeatedString: []string{"goodbye"}},
  370. {MapStringString: map[string]string{"key": "value"}},
  371. {OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
  372. A: proto.Int32(5),
  373. }},
  374. func() *testpb.TestAllTypes {
  375. m := new(testpb.TestAllTypes)
  376. m.ProtoReflect().SetUnknown(pack.Message{
  377. pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
  378. }.Marshal())
  379. return m
  380. }(),
  381. }
  382. // It should be safe to concurrently merge non-overlapping fields.
  383. var wg sync.WaitGroup
  384. defer wg.Wait()
  385. for _, src := range srcs {
  386. wg.Add(1)
  387. go func(src proto.Message) {
  388. defer wg.Done()
  389. proto.Merge(dst, src)
  390. }(src)
  391. }
  392. }
  393. func TestMergeSelf(t *testing.T) {
  394. got := &testpb.TestAllTypes{
  395. OptionalInt32: proto.Int32(1),
  396. OptionalString: proto.String("hello"),
  397. RepeatedInt32: []int32{2, 3, 4},
  398. RepeatedString: []string{"goodbye"},
  399. MapStringString: map[string]string{"key": "value"},
  400. OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
  401. A: proto.Int32(5),
  402. },
  403. }
  404. got.ProtoReflect().SetUnknown(pack.Message{
  405. pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
  406. }.Marshal())
  407. proto.Merge(got, got)
  408. // The main impact of merging to self is that repeated fields and
  409. // unknown fields are doubled.
  410. want := &testpb.TestAllTypes{
  411. OptionalInt32: proto.Int32(1),
  412. OptionalString: proto.String("hello"),
  413. RepeatedInt32: []int32{2, 3, 4, 2, 3, 4},
  414. RepeatedString: []string{"goodbye", "goodbye"},
  415. MapStringString: map[string]string{"key": "value"},
  416. OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
  417. A: proto.Int32(5),
  418. },
  419. }
  420. want.ProtoReflect().SetUnknown(pack.Message{
  421. pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
  422. pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
  423. }.Marshal())
  424. if !proto.Equal(got, want) {
  425. t.Errorf("Equal mismatch:\ngot %v\nwant %v", got, want)
  426. }
  427. }