|
|
@@ -24,6 +24,7 @@ import (
|
|
|
"os/exec"
|
|
|
"os/user"
|
|
|
"path/filepath"
|
|
|
+ "sync"
|
|
|
"testing"
|
|
|
"text/template"
|
|
|
"time"
|
|
|
@@ -103,22 +104,23 @@ func clientConfig() *ssh.ClientConfig {
|
|
|
|
|
|
func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
|
|
|
s.cmd = exec.Command("sshd", "-f", s.configfile, "-i")
|
|
|
- stdin, err := s.cmd.StdinPipe()
|
|
|
+ r1, w1, err := os.Pipe()
|
|
|
if err != nil {
|
|
|
s.t.Fatal(err)
|
|
|
}
|
|
|
- stdout, err := s.cmd.StdoutPipe()
|
|
|
+ s.cmd.Stdout = w1
|
|
|
+ r2, w2, err := os.Pipe()
|
|
|
if err != nil {
|
|
|
s.t.Fatal(err)
|
|
|
}
|
|
|
- s.cmd.Stderr = os.Stderr // &s.output
|
|
|
- err = s.cmd.Start()
|
|
|
- if err != nil {
|
|
|
+ s.cmd.Stdin = r2
|
|
|
+ s.cmd.Stderr = os.Stderr
|
|
|
+ if err := s.cmd.Start(); err != nil {
|
|
|
s.t.Fail()
|
|
|
s.Shutdown()
|
|
|
s.t.Fatalf("s.cmd.Start: %v", err)
|
|
|
}
|
|
|
- conn, err := ssh.Client(&client{stdin, stdout}, config)
|
|
|
+ conn, err := ssh.Client(&client{wc: w2, r: r1}, config)
|
|
|
if err != nil {
|
|
|
s.t.Fail()
|
|
|
s.Shutdown()
|
|
|
@@ -136,24 +138,81 @@ func (s *server) Shutdown() {
|
|
|
}
|
|
|
if s.t.Failed() {
|
|
|
// log any output from sshd process
|
|
|
- s.t.Log(s.output.String())
|
|
|
+ s.t.Logf("sshd: %q", s.output.String())
|
|
|
}
|
|
|
s.cleanup()
|
|
|
}
|
|
|
|
|
|
// client wraps a pair of Reader/WriteClosers to implement the
|
|
|
-// net.Conn interface.
|
|
|
+// net.Conn interface. Importantly, client also mocks the
|
|
|
+// ability of net.Conn to support concurrent calls to Read/Write
|
|
|
+// and Close. See golang.org/issue/5138 for more details.
|
|
|
type client struct {
|
|
|
- io.WriteCloser
|
|
|
- io.Reader
|
|
|
+ wc io.WriteCloser
|
|
|
+ r io.Reader
|
|
|
+ sync.Mutex // protects refcount and closing
|
|
|
+ refcount int
|
|
|
+ closing bool
|
|
|
}
|
|
|
|
|
|
+var errClosing = errors.New("use of closed network connection")
|
|
|
+
|
|
|
func (c *client) LocalAddr() net.Addr { return nil }
|
|
|
func (c *client) RemoteAddr() net.Addr { return nil }
|
|
|
func (c *client) SetDeadline(time.Time) error { return nil }
|
|
|
func (c *client) SetReadDeadline(time.Time) error { return nil }
|
|
|
func (c *client) SetWriteDeadline(time.Time) error { return nil }
|
|
|
|
|
|
+// incref, decref are copied from the net package (see net/fd_unix.go) to
|
|
|
+// implement the concurrent Close contract that net.Conn implementations
|
|
|
+// from that that package provide.
|
|
|
+
|
|
|
+func (c *client) incRef(closing bool) error {
|
|
|
+ c.Lock()
|
|
|
+ defer c.Unlock()
|
|
|
+ if c.closing {
|
|
|
+ return errClosing
|
|
|
+ }
|
|
|
+ c.refcount++
|
|
|
+ if closing {
|
|
|
+ c.closing = true
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *client) decRef() {
|
|
|
+ c.Lock()
|
|
|
+ defer c.Unlock()
|
|
|
+ c.refcount--
|
|
|
+ if c.closing && c.refcount == 0 {
|
|
|
+ c.wc.Close()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *client) Close() error {
|
|
|
+ if err := c.incRef(true); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ c.decRef()
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *client) Read(b []byte) (int, error) {
|
|
|
+ if err := c.incRef(false); err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ defer c.decRef()
|
|
|
+ return c.r.Read(b)
|
|
|
+}
|
|
|
+
|
|
|
+func (c *client) Write(b []byte) (int, error) {
|
|
|
+ if err := c.incRef(false); err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ defer c.decRef()
|
|
|
+ return c.wc.Write(b)
|
|
|
+}
|
|
|
+
|
|
|
// newServer returns a new mock ssh server.
|
|
|
func newServer(t *testing.T) *server {
|
|
|
dir, err := ioutil.TempDir("", "sshtest")
|