Browse Source

go.crypto/ssh: add support for client side global requests

* Add support for RFC4254 section 4 global requests.
* Improve clientConn.Listen to process responses properly.

R=agl, gustav.paul
CC=golang-dev
https://golang.org/cl/6130050
Dave Cheney 13 years ago
parent
commit
b4b42222af
3 changed files with 101 additions and 51 deletions
  1. 60 21
      ssh/client.go
  2. 12 2
      ssh/messages.go
  3. 29 28
      ssh/tcpip.go

+ 60 - 21
ssh/client.go

@@ -23,14 +23,21 @@ type ClientConn struct {
 	*transport
 	config      *ClientConfig
 	chanlist    // channels associated with this connection
-	forwardList // forwared tcpip connections from the remote side
+	forwardList // forwarded tcpip connections from the remote side
+	globalRequest
+}
+
+type globalRequest struct {
+	sync.Mutex
+	response chan interface{}
 }
 
 // Client returns a new SSH client connection using c as the underlying transport.
 func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
 	conn := &ClientConn{
-		transport: newTransport(c, config.rand()),
-		config:    config,
+		transport:     newTransport(c, config.rand()),
+		config:        config,
+		globalRequest: globalRequest{response: make(chan interface{}, 1)},
 	}
 	if err := conn.handshake(); err != nil {
 		conn.Close()
@@ -273,6 +280,8 @@ func (c *ClientConn) mainLoop() {
 					// invalid window update
 					return
 				}
+			case *globalRequestSuccessMsg, *globalRequestFailureMsg:
+				c.globalRequest.response <- msg
 			case *disconnectMsg:
 				return
 			default:
@@ -286,22 +295,24 @@ func (c *ClientConn) mainLoop() {
 func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 	switch msg.ChanType {
 	case "forwarded-tcpip":
-		addr, err := parseAddr(msg.TypeSpecificData)
-		if err != nil {
+		laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData)
+		if !ok {
 			// invalid request
-			m := channelOpenFailureMsg{
-				PeersId:  msg.PeersId,
-				Reason:   ConnectionFailed,
-				Message:  fmt.Sprintf("invalid request: %v", err),
-				Language: "en_US.UTF-8",
-			}
-			c.writePacket(marshal(msgChannelOpenFailure, m))
+			c.sendConnectionFailed(msg.PeersId)
 			return
 		}
-		l, ok := c.forwardList.Lookup(addr)
+		l, ok := c.forwardList.Lookup(laddr)
 		if !ok {
+			fmt.Println("could not find forward list entry for", laddr)
 			// Section 7.2, implementations MUST reject suprious incoming
 			// connections.
+			c.sendConnectionFailed(msg.PeersId)
+			return
+		}
+		raddr, rest, ok := parseTCPAddr(rest)
+		if !ok {
+			// invalid request
+			c.sendConnectionFailed(msg.PeersId)
 			return
 		}
 		ch := c.newChan(c.transport)
@@ -315,7 +326,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 			MaxPacketSize: 1 << 15, // RFC 4253 6.1
 		}
 		c.writePacket(marshal(msgChannelOpenConfirm, m))
-		l <- forward{ch, addr}
+		l <- forward{ch, raddr}
 	default:
 		// unknown channel type
 		m := channelOpenFailureMsg{
@@ -328,23 +339,51 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 	}
 }
 
-// parseAddr parses the originating address from the remote into a *net.TCPAddr. 
+// sendGlobalRequest sends a global request message as specified
+// in RFC4254 section 4. To correctly synchronise messages, a lock
+// is held internally until a response is returned.
+func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
+	c.globalRequest.Lock()
+	defer c.globalRequest.Unlock()
+	if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil {
+		return nil, err
+	}
+	r := <-c.globalRequest.response
+	if r, ok := r.(*globalRequestSuccessMsg); ok {
+		return r, nil
+	}
+	return nil, errors.New("request failed")
+}
+
+// sendConnectionFailed rejects an incoming channel identified 
+// by peersId.
+func (c *ClientConn) sendConnectionFailed(peersId uint32) error {
+	m := channelOpenFailureMsg{
+		PeersId:  peersId,
+		Reason:   ConnectionFailed,
+		Message:  "invalid request",
+		Language: "en_US.UTF-8",
+	}
+	return c.writePacket(marshal(msgChannelOpenFailure, m))
+}
+
+// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. 
 // RFC 4254 section 7.2 is mute on what to do if parsing fails but the forwardlist
 // requires a valid *net.TCPAddr to operate, so we enforce that restriction here.
