Преглед изворни кода

go.crypto/ssh: Read returns all unread bytes before returning io.EOF.

Fixes golang/go#4158.

R=dave, agl
CC=golang-dev
https://golang.org/cl/6586060
Gustav Paul пре 13 година
родитељ
комит
1aa39d6262
2 измењених фајлова са 60 додато и 4 уклоњено
  1. 11 4
      ssh/channel.go
  2. 49 0
      ssh/client_func_test.go

+ 11 - 4
ssh/channel.go

@@ -434,7 +434,7 @@ func (c *serverChan) ExtraData() []byte {
 	return c.extraData
 }
 
-// A clientChan represents a single RFC 4254 channel multiplexed 
+// A clientChan represents a single RFC 4254 channel multiplexed
 // over a SSH connection.
 type clientChan struct {
 	channel
@@ -502,8 +502,8 @@ func (c *clientChan) Close() error {
 // A chanWriter represents the stdin of a remote process.
 type chanWriter struct {
 	*channel
-	// indicates the writer has been closed. eof is owned by the 
-	// caller of Write/Close. 
+	// indicates the writer has been closed. eof is owned by the
+	// caller of Write/Close.
 	eof bool
 }
 
@@ -562,5 +562,12 @@ func (r *chanReader) Read(buf []byte) (int, error) {
 		}
 		return 0, err
 	}
-	return n, r.sendWindowAdj(n)
+	err = r.sendWindowAdj(n)
+	if err == io.EOF && n > 0 {
+		// sendWindowAdjust can return io.EOF if the remote peer has
+		// closed the connection, however we want to defer forwarding io.EOF to the
+		// caller of Read until the buffer has been drained.
+		err = nil
+	}
+	return n, err
 }

+ 49 - 0
ssh/client_func_test.go

@@ -10,7 +10,9 @@ package ssh
 // -ssh.user and -ssh.pass must be passed to gotest.
 
 import (
+	"bytes"
 	"flag"
+	"io"
 	"testing"
 )
 
@@ -59,3 +61,50 @@ func TestFuncPublickeyAuth(t *testing.T) {
 	}
 	defer conn.Close()
 }
+
+func TestFuncLargeRead(t *testing.T) {
+	if *sshuser == "" {
+		t.Log("ssh.user not defined, skipping test")
+		return
+	}
+	kc := new(keychain)
+	if err := kc.loadPEM(*sshprivkey); err != nil {
+		t.Fatalf("unable to load private key: %s", err)
+	}
+	config := &ClientConfig{
+		User: *sshuser,
+		Auth: []ClientAuth{
+			ClientAuthKeyring(kc),
+		},
+	}
+	conn, err := Dial("tcp", "localhost:22", config)
+	if err != nil {
+		t.Fatalf("unable to connect: %s", err)
+	}
+	defer conn.Close()
+
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("unable to create new session: %s", err)
+	}
+
+	stdout, err := session.StdoutPipe()
+	if err != nil {
+		t.Fatalf("unable to acquire stdout pipe: %s", err)
+	}
+
+	err = session.Start("dd if=/dev/urandom bs=2048 count=1")
+	if err != nil {
+		t.Fatalf("unable to execute remote command: %s", err)
+	}
+
+	buf := new(bytes.Buffer)
+	n, err := io.Copy(buf, stdout)
+	if err != nil {
+		t.Fatalf("error reading from remote stdout: %s", err)
+	}
+
+	if n != 2048 {
+		t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n)
+	}
+}