Selaa lähdekoodia

go.crypto/ssh: fix and test port forwarding.

Set maxPacket in forwarded connection, and use the requested port
number as key in forwardList.

R=golang-dev, agl, dave
CC=golang-dev
https://golang.org/cl/9753044
Han-Wen Nienhuys 12 vuotta sitten
vanhempi
commit
0d8dc3cd6a
3 muutettua tiedostoa jossa 131 lisäystä ja 19 poistoa
  1. 15 5
      ssh/client.go
  2. 29 14
      ssh/tcpip.go
  3. 87 0
      ssh/test/forward_test.go

+ 15 - 5
ssh/client.go

@@ -335,6 +335,10 @@ func (c *ClientConn) mainLoop() {
 
 // Handle channel open messages from the remote side.
 func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
+	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+		c.sendConnectionFailed(msg.PeersId)
+	}
+
 	switch msg.ChanType {
 	case "forwarded-tcpip":
 		laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData)
@@ -343,8 +347,10 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 			c.sendConnectionFailed(msg.PeersId)
 			return
 		}
-		l, ok := c.forwardList.lookup(laddr)
+
+		l, ok := c.forwardList.lookup(*laddr)
 		if !ok {
+			// TODO: print on a more structured log.
 			fmt.Println("could not find forward list entry for", laddr)
 			// Section 7.2, implementations MUST reject suprious incoming
 			// connections.
@@ -360,13 +366,17 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 		ch := c.newChan(c.transport)
 		ch.remoteId = msg.PeersId
 		ch.remoteWin.add(msg.PeersWindow)
+		ch.maxPacket = msg.MaxPacketSize
 
 		m := channelOpenConfirmMsg{
-			PeersId:       ch.remoteId,
-			MyId:          ch.localId,
-			MyWindow:      1 << 14,
-			MaxPacketSize: 1 << 15, // RFC 4253 6.1
+			PeersId:  ch.remoteId,
+			MyId:     ch.localId,
+			MyWindow: 1 << 14,
+
+			// As per RFC 4253 6.1, 32k is also the minimum.
+			MaxPacketSize: 1 << 15,
 		}
+
 		c.writePacket(marshal(msgChannelOpenConfirm, m))
 		l <- forward{ch, raddr}
 	default:

+ 29 - 14
ssh/tcpip.go

@@ -47,8 +47,16 @@ func (c *ClientConn) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
 	if err != nil {
 		return nil, err
 	}
-	// fixup laddr. If the original port was 0, then the remote side will
-	// supply one in the resp.
+
+	// Register this forward, using the port number we requested.
+	// If we requested port 0 (auto allocated port), we have to
+	// register under 0, since the channelOpenMsg will list 0
+	// rather than the allocated port number.
+	ch := c.forwardList.add(*laddr)
+
+	// If the original port was 0, then the remote side will
+	// supply a real port number in the response.
+	origPort := uint32(laddr.Port)
 	if laddr.Port == 0 {
 		port, _, ok := parseUint32(resp.Data)
 		if !ok {
@@ -57,9 +65,7 @@ func (c *ClientConn) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
 		laddr.Port = int(port)
 	}
 
-	// register this forward
-	ch := c.forwardList.add(laddr)
-	return &tcpListener{laddr, c, ch}, nil
+	return &tcpListener{laddr, origPort, c, ch}, nil
 }
 
 // forwardList stores a mapping between remote
@@ -72,17 +78,19 @@ type forwardList struct {
 // forwardEntry represents an established mapping of a laddr on a
 // remote ssh server to a channel connected to a tcpListener.
 type forwardEntry struct {
-	laddr *net.TCPAddr
+	laddr net.TCPAddr
 	c     chan forward
 }
 
-// forward represents an incoming forwarded tcpip connection
+// forward represents an incoming forwarded tcpip connection. The
+// arguments to add/remove/lookup should be address as specified in
+// the original forward-request.
 type forward struct {
 	c     *clientChan  // the ssh client channel underlying this forward
 	raddr *net.TCPAddr // the raddr of the incoming connection
 }
 
-func (l *forwardList) add(addr *net.TCPAddr) chan forward {
+func (l *forwardList) add(addr net.TCPAddr) chan forward {
 	l.Lock()
 	defer l.Unlock()
 	f := forwardEntry{
@@ -93,7 +101,7 @@ func (l *forwardList) add(addr *net.TCPAddr) chan forward {
 	return f.c
 }
 
-func (l *forwardList) remove(addr *net.TCPAddr) {
+func (l *forwardList) remove(addr net.TCPAddr) {
 	l.Lock()
 	defer l.Unlock()
 	for i, f := range l.entries {
@@ -104,7 +112,7 @@ func (l *forwardList) remove(addr *net.TCPAddr) {
 	}
 }
 
-func (l *forwardList) lookup(addr *net.TCPAddr) (chan forward, bool) {
+func (l *forwardList) lookup(addr net.TCPAddr) (chan forward, bool) {
 	l.Lock()
 	defer l.Unlock()
 	for _, f := range l.entries {
@@ -117,8 +125,11 @@ func (l *forwardList) lookup(addr *net.TCPAddr) (chan forward, bool) {
 
 type tcpListener struct {
 	laddr *net.TCPAddr
-	conn  *ClientConn
-	in    <-chan forward
+
+	// The port with which we made the request, which can be 0.
+	origPort uint32
+	conn     *ClientConn
+	in       <-chan forward
 }
 
 // Accept waits for and returns the next connection to the listener.
@@ -144,9 +155,13 @@ func (l *tcpListener) Close() error {
 		"cancel-tcpip-forward",
 		true,
 		l.laddr.IP.String(),
-		uint32(l.laddr.Port),
+		l.origPort,
+	}
+	origAddr := net.TCPAddr{
+		IP:   l.laddr.IP,
+		Port: int(l.origPort),
 	}
-	l.conn.forwardList.remove(l.laddr)
+	l.conn.forwardList.remove(origAddr)
 	if _, err := l.conn.sendGlobalRequest(m); err != nil {
 		return err
 	}

+ 87 - 0
ssh/test/forward_test.go

@@ -0,0 +1,87 @@
+package test
+
+import (
+	"bytes"
+	"io"
+	"io/ioutil"
+	"math/rand"
+	"net"
+	"testing"
+)
+
+func TestPortForward(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+	conn := server.Dial(clientConfig())
+	defer conn.Close()
+
+	sshListener, err := conn.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("conn.Listen failed: %v", err)
+	}
+
+	go func() {
+		sshConn, err := sshListener.Accept()
+		if err != nil {
+			t.Fatalf("listen.Accept failed: %v", err)
+		}
+
+		_, err = io.Copy(sshConn, sshConn)
+		if err != nil && err != io.EOF {
+			t.Fatalf("ssh client copy: %v", err)
+		}
+		sshConn.Close()
+	}()
+
+	forwardedAddr := sshListener.Addr().String()
+	tcpConn, err := net.Dial("tcp", forwardedAddr)
+	if err != nil {
+		t.Fatalf("TCP dial failed: %v", err)
+	}
+
+	readChan := make(chan []byte)
+	go func() {
+		data, _ := ioutil.ReadAll(tcpConn)
+		readChan <- data
+	}()
+
+	// Invent some data.
+	data := make([]byte, 100*1000)
+	for i := range data {
+		data[i] = byte(i % 255)
+	}
+
+	var sent []byte
+	for len(sent) < 1000*1000 {
+		// Send random sized chunks
+		m := rand.Intn(len(data))
+		n, err := tcpConn.Write(data[:m])
+		if err != nil {
+			break
+		}
+		sent = append(sent, data[:n]...)
+	}
+	if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
+		t.Errorf("tcpConn.CloseWrite: %v", err)
+	}
+
+	read := <-readChan
+
+	if len(sent) != len(read) {
+		t.Fatalf("got %d bytes, want %d", len(read), len(sent))
+	}
+	if bytes.Compare(sent, read) != 0 {
+		t.Fatalf("read back data does not match")
+	}
+
+	if err := sshListener.Close(); err != nil {
+		t.Fatalf("sshListener.Close: %v", err)
+	}
+
+	// Check that the forward disappeared.
+	tcpConn, err = net.Dial("tcp", forwardedAddr)
+	if err == nil {
+		tcpConn.Close()
+		t.Errorf("still listening to %s after closing", forwardedAddr)
+	}
+}