瀏覽代碼

go.crypto/ssh: improve TestServerWindow robustness

Fix a few resource leaks and prevent the test from
hanging if an error occurs reading from the remote
server.

R=agl, gustav.paul, kardianos
CC=golang-dev
https://golang.org/cl/6423065
Dave Cheney 13 年之前
父節點
當前提交
e751d5236a
共有 1 個文件被更改,包括 7 次插入8 次删除
  1. 7 8
      ssh/server_test.go

+ 7 - 8
ssh/server_test.go

@@ -17,18 +17,18 @@ const windowTestBytes = 16000 * 200
 
 // CopyNRandomly copies n bytes from src to dst. It uses a variable, and random,
 // buffer size to exercise more code paths.
-func CopyNRandomly(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
+func CopyNRandomly(dst io.Writer, src io.Reader, n int) (written int, err error) {
 	buf := make([]byte, 32*1024)
 	for written < n {
 		l := (rand.Intn(30) + 1) * 1024
-		if d := n - written; d < int64(l) {
-			l = int(d)
+		if d := n - written; d < l {
+			l = d
 		}
 		nr, er := src.Read(buf[0:l])
 		if nr > 0 {
 			nw, ew := dst.Write(buf[0:nr])
 			if nw > 0 {
-				written += int64(nw)
+				written += nw
 			}
 			if ew != nil {
 				err = ew
@@ -75,6 +75,7 @@ func runSSHClient(t *testing.T, addr string) {
 	// Read back the data from the server.
 	go func() {
 		defer session.Close()
+		defer close(wait)
 		serverStdout, err := session.StdoutPipe()
 		if err != nil {
 			t.Fatal(err)
@@ -87,7 +88,6 @@ func runSSHClient(t *testing.T, addr string) {
 		if n != windowTestBytes {
 			t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes)
 		}
-		wait <- true
 	}()
 
 	serverStdin, err := session.StdinPipe()
@@ -126,11 +126,10 @@ func startSSHServer(t *testing.T) (addr string) {
 	}
 
 	addr = listener.Addr().String()
-
 	go func() {
+		defer listener.Close()
 		for {
 			sConn, err := listener.Accept()
-
 			err = sConn.Handshake()
 			if err != nil {
 				if err != io.EOF {
@@ -147,6 +146,7 @@ func startSSHServer(t *testing.T) (addr string) {
 }
 
 func connRun(t *testing.T, sConn *ServerConn) {
+	defer sConn.Close()
 	for {
 		channel, err := sConn.Accept()
 		if err != nil {
@@ -167,7 +167,6 @@ func connRun(t *testing.T, sConn *ServerConn) {
 
 		go func() {
 			defer channel.Close()
-
 			n, err := CopyNRandomly(channel, channel, windowTestBytes)
 			if err != nil && err != io.EOF {
 				if err == io.ErrShortWrite {