forward_test.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package test
  2. import (
  3. "bytes"
  4. "io"
  5. "io/ioutil"
  6. "math/rand"
  7. "net"
  8. "testing"
  9. )
  10. func TestPortForward(t *testing.T) {
  11. server := newServer(t)
  12. defer server.Shutdown()
  13. conn := server.Dial(clientConfig())
  14. defer conn.Close()
  15. sshListener, err := conn.Listen("tcp", "127.0.0.1:0")
  16. if err != nil {
  17. t.Fatalf("conn.Listen failed: %v", err)
  18. }
  19. go func() {
  20. sshConn, err := sshListener.Accept()
  21. if err != nil {
  22. t.Fatalf("listen.Accept failed: %v", err)
  23. }
  24. _, err = io.Copy(sshConn, sshConn)
  25. if err != nil && err != io.EOF {
  26. t.Fatalf("ssh client copy: %v", err)
  27. }
  28. sshConn.Close()
  29. }()
  30. forwardedAddr := sshListener.Addr().String()
  31. tcpConn, err := net.Dial("tcp", forwardedAddr)
  32. if err != nil {
  33. t.Fatalf("TCP dial failed: %v", err)
  34. }
  35. readChan := make(chan []byte)
  36. go func() {
  37. data, _ := ioutil.ReadAll(tcpConn)
  38. readChan <- data
  39. }()
  40. // Invent some data.
  41. data := make([]byte, 100*1000)
  42. for i := range data {
  43. data[i] = byte(i % 255)
  44. }
  45. var sent []byte
  46. for len(sent) < 1000*1000 {
  47. // Send random sized chunks
  48. m := rand.Intn(len(data))
  49. n, err := tcpConn.Write(data[:m])
  50. if err != nil {
  51. break
  52. }
  53. sent = append(sent, data[:n]...)
  54. }
  55. if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
  56. t.Errorf("tcpConn.CloseWrite: %v", err)
  57. }
  58. read := <-readChan
  59. if len(sent) != len(read) {
  60. t.Fatalf("got %d bytes, want %d", len(read), len(sent))
  61. }
  62. if bytes.Compare(sent, read) != 0 {
  63. t.Fatalf("read back data does not match")
  64. }
  65. if err := sshListener.Close(); err != nil {
  66. t.Fatalf("sshListener.Close: %v", err)
  67. }
  68. // Check that the forward disappeared.
  69. tcpConn, err = net.Dial("tcp", forwardedAddr)
  70. if err == nil {
  71. tcpConn.Close()
  72. t.Errorf("still listening to %s after closing", forwardedAddr)
  73. }
  74. }