瀏覽代碼

go.crypto/ssh: Use net.UnixConn for connecting client and sshd.

This obviates custom code to emulate a thread-safe connection.

Use this for testing that listeners close if the connection breaks.

R=dave, agl, fullung
CC=golang-dev
https://golang.org/cl/11781043
Han-Wen Nienhuys 12 年之前
父節點
當前提交
a93ee0c91a
共有 2 個文件被更改,包括 78 次插入83 次删除
  1. 34 2
      ssh/test/forward_unix_test.go
  2. 44 81
      ssh/test/test_unix_test.go

+ 34 - 2
ssh/test/forward_unix_test.go

@@ -124,5 +124,37 @@ func TestAcceptClose(t *testing.T) {
 	}
 }
 
-// TODO(hanwen): test that closing the connection also
-// exits the listeners.
+// Check that listeners exit if the underlying client transport dies.
+func TestPortForwardConnectionClose(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+	conn := server.Dial(clientConfig())
+
+	sshListener, err := conn.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	quit := make(chan error, 1)
+	go func() {
+		for {
+			c, err := sshListener.Accept()
+			if err != nil {
+				quit <- err
+				break
+			}
+			c.Close()
+		}
+	}()
+
+	// It would be even nicer if we closed the server side, but it
+	// is more involved as the fd for that side is dup()ed.
+	server.clientConn.Close()
+
+	select {
+	case <-time.After(1 * time.Second):
+		t.Errorf("timeout: listener did not close.")
+	case err := <-quit:
+		t.Logf("quit as expected (error %v)", err)
+	}
+}

+ 44 - 81
ssh/test/test_unix_test.go

@@ -24,10 +24,8 @@ import (
 	"os/exec"
 	"os/user"
 	"path/filepath"
-	"sync"
 	"testing"
 	"text/template"
-	"time"
 
 	"code.google.com/p/go.crypto/ssh"
 )
@@ -82,6 +80,9 @@ type server struct {
 	configfile string
 	cmd        *exec.Cmd
 	output     bytes.Buffer // holds stderr from sshd process
+
+	// Client half of the network connection.
+	clientConn net.Conn
 }
 
 func username() string {
@@ -135,30 +136,62 @@ func clientConfig() *ssh.ClientConfig {
 	return config
 }
 
+// unixConnection creates two halves of a connected net.UnixConn.  It
+// is used for connecting the Go SSH client with sshd without opening
+// ports.
+func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
+	dir, err := ioutil.TempDir("", "unixConnection")
+	if err != nil {
+		return nil, nil, err
+	}
+	defer os.Remove(dir)
+
+	addr := filepath.Join(dir, "ssh")
+	listener, err := net.Listen("unix", addr)
+	if err != nil {
+		return nil, nil, err
+	}
+	defer listener.Close()
+	c1, err := net.Dial("unix", addr)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	c2, err := listener.Accept()
+	if err != nil {
+		c1.Close()
+		return nil, nil, err
+	}
+
+	return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
+}
+
 func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
 	sshd, err := exec.LookPath("sshd")
 	if err != nil {
 		s.t.Skipf("skipping test: %v", err)
 	}
-	s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
-	r1, w1, err := os.Pipe()
+
+	c1, c2, err := unixConnection()
 	if err != nil {
-		s.t.Fatal(err)
+		s.t.Fatalf("unixConnection: %v", err)
 	}
-	s.cmd.Stdout = w1
-	r2, w2, err := os.Pipe()
+
+	s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
+	f, err := c2.File()
 	if err != nil {
-		s.t.Fatal(err)
+		s.t.Fatalf("UnixConn.File: %v", err)
 	}
-	s.cmd.Stdin = r2
+	s.cmd.Stdin = f
+	s.cmd.Stdout = f
 	s.cmd.Stderr = os.Stderr
 	if err := s.cmd.Start(); err != nil {
 		s.t.Fail()
 		s.Shutdown()
 		s.t.Fatalf("s.cmd.Start: %v", err)
 	}
-
-	return ssh.Client(&client{wc: w2, r: r1}, config)
+	s.clientConn = c1
+	return ssh.Client(c1, config)
 }
 
 func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
@@ -185,76 +218,6 @@ func (s *server) Shutdown() {
 	s.cleanup()
 }
 
-// client wraps a pair of Reader/WriteClosers to implement the
-// net.Conn interface. Importantly, client also mocks the
-// ability of net.Conn to support concurrent calls to Read/Write
-// and Close. See golang.org/issue/5138 for more details.
-type client struct {
-	wc         io.WriteCloser
-	r          io.Reader
-	sync.Mutex // protects refcount and closing
-	refcount   int
-	closing    bool
-}
-
-var errClosing = errors.New("use of closed network connection")
-
-func (c *client) LocalAddr() net.Addr              { return nil }
-func (c *client) RemoteAddr() net.Addr             { return nil }
-func (c *client) SetDeadline(time.Time) error      { return nil }
-func (c *client) SetReadDeadline(time.Time) error  { return nil }
-func (c *client) SetWriteDeadline(time.Time) error { return nil }
-
-// incref, decref are copied from the net package (see net/fd_unix.go) to
-// implement the concurrent Close contract that net.Conn implementations
-// from that that package provide.
-
-func (c *client) incRef(closing bool) error {
-	c.Lock()
-	defer c.Unlock()
-	if c.closing {
-		return errClosing
-	}
-	c.refcount++
-	if closing {
-		c.closing = true
-	}
-	return nil
-}
-
-func (c *client) decRef() {
-	c.Lock()
-	defer c.Unlock()
-	c.refcount--
-	if c.closing && c.refcount == 0 {
-		c.wc.Close()
-	}
-}
-
-func (c *client) Close() error {
-	if err := c.incRef(true); err != nil {
-		return err
-	}
-	c.decRef()
-	return nil
-}
-
-func (c *client) Read(b []byte) (int, error) {
-	if err := c.incRef(false); err != nil {
-		return 0, err
-	}
-	defer c.decRef()
-	return c.r.Read(b)
-}
-
-func (c *client) Write(b []byte) (int, error) {
-	if err := c.incRef(false); err != nil {
-		return 0, err
-	}
-	defer c.decRef()
-	return c.wc.Write(b)
-}
-
 // newServer returns a new mock ssh server.
 func newServer(t *testing.T) *server {
 	dir, err := ioutil.TempDir("", "sshtest")