-func parseAddr(b []byte) (*net.TCPAddr, error) {
+func parseTCPAddr(b []byte) (*net.TCPAddr, []byte, bool) {
 	addr, b, ok := parseString(b)
 	if !ok {
-		return nil, ParseError{msgChannelOpen}
+		return nil, b, false
 	}
-	port, _, ok := parseUint32(b)
+	port, b, ok := parseUint32(b)
 	if !ok {
-		return nil, ParseError{msgChannelOpen}
+		return nil, b, false
 	}
 	ip := net.ParseIP(string(addr))
 	if ip == nil {
-		return nil, ParseError{msgChannelOpen}
+		return nil, b, false
 	}
-	return &net.TCPAddr{ip, int(port)}, nil
+	return &net.TCPAddr{ip, int(port)}, b, true
 }
 
 // Dial connects to the given network address using net.Dial and

+ 12 - 2
ssh/messages.go

@@ -177,6 +177,16 @@ type globalRequestMsg struct {
 	WantReply bool
 }
 
+// See RFC 4254, section 4
+type globalRequestSuccessMsg struct {
+	Data []byte `ssh:"rest"`
+}
+
+// See RFC 4254, section 4
+type globalRequestFailureMsg struct {
+	Data []byte `ssh:"rest"`
+}
+
 // See RFC 4254, section 5.2
 type windowAdjustMsg struct {
 	PeersId         uint32
@@ -584,9 +594,9 @@ func decode(packet []byte) interface{} {
 	case msgGlobalRequest:
 		msg = new(globalRequestMsg)
 	case msgRequestSuccess:
-		msg = new(channelRequestSuccessMsg)
+		msg = new(globalRequestSuccessMsg)
 	case msgRequestFailure:
-		msg = new(channelRequestFailureMsg)
+		msg = new(globalRequestFailureMsg)
 	case msgChannelOpen:
 		msg = new(channelOpenMsg)
 	case msgChannelOpenConfirm:

+ 29 - 28
ssh/tcpip.go

@@ -13,30 +13,15 @@ import (
 	"time"
 )
 
-var (
-	// TODO(dfc) relax this restriction
-	errNoPort = errors.New("A port number must be supplied")
-)
-
 // Listen requests the remote peer open a listening socket 
 // on addr. Incoming connections will be available by calling
 // Accept on the returned net.Listener.
 func (c *ClientConn) Listen(n, addr string) (net.Listener, error) {
-	raddr, err := net.ResolveTCPAddr(n, addr)
+	laddr, err := net.ResolveTCPAddr(n, addr)
 	if err != nil {
 		return nil, err
 	}
-	return c.ListenTCP(raddr)
-}
-
-// ListenTCP requests the remote peer open a listening socket 
-// on raddr. Incoming connections will be available by calling
-// Accept on the returned net.Listener.
-func (c *ClientConn) ListenTCP(raddr *net.TCPAddr) (net.Listener, error) {
-	if raddr.Port == 0 {
-		return nil, errNoPort
-	}
-	return c.listen(raddr)
+	return c.ListenTCP(laddr)
 }
 
 // RFC 4254 7.1
@@ -47,21 +32,34 @@ type channelForwardMsg struct {
 	rport     uint32
 }
 
-func (c *ClientConn) listen(addr *net.TCPAddr) (net.Listener, error) {
+// ListenTCP requests the remote peer open a listening socket 
+// on laddr. Incoming connections will be available by calling
+// Accept on the returned net.Listener.
+func (c *ClientConn) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
 	m := channelForwardMsg{
 		"tcpip-forward",
-		false, // can't handle reply message from remote yet
-		addr.IP.String(),
-		uint32(addr.Port),
+		true, // sendGlobalRequest waits for a reply
+		laddr.IP.String(),
+		uint32(laddr.Port),
 	}
-	// register this forward
-	ch := c.forwardList.Add(addr)
 	// send message
-	if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil {
-		c.forwardList.Remove(addr)
+	resp, err := c.sendGlobalRequest(m)
+	if err != nil {
 		return nil, err
 	}
-	return &tcpListener{addr, c, ch}, nil
+	// fixup laddr. If the original port was 0, then the remote side will
+	// supply one in the resp.
+	if laddr.Port == 0 {
+		port, _, ok := parseUint32(resp.Data)
+		if !ok {
+			return nil, errors.New("unable to parse response")
+		}
+		laddr.Port = int(port)
+	}
+
+	// register this forward
+	ch := c.forwardList.Add(laddr)
+	return &tcpListener{laddr, c, ch}, nil
 }
 
 // forwardList stores a mapping between remote 
@@ -144,12 +142,15 @@ func (l *tcpListener) Accept() (net.Conn, error) {
 func (l *tcpListener) Close() error {
 	m := channelForwardMsg{
 		"cancel-tcpip-forward",
-		false, // TODO(dfc) process reply
+		true,
 		l.laddr.IP.String(),
 		uint32(l.laddr.Port),
 	}
 	l.conn.forwardList.Remove(l.laddr)
-	return l.conn.writePacket(marshal(msgGlobalRequest, m))
+	if _, err := l.conn.sendGlobalRequest(m); err != nil {
+		return err
+	}
+	return nil
 }
 
 // Addr returns the listener's network address.