瀏覽代碼

go.crypto/ssh: fix race on mock ssh network connection

Fixes golang/go#5138.
Fixes golang/go#4703.

This appears to pass my stress tests with and without the -race detector, but I'd like to see others hit it with their machines.

R=golang-dev, fullung, huin, kardianos, agl
CC=golang-dev
https://golang.org/cl/9929043
Dave Cheney 12 年之前
父節點
當前提交
e8889f5e72
共有 1 個文件被更改,包括 69 次插入10 次删除
  1. 69 10
      ssh/test/test_unix_test.go

+ 69 - 10
ssh/test/test_unix_test.go

@@ -24,6 +24,7 @@ import (
 	"os/exec"
 	"os/exec"
 	"os/user"
 	"os/user"
 	"path/filepath"
 	"path/filepath"
+	"sync"
 	"testing"
 	"testing"
 	"text/template"
 	"text/template"
 	"time"
 	"time"
@@ -103,22 +104,23 @@ func clientConfig() *ssh.ClientConfig {
 
 
 func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
 func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
 	s.cmd = exec.Command("sshd", "-f", s.configfile, "-i")
 	s.cmd = exec.Command("sshd", "-f", s.configfile, "-i")
-	stdin, err := s.cmd.StdinPipe()
+	r1, w1, err := os.Pipe()
 	if err != nil {
 	if err != nil {
 		s.t.Fatal(err)
 		s.t.Fatal(err)
 	}
 	}
-	stdout, err := s.cmd.StdoutPipe()
+	s.cmd.Stdout = w1
+	r2, w2, err := os.Pipe()
 	if err != nil {
 	if err != nil {
 		s.t.Fatal(err)
 		s.t.Fatal(err)
 	}
 	}
-	s.cmd.Stderr = os.Stderr // &s.output
-	err = s.cmd.Start()
-	if err != nil {
+	s.cmd.Stdin = r2
+	s.cmd.Stderr = os.Stderr
+	if err := s.cmd.Start(); err != nil {
 		s.t.Fail()
 		s.t.Fail()
 		s.Shutdown()
 		s.Shutdown()
 		s.t.Fatalf("s.cmd.Start: %v", err)
 		s.t.Fatalf("s.cmd.Start: %v", err)
 	}
 	}
-	conn, err := ssh.Client(&client{stdin, stdout}, config)
+	conn, err := ssh.Client(&client{wc: w2, r: r1}, config)
 	if err != nil {
 	if err != nil {
 		s.t.Fail()
 		s.t.Fail()
 		s.Shutdown()
 		s.Shutdown()
@@ -136,24 +138,81 @@ func (s *server) Shutdown() {
 	}
 	}
 	if s.t.Failed() {
 	if s.t.Failed() {
 		// log any output from sshd process
 		// log any output from sshd process
-		s.t.Log(s.output.String())
+		s.t.Logf("sshd: %q", s.output.String())
 	}
 	}
 	s.cleanup()
 	s.cleanup()
 }
 }
 
 
 // client wraps a pair of Reader/WriteClosers to implement the
 // client wraps a pair of Reader/WriteClosers to implement the
-// net.Conn interface.
+// net.Conn interface. Importantly, client also mocks the
+// ability of net.Conn to support concurrent calls to Read/Write
+// and Close. See golang.org/issue/5138 for more details.
 type client struct {
 type client struct {
-	io.WriteCloser
-	io.Reader
+	wc         io.WriteCloser
+	r          io.Reader
+	sync.Mutex // protects refcount and closing
+	refcount   int
+	closing    bool
 }
 }
 
 
+var errClosing = errors.New("use of closed network connection")
+
 func (c *client) LocalAddr() net.Addr              { return nil }
 func (c *client) LocalAddr() net.Addr              { return nil }
 func (c *client) RemoteAddr() net.Addr             { return nil }
 func (c *client) RemoteAddr() net.Addr             { return nil }
 func (c *client) SetDeadline(time.Time) error      { return nil }
 func (c *client) SetDeadline(time.Time) error      { return nil }
 func (c *client) SetReadDeadline(time.Time) error  { return nil }
 func (c *client) SetReadDeadline(time.Time) error  { return nil }
 func (c *client) SetWriteDeadline(time.Time) error { return nil }
 func (c *client) SetWriteDeadline(time.Time) error { return nil }
 
 
+// incref, decref are copied from the net package (see net/fd_unix.go) to
+// implement the concurrent Close contract that net.Conn implementations
+// from that that package provide.
+
+func (c *client) incRef(closing bool) error {
+	c.Lock()
+	defer c.Unlock()
+	if c.closing {
+		return errClosing
+	}
+	c.refcount++
+	if closing {
+		c.closing = true
+	}
+	return nil
+}
+
+func (c *client) decRef() {
+	c.Lock()
+	defer c.Unlock()
+	c.refcount--
+	if c.closing && c.refcount == 0 {
+		c.wc.Close()
+	}
+}
+
+func (c *client) Close() error {
+	if err := c.incRef(true); err != nil {
+		return err
+	}
+	c.decRef()
+	return nil
+}
+
+func (c *client) Read(b []byte) (int, error) {
+	if err := c.incRef(false); err != nil {
+		return 0, err
+	}
+	defer c.decRef()
+	return c.r.Read(b)
+}
+
+func (c *client) Write(b []byte) (int, error) {
+	if err := c.incRef(false); err != nil {
+		return 0, err
+	}
+	defer c.decRef()
+	return c.wc.Write(b)
+}
+
 // newServer returns a new mock ssh server.
 // newServer returns a new mock ssh server.
 func newServer(t *testing.T) *server {
 func newServer(t *testing.T) *server {
 	dir, err := ioutil.TempDir("", "sshtest")
 	dir, err := ioutil.TempDir("", "sshtest")