extensions_test.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. // Copyright 2014 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto_test
  5. import (
  6. "bytes"
  7. "fmt"
  8. "io"
  9. "reflect"
  10. "sort"
  11. "strings"
  12. "sync"
  13. "testing"
  14. "github.com/golang/protobuf/proto"
  15. pb "github.com/golang/protobuf/proto/test_proto"
  16. )
  17. func TestGetExtensionsWithMissingExtensions(t *testing.T) {
  18. msg := &pb.MyMessage{}
  19. ext1 := &pb.Ext{}
  20. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  21. t.Fatalf("Could not set ext1: %s", err)
  22. }
  23. exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
  24. pb.E_Ext_More,
  25. pb.E_Ext_Text,
  26. })
  27. if err != nil {
  28. t.Fatalf("GetExtensions() failed: %s", err)
  29. }
  30. if exts[0] != ext1 {
  31. t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
  32. }
  33. if exts[1] != nil {
  34. t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
  35. }
  36. }
  37. func TestGetExtensionWithEmptyBuffer(t *testing.T) {
  38. // Make sure that GetExtension returns an error if its
  39. // undecoded buffer is empty.
  40. msg := &pb.MyMessage{}
  41. proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
  42. _, err := proto.GetExtension(msg, pb.E_Ext_More)
  43. if want := io.ErrUnexpectedEOF; err != want {
  44. t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
  45. }
  46. }
  47. func TestGetExtensionForIncompleteDesc(t *testing.T) {
  48. msg := &pb.MyMessage{Count: proto.Int32(0)}
  49. extdesc1 := &proto.ExtensionDesc{
  50. ExtendedType: (*pb.MyMessage)(nil),
  51. ExtensionType: (*bool)(nil),
  52. Field: 123456789,
  53. Name: "a.b",
  54. Tag: "varint,123456789,opt",
  55. }
  56. ext1 := proto.Bool(true)
  57. if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
  58. t.Fatalf("Could not set ext1: %s", err)
  59. }
  60. extdesc2 := &proto.ExtensionDesc{
  61. ExtendedType: (*pb.MyMessage)(nil),
  62. ExtensionType: ([]byte)(nil),
  63. Field: 123456790,
  64. Name: "a.c",
  65. Tag: "bytes,123456790,opt",
  66. }
  67. ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
  68. if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
  69. t.Fatalf("Could not set ext2: %s", err)
  70. }
  71. extdesc3 := &proto.ExtensionDesc{
  72. ExtendedType: (*pb.MyMessage)(nil),
  73. ExtensionType: (*pb.Ext)(nil),
  74. Field: 123456791,
  75. Name: "a.d",
  76. Tag: "bytes,123456791,opt",
  77. }
  78. ext3 := &pb.Ext{Data: proto.String("foo")}
  79. if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
  80. t.Fatalf("Could not set ext3: %s", err)
  81. }
  82. b, err := proto.Marshal(msg)
  83. if err != nil {
  84. t.Fatalf("Could not marshal msg: %v", err)
  85. }
  86. if err := proto.Unmarshal(b, msg); err != nil {
  87. t.Fatalf("Could not unmarshal into msg: %v", err)
  88. }
  89. var expected proto.Buffer
  90. if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
  91. t.Fatalf("failed to compute expected prefix for ext1: %s", err)
  92. }
  93. if err := expected.EncodeVarint(1 /* bool true */); err != nil {
  94. t.Fatalf("failed to compute expected value for ext1: %s", err)
  95. }
  96. if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
  97. t.Fatalf("Failed to get raw value for ext1: %s", err)
  98. } else if !reflect.DeepEqual(b, expected.Bytes()) {
  99. t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
  100. }
  101. expected = proto.Buffer{} // reset
  102. if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
  103. t.Fatalf("failed to compute expected prefix for ext2: %s", err)
  104. }
  105. if err := expected.EncodeRawBytes(ext2); err != nil {
  106. t.Fatalf("failed to compute expected value for ext2: %s", err)
  107. }
  108. if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
  109. t.Fatalf("Failed to get raw value for ext2: %s", err)
  110. } else if !reflect.DeepEqual(b, expected.Bytes()) {
  111. t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
  112. }
  113. expected = proto.Buffer{} // reset
  114. if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
  115. t.Fatalf("failed to compute expected prefix for ext3: %s", err)
  116. }
  117. if b, err := proto.Marshal(ext3); err != nil {
  118. t.Fatalf("failed to compute expected value for ext3: %s", err)
  119. } else if err := expected.EncodeRawBytes(b); err != nil {
  120. t.Fatalf("failed to compute expected value for ext3: %s", err)
  121. }
  122. if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
  123. t.Fatalf("Failed to get raw value for ext3: %s", err)
  124. } else if !reflect.DeepEqual(b, expected.Bytes()) {
  125. t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
  126. }
  127. }
  128. func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
  129. msg := &pb.MyMessage{Count: proto.Int32(0)}
  130. extdesc1 := pb.E_Ext_More
  131. if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
  132. t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
  133. }
  134. ext1 := &pb.Ext{}
  135. if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
  136. t.Fatalf("Could not set ext1: %s", err)
  137. }
  138. extdesc2 := &proto.ExtensionDesc{
  139. ExtendedType: (*pb.MyMessage)(nil),
  140. ExtensionType: (*bool)(nil),
  141. Field: 123456789,
  142. Name: "a.b",
  143. Tag: "varint,123456789,opt",
  144. }
  145. ext2 := proto.Bool(false)
  146. if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
  147. t.Fatalf("Could not set ext2: %s", err)
  148. }
  149. b, err := proto.Marshal(msg)
  150. if err != nil {
  151. t.Fatalf("Could not marshal msg: %v", err)
  152. }
  153. if err := proto.Unmarshal(b, msg); err != nil {
  154. t.Fatalf("Could not unmarshal into msg: %v", err)
  155. }
  156. descs, err := proto.ExtensionDescs(msg)
  157. if err != nil {
  158. t.Fatalf("proto.ExtensionDescs: got error %v", err)
  159. }
  160. sortExtDescs(descs)
  161. wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
  162. if !reflect.DeepEqual(descs, wantDescs) {
  163. t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
  164. }
  165. }
  166. type ExtensionDescSlice []*proto.ExtensionDesc
  167. func (s ExtensionDescSlice) Len() int { return len(s) }
  168. func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
  169. func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
  170. func sortExtDescs(s []*proto.ExtensionDesc) {
  171. sort.Sort(ExtensionDescSlice(s))
  172. }
  173. func TestGetExtensionStability(t *testing.T) {
  174. check := func(m *pb.MyMessage) bool {
  175. ext1, err := proto.GetExtension(m, pb.E_Ext_More)
  176. if err != nil {
  177. t.Fatalf("GetExtension() failed: %s", err)
  178. }
  179. ext2, err := proto.GetExtension(m, pb.E_Ext_More)
  180. if err != nil {
  181. t.Fatalf("GetExtension() failed: %s", err)
  182. }
  183. return ext1 == ext2
  184. }
  185. msg := &pb.MyMessage{Count: proto.Int32(4)}
  186. ext0 := &pb.Ext{}
  187. if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
  188. t.Fatalf("Could not set ext1: %s", ext0)
  189. }
  190. if !check(msg) {
  191. t.Errorf("GetExtension() not stable before marshaling")
  192. }
  193. bb, err := proto.Marshal(msg)
  194. if err != nil {
  195. t.Fatalf("Marshal() failed: %s", err)
  196. }
  197. msg1 := &pb.MyMessage{}
  198. err = proto.Unmarshal(bb, msg1)
  199. if err != nil {
  200. t.Fatalf("Unmarshal() failed: %s", err)
  201. }
  202. if !check(msg1) {
  203. t.Errorf("GetExtension() not stable after unmarshaling")
  204. }
  205. }
  206. func TestGetExtensionDefaults(t *testing.T) {
  207. var setFloat64 float64 = 1
  208. var setFloat32 float32 = 2
  209. var setInt32 int32 = 3
  210. var setInt64 int64 = 4
  211. var setUint32 uint32 = 5
  212. var setUint64 uint64 = 6
  213. var setBool = true
  214. var setBool2 = false
  215. var setString = "Goodnight string"
  216. var setBytes = []byte("Goodnight bytes")
  217. var setEnum = pb.DefaultsMessage_TWO
  218. type testcase struct {
  219. ext *proto.ExtensionDesc // Extension we are testing.
  220. want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
  221. def interface{} // Expected value of extension after ClearExtension().
  222. }
  223. tests := []testcase{
  224. {pb.E_NoDefaultDouble, setFloat64, nil},
  225. {pb.E_NoDefaultFloat, setFloat32, nil},
  226. {pb.E_NoDefaultInt32, setInt32, nil},
  227. {pb.E_NoDefaultInt64, setInt64, nil},
  228. {pb.E_NoDefaultUint32, setUint32, nil},
  229. {pb.E_NoDefaultUint64, setUint64, nil},
  230. {pb.E_NoDefaultSint32, setInt32, nil},
  231. {pb.E_NoDefaultSint64, setInt64, nil},
  232. {pb.E_NoDefaultFixed32, setUint32, nil},
  233. {pb.E_NoDefaultFixed64, setUint64, nil},
  234. {pb.E_NoDefaultSfixed32, setInt32, nil},
  235. {pb.E_NoDefaultSfixed64, setInt64, nil},
  236. {pb.E_NoDefaultBool, setBool, nil},
  237. {pb.E_NoDefaultBool, setBool2, nil},
  238. {pb.E_NoDefaultString, setString, nil},
  239. {pb.E_NoDefaultBytes, setBytes, nil},
  240. {pb.E_NoDefaultEnum, setEnum, nil},
  241. {pb.E_DefaultDouble, setFloat64, float64(3.1415)},
  242. {pb.E_DefaultFloat, setFloat32, float32(3.14)},
  243. {pb.E_DefaultInt32, setInt32, int32(42)},
  244. {pb.E_DefaultInt64, setInt64, int64(43)},
  245. {pb.E_DefaultUint32, setUint32, uint32(44)},
  246. {pb.E_DefaultUint64, setUint64, uint64(45)},
  247. {pb.E_DefaultSint32, setInt32, int32(46)},
  248. {pb.E_DefaultSint64, setInt64, int64(47)},
  249. {pb.E_DefaultFixed32, setUint32, uint32(48)},
  250. {pb.E_DefaultFixed64, setUint64, uint64(49)},
  251. {pb.E_DefaultSfixed32, setInt32, int32(50)},
  252. {pb.E_DefaultSfixed64, setInt64, int64(51)},
  253. {pb.E_DefaultBool, setBool, true},
  254. {pb.E_DefaultBool, setBool2, true},
  255. {pb.E_DefaultString, setString, "Hello, string,def=foo"},
  256. {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
  257. {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
  258. }
  259. checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
  260. val, err := proto.GetExtension(msg, test.ext)
  261. if err != nil {
  262. if valWant != nil {
  263. return fmt.Errorf("GetExtension(): %s", err)
  264. }
  265. if want := proto.ErrMissingExtension; err != want {
  266. return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
  267. }
  268. return nil
  269. }
  270. // All proto2 extension values are either a pointer to a value or a slice of values.
  271. ty := reflect.TypeOf(val)
  272. tyWant := reflect.TypeOf(test.ext.ExtensionType)
  273. if got, want := ty, tyWant; got != want {
  274. return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
  275. }
  276. tye := ty.Elem()
  277. tyeWant := tyWant.Elem()
  278. if got, want := tye, tyeWant; got != want {
  279. return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
  280. }
  281. // Check the name of the type of the value.
  282. // If it is an enum it will be type int32 with the name of the enum.
  283. if got, want := tye.Name(), tye.Name(); got != want {
  284. return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
  285. }
  286. // Check that value is what we expect.
  287. // If we have a pointer in val, get the value it points to.
  288. valExp := val
  289. if ty.Kind() == reflect.Ptr {
  290. valExp = reflect.ValueOf(val).Elem().Interface()
  291. }
  292. if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
  293. return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
  294. }
  295. return nil
  296. }
  297. setTo := func(test testcase) interface{} {
  298. setTo := reflect.ValueOf(test.want)
  299. if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
  300. setTo = reflect.New(typ).Elem()
  301. setTo.Set(reflect.New(setTo.Type().Elem()))
  302. setTo.Elem().Set(reflect.ValueOf(test.want))
  303. }
  304. return setTo.Interface()
  305. }
  306. for _, test := range tests {
  307. msg := &pb.DefaultsMessage{}
  308. name := test.ext.Name
  309. // Check the initial value.
  310. if err := checkVal(test, msg, test.def); err != nil {
  311. t.Errorf("%s: %v", name, err)
  312. }
  313. // Set the per-type value and check value.
  314. name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
  315. if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
  316. t.Errorf("%s: SetExtension(): %v", name, err)
  317. continue
  318. }
  319. if err := checkVal(test, msg, test.want); err != nil {
  320. t.Errorf("%s: %v", name, err)
  321. continue
  322. }
  323. // Set and check the value.
  324. name += " (cleared)"
  325. proto.ClearExtension(msg, test.ext)
  326. if err := checkVal(test, msg, test.def); err != nil {
  327. t.Errorf("%s: %v", name, err)
  328. }
  329. }
  330. }
  331. func TestNilMessage(t *testing.T) {
  332. name := "nil interface"
  333. if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil {
  334. t.Errorf("%s: got %T %v, expected to fail", name, got, got)
  335. } else if !strings.Contains(err.Error(), "extendable") {
  336. t.Errorf("%s: got error %v, expected not-extendable error", name, err)
  337. }
  338. // Regression tests: all functions of the Extension API
  339. // used to panic when passed (*M)(nil), where M is a concrete message
  340. // type. Now they handle this gracefully as a no-op or reported error.
  341. var nilMsg *pb.MyMessage
  342. desc := pb.E_Ext_More
  343. isNotExtendable := func(err error) bool {
  344. return strings.Contains(fmt.Sprint(err), "not an extendable")
  345. }
  346. if proto.HasExtension(nilMsg, desc) {
  347. t.Error("HasExtension(nil) = true")
  348. }
  349. if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
  350. t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
  351. }
  352. if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
  353. t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
  354. }
  355. if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
  356. t.Errorf("SetExtension(nil) = %q (wrong error)", err)
  357. }
  358. proto.ClearExtension(nilMsg, desc) // no-op
  359. proto.ClearAllExtensions(nilMsg) // no-op
  360. }
  361. func TestExtensionsRoundTrip(t *testing.T) {
  362. msg := &pb.MyMessage{}
  363. ext1 := &pb.Ext{
  364. Data: proto.String("hi"),
  365. }
  366. ext2 := &pb.Ext{
  367. Data: proto.String("there"),
  368. }
  369. exists := proto.HasExtension(msg, pb.E_Ext_More)
  370. if exists {
  371. t.Error("Extension More present unexpectedly")
  372. }
  373. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  374. t.Error(err)
  375. }
  376. if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
  377. t.Error(err)
  378. }
  379. e, err := proto.GetExtension(msg, pb.E_Ext_More)
  380. if err != nil {
  381. t.Error(err)
  382. }
  383. x, ok := e.(*pb.Ext)
  384. if !ok {
  385. t.Errorf("e has type %T, expected test_proto.Ext", e)
  386. } else if *x.Data != "there" {
  387. t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
  388. }
  389. proto.ClearExtension(msg, pb.E_Ext_More)
  390. if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
  391. t.Errorf("got %v, expected ErrMissingExtension", e)
  392. }
  393. if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
  394. t.Error("expected bad extension error, got nil")
  395. }
  396. if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
  397. t.Error("expected extension err")
  398. }
  399. if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
  400. t.Error("expected some sort of type mismatch error, got nil")
  401. }
  402. }
  403. func TestNilExtension(t *testing.T) {
  404. msg := &pb.MyMessage{
  405. Count: proto.Int32(1),
  406. }
  407. if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
  408. t.Fatal(err)
  409. }
  410. if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
  411. t.Error("expected SetExtension to fail due to a nil extension")
  412. } else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want {
  413. t.Errorf("expected error %v, got %v", want, err)
  414. }
  415. // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
  416. // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
  417. }
  418. func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
  419. // Add a repeated extension to the result.
  420. tests := []struct {
  421. name string
  422. ext []*pb.ComplexExtension
  423. }{
  424. {
  425. "two fields",
  426. []*pb.ComplexExtension{
  427. {First: proto.Int32(7)},
  428. {Second: proto.Int32(11)},
  429. },
  430. },
  431. {
  432. "repeated field",
  433. []*pb.ComplexExtension{
  434. {Third: []int32{1000}},
  435. {Third: []int32{2000}},
  436. },
  437. },
  438. {
  439. "two fields and repeated field",
  440. []*pb.ComplexExtension{
  441. {Third: []int32{1000}},
  442. {First: proto.Int32(9)},
  443. {Second: proto.Int32(21)},
  444. {Third: []int32{2000}},
  445. },
  446. },
  447. }
  448. for _, test := range tests {
  449. // Marshal message with a repeated extension.
  450. msg1 := new(pb.OtherMessage)
  451. err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
  452. if err != nil {
  453. t.Fatalf("[%s] Error setting extension: %v", test.name, err)
  454. }
  455. b, err := proto.Marshal(msg1)
  456. if err != nil {
  457. t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
  458. }
  459. // Unmarshal and read the merged proto.
  460. msg2 := new(pb.OtherMessage)
  461. err = proto.Unmarshal(b, msg2)
  462. if err != nil {
  463. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  464. }
  465. e, err := proto.GetExtension(msg2, pb.E_RComplex)
  466. if err != nil {
  467. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  468. }
  469. ext := e.([]*pb.ComplexExtension)
  470. if ext == nil {
  471. t.Fatalf("[%s] Invalid extension", test.name)
  472. }
  473. if len(ext) != len(test.ext) {
  474. t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
  475. }
  476. for i := range test.ext {
  477. if !proto.Equal(ext[i], test.ext[i]) {
  478. t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
  479. }
  480. }
  481. }
  482. }
  483. func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
  484. // We may see multiple instances of the same extension in the wire
  485. // format. For example, the proto compiler may encode custom options in
  486. // this way. Here, we verify that we merge the extensions together.
  487. tests := []struct {
  488. name string
  489. ext []*pb.ComplexExtension
  490. }{
  491. {
  492. "two fields",
  493. []*pb.ComplexExtension{
  494. {First: proto.Int32(7)},
  495. {Second: proto.Int32(11)},
  496. },
  497. },
  498. {
  499. "repeated field",
  500. []*pb.ComplexExtension{
  501. {Third: []int32{1000}},
  502. {Third: []int32{2000}},
  503. },
  504. },
  505. {
  506. "two fields and repeated field",
  507. []*pb.ComplexExtension{
  508. {Third: []int32{1000}},
  509. {First: proto.Int32(9)},
  510. {Second: proto.Int32(21)},
  511. {Third: []int32{2000}},
  512. },
  513. },
  514. }
  515. for _, test := range tests {
  516. var buf bytes.Buffer
  517. var want pb.ComplexExtension
  518. // Generate a serialized representation of a repeated extension
  519. // by catenating bytes together.
  520. for i, e := range test.ext {
  521. // Merge to create the wanted proto.
  522. proto.Merge(&want, e)
  523. // serialize the message
  524. msg := new(pb.OtherMessage)
  525. err := proto.SetExtension(msg, pb.E_Complex, e)
  526. if err != nil {
  527. t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
  528. }
  529. b, err := proto.Marshal(msg)
  530. if err != nil {
  531. t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
  532. }
  533. buf.Write(b)
  534. }
  535. // Unmarshal and read the merged proto.
  536. msg2 := new(pb.OtherMessage)
  537. err := proto.Unmarshal(buf.Bytes(), msg2)
  538. if err != nil {
  539. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  540. }
  541. e, err := proto.GetExtension(msg2, pb.E_Complex)
  542. if err != nil {
  543. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  544. }
  545. ext := e.(*pb.ComplexExtension)
  546. if ext == nil {
  547. t.Fatalf("[%s] Invalid extension", test.name)
  548. }
  549. if !proto.Equal(ext, &want) {
  550. t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want)
  551. }
  552. }
  553. }
  554. func TestClearAllExtensions(t *testing.T) {
  555. // unregistered extension
  556. desc := &proto.ExtensionDesc{
  557. ExtendedType: (*pb.MyMessage)(nil),
  558. ExtensionType: (*bool)(nil),
  559. Field: 101010100,
  560. Name: "emptyextension",
  561. Tag: "varint,0,opt",
  562. }
  563. m := &pb.MyMessage{}
  564. if proto.HasExtension(m, desc) {
  565. t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
  566. }
  567. if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
  568. t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
  569. }
  570. if !proto.HasExtension(m, desc) {
  571. t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
  572. }
  573. proto.ClearAllExtensions(m)
  574. if proto.HasExtension(m, desc) {
  575. t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
  576. }
  577. }
  578. func TestMarshalRace(t *testing.T) {
  579. ext := &pb.Ext{}
  580. m := &pb.MyMessage{Count: proto.Int32(4)}
  581. if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil {
  582. t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
  583. }
  584. b, err := proto.Marshal(m)
  585. if err != nil {
  586. t.Fatalf("Could not marshal message: %v", err)
  587. }
  588. if err := proto.Unmarshal(b, m); err != nil {
  589. t.Fatalf("Could not unmarshal message: %v", err)
  590. }
  591. // after Unmarshal, the extension is in undecoded form.
  592. // GetExtension will decode it lazily. Make sure this does
  593. // not race against Marshal.
  594. wg := sync.WaitGroup{}
  595. errs := make(chan error, 3)
  596. for n := 3; n > 0; n-- {
  597. wg.Add(1)
  598. go func() {
  599. defer wg.Done()
  600. _, err := proto.Marshal(m)
  601. errs <- err
  602. }()
  603. }
  604. wg.Wait()
  605. close(errs)
  606. for err = range errs {
  607. if err != nil {
  608. t.Fatal(err)
  609. }
  610. }
  611. }