encode_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto_test
  5. import (
  6. "bytes"
  7. "fmt"
  8. "reflect"
  9. "testing"
  10. protoV1 "github.com/golang/protobuf/proto"
  11. "github.com/google/go-cmp/cmp"
  12. "google.golang.org/protobuf/proto"
  13. pref "google.golang.org/protobuf/reflect/protoreflect"
  14. test3pb "google.golang.org/protobuf/internal/testprotos/test3"
  15. )
  16. func TestEncode(t *testing.T) {
  17. for _, test := range testProtos {
  18. for _, want := range test.decodeTo {
  19. t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
  20. opts := proto.MarshalOptions{
  21. AllowPartial: test.partial,
  22. }
  23. wire, err := opts.Marshal(want)
  24. if err != nil {
  25. t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
  26. }
  27. size := proto.Size(want)
  28. if size != len(wire) {
  29. t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want))
  30. }
  31. got := newMessage(want)
  32. uopts := proto.UnmarshalOptions{
  33. AllowPartial: test.partial,
  34. }
  35. if err := uopts.Unmarshal(wire, got); err != nil {
  36. t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
  37. return
  38. }
  39. if test.invalidExtensions {
  40. // Equal doesn't work on messages containing invalid extension data.
  41. return
  42. }
  43. if !proto.Equal(got, want) {
  44. t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
  45. }
  46. })
  47. }
  48. }
  49. }
  50. func TestEncodeDeterministic(t *testing.T) {
  51. for _, test := range testProtos {
  52. for _, want := range test.decodeTo {
  53. t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
  54. opts := proto.MarshalOptions{
  55. Deterministic: true,
  56. AllowPartial: test.partial,
  57. }
  58. wire, err := opts.Marshal(want)
  59. if err != nil {
  60. t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
  61. }
  62. wire2, err := opts.Marshal(want)
  63. if err != nil {
  64. t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
  65. }
  66. if !bytes.Equal(wire, wire2) {
  67. t.Fatalf("deterministic marshal returned varying results:\n%v", cmp.Diff(wire, wire2))
  68. }
  69. got := newMessage(want)
  70. uopts := proto.UnmarshalOptions{
  71. AllowPartial: test.partial,
  72. }
  73. if err := uopts.Unmarshal(wire, got); err != nil {
  74. t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
  75. return
  76. }
  77. if test.invalidExtensions {
  78. // Equal doesn't work on messages containing invalid extension data.
  79. return
  80. }
  81. if !proto.Equal(got, want) {
  82. t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
  83. }
  84. })
  85. }
  86. }
  87. }
  88. func TestEncodeInvalidUTF8(t *testing.T) {
  89. for _, test := range invalidUTF8TestProtos {
  90. for _, want := range test.decodeTo {
  91. t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
  92. wire, err := proto.Marshal(want)
  93. if !isErrInvalidUTF8(err) {
  94. t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
  95. }
  96. got := newMessage(want)
  97. if err := proto.Unmarshal(wire, got); !isErrInvalidUTF8(err) {
  98. t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
  99. return
  100. }
  101. if !proto.Equal(got, want) {
  102. t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
  103. }
  104. })
  105. }
  106. }
  107. }
  108. func TestEncodeRequiredFieldChecks(t *testing.T) {
  109. for _, test := range testProtos {
  110. if !test.partial {
  111. continue
  112. }
  113. for _, m := range test.decodeTo {
  114. t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
  115. _, err := proto.Marshal(m)
  116. if err == nil {
  117. t.Fatalf("Marshal succeeded (want error)\nMessage:\n%v", marshalText(m))
  118. }
  119. })
  120. }
  121. }
  122. }
  123. func TestMarshalAppend(t *testing.T) {
  124. want := []byte("prefix")
  125. got := append([]byte(nil), want...)
  126. got, err := proto.MarshalOptions{}.MarshalAppend(got, &test3pb.TestAllTypes{
  127. OptionalString: "value",
  128. })
  129. if err != nil {
  130. t.Fatal(err)
  131. }
  132. if !bytes.HasPrefix(got, want) {
  133. t.Fatalf("MarshalAppend modified prefix: got %v, want prefix %v", got, want)
  134. }
  135. }
  136. // newMessage returns a new message with the same type and extension fields as m.
  137. func newMessage(m proto.Message) proto.Message {
  138. n := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
  139. m.ProtoReflect().KnownFields().ExtensionTypes().Range(func(xt pref.ExtensionType) bool {
  140. n.ProtoReflect().KnownFields().ExtensionTypes().Register(xt)
  141. return true
  142. })
  143. return n
  144. }