|
@@ -8,8 +8,10 @@ package ssh
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ crypto_rand "crypto/rand"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
+ "math/rand"
|
|
|
"net"
|
|
|
"testing"
|
|
|
|
|
@@ -42,6 +44,7 @@ func dial(handler serverType, t *testing.T) *ClientConn {
|
|
|
t.Errorf("Unable to handshake: %v", err)
|
|
|
return
|
|
|
}
|
|
|
+ done := make(chan struct{})
|
|
|
for {
|
|
|
ch, err := conn.Accept()
|
|
|
if err == io.EOF {
|
|
@@ -60,9 +63,12 @@ func dial(handler serverType, t *testing.T) *ClientConn {
|
|
|
continue
|
|
|
}
|
|
|
ch.Accept()
|
|
|
- go handler(ch.(*serverChan), t)
|
|
|
+ go func() {
|
|
|
+ defer close(done)
|
|
|
+ handler(ch.(*serverChan), t)
|
|
|
+ }()
|
|
|
}
|
|
|
- t.Log("done")
|
|
|
+ <-done
|
|
|
}()
|
|
|
|
|
|
config := &ClientConfig{
|
|
@@ -345,17 +351,19 @@ func TestClientStdinRespectsMaxPacketSize(t *testing.T) {
|
|
|
defer conn.Close()
|
|
|
session, err := conn.NewSession()
|
|
|
if err != nil {
|
|
|
- t.Fatalf("Unable to request new session: %v", err)
|
|
|
+ t.Fatalf("failed to request new session: %v", err)
|
|
|
}
|
|
|
defer session.Close()
|
|
|
- if err := session.Shell(); err != nil {
|
|
|
- t.Fatalf("Unable to execute command: %v", err)
|
|
|
+ stdin, err := session.StdinPipe()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("failed to obtain stdinpipe: %v", err)
|
|
|
}
|
|
|
-
|
|
|
- const size = 128 * 1024
|
|
|
- n, err := session.clientChan.stdin.Write(make([]byte, size))
|
|
|
- if n != size || err != nil {
|
|
|
- t.Fatalf("failed to write: %d, %v", n, err)
|
|
|
+ const size = 100 * 1000
|
|
|
+ for i := 0; i < 10; i++ {
|
|
|
+ n, err := stdin.Write(make([]byte, size))
|
|
|
+ if n != size || err != nil {
|
|
|
+ t.Fatalf("failed to write: %d, %v", n, err)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -385,7 +393,7 @@ func TestClientCannotSendAfterEOF(t *testing.T) {
|
|
|
t.Logf("test skipped")
|
|
|
return
|
|
|
|
|
|
- conn := dial(shellHandler, t)
|
|
|
+ conn := dial(exitWithoutSignalOrStatus, t)
|
|
|
defer conn.Close()
|
|
|
session, err := conn.NewSession()
|
|
|
if err != nil {
|
|
@@ -431,6 +439,59 @@ func TestClientCannotSendAfterClose(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+const windowTestBytes = 16000 * 200
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+func TestServerWindow(t *testing.T) {
|
|
|
+ origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
|
|
|
+ io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
|
|
|
+ origBytes := origBuf.Bytes()
|
|
|
+
|
|
|
+ conn := dial(echoHandler, t)
|
|
|
+ defer conn.Close()
|
|
|
+ session, err := conn.NewSession()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ defer session.Close()
|
|
|
+ result := make(chan []byte)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ defer close(result)
|
|
|
+ echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
|
|
|
+ serverStdout, err := session.StdoutPipe()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("StdoutPipe failed: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
|
|
|
+ if err != nil && err != io.EOF {
|
|
|
+ t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
|
|
|
+ }
|
|
|
+ result <- echoedBuf.Bytes()
|
|
|
+ }()
|
|
|
+
|
|
|
+ serverStdin, err := session.StdinPipe()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("StdinPipe failed: %v", err)
|
|
|
+ }
|
|
|
+ written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("falied to copy origBuf to serverStdin: %v", err)
|
|
|
+ }
|
|
|
+ if written != windowTestBytes {
|
|
|
+ t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
|
|
|
+ }
|
|
|
+
|
|
|
+ echoedBytes := <-result
|
|
|
+
|
|
|
+ if !bytes.Equal(origBytes, echoedBytes) {
|
|
|
+ t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
type exitStatusMsg struct {
|
|
|
PeersId uint32
|
|
|
Request string
|
|
@@ -509,7 +570,7 @@ func shellHandler(ch *serverChan, t *testing.T) {
|
|
|
|
|
|
func readLine(shell *ServerTerminal, t *testing.T) {
|
|
|
if _, err := shell.ReadLine(); err != nil && err != io.EOF {
|
|
|
- t.Fatalf("unable to read line: %v", err)
|
|
|
+ t.Errorf("unable to read line: %v", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -567,9 +628,11 @@ func discardHandler(ch *serverChan, t *testing.T) {
|
|
|
|
|
|
|
|
|
ch.sendWindowAdj(1024 * 1024)
|
|
|
- shell := newServerShell(ch, "> ")
|
|
|
- readLine(shell, t)
|
|
|
- io.Copy(ioutil.Discard, ch.serverConn)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ io.Copy(ioutil.Discard, ch)
|
|
|
}
|
|
|
|
|
|
func largeSendHandler(ch *serverChan, t *testing.T) {
|
|
@@ -585,3 +648,40 @@ func largeSendHandler(ch *serverChan, t *testing.T) {
|
|
|
t.Errorf("wrote packet larger than 32k")
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func echoHandler(ch *serverChan, t *testing.T) {
|
|
|
+ defer ch.Close()
|
|
|
+ if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
|
|
|
+ t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
|
|
|
+ var (
|
|
|
+ buf = make([]byte, 32*1024)
|
|
|
+ written int
|
|
|
+ remaining = n
|
|
|
+ )
|
|
|
+ for remaining > 0 {
|
|
|
+ l := rand.Intn(1 << 15)
|
|
|
+ if remaining < l {
|
|
|
+ l = remaining
|
|
|
+ }
|
|
|
+ nr, er := src.Read(buf[:l])
|
|
|
+ nw, ew := dst.Write(buf[:nr])
|
|
|
+ remaining -= nw
|
|
|
+ written += nw
|
|
|
+ if ew != nil {
|
|
|
+ return written, ew
|
|
|
+ }
|
|
|
+ if nr != nw {
|
|
|
+ return written, io.ErrShortWrite
|
|
|
+ }
|
|
|
+ if er != nil && er != io.EOF {
|
|
|
+ return written, er
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return written, nil
|
|
|
+}
|