浏览代码

go.crypto/ssh: improve channel max packet handling

This proposal moves the check for max packet into
channel.writePacket. Callers should be aware they cannot
pass a buffer larger than max packet. This is only a
concern to chanWriter.Write and appropriate guards are
already in place.

There was some max packet handling in transport.go but it was
incorrect. This has been removed.

This proposal also cleans up session_test.go.

R=gustav.paul, agl, fullung, huin
CC=golang-dev
https://golang.org/cl/6460075
Dave Cheney 12 年之前
父节点
当前提交
7343d5f584
共有 4 个文件被更改,包括 126 次插入76 次删除
  1. 24 8
      ssh/channel.go
  2. 4 2
      ssh/server.go
  3. 97 56
      ssh/session_test.go
  4. 1 10
      ssh/transport.go

+ 24 - 8
ssh/channel.go

@@ -6,6 +6,7 @@ package ssh
 
 import (
 	"errors"
+	"fmt"
 	"io"
 	"sync"
 )
@@ -14,8 +15,13 @@ import (
 // section 5.2.
 type extendedDataTypeCode uint32
 
-// extendedDataStderr is the extended data type that is used for stderr.
-const extendedDataStderr extendedDataTypeCode = 1
+const (
+	// extendedDataStderr is the extended data type that is used for stderr.
+	extendedDataStderr extendedDataTypeCode = 1
+
+	// minPacketLength defines the smallest valid packet
+	minPacketLength = 9
+)
 
 // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
 // SSH connection. Channel.Read can return a ChannelRequest as an error.
@@ -74,7 +80,7 @@ type channel struct {
 	conn              // the underlying transport
 	localId, remoteId uint32
 	remoteWin         window
-	maxPacketSize     uint32
+	maxPacket         uint32
 
 	theyClosed  bool // indicates the close msg has been received from the remote side
 	weClosed    bool // incidates the close msg has been sent from our side
@@ -114,6 +120,13 @@ func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string)
 	return c.writePacket(marshal(msgChannelOpenFailure, reject))
 }
 
+func (c *channel) writePacket(b []byte) error {
+	if uint32(len(b)) > c.maxPacket {
+		return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
+	}
+	return c.conn.writePacket(b)
+}
+
 type serverChan struct {
 	channel
 	// immutable once created
@@ -144,7 +157,7 @@ func (c *serverChan) Accept() error {
 		PeersId:       c.remoteId,
 		MyId:          c.localId,
 		MyWindow:      c.myWindow,
-		MaxPacketSize: c.maxPacketSize,
+		MaxPacketSize: c.maxPacket,
 	}
 	return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
 }
