mux_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "io"
  7. "io/ioutil"
  8. "sync"
  9. "testing"
  10. )
  11. func muxPair() (*mux, *mux) {
  12. a, b := memPipe()
  13. s := newMux(a)
  14. c := newMux(b)
  15. return s, c
  16. }
  17. // Returns both ends of a channel, and the mux for the the 2nd
  18. // channel.
  19. func channelPair(t *testing.T) (*channel, *channel, *mux) {
  20. c, s := muxPair()
  21. res := make(chan *channel, 1)
  22. go func() {
  23. newCh, ok := <-s.incomingChannels
  24. if !ok {
  25. t.Fatalf("No incoming channel")
  26. }
  27. if newCh.ChannelType() != "chan" {
  28. t.Fatalf("got type %q want chan", newCh.ChannelType())
  29. }
  30. ch, _, err := newCh.Accept()
  31. if err != nil {
  32. t.Fatalf("Accept %v", err)
  33. }
  34. res <- ch.(*channel)
  35. }()
  36. ch, err := c.openChannel("chan", nil)
  37. if err != nil {
  38. t.Fatalf("OpenChannel: %v", err)
  39. }
  40. return <-res, ch, c
  41. }
  42. func TestMuxReadWrite(t *testing.T) {
  43. s, c, mux := channelPair(t)
  44. defer s.Close()
  45. defer c.Close()
  46. defer mux.Close()
  47. magic := "hello world"
  48. magicExt := "hello stderr"
  49. go func() {
  50. _, err := s.Write([]byte(magic))
  51. if err != nil {
  52. t.Fatalf("Write: %v", err)
  53. }
  54. _, err = s.Extended(1).Write([]byte(magicExt))
  55. if err != nil {
  56. t.Fatalf("Write: %v", err)
  57. }
  58. err = s.Close()
  59. if err != nil {
  60. t.Fatalf("Close: %v", err)
  61. }
  62. }()
  63. var buf [1024]byte
  64. n, err := c.Read(buf[:])
  65. if err != nil {
  66. t.Fatalf("server Read: %v", err)
  67. }
  68. got := string(buf[:n])
  69. if got != magic {
  70. t.Fatalf("server: got %q want %q", got, magic)
  71. }
  72. n, err = c.Extended(1).Read(buf[:])
  73. if err != nil {
  74. t.Fatalf("server Read: %v", err)
  75. }
  76. got = string(buf[:n])
  77. if got != magicExt {
  78. t.Fatalf("server: got %q want %q", got, magic)
  79. }
  80. }
  81. func TestMuxChannelOverflow(t *testing.T) {
  82. reader, writer, mux := channelPair(t)
  83. defer reader.Close()
  84. defer writer.Close()
  85. defer mux.Close()
  86. wDone := make(chan int, 1)
  87. go func() {
  88. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  89. t.Errorf("could not fill window: %v", err)
  90. }
  91. writer.Write(make([]byte, 1))
  92. wDone <- 1
  93. }()
  94. writer.remoteWin.waitWriterBlocked()
  95. // Send 1 byte.
  96. packet := make([]byte, 1+4+4+1)
  97. packet[0] = msgChannelData
  98. marshalUint32(packet[1:], writer.remoteId)
  99. marshalUint32(packet[5:], uint32(1))
  100. packet[9] = 42
  101. if err := writer.mux.conn.writePacket(packet); err != nil {
  102. t.Errorf("could not send packet")
  103. }
  104. if _, err := reader.SendRequest("hello", true, nil); err == nil {
  105. t.Errorf("SendRequest succeeded.")
  106. }
  107. <-wDone
  108. }
  109. func TestMuxChannelCloseWriteUnblock(t *testing.T) {
  110. reader, writer, mux := channelPair(t)
  111. defer reader.Close()
  112. defer writer.Close()
  113. defer mux.Close()
  114. wDone := make(chan int, 1)
  115. go func() {
  116. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  117. t.Errorf("could not fill window: %v", err)
  118. }
  119. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  120. t.Errorf("got %v, want EOF for unblock write", err)
  121. }
  122. wDone <- 1
  123. }()
  124. writer.remoteWin.waitWriterBlocked()
  125. reader.Close()
  126. <-wDone
  127. }
  128. func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
  129. reader, writer, mux := channelPair(t)
  130. defer reader.Close()
  131. defer writer.Close()
  132. defer mux.Close()
  133. wDone := make(chan int, 1)
  134. go func() {
  135. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  136. t.Errorf("could not fill window: %v", err)
  137. }
  138. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  139. t.Errorf("got %v, want EOF for unblock write", err)
  140. }
  141. wDone <- 1
  142. }()
  143. writer.remoteWin.waitWriterBlocked()
  144. mux.Close()
  145. <-wDone
  146. }
  147. func TestMuxReject(t *testing.T) {
  148. client, server := muxPair()
  149. defer server.Close()
  150. defer client.Close()
  151. go func() {
  152. ch, ok := <-server.incomingChannels
  153. if !ok {
  154. t.Fatalf("Accept")
  155. }
  156. if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
  157. t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
  158. }
  159. ch.Reject(RejectionReason(42), "message")
  160. }()
  161. ch, err := client.openChannel("ch", []byte("extra"))
  162. if ch != nil {
  163. t.Fatal("openChannel not rejected")
  164. }
  165. ocf, ok := err.(*OpenChannelError)
  166. if !ok {
  167. t.Errorf("got %#v want *OpenChannelError", err)
  168. } else if ocf.Reason != 42 || ocf.Message != "message" {
  169. t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
  170. }
  171. want := "ssh: rejected: unknown reason 42 (message)"
  172. if err.Error() != want {
  173. t.Errorf("got %q, want %q", err.Error(), want)
  174. }
  175. }
  176. func TestMuxChannelRequest(t *testing.T) {
  177. client, server, mux := channelPair(t)
  178. defer server.Close()
  179. defer client.Close()
  180. defer mux.Close()
  181. var received int
  182. var wg sync.WaitGroup
  183. wg.Add(1)
  184. go func() {
  185. for r := range server.incomingRequests {
  186. received++
  187. r.Reply(r.Type == "yes", nil)
  188. }
  189. wg.Done()
  190. }()
  191. _, err := client.SendRequest("yes", false, nil)
  192. if err != nil {
  193. t.Fatalf("SendRequest: %v", err)
  194. }
  195. ok, err := client.SendRequest("yes", true, nil)
  196. if err != nil {
  197. t.Fatalf("SendRequest: %v", err)
  198. }
  199. if !ok {
  200. t.Errorf("SendRequest(yes): %v", ok)
  201. }
  202. ok, err = client.SendRequest("no", true, nil)
  203. if err != nil {
  204. t.Fatalf("SendRequest: %v", err)
  205. }
  206. if ok {
  207. t.Errorf("SendRequest(no): %v", ok)
  208. }
  209. client.Close()
  210. wg.Wait()
  211. if received != 3 {
  212. t.Errorf("got %d requests, want %d", received, 3)
  213. }
  214. }
  215. func TestMuxGlobalRequest(t *testing.T) {
  216. clientMux, serverMux := muxPair()
  217. defer serverMux.Close()
  218. defer clientMux.Close()
  219. var seen bool
  220. go func() {
  221. for r := range serverMux.incomingRequests {
  222. seen = seen || r.Type == "peek"
  223. if r.WantReply {
  224. err := r.Reply(r.Type == "yes",
  225. append([]byte(r.Type), r.Payload...))
  226. if err != nil {
  227. t.Errorf("AckRequest: %v", err)
  228. }
  229. }
  230. }
  231. }()
  232. _, _, err := clientMux.SendRequest("peek", false, nil)
  233. if err != nil {
  234. t.Errorf("SendRequest: %v", err)
  235. }
  236. ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
  237. if !ok || string(data) != "yesa" || err != nil {
  238. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  239. ok, data, err)
  240. }
  241. if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
  242. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  243. ok, data, err)
  244. }
  245. if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
  246. t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
  247. ok, data, err)
  248. }
  249. clientMux.Disconnect(0, "")
  250. if !seen {
  251. t.Errorf("never saw 'peek' request")
  252. }
  253. }
  254. func TestMuxGlobalRequestUnblock(t *testing.T) {
  255. clientMux, serverMux := muxPair()
  256. defer serverMux.Close()
  257. defer clientMux.Close()
  258. result := make(chan error, 1)
  259. go func() {
  260. _, _, err := clientMux.SendRequest("hello", true, nil)
  261. result <- err
  262. }()
  263. <-serverMux.incomingRequests
  264. serverMux.conn.Close()
  265. err := <-result
  266. if err != io.EOF {
  267. t.Errorf("want EOF, got %v", io.EOF)
  268. }
  269. }
  270. func TestMuxChannelRequestUnblock(t *testing.T) {
  271. a, b, connB := channelPair(t)
  272. defer a.Close()
  273. defer b.Close()
  274. defer connB.Close()
  275. result := make(chan error, 1)
  276. go func() {
  277. _, err := a.SendRequest("hello", true, nil)
  278. result <- err
  279. }()
  280. <-b.incomingRequests
  281. connB.conn.Close()
  282. err := <-result
  283. if err != io.EOF {
  284. t.Errorf("want EOF, got %v", err)
  285. }
  286. }
  287. func TestMuxDisconnect(t *testing.T) {
  288. a, b := muxPair()
  289. defer a.Close()
  290. defer b.Close()
  291. go func() {
  292. for r := range b.incomingRequests {
  293. r.Reply(true, nil)
  294. }
  295. }()
  296. a.Disconnect(42, "whatever")
  297. ok, _, err := a.SendRequest("hello", true, nil)
  298. if ok || err == nil {
  299. t.Errorf("got reply after disconnecting")
  300. }
  301. err = b.Wait()
  302. if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
  303. t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
  304. }
  305. }
  306. func TestMuxCloseChannel(t *testing.T) {
  307. r, w, mux := channelPair(t)
  308. defer mux.Close()
  309. defer r.Close()
  310. defer w.Close()
  311. result := make(chan error, 1)
  312. go func() {
  313. var b [1024]byte
  314. _, err := r.Read(b[:])
  315. result <- err
  316. }()
  317. if err := w.Close(); err != nil {
  318. t.Errorf("w.Close: %v", err)
  319. }
  320. if _, err := w.Write([]byte("hello")); err != io.EOF {
  321. t.Errorf("got err %v, want io.EOF after Close", err)
  322. }
  323. if err := <-result; err != io.EOF {
  324. t.Errorf("got %v (%T), want io.EOF", err, err)
  325. }
  326. }
  327. func TestMuxCloseWriteChannel(t *testing.T) {
  328. r, w, mux := channelPair(t)
  329. defer mux.Close()
  330. result := make(chan error, 1)
  331. go func() {
  332. var b [1024]byte
  333. _, err := r.Read(b[:])
  334. result <- err
  335. }()
  336. if err := w.CloseWrite(); err != nil {
  337. t.Errorf("w.CloseWrite: %v", err)
  338. }
  339. if _, err := w.Write([]byte("hello")); err != io.EOF {
  340. t.Errorf("got err %v, want io.EOF after CloseWrite", err)
  341. }
  342. if err := <-result; err != io.EOF {
  343. t.Errorf("got %v (%T), want io.EOF", err, err)
  344. }
  345. }
  346. func TestMuxInvalidRecord(t *testing.T) {
  347. a, b := muxPair()
  348. defer a.Close()
  349. defer b.Close()
  350. packet := make([]byte, 1+4+4+1)
  351. packet[0] = msgChannelData
  352. marshalUint32(packet[1:], 29348723 /* invalid channel id */)
  353. marshalUint32(packet[5:], 1)
  354. packet[9] = 42
  355. a.conn.writePacket(packet)
  356. go a.SendRequest("hello", false, nil)
  357. // 'a' wrote an invalid packet, so 'b' has exited.
  358. req, ok := <-b.incomingRequests
  359. if ok {
  360. t.Errorf("got request %#v after receiving invalid packet", req)
  361. }
  362. }
  363. func TestZeroWindowAdjust(t *testing.T) {
  364. a, b, mux := channelPair(t)
  365. defer a.Close()
  366. defer b.Close()
  367. defer mux.Close()
  368. go func() {
  369. io.WriteString(a, "hello")
  370. // bogus adjust.
  371. a.sendMessage(windowAdjustMsg{})
  372. io.WriteString(a, "world")
  373. a.Close()
  374. }()
  375. want := "helloworld"
  376. c, _ := ioutil.ReadAll(b)
  377. if string(c) != want {
  378. t.Errorf("got %q want %q", c, want)
  379. }
  380. }
  381. func TestMuxMaxPacketSize(t *testing.T) {
  382. a, b, mux := channelPair(t)
  383. defer a.Close()
  384. defer b.Close()
  385. defer mux.Close()
  386. large := make([]byte, a.maxRemotePayload+1)
  387. packet := make([]byte, 1+4+4+1+len(large))
  388. packet[0] = msgChannelData
  389. marshalUint32(packet[1:], a.remoteId)
  390. marshalUint32(packet[5:], uint32(len(large)))
  391. packet[9] = 42
  392. if err := a.mux.conn.writePacket(packet); err != nil {
  393. t.Errorf("could not send packet")
  394. }
  395. go a.SendRequest("hello", false, nil)
  396. _, ok := <-b.incomingRequests
  397. if ok {
  398. t.Errorf("connection still alive after receiving large packet.")
  399. }
  400. }
  401. // Don't ship code with debug=true.
  402. func TestDebug(t *testing.T) {
  403. if debugMux {
  404. t.Error("mux debug switched on")
  405. }
  406. if debugHandshake {
  407. t.Error("handshake debug switched on")
  408. }
  409. }