extensions_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. // Go support for Protocol Buffers - Google's data interchange format
  2. //
  3. // Copyright 2014 The Go Authors. All rights reserved.
  4. // https://github.com/golang/protobuf
  5. //
  6. // Redistribution and use in source and binary forms, with or without
  7. // modification, are permitted provided that the following conditions are
  8. // met:
  9. //
  10. // * Redistributions of source code must retain the above copyright
  11. // notice, this list of conditions and the following disclaimer.
  12. // * Redistributions in binary form must reproduce the above
  13. // copyright notice, this list of conditions and the following disclaimer
  14. // in the documentation and/or other materials provided with the
  15. // distribution.
  16. // * Neither the name of Google Inc. nor the names of its
  17. // contributors may be used to endorse or promote products derived from
  18. // this software without specific prior written permission.
  19. //
  20. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  21. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  22. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  23. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  24. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  25. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  26. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  27. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  28. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  29. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  30. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  31. package proto_test
  32. import (
  33. "bytes"
  34. "fmt"
  35. "reflect"
  36. "sort"
  37. "testing"
  38. "github.com/golang/protobuf/proto"
  39. pb "github.com/golang/protobuf/proto/testdata"
  40. )
  41. func TestGetExtensionsWithMissingExtensions(t *testing.T) {
  42. msg := &pb.MyMessage{}
  43. ext1 := &pb.Ext{}
  44. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  45. t.Fatalf("Could not set ext1: %s", err)
  46. }
  47. exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
  48. pb.E_Ext_More,
  49. pb.E_Ext_Text,
  50. })
  51. if err != nil {
  52. t.Fatalf("GetExtensions() failed: %s", err)
  53. }
  54. if exts[0] != ext1 {
  55. t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
  56. }
  57. if exts[1] != nil {
  58. t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
  59. }
  60. }
  61. func TestExtensionDescsWithMissingExtensions(t *testing.T) {
  62. msg := &pb.MyMessage{Count: proto.Int32(0)}
  63. extdesc1 := pb.E_Ext_More
  64. if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
  65. t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
  66. }
  67. ext1 := &pb.Ext{}
  68. if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
  69. t.Fatalf("Could not set ext1: %s", err)
  70. }
  71. extdesc2 := &proto.ExtensionDesc{
  72. ExtendedType: (*pb.MyMessage)(nil),
  73. ExtensionType: (*bool)(nil),
  74. Field: 123456789,
  75. Name: "a.b",
  76. Tag: "varint,123456789,opt",
  77. }
  78. ext2 := proto.Bool(false)
  79. if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
  80. t.Fatalf("Could not set ext2: %s", err)
  81. }
  82. b, err := proto.Marshal(msg)
  83. if err != nil {
  84. t.Fatalf("Could not marshal msg: %v", err)
  85. }
  86. if err := proto.Unmarshal(b, msg); err != nil {
  87. t.Fatalf("Could not unmarshal into msg: %v", err)
  88. }
  89. descs, err := proto.ExtensionDescs(msg)
  90. if err != nil {
  91. t.Fatalf("proto.ExtensionDescs: got error %v", err)
  92. }
  93. sortExtDescs(descs)
  94. wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}}
  95. if !reflect.DeepEqual(descs, wantDescs) {
  96. t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
  97. }
  98. }
  99. type ExtensionDescSlice []*proto.ExtensionDesc
  100. func (s ExtensionDescSlice) Len() int { return len(s) }
  101. func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
  102. func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
  103. func sortExtDescs(s []*proto.ExtensionDesc) {
  104. sort.Sort(ExtensionDescSlice(s))
  105. }
  106. func TestGetExtensionStability(t *testing.T) {
  107. check := func(m *pb.MyMessage) bool {
  108. ext1, err := proto.GetExtension(m, pb.E_Ext_More)
  109. if err != nil {
  110. t.Fatalf("GetExtension() failed: %s", err)
  111. }
  112. ext2, err := proto.GetExtension(m, pb.E_Ext_More)
  113. if err != nil {
  114. t.Fatalf("GetExtension() failed: %s", err)
  115. }
  116. return ext1 == ext2
  117. }
  118. msg := &pb.MyMessage{Count: proto.Int32(4)}
  119. ext0 := &pb.Ext{}
  120. if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
  121. t.Fatalf("Could not set ext1: %s", ext0)
  122. }
  123. if !check(msg) {
  124. t.Errorf("GetExtension() not stable before marshaling")
  125. }
  126. bb, err := proto.Marshal(msg)
  127. if err != nil {
  128. t.Fatalf("Marshal() failed: %s", err)
  129. }
  130. msg1 := &pb.MyMessage{}
  131. err = proto.Unmarshal(bb, msg1)
  132. if err != nil {
  133. t.Fatalf("Unmarshal() failed: %s", err)
  134. }
  135. if !check(msg1) {
  136. t.Errorf("GetExtension() not stable after unmarshaling")
  137. }
  138. }
  139. func TestGetExtensionDefaults(t *testing.T) {
  140. var setFloat64 float64 = 1
  141. var setFloat32 float32 = 2
  142. var setInt32 int32 = 3
  143. var setInt64 int64 = 4
  144. var setUint32 uint32 = 5
  145. var setUint64 uint64 = 6
  146. var setBool = true
  147. var setBool2 = false
  148. var setString = "Goodnight string"
  149. var setBytes = []byte("Goodnight bytes")
  150. var setEnum = pb.DefaultsMessage_TWO
  151. type testcase struct {
  152. ext *proto.ExtensionDesc // Extension we are testing.
  153. want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
  154. def interface{} // Expected value of extension after ClearExtension().
  155. }
  156. tests := []testcase{
  157. {pb.E_NoDefaultDouble, setFloat64, nil},
  158. {pb.E_NoDefaultFloat, setFloat32, nil},
  159. {pb.E_NoDefaultInt32, setInt32, nil},
  160. {pb.E_NoDefaultInt64, setInt64, nil},
  161. {pb.E_NoDefaultUint32, setUint32, nil},
  162. {pb.E_NoDefaultUint64, setUint64, nil},
  163. {pb.E_NoDefaultSint32, setInt32, nil},
  164. {pb.E_NoDefaultSint64, setInt64, nil},
  165. {pb.E_NoDefaultFixed32, setUint32, nil},
  166. {pb.E_NoDefaultFixed64, setUint64, nil},
  167. {pb.E_NoDefaultSfixed32, setInt32, nil},
  168. {pb.E_NoDefaultSfixed64, setInt64, nil},
  169. {pb.E_NoDefaultBool, setBool, nil},
  170. {pb.E_NoDefaultBool, setBool2, nil},
  171. {pb.E_NoDefaultString, setString, nil},
  172. {pb.E_NoDefaultBytes, setBytes, nil},
  173. {pb.E_NoDefaultEnum, setEnum, nil},
  174. {pb.E_DefaultDouble, setFloat64, float64(3.1415)},
  175. {pb.E_DefaultFloat, setFloat32, float32(3.14)},
  176. {pb.E_DefaultInt32, setInt32, int32(42)},
  177. {pb.E_DefaultInt64, setInt64, int64(43)},
  178. {pb.E_DefaultUint32, setUint32, uint32(44)},
  179. {pb.E_DefaultUint64, setUint64, uint64(45)},
  180. {pb.E_DefaultSint32, setInt32, int32(46)},
  181. {pb.E_DefaultSint64, setInt64, int64(47)},
  182. {pb.E_DefaultFixed32, setUint32, uint32(48)},
  183. {pb.E_DefaultFixed64, setUint64, uint64(49)},
  184. {pb.E_DefaultSfixed32, setInt32, int32(50)},
  185. {pb.E_DefaultSfixed64, setInt64, int64(51)},
  186. {pb.E_DefaultBool, setBool, true},
  187. {pb.E_DefaultBool, setBool2, true},
  188. {pb.E_DefaultString, setString, "Hello, string"},
  189. {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
  190. {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
  191. }
  192. checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
  193. val, err := proto.GetExtension(msg, test.ext)
  194. if err != nil {
  195. if valWant != nil {
  196. return fmt.Errorf("GetExtension(): %s", err)
  197. }
  198. if want := proto.ErrMissingExtension; err != want {
  199. return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
  200. }
  201. return nil
  202. }
  203. // All proto2 extension values are either a pointer to a value or a slice of values.
  204. ty := reflect.TypeOf(val)
  205. tyWant := reflect.TypeOf(test.ext.ExtensionType)
  206. if got, want := ty, tyWant; got != want {
  207. return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
  208. }
  209. tye := ty.Elem()
  210. tyeWant := tyWant.Elem()
  211. if got, want := tye, tyeWant; got != want {
  212. return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
  213. }
  214. // Check the name of the type of the value.
  215. // If it is an enum it will be type int32 with the name of the enum.
  216. if got, want := tye.Name(), tye.Name(); got != want {
  217. return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
  218. }
  219. // Check that value is what we expect.
  220. // If we have a pointer in val, get the value it points to.
  221. valExp := val
  222. if ty.Kind() == reflect.Ptr {
  223. valExp = reflect.ValueOf(val).Elem().Interface()
  224. }
  225. if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
  226. return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
  227. }
  228. return nil
  229. }
  230. setTo := func(test testcase) interface{} {
  231. setTo := reflect.ValueOf(test.want)
  232. if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
  233. setTo = reflect.New(typ).Elem()
  234. setTo.Set(reflect.New(setTo.Type().Elem()))
  235. setTo.Elem().Set(reflect.ValueOf(test.want))
  236. }
  237. return setTo.Interface()
  238. }
  239. for _, test := range tests {
  240. msg := &pb.DefaultsMessage{}
  241. name := test.ext.Name
  242. // Check the initial value.
  243. if err := checkVal(test, msg, test.def); err != nil {
  244. t.Errorf("%s: %v", name, err)
  245. }
  246. // Set the per-type value and check value.
  247. name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
  248. if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
  249. t.Errorf("%s: SetExtension(): %v", name, err)
  250. continue
  251. }
  252. if err := checkVal(test, msg, test.want); err != nil {
  253. t.Errorf("%s: %v", name, err)
  254. continue
  255. }
  256. // Set and check the value.
  257. name += " (cleared)"
  258. proto.ClearExtension(msg, test.ext)
  259. if err := checkVal(test, msg, test.def); err != nil {
  260. t.Errorf("%s: %v", name, err)
  261. }
  262. }
  263. }
  264. func TestExtensionsRoundTrip(t *testing.T) {
  265. msg := &pb.MyMessage{}
  266. ext1 := &pb.Ext{
  267. Data: proto.String("hi"),
  268. }
  269. ext2 := &pb.Ext{
  270. Data: proto.String("there"),
  271. }
  272. exists := proto.HasExtension(msg, pb.E_Ext_More)
  273. if exists {
  274. t.Error("Extension More present unexpectedly")
  275. }
  276. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  277. t.Error(err)
  278. }
  279. if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
  280. t.Error(err)
  281. }
  282. e, err := proto.GetExtension(msg, pb.E_Ext_More)
  283. if err != nil {
  284. t.Error(err)
  285. }
  286. x, ok := e.(*pb.Ext)
  287. if !ok {
  288. t.Errorf("e has type %T, expected testdata.Ext", e)
  289. } else if *x.Data != "there" {
  290. t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
  291. }
  292. proto.ClearExtension(msg, pb.E_Ext_More)
  293. if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
  294. t.Errorf("got %v, expected ErrMissingExtension", e)
  295. }
  296. if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
  297. t.Error("expected bad extension error, got nil")
  298. }
  299. if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
  300. t.Error("expected extension err")
  301. }
  302. if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
  303. t.Error("expected some sort of type mismatch error, got nil")
  304. }
  305. }
  306. func TestNilExtension(t *testing.T) {
  307. msg := &pb.MyMessage{
  308. Count: proto.Int32(1),
  309. }
  310. if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
  311. t.Fatal(err)
  312. }
  313. if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
  314. t.Error("expected SetExtension to fail due to a nil extension")
  315. } else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want {
  316. t.Errorf("expected error %v, got %v", want, err)
  317. }
  318. // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
  319. // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
  320. }
  321. func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
  322. // Add a repeated extension to the result.
  323. tests := []struct {
  324. name string
  325. ext []*pb.ComplexExtension
  326. }{
  327. {
  328. "two fields",
  329. []*pb.ComplexExtension{
  330. {First: proto.Int32(7)},
  331. {Second: proto.Int32(11)},
  332. },
  333. },
  334. {
  335. "repeated field",
  336. []*pb.ComplexExtension{
  337. {Third: []int32{1000}},
  338. {Third: []int32{2000}},
  339. },
  340. },
  341. {
  342. "two fields and repeated field",
  343. []*pb.ComplexExtension{
  344. {Third: []int32{1000}},
  345. {First: proto.Int32(9)},
  346. {Second: proto.Int32(21)},
  347. {Third: []int32{2000}},
  348. },
  349. },
  350. }
  351. for _, test := range tests {
  352. // Marshal message with a repeated extension.
  353. msg1 := new(pb.OtherMessage)
  354. err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
  355. if err != nil {
  356. t.Fatalf("[%s] Error setting extension: %v", test.name, err)
  357. }
  358. b, err := proto.Marshal(msg1)
  359. if err != nil {
  360. t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
  361. }
  362. // Unmarshal and read the merged proto.
  363. msg2 := new(pb.OtherMessage)
  364. err = proto.Unmarshal(b, msg2)
  365. if err != nil {
  366. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  367. }
  368. e, err := proto.GetExtension(msg2, pb.E_RComplex)
  369. if err != nil {
  370. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  371. }
  372. ext := e.([]*pb.ComplexExtension)
  373. if ext == nil {
  374. t.Fatalf("[%s] Invalid extension", test.name)
  375. }
  376. if !reflect.DeepEqual(ext, test.ext) {
  377. t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext)
  378. }
  379. }
  380. }
  381. func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
  382. // We may see multiple instances of the same extension in the wire
  383. // format. For example, the proto compiler may encode custom options in
  384. // this way. Here, we verify that we merge the extensions together.
  385. tests := []struct {
  386. name string
  387. ext []*pb.ComplexExtension
  388. }{
  389. {
  390. "two fields",
  391. []*pb.ComplexExtension{
  392. {First: proto.Int32(7)},
  393. {Second: proto.Int32(11)},
  394. },
  395. },
  396. {
  397. "repeated field",
  398. []*pb.ComplexExtension{
  399. {Third: []int32{1000}},
  400. {Third: []int32{2000}},
  401. },
  402. },
  403. {
  404. "two fields and repeated field",
  405. []*pb.ComplexExtension{
  406. {Third: []int32{1000}},
  407. {First: proto.Int32(9)},
  408. {Second: proto.Int32(21)},
  409. {Third: []int32{2000}},
  410. },
  411. },
  412. }
  413. for _, test := range tests {
  414. var buf bytes.Buffer
  415. var want pb.ComplexExtension
  416. // Generate a serialized representation of a repeated extension
  417. // by catenating bytes together.
  418. for i, e := range test.ext {
  419. // Merge to create the wanted proto.
  420. proto.Merge(&want, e)
  421. // serialize the message
  422. msg := new(pb.OtherMessage)
  423. err := proto.SetExtension(msg, pb.E_Complex, e)
  424. if err != nil {
  425. t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
  426. }
  427. b, err := proto.Marshal(msg)
  428. if err != nil {
  429. t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
  430. }
  431. buf.Write(b)
  432. }
  433. // Unmarshal and read the merged proto.
  434. msg2 := new(pb.OtherMessage)
  435. err := proto.Unmarshal(buf.Bytes(), msg2)
  436. if err != nil {
  437. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  438. }
  439. e, err := proto.GetExtension(msg2, pb.E_Complex)
  440. if err != nil {
  441. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  442. }
  443. ext := e.(*pb.ComplexExtension)
  444. if ext == nil {
  445. t.Fatalf("[%s] Invalid extension", test.name)
  446. }
  447. if !reflect.DeepEqual(*ext, want) {
  448. t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want)
  449. }
  450. }
  451. }
  452. func TestClearAllExtensions(t *testing.T) {
  453. // unregistered extension
  454. desc := &proto.ExtensionDesc{
  455. ExtendedType: (*pb.MyMessage)(nil),
  456. ExtensionType: (*bool)(nil),
  457. Field: 101010100,
  458. Name: "emptyextension",
  459. Tag: "varint,0,opt",
  460. }
  461. m := &pb.MyMessage{}
  462. if proto.HasExtension(m, desc) {
  463. t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
  464. }
  465. if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
  466. t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
  467. }
  468. if !proto.HasExtension(m, desc) {
  469. t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
  470. }
  471. proto.ClearAllExtensions(m)
  472. if proto.HasExtension(m, desc) {
  473. t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
  474. }
  475. }