@@ -450,10 +463,12 @@ func newClientChan(cc conn, id uint32) *clientChan {
 func (c *clientChan) waitForChannelOpenResponse() error {
 	switch msg := (<-c.msg).(type) {
 	case *channelOpenConfirmMsg:
+		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+			return errors.New("ssh: invalid MaxPacketSize from peer")
+		}
 		// fixup remoteId field
 		c.remoteId = msg.MyId
-		// TODO(dfc) asset this is < 2^31.
-		c.maxPacketSize = msg.MaxPacketSize
+		c.maxPacket = msg.MaxPacketSize
 		c.remoteWin.add(msg.MyWindow)
 		return nil
 	case *channelOpenFailureMsg:
@@ -478,10 +493,11 @@ type chanWriter struct {
 
 // Write writes data to the remote process's standard input.
 func (w *chanWriter) Write(data []byte) (written int, err error) {
+	const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
 	for len(data) > 0 {
-		// never send more data than maxPacketSize even if
+		// never send more data than maxPacket even if
 		// there is sufficent window.
-		n := min(int(w.maxPacketSize), len(data))
+		n := min(int(w.maxPacket-headerLength), len(data))
 		n = int(w.remoteWin.reserve(uint32(n)))
 		remoteId := w.remoteId
 		packet := []byte{

+ 4 - 2
ssh/server.go

@@ -564,13 +564,15 @@ func (s *ServerConn) Accept() (Channel, error) {
 		default:
 			switch msg := decode(packet).(type) {
 			case *channelOpenMsg:
+				if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+					return nil, errors.New("ssh: invalid MaxPacketSize from peer")
+				}
 				c := &serverChan{
 					channel: channel{
 						conn:      s,
 						remoteId:  msg.PeersId,
 						remoteWin: window{Cond: newCond()},
-						// TODO(dfc) assert this param is < 2^31.
-						maxPacketSize: msg.MaxPacketSize,
+						maxPacket: msg.MaxPacketSize,
 					},
 					chanType:    msg.ChanType,
 					extraData:   msg.TypeSpecificData,

+ 97 - 56
ssh/session_test.go

@@ -16,7 +16,7 @@ import (
 	"code.google.com/p/go.crypto/ssh/terminal"
 )
 
-type serverType func(*serverChan)
+type serverType func(*serverChan, *testing.T)
 
 // dial constructs a new test server and returns a *ClientConn.
 func dial(handler serverType, t *testing.T) *ClientConn {
@@ -28,7 +28,7 @@ func dial(handler serverType, t *testing.T) *ClientConn {
 
 	l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
 	if err != nil {
-		t.Fatalf("unable to listen: %s", err)
+		t.Fatalf("unable to listen: %v", err)
 	}
 	go func() {
 		defer l.Close()
@@ -60,7 +60,7 @@ func dial(handler serverType, t *testing.T) *ClientConn {
 				continue
 			}
 			ch.Accept()
-			go handler(ch.(*serverChan))
+			go handler(ch.(*serverChan), t)
 		}
 		t.Log("done")
 	}()
@@ -74,7 +74,7 @@ func dial(handler serverType, t *testing.T) *ClientConn {
 
 	c, err := Dial("tcp", l.Addr().String(), config)
 	if err != nil {
-		t.Fatalf("unable to dial remote side: %s", err)
+		t.Fatalf("unable to dial remote side: %v", err)
 	}
 	return c
 }
@@ -85,7 +85,7 @@ func TestSessionShell(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	stdout := new(bytes.Buffer)
@@ -94,7 +94,7 @@ func TestSessionShell(t *testing.T) {
 		t.Fatalf("Unable to execute command: %s", err)
 	}
 	if err := session.Wait(); err != nil {
-		t.Fatalf("Remote command did not exit cleanly: %s", err)
+		t.Fatalf("Remote command did not exit cleanly: %v", err)
 	}
 	actual := stdout.String()
 	if actual != "golang" {
@@ -110,7 +110,7 @@ func TestSessionStdoutPipe(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	stdout, err := session.StdoutPipe()
@@ -119,7 +119,7 @@ func TestSessionStdoutPipe(t *testing.T) {
 	}
 	var buf bytes.Buffer
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	done := make(chan bool, 1)
 	go func() {
@@ -129,7 +129,7 @@ func TestSessionStdoutPipe(t *testing.T) {
 		done <- true
 	}()
 	if err := session.Wait(); err != nil {
-		t.Fatalf("Remote command did not exit cleanly: %s", err)
+		t.Fatalf("Remote command did not exit cleanly: %v", err)
 	}
 	<-done
 	actual := buf.String()
@@ -144,11 +144,11 @@ func TestExitStatusNonZero(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -159,7 +159,7 @@ func TestExitStatusNonZero(t *testing.T) {
 		t.Fatalf("expected *ExitError but got %T", err)
 	}
 	if e.ExitStatus() != 15 {
-		t.Fatalf("expected command to exit with 15 but got %s", e.ExitStatus())
+		t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
 	}
 }
 
@@ -169,16 +169,16 @@ func TestExitStatusZero(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err != nil {
-		t.Fatalf("expected nil but got %s", err)
+		t.Fatalf("expected nil but got %v", err)
 	}
 }
 
@@ -188,11 +188,11 @@ func TestExitSignalAndStatus(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -213,11 +213,11 @@ func TestKnownExitSignalOnly(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -238,11 +238,11 @@ func TestUnknownExitSignal(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -263,11 +263,11 @@ func TestExitWithoutStatusOrSignal(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err == nil {
@@ -286,7 +286,7 @@ func TestInvalidServerMessage(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	// Make sure that we closed all the clientChans when the connection
 	// failed.
@@ -302,16 +302,16 @@ func TestClientZeroWindowAdjust(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	err = session.Wait()
 	if err != nil {
-		t.Fatalf("expected nil but got %s", err)
+		t.Fatalf("expected nil but got %v", err)
 	}
 }
 
@@ -322,12 +322,12 @@ func TestServerZeroWindowAdjust(t *testing.T) {
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 
 	// send a bogus zero sized window update
@@ -335,21 +335,21 @@ func TestServerZeroWindowAdjust(t *testing.T) {
 
 	err = session.Wait()
 	if err != nil {
-		t.Fatalf("expected nil but got %s", err)
+		t.Fatalf("expected nil but got %v", err)
 	}
 }
 
-// Verify that we never send a packet larger than maxpacket.
+// Verify that the client never sends a packet larger than maxpacket.
 func TestClientStdinRespectsMaxPacketSize(t *testing.T) {
 	conn := dial(discardHandler, t)
 	defer conn.Close()
 	session, err := conn.NewSession()
 	if err != nil {
-		t.Fatalf("Unable to request new session: %s", err)
+		t.Fatalf("Unable to request new session: %v", err)
 	}
 	defer session.Close()
 	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %s", err)
+		t.Fatalf("Unable to execute command: %v", err)
 	}
 	// try to stuff 128k of data into a 32k hole.
 	const size = 128 * 1024
@@ -359,6 +359,27 @@ func TestClientStdinRespectsMaxPacketSize(t *testing.T) {
 	}
 }
 
+// Verify that the client never accepts a packet larger than maxpacket.
+func TestServerStdoutRespectsMaxPacketSize(t *testing.T) {
+	conn := dial(largeSendHandler, t)
+	defer conn.Close()
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("Unable to request new session: %v", err)
+	}
+	defer session.Close()
+	out, err := session.StdoutPipe()
+	if err != nil {
+		t.Fatalf("Unable to connect to Stdout: %v", err)
+	}
+	if err := session.Shell(); err != nil {
+		t.Fatalf("Unable to execute command: %v", err)
+	}
+	if _, err := ioutil.ReadAll(out); err != nil {
+		t.Fatalf("failed to read: %v", err)
+	}
+}
+
 type exitStatusMsg struct {
 	PeersId   uint32
 	Request   string
@@ -384,68 +405,70 @@ func newServerShell(ch *serverChan, prompt string) *ServerTerminal {
 	}
 }
 
-func exitStatusZeroHandler(ch *serverChan) {
+func exitStatusZeroHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	// this string is returned to stdout
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendStatus(0, ch)
+	sendStatus(0, ch, t)
 }
 
-func exitStatusNonZeroHandler(ch *serverChan) {
+func exitStatusNonZeroHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendStatus(15, ch)
+	sendStatus(15, ch, t)
 }
 
-func exitSignalAndStatusHandler(ch *serverChan) {
+func exitSignalAndStatusHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendStatus(15, ch)
-	sendSignal("TERM", ch)
+	sendStatus(15, ch, t)
+	sendSignal("TERM", ch, t)
 }
 
-func exitSignalHandler(ch *serverChan) {
+func exitSignalHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendSignal("TERM", ch)
+	sendSignal("TERM", ch, t)
 }
 
-func exitSignalUnknownHandler(ch *serverChan) {
+func exitSignalUnknownHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendSignal("SYS", ch)
+	sendSignal("SYS", ch, t)
 }
 
-func exitWithoutSignalOrStatus(ch *serverChan) {
+func exitWithoutSignalOrStatus(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
 }
 
-func shellHandler(ch *serverChan) {
+func shellHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	// this string is returned to stdout
 	shell := newServerShell(ch, "golang")
 	shell.ReadLine()
-	sendStatus(0, ch)
+	sendStatus(0, ch, t)
 }
 
-func sendStatus(status uint32, ch *serverChan) {
+func sendStatus(status uint32, ch *serverChan, t *testing.T) {
 	msg := exitStatusMsg{
 		PeersId:   ch.remoteId,
 		Request:   "exit-status",
 		WantReply: false,
 		Status:    status,
 	}
-	ch.serverConn.writePacket(marshal(msgChannelRequest, msg))
+	if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil {
+		t.Errorf("unable to send status: %v", err)
+	}
 }
 
-func sendSignal(signal string, ch *serverChan) {
+func sendSignal(signal string, ch *serverChan, t *testing.T) {
 	sig := exitSignalMsg{
 		PeersId:    ch.remoteId,
 		Request:    "exit-signal",
@@ -455,10 +478,12 @@ func sendSignal(signal string, ch *serverChan) {
 		Errmsg:     "Process terminated",
 		Lang:       "en-GB-oed",
 	}
-	ch.serverConn.writePacket(marshal(msgChannelRequest, sig))
+	if err := ch.writePacket(marshal(msgChannelRequest, sig)); err != nil {
+		t.Errorf("unable to send signal: %v", err)
+	}
 }
 
-func sendInvalidRecord(ch *serverChan) {
+func sendInvalidRecord(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	packet := make([]byte, 1+4+4+1)
 	packet[0] = msgChannelData
@@ -466,19 +491,21 @@ func sendInvalidRecord(ch *serverChan) {
 	marshalUint32(packet[5:], 1)
 	packet[9] = 42
 
-	ch.serverConn.writePacket(packet)
+	if err := ch.writePacket(packet); err != nil {
+		t.Errorf("unable send invalid record: %v", err)
+	}
 }
 
-func sendZeroWindowAdjust(ch *serverChan) {
+func sendZeroWindowAdjust(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	// send a bogus zero sized window update
 	ch.sendWindowAdj(0)
 	shell := newServerShell(ch, "> ")
 	shell.ReadLine()
-	sendStatus(0, ch)
+	sendStatus(0, ch, t)
 }
 
-func discardHandler(ch *serverChan) {
+func discardHandler(ch *serverChan, t *testing.T) {
 	defer ch.Close()
 	// grow the window to avoid being fooled by
 	// the initial 1 << 14 window.
@@ -487,3 +514,17 @@ func discardHandler(ch *serverChan) {
 	shell.ReadLine()
 	io.Copy(ioutil.Discard, ch.serverConn)
 }
+
+func largeSendHandler(ch *serverChan, t *testing.T) {
+	defer ch.Close()
+	// grow the window to avoid being fooled by
+	// the initial 1 << 14 window.
+	ch.sendWindowAdj(1024 * 1024)
+	shell := newServerShell(ch, "> ")
+	shell.ReadLine()
+	// try to send more than the 32k window
+	// will allow
+	if err := ch.writePacket(make([]byte, 128*1024)); err == nil {
+		t.Errorf("wrote packet larger than 32k")
+	}
+}

+ 1 - 10
ssh/transport.go

@@ -19,9 +19,6 @@ import (
 
 const (
 	packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
-	minPacketSize      = 16
-	maxPacketSize      = 36000
-	minPaddingSize     = 4 // TODO(huin) should this be configurable?
 )
 
 // conn represents an ssh transport that implements packet based
@@ -97,9 +94,6 @@ func (r *reader) readOnePacket() ([]byte, error) {
 	if length <= paddingLength+1 {
 		return nil, errors.New("ssh: invalid packet length")
 	}
-	if length > maxPacketSize {
-		return nil, errors.New("ssh: packet too large")
-	}
 
 	packet := make([]byte, length-1+macSize)
 	if _, err := io.ReadFull(r, packet); err != nil {
@@ -196,11 +190,8 @@ func (w *writer) writePacket(packet []byte) error {
 		}
 	}
 
-	if err := w.Flush(); err != nil {
-		return err
-	}
 	w.seqNum++
-	return err
+	return w.Flush()
 }
 
 // Send a message to the remote peer