messages_test.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "math/big"
  7. "math/rand"
  8. "reflect"
  9. "testing"
  10. "testing/quick"
  11. )
  12. var intLengthTests = []struct {
  13. val, length int
  14. }{
  15. {0, 4 + 0},
  16. {1, 4 + 1},
  17. {127, 4 + 1},
  18. {128, 4 + 2},
  19. {-1, 4 + 1},
  20. }
  21. func TestIntLength(t *testing.T) {
  22. for _, test := range intLengthTests {
  23. v := new(big.Int).SetInt64(int64(test.val))
  24. length := intLength(v)
  25. if length != test.length {
  26. t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
  27. }
  28. }
  29. }
  30. var messageTypes = []interface{}{
  31. &kexInitMsg{},
  32. &kexDHInitMsg{},
  33. &serviceRequestMsg{},
  34. &serviceAcceptMsg{},
  35. &userAuthRequestMsg{},
  36. &channelOpenMsg{},
  37. &channelOpenConfirmMsg{},
  38. &channelOpenFailureMsg{},
  39. &channelRequestMsg{},
  40. &channelRequestSuccessMsg{},
  41. }
  42. func TestMarshalUnmarshal(t *testing.T) {
  43. rand := rand.New(rand.NewSource(0))
  44. for i, iface := range messageTypes {
  45. ty := reflect.ValueOf(iface).Type()
  46. n := 100
  47. if testing.Short() {
  48. n = 5
  49. }
  50. for j := 0; j < n; j++ {
  51. v, ok := quick.Value(ty, rand)
  52. if !ok {
  53. t.Errorf("#%d: failed to create value", i)
  54. break
  55. }
  56. m1 := v.Elem().Interface()
  57. m2 := iface
  58. marshaled := marshal(msgIgnore, m1)
  59. if err := unmarshal(m2, marshaled, msgIgnore); err != nil {
  60. t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err)
  61. break
  62. }
  63. if !reflect.DeepEqual(v.Interface(), m2) {
  64. t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled)
  65. break
  66. }
  67. }
  68. }
  69. }
  70. func TestUnmarshalEmptyPacket(t *testing.T) {
  71. var b []byte
  72. var m channelRequestSuccessMsg
  73. err := unmarshal(&m, b, msgChannelRequest)
  74. want := ParseError{msgChannelRequest}
  75. if _, ok := err.(ParseError); !ok {
  76. t.Fatalf("got %T, want %T", err, want)
  77. }
  78. if got := err.(ParseError); want != got {
  79. t.Fatal("got %#v, want %#v", got, want)
  80. }
  81. }
  82. func TestUnmarshalUnexpectedPacket(t *testing.T) {
  83. type S struct {
  84. I uint32
  85. S string
  86. B bool
  87. }
  88. s := S{42, "hello", true}
  89. packet := marshal(42, s)
  90. roundtrip := S{}
  91. err := unmarshal(&roundtrip, packet, 43)
  92. if err == nil {
  93. t.Fatal("expected error, not nil")
  94. }
  95. want := UnexpectedMessageError{43, 42}
  96. if got, ok := err.(UnexpectedMessageError); !ok || want != got {
  97. t.Fatal("expected %q, got %q", want, got)
  98. }
  99. }
  100. func TestBareMarshalUnmarshal(t *testing.T) {
  101. type S struct {
  102. I uint32
  103. S string
  104. B bool
  105. }
  106. s := S{42, "hello", true}
  107. packet := marshal(0, s)
  108. roundtrip := S{}
  109. unmarshal(&roundtrip, packet, 0)
  110. if !reflect.DeepEqual(s, roundtrip) {
  111. t.Errorf("got %#v, want %#v", roundtrip, s)
  112. }
  113. }
  114. func TestBareMarshal(t *testing.T) {
  115. type S2 struct {
  116. I uint32
  117. }
  118. s := S2{42}
  119. packet := marshal(0, s)
  120. i, rest, ok := parseUint32(packet)
  121. if len(rest) > 0 || !ok {
  122. t.Errorf("parseInt(%q): parse error", packet)
  123. }
  124. if i != s.I {
  125. t.Errorf("got %d, want %d", i, s.I)
  126. }
  127. }
  128. func randomBytes(out []byte, rand *rand.Rand) {
  129. for i := 0; i < len(out); i++ {
  130. out[i] = byte(rand.Int31())
  131. }
  132. }
  133. func randomNameList(rand *rand.Rand) []string {
  134. ret := make([]string, rand.Int31()&15)
  135. for i := range ret {
  136. s := make([]byte, 1+(rand.Int31()&15))
  137. for j := range s {
  138. s[j] = 'a' + uint8(rand.Int31()&15)
  139. }
  140. ret[i] = string(s)
  141. }
  142. return ret
  143. }
  144. func randomInt(rand *rand.Rand) *big.Int {
  145. return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
  146. }
  147. func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  148. ki := &kexInitMsg{}
  149. randomBytes(ki.Cookie[:], rand)
  150. ki.KexAlgos = randomNameList(rand)
  151. ki.ServerHostKeyAlgos = randomNameList(rand)
  152. ki.CiphersClientServer = randomNameList(rand)
  153. ki.CiphersServerClient = randomNameList(rand)
  154. ki.MACsClientServer = randomNameList(rand)
  155. ki.MACsServerClient = randomNameList(rand)
  156. ki.CompressionClientServer = randomNameList(rand)
  157. ki.CompressionServerClient = randomNameList(rand)
  158. ki.LanguagesClientServer = randomNameList(rand)
  159. ki.LanguagesServerClient = randomNameList(rand)
  160. if rand.Int31()&1 == 1 {
  161. ki.FirstKexFollows = true
  162. }
  163. return reflect.ValueOf(ki)
  164. }
  165. func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  166. dhi := &kexDHInitMsg{}
  167. dhi.X = randomInt(rand)
  168. return reflect.ValueOf(dhi)
  169. }
  170. // TODO(dfc) maybe this can be removed in the future if testing/quick can handle
  171. // derived basic types.
  172. func (RejectionReason) Generate(rand *rand.Rand, size int) reflect.Value {
  173. m := RejectionReason(Prohibited)
  174. return reflect.ValueOf(m)
  175. }
  176. var (
  177. _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
  178. _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
  179. _kexInit = marshal(msgKexInit, _kexInitMsg)
  180. _kexDHInit = marshal(msgKexDHInit, _kexDHInitMsg)
  181. )
  182. func BenchmarkMarshalKexInitMsg(b *testing.B) {
  183. for i := 0; i < b.N; i++ {
  184. marshal(msgKexInit, _kexInitMsg)
  185. }
  186. }
  187. func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
  188. m := new(kexInitMsg)
  189. for i := 0; i < b.N; i++ {
  190. unmarshal(m, _kexInit, msgKexInit)
  191. }
  192. }
  193. func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
  194. for i := 0; i < b.N; i++ {
  195. marshal(msgKexDHInit, _kexDHInitMsg)
  196. }
  197. }
  198. func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
  199. m := new(kexDHInitMsg)
  200. for i := 0; i < b.N; i++ {
  201. unmarshal(m, _kexDHInit, msgKexDHInit)
  202. }
  203. }