server_test.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. // Copyright 2012 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "bytes"
  7. crypto_rand "crypto/rand"
  8. "io"
  9. "math/rand"
  10. "testing"
  11. )
  12. // windowTestBytes is the number of bytes that we'll send to the SSH server.
  13. const windowTestBytes = 16000 * 200
  14. // CopyNRandomly copies n bytes from src to dst. It uses a variable, and random,
  15. // buffer size to exercise more code paths.
  16. func CopyNRandomly(dst io.Writer, src io.Reader, n int) (written int, err error) {
  17. buf := make([]byte, 32*1024)
  18. for written < n {
  19. l := (rand.Intn(30) + 1) * 1024
  20. if d := n - written; d < l {
  21. l = d
  22. }
  23. nr, er := src.Read(buf[0:l])
  24. if nr > 0 {
  25. nw, ew := dst.Write(buf[0:nr])
  26. if nw > 0 {
  27. written += nw
  28. }
  29. if ew != nil {
  30. err = ew
  31. break
  32. }
  33. if nr != nw {
  34. err = io.ErrShortWrite
  35. break
  36. }
  37. }
  38. if er != nil {
  39. err = er
  40. break
  41. }
  42. }
  43. return written, err
  44. }
  45. func TestServerWindow(t *testing.T) {
  46. addr := startSSHServer(t)
  47. runSSHClient(t, addr)
  48. }
  49. // runSSHClient writes random data to the server. The server is expected to echo
  50. // the same data back, which is compared against the original.
  51. func runSSHClient(t *testing.T, addr string) {
  52. conn, err := Dial("tcp", addr, &ClientConfig{})
  53. if err != nil {
  54. t.Fatal(err)
  55. }
  56. session, err := conn.NewSession()
  57. if err != nil {
  58. t.Fatal(err)
  59. }
  60. origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
  61. echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
  62. io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
  63. origBytes := origBuf.Bytes()
  64. wait := make(chan bool)
  65. // Read back the data from the server.
  66. go func() {
  67. defer session.Close()
  68. defer close(wait)
  69. serverStdout, err := session.StdoutPipe()
  70. if err != nil {
  71. t.Fatal(err)
  72. }
  73. n, err := CopyNRandomly(echoedBuf, serverStdout, windowTestBytes)
  74. if err != nil && err != io.EOF {
  75. t.Fatal(err)
  76. }
  77. if n != windowTestBytes {
  78. t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes)
  79. }
  80. }()
  81. serverStdin, err := session.StdinPipe()
  82. if err != nil {
  83. t.Fatal(err)
  84. }
  85. written, err := CopyNRandomly(serverStdin, origBuf, windowTestBytes)
  86. if err != nil {
  87. t.Fatal(err)
  88. }
  89. if written != windowTestBytes {
  90. t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
  91. }
  92. <-wait
  93. if !bytes.Equal(origBytes, echoedBuf.Bytes()) {
  94. t.Error("Echoed buffer differed from original")
  95. }
  96. }
  97. func startSSHServer(t *testing.T) (addr string) {
  98. config := &ServerConfig{
  99. NoClientAuth: true,
  100. }
  101. err := config.SetRSAPrivateKey([]byte(testServerPrivateKey))
  102. if err != nil {
  103. t.Fatalf("Failed to parse private key: %s", err.Error())
  104. }
  105. listener, err := Listen("tcp", "127.0.0.1:0", config)
  106. if err != nil {
  107. t.Fatalf("Bind error: %s", err)
  108. }
  109. addr = listener.Addr().String()
  110. go func() {
  111. defer listener.Close()
  112. for {
  113. sConn, err := listener.Accept()
  114. err = sConn.Handshake()
  115. if err != nil {
  116. if err != io.EOF {
  117. t.Fatalf("failed to handshake: %s", err)
  118. }
  119. return
  120. }
  121. go connRun(t, sConn)
  122. }
  123. }()
  124. return
  125. }
  126. func connRun(t *testing.T, sConn *ServerConn) {
  127. defer sConn.Close()
  128. for {
  129. channel, err := sConn.Accept()
  130. if err != nil {
  131. if err == io.EOF {
  132. break
  133. }
  134. t.Fatalf("ServerConn.Accept failed: %s", err)
  135. }
  136. if channel.ChannelType() != "session" {
  137. channel.Reject(UnknownChannelType, "unknown channel type")
  138. continue
  139. }
  140. err = channel.Accept()
  141. if err != nil {
  142. t.Fatalf("Channel.Accept failed: %s", err)
  143. }
  144. go func() {
  145. defer channel.Close()
  146. n, err := CopyNRandomly(channel, channel, windowTestBytes)
  147. if err != nil && err != io.EOF {
  148. if err == io.ErrShortWrite {
  149. t.Fatalf("short write, wrote %d, expected %d", n, windowTestBytes)
  150. }
  151. t.Fatal(err)
  152. }
  153. }()
  154. }
  155. }