Browse Source

go.crypto/ssh: improve test reliability

Fixes golang/go#3989.

Tested for several hours on an 8 core ec2 instance with
random GOMAXPROC values.

Also, rolls server_test.go into session_test using the
existing dial() framework.

R=fullung, agl, kardianos
CC=golang-dev
https://golang.org/cl/6475063
Dave Cheney 12 years ago
parent
commit
06790d30c2
2 changed files with 115 additions and 194 deletions
  1. 0 179
      ssh/server_test.go
  2. 115 15
      ssh/session_test.go

+ 0 - 179
ssh/server_test.go

@@ -1,179 +0,0 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package ssh
-
-import (
-	"bytes"
-	crypto_rand "crypto/rand"
-	"io"
-	"math/rand"
-	"testing"
-)
-
-// windowTestBytes is the number of bytes that we'll send to the SSH server.
-const windowTestBytes = 16000 * 200
-
-// CopyNRandomly copies n bytes from src to dst. It uses a variable, and random,
-// buffer size to exercise more code paths.
-func CopyNRandomly(dst io.Writer, src io.Reader, n int) (written int, err error) {
-	buf := make([]byte, 32*1024)
-	for written < n {
-		l := (rand.Intn(30) + 1) * 1024
-		if d := n - written; d < l {
-			l = d
-		}
-		nr, er := src.Read(buf[0:l])
-		if nr > 0 {
-			nw, ew := dst.Write(buf[0:nr])
-			if nw > 0 {
-				written += nw
-			}
-			if ew != nil {
-				err = ew
-				break
-			}
-			if nr != nw {
-				err = io.ErrShortWrite
-				break
-			}
-		}
-		if er != nil {
-			err = er
-			break
-		}
-	}
-	return written, err
-}
-
-func TestServerWindow(t *testing.T) {
-	addr := startSSHServer(t)
-	runSSHClient(t, addr)
-}
-
-// runSSHClient writes random data to the server. The server is expected to echo
-// the same data back, which is compared against the original.
-func runSSHClient(t *testing.T, addr string) {
-	conn, err := Dial("tcp", addr, &ClientConfig{})
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	session, err := conn.NewSession()
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
-	echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
-	io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
-	origBytes := origBuf.Bytes()
-
-	wait := make(chan bool)
-
-	// Read back the data from the server.
-	go func() {
-		defer session.Close()
-		defer close(wait)
-		serverStdout, err := session.StdoutPipe()
-		if err != nil {
-			t.Fatal(err)
-		}
-
-		n, err := CopyNRandomly(echoedBuf, serverStdout, windowTestBytes)
-		if err != nil && err != io.EOF {
-			t.Fatal(err)
-		}
-		if n != windowTestBytes {
-			t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes)
-		}
-	}()
-
-	serverStdin, err := session.StdinPipe()
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	written, err := CopyNRandomly(serverStdin, origBuf, windowTestBytes)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if written != windowTestBytes {
-		t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
-	}
-
-	<-wait
-
-	if !bytes.Equal(origBytes, echoedBuf.Bytes()) {
-		t.Error("Echoed buffer differed from original")
-	}
-}
-
-func startSSHServer(t *testing.T) (addr string) {
-	config := &ServerConfig{
-		NoClientAuth: true,
-	}
-
-	err := config.SetRSAPrivateKey([]byte(testServerPrivateKey))
-	if err != nil {
-		t.Fatalf("Failed to parse private key: %s", err.Error())
-	}
-
-	listener, err := Listen("tcp", "127.0.0.1:0", config)
-	if err != nil {
-		t.Fatalf("Bind error: %s", err)
-	}
-
-	addr = listener.Addr().String()
-	go func() {
-		defer listener.Close()
-		for {
-			sConn, err := listener.Accept()
-			err = sConn.Handshake()
-			if err != nil {
-				if err != io.EOF {
-					t.Fatalf("failed to handshake: %s", err)
-				}
-				return
-			}
-
-			go connRun(t, sConn)
-		}
-	}()
-
-	return
-}
-
-func connRun(t *testing.T, sConn *ServerConn) {
-	defer sConn.Close()
-	for {
-		channel, err := sConn.Accept()
-		if err != nil {
-			if err == io.EOF {
-				break
-			}
-			t.Fatalf("ServerConn.Accept failed: %s", err)
-		}
-
-		if channel.ChannelType() != "session" {
-			channel.Reject(UnknownChannelType, "unknown channel type")
-			continue
-		}
-		err = channel.Accept()
-		if err != nil {
-			t.Fatalf("Channel.Accept failed: %s", err)
-		}
-
-		go func() {
-			defer channel.Close()
-			n, err := CopyNRandomly(channel, channel, windowTestBytes)
-			if err != nil && err != io.EOF {
-				if err == io.ErrShortWrite {
-					t.Fatalf("short write, wrote %d, expected %d", n, windowTestBytes)
-				}
-				t.Fatal(err)
-			}
-		}()
-	}
-}

+ 115 - 15
ssh/session_test.go

