extensions_test.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. // Go support for Protocol Buffers - Google's data interchange format
  2. //
  3. // Copyright 2014 The Go Authors. All rights reserved.
  4. // https://github.com/golang/protobuf
  5. //
  6. // Redistribution and use in source and binary forms, with or without
  7. // modification, are permitted provided that the following conditions are
  8. // met:
  9. //
  10. // * Redistributions of source code must retain the above copyright
  11. // notice, this list of conditions and the following disclaimer.
  12. // * Redistributions in binary form must reproduce the above
  13. // copyright notice, this list of conditions and the following disclaimer
  14. // in the documentation and/or other materials provided with the
  15. // distribution.
  16. // * Neither the name of Google Inc. nor the names of its
  17. // contributors may be used to endorse or promote products derived from
  18. // this software without specific prior written permission.
  19. //
  20. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  21. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  22. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  23. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  24. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  25. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  26. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  27. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  28. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  29. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  30. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  31. package proto_test
  32. import (
  33. "bytes"
  34. "fmt"
  35. "io"
  36. "reflect"
  37. "sort"
  38. "strings"
  39. "sync"
  40. "testing"
  41. "github.com/golang/protobuf/proto"
  42. pb "github.com/golang/protobuf/proto/test_proto"
  43. )
  44. func TestGetExtensionsWithMissingExtensions(t *testing.T) {
  45. msg := &pb.MyMessage{}
  46. ext1 := &pb.Ext{}
  47. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  48. t.Fatalf("Could not set ext1: %s", err)
  49. }
  50. exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
  51. pb.E_Ext_More,
  52. pb.E_Ext_Text,
  53. })
  54. if err != nil {
  55. t.Fatalf("GetExtensions() failed: %s", err)
  56. }
  57. if exts[0] != ext1 {
  58. t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
  59. }
  60. if exts[1] != nil {
  61. t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
  62. }
  63. }
  64. func TestGetExtensionWithEmptyBuffer(t *testing.T) {
  65. // Make sure that GetExtension returns an error if its
  66. // undecoded buffer is empty.
  67. msg := &pb.MyMessage{}
  68. proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
  69. _, err := proto.GetExtension(msg, pb.E_Ext_More)
  70. if want := io.ErrUnexpectedEOF; err != want {
  71. t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
  72. }
  73. }
  74. func TestGetExtensionForIncompleteDesc(t *testing.T) {
  75. msg := &pb.MyMessage{Count: proto.Int32(0)}
  76. extdesc1 := &proto.ExtensionDesc{
  77. ExtendedType: (*pb.MyMessage)(nil),
  78. ExtensionType: (*bool)(nil),
  79. Field: 123456789,
  80. Name: "a.b",
  81. Tag: "varint,123456789,opt",
  82. }
  83. ext1 := proto.Bool(true)
  84. if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
  85. t.Fatalf("Could not set ext1: %s", err)
  86. }
  87. extdesc2 := &proto.ExtensionDesc{
  88. ExtendedType: (*pb.MyMessage)(nil),
  89. ExtensionType: ([]byte)(nil),
  90. Field: 123456790,
  91. Name: "a.c",
  92. Tag: "bytes,123456790,opt",
  93. }
  94. ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
  95. if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
  96. t.Fatalf("Could not set ext2: %s", err)
  97. }
  98. extdesc3 := &proto.ExtensionDesc{
  99. ExtendedType: (*pb.MyMessage)(nil),
  100. ExtensionType: (*pb.Ext)(nil),
  101. Field: 123456791,
  102. Name: "a.d",
  103. Tag: "bytes,123456791,opt",
  104. }
  105. ext3 := &pb.Ext{Data: proto.String("foo")}
  106. if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
  107. t.Fatalf("Could not set ext3: %s", err)
  108. }
  109. b, err := proto.Marshal(msg)
  110. if err != nil {
  111. t.Fatalf("Could not marshal msg: %v", err)
  112. }
  113. if err := proto.Unmarshal(b, msg); err != nil {
  114. t.Fatalf("Could not unmarshal into msg: %v", err)
  115. }
  116. var expected proto.Buffer
  117. if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
  118. t.Fatalf("failed to compute expected prefix for ext1: %s", err)
  119. }
  120. if err := expected.EncodeVarint(1 /* bool true */); err != nil {
  121. t.Fatalf("failed to compute expected value for ext1: %s", err)
  122. }
  123. if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
  124. t.Fatalf("Failed to get raw value for ext1: %s", err)
  125. } else if !reflect.DeepEqual(b, expected.Bytes()) {
  126. t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
  127. }
  128. expected = proto.Buffer{} // reset
  129. if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
  130. t.Fatalf("failed to compute expected prefix for ext2: %s", err)
  131. }
  132. if err := expected.EncodeRawBytes(ext2); err != nil {
  133. t.Fatalf("failed to compute expected value for ext2: %s", err)
  134. }
  135. if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
  136. t.Fatalf("Failed to get raw value for ext2: %s", err)
  137. } else if !reflect.DeepEqual(b, expected.Bytes()) {
  138. t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
  139. }
  140. expected = proto.Buffer{} // reset
  141. if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
  142. t.Fatalf("failed to compute expected prefix for ext3: %s", err)
  143. }
  144. if b, err := proto.Marshal(ext3); err != nil {
  145. t.Fatalf("failed to compute expected value for ext3: %s", err)
  146. } else if err := expected.EncodeRawBytes(b); err != nil {
  147. t.Fatalf("failed to compute expected value for ext3: %s", err)
  148. }
  149. if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
  150. t.Fatalf("Failed to get raw value for ext3: %s", err)
  151. } else if !reflect.DeepEqual(b, expected.Bytes()) {
  152. t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
  153. }
  154. }
  155. func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
  156. msg := &pb.MyMessage{Count: proto.Int32(0)}
  157. extdesc1 := pb.E_Ext_More
  158. if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
  159. t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
  160. }
  161. ext1 := &pb.Ext{}
  162. if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
  163. t.Fatalf("Could not set ext1: %s", err)
  164. }
  165. extdesc2 := &proto.ExtensionDesc{
  166. ExtendedType: (*pb.MyMessage)(nil),
  167. ExtensionType: (*bool)(nil),
  168. Field: 123456789,
  169. Name: "a.b",
  170. Tag: "varint,123456789,opt",
  171. }
  172. ext2 := proto.Bool(false)
  173. if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
  174. t.Fatalf("Could not set ext2: %s", err)
  175. }
  176. b, err := proto.Marshal(msg)
  177. if err != nil {
  178. t.Fatalf("Could not marshal msg: %v", err)
  179. }
  180. if err := proto.Unmarshal(b, msg); err != nil {
  181. t.Fatalf("Could not unmarshal into msg: %v", err)
  182. }
  183. descs, err := proto.ExtensionDescs(msg)
  184. if err != nil {
  185. t.Fatalf("proto.ExtensionDescs: got error %v", err)
  186. }
  187. sortExtDescs(descs)
  188. wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
  189. if !reflect.DeepEqual(descs, wantDescs) {
  190. t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
  191. }
  192. }
  193. type ExtensionDescSlice []*proto.ExtensionDesc
  194. func (s ExtensionDescSlice) Len() int { return len(s) }
  195. func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
  196. func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
  197. func sortExtDescs(s []*proto.ExtensionDesc) {
  198. sort.Sort(ExtensionDescSlice(s))
  199. }
  200. func TestGetExtensionStability(t *testing.T) {
  201. check := func(m *pb.MyMessage) bool {
  202. ext1, err := proto.GetExtension(m, pb.E_Ext_More)
  203. if err != nil {
  204. t.Fatalf("GetExtension() failed: %s", err)
  205. }
  206. ext2, err := proto.GetExtension(m, pb.E_Ext_More)
  207. if err != nil {
  208. t.Fatalf("GetExtension() failed: %s", err)
  209. }
  210. return ext1 == ext2
  211. }
  212. msg := &pb.MyMessage{Count: proto.Int32(4)}
  213. ext0 := &pb.Ext{}
  214. if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
  215. t.Fatalf("Could not set ext1: %s", ext0)
  216. }
  217. if !check(msg) {
  218. t.Errorf("GetExtension() not stable before marshaling")
  219. }
  220. bb, err := proto.Marshal(msg)
  221. if err != nil {
  222. t.Fatalf("Marshal() failed: %s", err)
  223. }
  224. msg1 := &pb.MyMessage{}
  225. err = proto.Unmarshal(bb, msg1)
  226. if err != nil {
  227. t.Fatalf("Unmarshal() failed: %s", err)
  228. }
  229. if !check(msg1) {
  230. t.Errorf("GetExtension() not stable after unmarshaling")
  231. }
  232. }
  233. func TestGetExtensionDefaults(t *testing.T) {
  234. var setFloat64 float64 = 1
  235. var setFloat32 float32 = 2
  236. var setInt32 int32 = 3
  237. var setInt64 int64 = 4
  238. var setUint32 uint32 = 5
  239. var setUint64 uint64 = 6
  240. var setBool = true
  241. var setBool2 = false
  242. var setString = "Goodnight string"
  243. var setBytes = []byte("Goodnight bytes")
  244. var setEnum = pb.DefaultsMessage_TWO
  245. type testcase struct {
  246. ext *proto.ExtensionDesc // Extension we are testing.
  247. want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
  248. def interface{} // Expected value of extension after ClearExtension().
  249. }
  250. tests := []testcase{
  251. {pb.E_NoDefaultDouble, setFloat64, nil},
  252. {pb.E_NoDefaultFloat, setFloat32, nil},
  253. {pb.E_NoDefaultInt32, setInt32, nil},
  254. {pb.E_NoDefaultInt64, setInt64, nil},
  255. {pb.E_NoDefaultUint32, setUint32, nil},
  256. {pb.E_NoDefaultUint64, setUint64, nil},
  257. {pb.E_NoDefaultSint32, setInt32, nil},
  258. {pb.E_NoDefaultSint64, setInt64, nil},
  259. {pb.E_NoDefaultFixed32, setUint32, nil},
  260. {pb.E_NoDefaultFixed64, setUint64, nil},
  261. {pb.E_NoDefaultSfixed32, setInt32, nil},
  262. {pb.E_NoDefaultSfixed64, setInt64, nil},
  263. {pb.E_NoDefaultBool, setBool, nil},
  264. {pb.E_NoDefaultBool, setBool2, nil},
  265. {pb.E_NoDefaultString, setString, nil},
  266. {pb.E_NoDefaultBytes, setBytes, nil},
  267. {pb.E_NoDefaultEnum, setEnum, nil},
  268. {pb.E_DefaultDouble, setFloat64, float64(3.1415)},
  269. {pb.E_DefaultFloat, setFloat32, float32(3.14)},
  270. {pb.E_DefaultInt32, setInt32, int32(42)},
  271. {pb.E_DefaultInt64, setInt64, int64(43)},
  272. {pb.E_DefaultUint32, setUint32, uint32(44)},
  273. {pb.E_DefaultUint64, setUint64, uint64(45)},
  274. {pb.E_DefaultSint32, setInt32, int32(46)},
  275. {pb.E_DefaultSint64, setInt64, int64(47)},
  276. {pb.E_DefaultFixed32, setUint32, uint32(48)},
  277. {pb.E_DefaultFixed64, setUint64, uint64(49)},
  278. {pb.E_DefaultSfixed32, setInt32, int32(50)},
  279. {pb.E_DefaultSfixed64, setInt64, int64(51)},
  280. {pb.E_DefaultBool, setBool, true},
  281. {pb.E_DefaultBool, setBool2, true},
  282. {pb.E_DefaultString, setString, "Hello, string,def=foo"},
  283. {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
  284. {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
  285. }
  286. checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
  287. val, err := proto.GetExtension(msg, test.ext)
  288. if err != nil {
  289. if valWant != nil {
  290. return fmt.Errorf("GetExtension(): %s", err)
  291. }
  292. if want := proto.ErrMissingExtension; err != want {
  293. return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
  294. }
  295. return nil
  296. }
  297. // All proto2 extension values are either a pointer to a value or a slice of values.
  298. ty := reflect.TypeOf(val)
  299. tyWant := reflect.TypeOf(test.ext.ExtensionType)
  300. if got, want := ty, tyWant; got != want {
  301. return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
  302. }
  303. tye := ty.Elem()
  304. tyeWant := tyWant.Elem()
  305. if got, want := tye, tyeWant; got != want {
  306. return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
  307. }
  308. // Check the name of the type of the value.
  309. // If it is an enum it will be type int32 with the name of the enum.
  310. if got, want := tye.Name(), tye.Name(); got != want {
  311. return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
  312. }
  313. // Check that value is what we expect.
  314. // If we have a pointer in val, get the value it points to.
  315. valExp := val
  316. if ty.Kind() == reflect.Ptr {
  317. valExp = reflect.ValueOf(val).Elem().Interface()
  318. }
  319. if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
  320. return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
  321. }
  322. return nil
  323. }
  324. setTo := func(test testcase) interface{} {
  325. setTo := reflect.ValueOf(test.want)
  326. if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
  327. setTo = reflect.New(typ).Elem()
  328. setTo.Set(reflect.New(setTo.Type().Elem()))
  329. setTo.Elem().Set(reflect.ValueOf(test.want))
  330. }
  331. return setTo.Interface()
  332. }
  333. for _, test := range tests {
  334. msg := &pb.DefaultsMessage{}
  335. name := test.ext.Name
  336. // Check the initial value.
  337. if err := checkVal(test, msg, test.def); err != nil {
  338. t.Errorf("%s: %v", name, err)
  339. }
  340. // Set the per-type value and check value.
  341. name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
  342. if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
  343. t.Errorf("%s: SetExtension(): %v", name, err)
  344. continue
  345. }
  346. if err := checkVal(test, msg, test.want); err != nil {
  347. t.Errorf("%s: %v", name, err)
  348. continue
  349. }
  350. // Set and check the value.
  351. name += " (cleared)"
  352. proto.ClearExtension(msg, test.ext)
  353. if err := checkVal(test, msg, test.def); err != nil {
  354. t.Errorf("%s: %v", name, err)
  355. }
  356. }
  357. }
  358. func TestNilMessage(t *testing.T) {
  359. name := "nil interface"
  360. if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil {
  361. t.Errorf("%s: got %T %v, expected to fail", name, got, got)
  362. } else if !strings.Contains(err.Error(), "extendable") {
  363. t.Errorf("%s: got error %v, expected not-extendable error", name, err)
  364. }
  365. // Regression tests: all functions of the Extension API
  366. // used to panic when passed (*M)(nil), where M is a concrete message
  367. // type. Now they handle this gracefully as a no-op or reported error.
  368. var nilMsg *pb.MyMessage
  369. desc := pb.E_Ext_More
  370. isNotExtendable := func(err error) bool {
  371. return strings.Contains(fmt.Sprint(err), "not extendable")
  372. }
  373. if proto.HasExtension(nilMsg, desc) {
  374. t.Error("HasExtension(nil) = true")
  375. }
  376. if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
  377. t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
  378. }
  379. if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
  380. t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
  381. }
  382. if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
  383. t.Errorf("SetExtension(nil) = %q (wrong error)", err)
  384. }
  385. proto.ClearExtension(nilMsg, desc) // no-op
  386. proto.ClearAllExtensions(nilMsg) // no-op
  387. }
  388. func TestExtensionsRoundTrip(t *testing.T) {
  389. msg := &pb.MyMessage{}
  390. ext1 := &pb.Ext{
  391. Data: proto.String("hi"),
  392. }
  393. ext2 := &pb.Ext{
  394. Data: proto.String("there"),
  395. }
  396. exists := proto.HasExtension(msg, pb.E_Ext_More)
  397. if exists {
  398. t.Error("Extension More present unexpectedly")
  399. }
  400. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  401. t.Error(err)
  402. }
  403. if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
  404. t.Error(err)
  405. }
  406. e, err := proto.GetExtension(msg, pb.E_Ext_More)
  407. if err != nil {
  408. t.Error(err)
  409. }
  410. x, ok := e.(*pb.Ext)
  411. if !ok {
  412. t.Errorf("e has type %T, expected test_proto.Ext", e)
  413. } else if *x.Data != "there" {
  414. t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
  415. }
  416. proto.ClearExtension(msg, pb.E_Ext_More)
  417. if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
  418. t.Errorf("got %v, expected ErrMissingExtension", e)
  419. }
  420. if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
  421. t.Error("expected bad extension error, got nil")
  422. }
  423. if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
  424. t.Error("expected extension err")
  425. }
  426. if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
  427. t.Error("expected some sort of type mismatch error, got nil")
  428. }
  429. }
  430. func TestNilExtension(t *testing.T) {
  431. msg := &pb.MyMessage{
  432. Count: proto.Int32(1),
  433. }
  434. if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
  435. t.Fatal(err)
  436. }
  437. if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
  438. t.Error("expected SetExtension to fail due to a nil extension")
  439. } else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want {
  440. t.Errorf("expected error %v, got %v", want, err)
  441. }
  442. // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
  443. // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
  444. }
  445. func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
  446. // Add a repeated extension to the result.
  447. tests := []struct {
  448. name string
  449. ext []*pb.ComplexExtension
  450. }{
  451. {
  452. "two fields",
  453. []*pb.ComplexExtension{
  454. {First: proto.Int32(7)},
  455. {Second: proto.Int32(11)},
  456. },
  457. },
  458. {
  459. "repeated field",
  460. []*pb.ComplexExtension{
  461. {Third: []int32{1000}},
  462. {Third: []int32{2000}},
  463. },
  464. },
  465. {
  466. "two fields and repeated field",
  467. []*pb.ComplexExtension{
  468. {Third: []int32{1000}},
  469. {First: proto.Int32(9)},
  470. {Second: proto.Int32(21)},
  471. {Third: []int32{2000}},
  472. },
  473. },
  474. }
  475. for _, test := range tests {
  476. // Marshal message with a repeated extension.
  477. msg1 := new(pb.OtherMessage)
  478. err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
  479. if err != nil {
  480. t.Fatalf("[%s] Error setting extension: %v", test.name, err)
  481. }
  482. b, err := proto.Marshal(msg1)
  483. if err != nil {
  484. t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
  485. }
  486. // Unmarshal and read the merged proto.
  487. msg2 := new(pb.OtherMessage)
  488. err = proto.Unmarshal(b, msg2)
  489. if err != nil {
  490. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  491. }
  492. e, err := proto.GetExtension(msg2, pb.E_RComplex)
  493. if err != nil {
  494. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  495. }
  496. ext := e.([]*pb.ComplexExtension)
  497. if ext == nil {
  498. t.Fatalf("[%s] Invalid extension", test.name)
  499. }
  500. if len(ext) != len(test.ext) {
  501. t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
  502. }
  503. for i := range test.ext {
  504. if !proto.Equal(ext[i], test.ext[i]) {
  505. t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
  506. }
  507. }
  508. }
  509. }
  510. func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
  511. // We may see multiple instances of the same extension in the wire
  512. // format. For example, the proto compiler may encode custom options in
  513. // this way. Here, we verify that we merge the extensions together.
  514. tests := []struct {
  515. name string
  516. ext []*pb.ComplexExtension
  517. }{
  518. {
  519. "two fields",
  520. []*pb.ComplexExtension{
  521. {First: proto.Int32(7)},
  522. {Second: proto.Int32(11)},
  523. },
  524. },
  525. {
  526. "repeated field",
  527. []*pb.ComplexExtension{
  528. {Third: []int32{1000}},
  529. {Third: []int32{2000}},
  530. },
  531. },
  532. {
  533. "two fields and repeated field",
  534. []*pb.ComplexExtension{
  535. {Third: []int32{1000}},
  536. {First: proto.Int32(9)},
  537. {Second: proto.Int32(21)},
  538. {Third: []int32{2000}},
  539. },
  540. },
  541. }
  542. for _, test := range tests {
  543. var buf bytes.Buffer
  544. var want pb.ComplexExtension
  545. // Generate a serialized representation of a repeated extension
  546. // by catenating bytes together.
  547. for i, e := range test.ext {
  548. // Merge to create the wanted proto.
  549. proto.Merge(&want, e)
  550. // serialize the message
  551. msg := new(pb.OtherMessage)
  552. err := proto.SetExtension(msg, pb.E_Complex, e)
  553. if err != nil {
  554. t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
  555. }
  556. b, err := proto.Marshal(msg)
  557. if err != nil {
  558. t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
  559. }
  560. buf.Write(b)
  561. }
  562. // Unmarshal and read the merged proto.
  563. msg2 := new(pb.OtherMessage)
  564. err := proto.Unmarshal(buf.Bytes(), msg2)
  565. if err != nil {
  566. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  567. }
  568. e, err := proto.GetExtension(msg2, pb.E_Complex)
  569. if err != nil {
  570. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  571. }
  572. ext := e.(*pb.ComplexExtension)
  573. if ext == nil {
  574. t.Fatalf("[%s] Invalid extension", test.name)
  575. }
  576. if !proto.Equal(ext, &want) {
  577. t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want)
  578. }
  579. }
  580. }
  581. func TestClearAllExtensions(t *testing.T) {
  582. // unregistered extension
  583. desc := &proto.ExtensionDesc{
  584. ExtendedType: (*pb.MyMessage)(nil),
  585. ExtensionType: (*bool)(nil),
  586. Field: 101010100,
  587. Name: "emptyextension",
  588. Tag: "varint,0,opt",
  589. }
  590. m := &pb.MyMessage{}
  591. if proto.HasExtension(m, desc) {
  592. t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
  593. }
  594. if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
  595. t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
  596. }
  597. if !proto.HasExtension(m, desc) {
  598. t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
  599. }
  600. proto.ClearAllExtensions(m)
  601. if proto.HasExtension(m, desc) {
  602. t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
  603. }
  604. }
  605. func TestMarshalRace(t *testing.T) {
  606. ext := &pb.Ext{}
  607. m := &pb.MyMessage{Count: proto.Int32(4)}
  608. if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil {
  609. t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
  610. }
  611. b, err := proto.Marshal(m)
  612. if err != nil {
  613. t.Fatalf("Could not marshal message: %v", err)
  614. }
  615. if err := proto.Unmarshal(b, m); err != nil {
  616. t.Fatalf("Could not unmarshal message: %v", err)
  617. }
  618. // after Unmarshal, the extension is in undecoded form.
  619. // GetExtension will decode it lazily. Make sure this does
  620. // not race against Marshal.
  621. wg := sync.WaitGroup{}
  622. errs := make(chan error, 3)
  623. for n := 3; n > 0; n-- {
  624. wg.Add(1)
  625. go func() {
  626. defer wg.Done()
  627. _, err := proto.Marshal(m)
  628. errs <- err
  629. }()
  630. }
  631. wg.Wait()
  632. close(errs)
  633. for err = range errs {
  634. if err != nil {
  635. t.Fatal(err)
  636. }
  637. }
  638. }