extensions_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. // Go support for Protocol Buffers - Google's data interchange format
  2. //
  3. // Copyright 2014 The Go Authors. All rights reserved.
  4. // https://github.com/golang/protobuf
  5. //
  6. // Redistribution and use in source and binary forms, with or without
  7. // modification, are permitted provided that the following conditions are
  8. // met:
  9. //
  10. // * Redistributions of source code must retain the above copyright
  11. // notice, this list of conditions and the following disclaimer.
  12. // * Redistributions in binary form must reproduce the above
  13. // copyright notice, this list of conditions and the following disclaimer
  14. // in the documentation and/or other materials provided with the
  15. // distribution.
  16. // * Neither the name of Google Inc. nor the names of its
  17. // contributors may be used to endorse or promote products derived from
  18. // this software without specific prior written permission.
  19. //
  20. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  21. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  22. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  23. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  24. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  25. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  26. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  27. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  28. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  29. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  30. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  31. package proto_test
  32. import (
  33. "bytes"
  34. "fmt"
  35. "reflect"
  36. "testing"
  37. "github.com/golang/protobuf/proto"
  38. pb "github.com/golang/protobuf/proto/testdata"
  39. )
  40. func TestGetExtensionsWithMissingExtensions(t *testing.T) {
  41. msg := &pb.MyMessage{}
  42. ext1 := &pb.Ext{}
  43. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  44. t.Fatalf("Could not set ext1: %s", ext1)
  45. }
  46. exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
  47. pb.E_Ext_More,
  48. pb.E_Ext_Text,
  49. })
  50. if err != nil {
  51. t.Fatalf("GetExtensions() failed: %s", err)
  52. }
  53. if exts[0] != ext1 {
  54. t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
  55. }
  56. if exts[1] != nil {
  57. t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
  58. }
  59. }
  60. func TestGetExtensionStability(t *testing.T) {
  61. check := func(m *pb.MyMessage) bool {
  62. ext1, err := proto.GetExtension(m, pb.E_Ext_More)
  63. if err != nil {
  64. t.Fatalf("GetExtension() failed: %s", err)
  65. }
  66. ext2, err := proto.GetExtension(m, pb.E_Ext_More)
  67. if err != nil {
  68. t.Fatalf("GetExtension() failed: %s", err)
  69. }
  70. return ext1 == ext2
  71. }
  72. msg := &pb.MyMessage{Count: proto.Int32(4)}
  73. ext0 := &pb.Ext{}
  74. if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
  75. t.Fatalf("Could not set ext1: %s", ext0)
  76. }
  77. if !check(msg) {
  78. t.Errorf("GetExtension() not stable before marshaling")
  79. }
  80. bb, err := proto.Marshal(msg)
  81. if err != nil {
  82. t.Fatalf("Marshal() failed: %s", err)
  83. }
  84. msg1 := &pb.MyMessage{}
  85. err = proto.Unmarshal(bb, msg1)
  86. if err != nil {
  87. t.Fatalf("Unmarshal() failed: %s", err)
  88. }
  89. if !check(msg1) {
  90. t.Errorf("GetExtension() not stable after unmarshaling")
  91. }
  92. }
  93. func TestGetExtensionDefaults(t *testing.T) {
  94. var setFloat64 float64 = 1
  95. var setFloat32 float32 = 2
  96. var setInt32 int32 = 3
  97. var setInt64 int64 = 4
  98. var setUint32 uint32 = 5
  99. var setUint64 uint64 = 6
  100. var setBool = true
  101. var setBool2 = false
  102. var setString = "Goodnight string"
  103. var setBytes = []byte("Goodnight bytes")
  104. var setEnum = pb.DefaultsMessage_TWO
  105. type testcase struct {
  106. ext *proto.ExtensionDesc // Extension we are testing.
  107. want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
  108. def interface{} // Expected value of extension after ClearExtension().
  109. }
  110. tests := []testcase{
  111. {pb.E_NoDefaultDouble, setFloat64, nil},
  112. {pb.E_NoDefaultFloat, setFloat32, nil},
  113. {pb.E_NoDefaultInt32, setInt32, nil},
  114. {pb.E_NoDefaultInt64, setInt64, nil},
  115. {pb.E_NoDefaultUint32, setUint32, nil},
  116. {pb.E_NoDefaultUint64, setUint64, nil},
  117. {pb.E_NoDefaultSint32, setInt32, nil},
  118. {pb.E_NoDefaultSint64, setInt64, nil},
  119. {pb.E_NoDefaultFixed32, setUint32, nil},
  120. {pb.E_NoDefaultFixed64, setUint64, nil},
  121. {pb.E_NoDefaultSfixed32, setInt32, nil},
  122. {pb.E_NoDefaultSfixed64, setInt64, nil},
  123. {pb.E_NoDefaultBool, setBool, nil},
  124. {pb.E_NoDefaultBool, setBool2, nil},
  125. {pb.E_NoDefaultString, setString, nil},
  126. {pb.E_NoDefaultBytes, setBytes, nil},
  127. {pb.E_NoDefaultEnum, setEnum, nil},
  128. {pb.E_DefaultDouble, setFloat64, float64(3.1415)},
  129. {pb.E_DefaultFloat, setFloat32, float32(3.14)},
  130. {pb.E_DefaultInt32, setInt32, int32(42)},
  131. {pb.E_DefaultInt64, setInt64, int64(43)},
  132. {pb.E_DefaultUint32, setUint32, uint32(44)},
  133. {pb.E_DefaultUint64, setUint64, uint64(45)},
  134. {pb.E_DefaultSint32, setInt32, int32(46)},
  135. {pb.E_DefaultSint64, setInt64, int64(47)},
  136. {pb.E_DefaultFixed32, setUint32, uint32(48)},
  137. {pb.E_DefaultFixed64, setUint64, uint64(49)},
  138. {pb.E_DefaultSfixed32, setInt32, int32(50)},
  139. {pb.E_DefaultSfixed64, setInt64, int64(51)},
  140. {pb.E_DefaultBool, setBool, true},
  141. {pb.E_DefaultBool, setBool2, true},
  142. {pb.E_DefaultString, setString, "Hello, string"},
  143. {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
  144. {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
  145. }
  146. checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
  147. val, err := proto.GetExtension(msg, test.ext)
  148. if err != nil {
  149. if valWant != nil {
  150. return fmt.Errorf("GetExtension(): %s", err)
  151. }
  152. if want := proto.ErrMissingExtension; err != want {
  153. return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
  154. }
  155. return nil
  156. }
  157. // All proto2 extension values are either a pointer to a value or a slice of values.
  158. ty := reflect.TypeOf(val)
  159. tyWant := reflect.TypeOf(test.ext.ExtensionType)
  160. if got, want := ty, tyWant; got != want {
  161. return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
  162. }
  163. tye := ty.Elem()
  164. tyeWant := tyWant.Elem()
  165. if got, want := tye, tyeWant; got != want {
  166. return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
  167. }
  168. // Check the name of the type of the value.
  169. // If it is an enum it will be type int32 with the name of the enum.
  170. if got, want := tye.Name(), tye.Name(); got != want {
  171. return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
  172. }
  173. // Check that value is what we expect.
  174. // If we have a pointer in val, get the value it points to.
  175. valExp := val
  176. if ty.Kind() == reflect.Ptr {
  177. valExp = reflect.ValueOf(val).Elem().Interface()
  178. }
  179. if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
  180. return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
  181. }
  182. return nil
  183. }
  184. setTo := func(test testcase) interface{} {
  185. setTo := reflect.ValueOf(test.want)
  186. if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
  187. setTo = reflect.New(typ).Elem()
  188. setTo.Set(reflect.New(setTo.Type().Elem()))
  189. setTo.Elem().Set(reflect.ValueOf(test.want))
  190. }
  191. return setTo.Interface()
  192. }
  193. for _, test := range tests {
  194. msg := &pb.DefaultsMessage{}
  195. name := test.ext.Name
  196. // Check the initial value.
  197. if err := checkVal(test, msg, test.def); err != nil {
  198. t.Errorf("%s: %v", name, err)
  199. }
  200. // Set the per-type value and check value.
  201. name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
  202. if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
  203. t.Errorf("%s: SetExtension(): %v", name, err)
  204. continue
  205. }
  206. if err := checkVal(test, msg, test.want); err != nil {
  207. t.Errorf("%s: %v", name, err)
  208. continue
  209. }
  210. // Set and check the value.
  211. name += " (cleared)"
  212. proto.ClearExtension(msg, test.ext)
  213. if err := checkVal(test, msg, test.def); err != nil {
  214. t.Errorf("%s: %v", name, err)
  215. }
  216. }
  217. }
  218. func TestExtensionsRoundTrip(t *testing.T) {
  219. msg := &pb.MyMessage{}
  220. ext1 := &pb.Ext{
  221. Data: proto.String("hi"),
  222. }
  223. ext2 := &pb.Ext{
  224. Data: proto.String("there"),
  225. }
  226. exists := proto.HasExtension(msg, pb.E_Ext_More)
  227. if exists {
  228. t.Error("Extension More present unexpectedly")
  229. }
  230. if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
  231. t.Error(err)
  232. }
  233. if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
  234. t.Error(err)
  235. }
  236. e, err := proto.GetExtension(msg, pb.E_Ext_More)
  237. if err != nil {
  238. t.Error(err)
  239. }
  240. x, ok := e.(*pb.Ext)
  241. if !ok {
  242. t.Errorf("e has type %T, expected testdata.Ext", e)
  243. } else if *x.Data != "there" {
  244. t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
  245. }
  246. proto.ClearExtension(msg, pb.E_Ext_More)
  247. if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
  248. t.Errorf("got %v, expected ErrMissingExtension", e)
  249. }
  250. if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
  251. t.Error("expected bad extension error, got nil")
  252. }
  253. if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
  254. t.Error("expected extension err")
  255. }
  256. if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
  257. t.Error("expected some sort of type mismatch error, got nil")
  258. }
  259. }
  260. func TestNilExtension(t *testing.T) {
  261. msg := &pb.MyMessage{
  262. Count: proto.Int32(1),
  263. }
  264. if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
  265. t.Fatal(err)
  266. }
  267. if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
  268. t.Error("expected SetExtension to fail due to a nil extension")
  269. } else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want {
  270. t.Errorf("expected error %v, got %v", want, err)
  271. }
  272. // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
  273. // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
  274. }
  275. func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
  276. // Add a repeated extension to the result.
  277. tests := []struct {
  278. name string
  279. ext []*pb.ComplexExtension
  280. }{
  281. {
  282. "two fields",
  283. []*pb.ComplexExtension{
  284. {First: proto.Int32(7)},
  285. {Second: proto.Int32(11)},
  286. },
  287. },
  288. {
  289. "repeated field",
  290. []*pb.ComplexExtension{
  291. {Third: []int32{1000}},
  292. {Third: []int32{2000}},
  293. },
  294. },
  295. {
  296. "two fields and repeated field",
  297. []*pb.ComplexExtension{
  298. {Third: []int32{1000}},
  299. {First: proto.Int32(9)},
  300. {Second: proto.Int32(21)},
  301. {Third: []int32{2000}},
  302. },
  303. },
  304. }
  305. for _, test := range tests {
  306. // Marshal message with a repeated extension.
  307. msg1 := new(pb.OtherMessage)
  308. err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
  309. if err != nil {
  310. t.Fatalf("[%s] Error setting extension: %v", test.name, err)
  311. }
  312. b, err := proto.Marshal(msg1)
  313. if err != nil {
  314. t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
  315. }
  316. // Unmarshal and read the merged proto.
  317. msg2 := new(pb.OtherMessage)
  318. err = proto.Unmarshal(b, msg2)
  319. if err != nil {
  320. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  321. }
  322. e, err := proto.GetExtension(msg2, pb.E_RComplex)
  323. if err != nil {
  324. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  325. }
  326. ext := e.([]*pb.ComplexExtension)
  327. if ext == nil {
  328. t.Fatalf("[%s] Invalid extension", test.name)
  329. }
  330. if !reflect.DeepEqual(ext, test.ext) {
  331. t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext)
  332. }
  333. }
  334. }
  335. func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
  336. // We may see multiple instances of the same extension in the wire
  337. // format. For example, the proto compiler may encode custom options in
  338. // this way. Here, we verify that we merge the extensions together.
  339. tests := []struct {
  340. name string
  341. ext []*pb.ComplexExtension
  342. }{
  343. {
  344. "two fields",
  345. []*pb.ComplexExtension{
  346. {First: proto.Int32(7)},
  347. {Second: proto.Int32(11)},
  348. },
  349. },
  350. {
  351. "repeated field",
  352. []*pb.ComplexExtension{
  353. {Third: []int32{1000}},
  354. {Third: []int32{2000}},
  355. },
  356. },
  357. {
  358. "two fields and repeated field",
  359. []*pb.ComplexExtension{
  360. {Third: []int32{1000}},
  361. {First: proto.Int32(9)},
  362. {Second: proto.Int32(21)},
  363. {Third: []int32{2000}},
  364. },
  365. },
  366. }
  367. for _, test := range tests {
  368. var buf bytes.Buffer
  369. var want pb.ComplexExtension
  370. // Generate a serialized representation of a repeated extension
  371. // by catenating bytes together.
  372. for i, e := range test.ext {
  373. // Merge to create the wanted proto.
  374. proto.Merge(&want, e)
  375. // serialize the message
  376. msg := new(pb.OtherMessage)
  377. err := proto.SetExtension(msg, pb.E_Complex, e)
  378. if err != nil {
  379. t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
  380. }
  381. b, err := proto.Marshal(msg)
  382. if err != nil {
  383. t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
  384. }
  385. buf.Write(b)
  386. }
  387. // Unmarshal and read the merged proto.
  388. msg2 := new(pb.OtherMessage)
  389. err := proto.Unmarshal(buf.Bytes(), msg2)
  390. if err != nil {
  391. t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
  392. }
  393. e, err := proto.GetExtension(msg2, pb.E_Complex)
  394. if err != nil {
  395. t.Fatalf("[%s] Error getting extension: %v", test.name, err)
  396. }
  397. ext := e.(*pb.ComplexExtension)
  398. if ext == nil {
  399. t.Fatalf("[%s] Invalid extension", test.name)
  400. }
  401. if !reflect.DeepEqual(*ext, want) {
  402. t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want)
  403. }
  404. }
  405. }