@@ -8,8 +8,10 @@ package ssh
 
 import (
 	"bytes"
+	crypto_rand "crypto/rand"
 	"io"
 	"io/ioutil"
+	"math/rand"
 	"net"
 	"testing"
 
@@ -42,6 +44,7 @@ func dial(handler serverType, t *testing.T) *ClientConn {
 			t.Errorf("Unable to handshake: %v", err)
 			return
 		}
+		done := make(chan struct{})
 		for {
 			ch, err := conn.Accept()
 			if err == io.EOF {
@@ -60,9 +63,12 @@ func dial(handler serverType, t *testing.T) *ClientConn {
 				continue
 			}
 			ch.Accept()
-			go handler(ch.(*serverChan), t)
+			go func() {
+				defer close(done)
+				handler(ch.(*serverChan), t)
+			}()
 		}
-		t.Log("done")
+		<-done
 	}()
 
 	config := &ClientConfig{
@@ -345,17 +351,19 @@ func TestClientStdinRespectsMaxPacketSize(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %v", err)
+		t.Fatalf("failed to request new session: %v", err)
 	}
 	defer session.Close()
-	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %v", err)
+	stdin, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("failed to obtain stdinpipe: %v", err)
 	}
-	// try to stuff 128k of data into a 32k hole.
-	const size = 128 * 1024
-	n, err := session.clientChan.stdin.Write(make([]byte, size))
-	if n != size || err != nil {
-		t.Fatalf("failed to write: %d, %v", n, err)
+	const size = 100 * 1000
+	for i := 0; i < 10; i++ {
+		n, err := stdin.Write(make([]byte, size))
+		if n != size || err != nil {
+			t.Fatalf("failed to write: %d, %v", n, err)
+		}
 	}
 }
 
@@ -385,7 +393,7 @@ func TestClientCannotSendAfterEOF(t *testing.T) {
 	t.Logf("test skipped")
 	return
 
-	conn := dial(shellHandler, t)
+	conn := dial(exitWithoutSignalOrStatus, t)
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
@@ -431,6 +439,59 @@ func TestClientCannotSendAfterClose(t *testing.T) {
 	}
 }
 
+// windowTestBytes is the number of bytes that we'll send to the SSH server.
+const windowTestBytes = 16000 * 200
+
+// TestServerWindow writes random data to the server. The server is expected to echo
+// the same data back, which is compared against the original.
+func TestServerWindow(t *testing.T) {
+	origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
+	io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
+	origBytes := origBuf.Bytes()
+
+	conn := dial(echoHandler, t)
+	defer conn.Close()
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer session.Close()
+	result := make(chan []byte)
+
+	go func() {
+		defer close(result)
+		echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
+		serverStdout, err := session.StdoutPipe()
+		if err != nil {
+			t.Errorf("StdoutPipe failed: %v", err)
+			return
+		}
+		n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
+		if err != nil && err != io.EOF {
+			t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
+		}
+		result <- echoedBuf.Bytes()
+	}()
+
+	serverStdin, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("StdinPipe failed: %v", err)
+	}
+	written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
+	if err != nil {
+		t.Fatalf("falied to copy origBuf to serverStdin: %v", err)
+	}
+	if written != windowTestBytes {
+		t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
+	}
+
+	echoedBytes := <-result
+
+	if !bytes.Equal(origBytes, echoedBytes) {
+		t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
+	}
+}
+
 type exitStatusMsg struct {
 	PeersId   uint32
 	Request   string
@@ -509,7 +570,7 @@ func shellHandler(ch *serverChan, t *testing.T) {
 
 func readLine(shell *ServerTerminal, t *testing.T) {
 	if _, err := shell.ReadLine(); err != nil && err != io.EOF {
-		t.Fatalf("unable to read line: %v", err)
+		t.Errorf("unable to read line: %v", err)
 	}
 }
 
@@ -567,9 +628,11 @@ func discardHandler(ch *serverChan, t *testing.T) {
 	// grow the window to avoid being fooled by
 	// the initial 1 << 14 window.
 	ch.sendWindowAdj(1024 * 1024)
-	shell := newServerShell(ch, "> ")
-	readLine(shell, t)
-	io.Copy(ioutil.Discard, ch.serverConn)
+	// TODO(dfc) io.Copy can return a non EOF error here
+	// because closed channel errors can leak here if the
+	// read from ch causes a window adjustment after the 
+	// remote has signaled close.
+	io.Copy(ioutil.Discard, ch)
 }
 
 func largeSendHandler(ch *serverChan, t *testing.T) {
@@ -585,3 +648,40 @@ func largeSendHandler(ch *serverChan, t *testing.T) {
 		t.Errorf("wrote packet larger than 32k")
 	}
 }
+
+func echoHandler(ch *serverChan, t *testing.T) {
+	defer ch.Close()
+	if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
+		t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
+	}
+}
+
+// copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
+// buffer size to exercise more code paths.
+func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
+	var (
+		buf       = make([]byte, 32*1024)
+		written   int
+		remaining = n
+	)
+	for remaining > 0 {
+		l := rand.Intn(1 << 15)
+		if remaining < l {
+			l = remaining
+		}
+		nr, er := src.Read(buf[:l])
+		nw, ew := dst.Write(buf[:nr])
+		remaining -= nw
+		written += nw
+		if ew != nil {
+			return written, ew
+		}
+		if nr != nw {
+			return written, io.ErrShortWrite
+		}
+		if er != nil && er != io.EOF {
+			return written, er
+		}
+	}
+	return written, nil
+}