encode_test.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package proto_test
  2. import (
  3. "bytes"
  4. "fmt"
  5. "reflect"
  6. "testing"
  7. protoV1 "github.com/golang/protobuf/proto"
  8. "github.com/golang/protobuf/v2/proto"
  9. pref "github.com/golang/protobuf/v2/reflect/protoreflect"
  10. "github.com/google/go-cmp/cmp"
  11. test3pb "github.com/golang/protobuf/v2/internal/testprotos/test3"
  12. )
  13. func TestEncode(t *testing.T) {
  14. for _, test := range testProtos {
  15. for _, want := range test.decodeTo {
  16. t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
  17. opts := proto.MarshalOptions{
  18. AllowPartial: test.partial,
  19. }
  20. wire, err := opts.Marshal(want)
  21. if err != nil {
  22. t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
  23. }
  24. size := proto.Size(want)
  25. if size != len(wire) {
  26. t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want))
  27. }
  28. got := newMessage(want)
  29. uopts := proto.UnmarshalOptions{
  30. AllowPartial: test.partial,
  31. }
  32. if err := uopts.Unmarshal(wire, got); err != nil {
  33. t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
  34. return
  35. }
  36. if test.invalidExtensions {
  37. // Equal doesn't work on messages containing invalid extension data.
  38. return
  39. }
  40. if !proto.Equal(got, want) {
  41. t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
  42. }
  43. })
  44. }
  45. }
  46. }
  47. func TestEncodeDeterministic(t *testing.T) {
  48. for _, test := range testProtos {
  49. for _, want := range test.decodeTo {
  50. t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
  51. opts := proto.MarshalOptions{
  52. Deterministic: true,
  53. AllowPartial: test.partial,
  54. }
  55. wire, err := opts.Marshal(want)
  56. if err != nil {
  57. t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
  58. }
  59. wire2, err := opts.Marshal(want)
  60. if err != nil {
  61. t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
  62. }
  63. if !bytes.Equal(wire, wire2) {
  64. t.Fatalf("deterministic marshal returned varying results:\n%v", cmp.Diff(wire, wire2))
  65. }
  66. got := newMessage(want)
  67. uopts := proto.UnmarshalOptions{
  68. AllowPartial: test.partial,
  69. }
  70. if err := uopts.Unmarshal(wire, got); err != nil {
  71. t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
  72. return
  73. }
  74. if test.invalidExtensions {
  75. // Equal doesn't work on messages containing invalid extension data.
  76. return
  77. }
  78. if !proto.Equal(got, want) {
  79. t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
  80. }
  81. })
  82. }
  83. }
  84. }
  85. func TestEncodeInvalidUTF8(t *testing.T) {
  86. for _, test := range invalidUTF8TestProtos {
  87. for _, want := range test.decodeTo {
  88. t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
  89. wire, err := proto.Marshal(want)
  90. if !isErrInvalidUTF8(err) {
  91. t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
  92. }
  93. got := newMessage(want)
  94. if err := proto.Unmarshal(wire, got); !isErrInvalidUTF8(err) {
  95. t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
  96. return
  97. }
  98. if !proto.Equal(got, want) {
  99. t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
  100. }
  101. })
  102. }
  103. }
  104. }
  105. func TestEncodeRequiredFieldChecks(t *testing.T) {
  106. for _, test := range testProtos {
  107. if !test.partial {
  108. continue
  109. }
  110. for _, m := range test.decodeTo {
  111. t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
  112. _, err := proto.Marshal(m)
  113. if err == nil {
  114. t.Fatalf("Marshal succeeded (want error)\nMessage:\n%v", marshalText(m))
  115. }
  116. })
  117. }
  118. }
  119. }
  120. func TestMarshalAppend(t *testing.T) {
  121. want := []byte("prefix")
  122. got := append([]byte(nil), want...)
  123. got, err := proto.MarshalOptions{}.MarshalAppend(got, &test3pb.TestAllTypes{
  124. OptionalString: "value",
  125. })
  126. if err != nil {
  127. t.Fatal(err)
  128. }
  129. if !bytes.HasPrefix(got, want) {
  130. t.Fatalf("MarshalAppend modified prefix: got %v, want prefix %v", got, want)
  131. }
  132. }
  133. // newMessage returns a new message with the same type and extension fields as m.
  134. func newMessage(m proto.Message) proto.Message {
  135. n := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
  136. m.ProtoReflect().KnownFields().ExtensionTypes().Range(func(xt pref.ExtensionType) bool {
  137. n.ProtoReflect().KnownFields().ExtensionTypes().Register(xt)
  138. return true
  139. })
  140. return n
  141. }