소스 검색

go.crypto/ssh: close channel feeding tcpListener.

Close both on closing the listener, and on closing the
connection. Test the former case.

R=dave
CC=golang-dev
https://golang.org/cl/11349043
Han-Wen Nienhuys 12 년 전
부모
커밋
7f7cbbf18e
3개의 변경된 파일68개의 추가작업 그리고 8개의 파일을 삭제
  1. 2 1
      ssh/client.go
  2. 13 0
      ssh/tcpip.go
  3. 53 7
      ssh/test/forward_unix_test.go

+ 2 - 1
ssh/client.go

@@ -210,7 +210,8 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 func (c *ClientConn) mainLoop() {
 	defer func() {
 		c.Close()
-		c.closeAll()
+		c.chanList.closeAll()
+		c.forwardList.closeAll()
 	}()
 
 	for {

+ 13 - 0
ssh/tcpip.go

@@ -101,17 +101,30 @@ func (l *forwardList) add(addr net.TCPAddr) chan forward {
 	return f.c
 }
 
+// remove removes the forward entry, and the channel feeding its
+// listener.
 func (l *forwardList) remove(addr net.TCPAddr) {
 	l.Lock()
 	defer l.Unlock()
 	for i, f := range l.entries {
 		if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
 			l.entries = append(l.entries[:i], l.entries[i+1:]...)
+			close(f.c)
 			return
 		}
 	}
 }
 
+// closeAll closes and clears all forwards.
+func (l *forwardList) closeAll() {
+	l.Lock()
+	defer l.Unlock()
+	for _, f := range l.entries {
+		close(f.c)
+	}
+	l.entries = nil
+}
+
 func (l *forwardList) lookup(addr net.TCPAddr) (chan forward, bool) {
 	l.Lock()
 	defer l.Unlock()

+ 53 - 7
ssh/test/forward_unix_test.go

@@ -14,14 +14,12 @@ import (
 	"math/rand"
 	"net"
 	"testing"
-)
+	"time"
 
-func TestPortForward(t *testing.T) {
-	server := newServer(t)
-	defer server.Shutdown()
-	conn := server.Dial(clientConfig())
-	defer conn.Close()
+	"code.google.com/p/go.crypto/ssh"
+)
 
+func listenSSHAuto(conn *ssh.ClientConn) (net.Listener, error) {
 	var sshListener net.Listener
 	var err error
 	tries := 10
@@ -38,7 +36,21 @@ func TestPortForward(t *testing.T) {
 	}
 
 	if err != nil {
-		t.Fatalf("conn.Listen failed: %v (after %d tries)", err, tries)
+		return nil, fmt.Errorf("conn.Listen failed: %v (after %d tries)", err, tries)
+	}
+
+	return sshListener, nil
+}
+
+func TestPortForward(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+	conn := server.Dial(clientConfig())
+	defer conn.Close()
+
+	sshListener, err := listenSSHAuto(conn)
+	if err != nil {
+		t.Fatal(err)
 	}
 
 	go func() {
@@ -106,3 +118,37 @@ func TestPortForward(t *testing.T) {
 		t.Errorf("still listening to %s after closing", forwardedAddr)
 	}
 }
+
+func TestAcceptClose(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+	conn := server.Dial(clientConfig())
+
+	sshListener, err := listenSSHAuto(conn)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	quit := make(chan error, 1)
+	go func() {
+		for {
+			c, err := sshListener.Accept()
+			if err != nil {
+				quit <- err
+				break
+			}
+			c.Close()
+		}
+	}()
+	sshListener.Close()
+
+	select {
+	case <-time.After(1 * time.Second):
+		t.Errorf("timeout: listener did not close.")
+	case err := <-quit:
+		t.Logf("quit as expected (error %v)", err)
+	}
+}
+
+// TODO(hanwen): test that closing the connection also
+// exits the listeners.