Browse Source

go.crypto/ssh: ensure {Server,Client}Conn do not expose io.ReadWriter

Transport should not be a ReadWriter. It can only write packets, i.e. no partial reads or writes. Furthermore, you can currently do ClientConn.Write() while the connection is live, which sends raw bytes over the connection. Doing so will confuse the transports because the data is not encrypted.

As a consequence, ClientConn and ServerConn stop being a net.Conn

Finally, ensure that {Server,Client}Conn implement LocalAddr and RemoteAddr methods that previously were exposed by an embedded net.Conn field.

R=hanwen
CC=golang-dev
https://golang.org/cl/16610043
Dave Cheney 12 years ago
parent
commit
c0d640c887
6 changed files with 88 additions and 39 deletions
  1. 25 16
      ssh/client.go
  2. 2 2
      ssh/client_auth.go
  3. 31 0
      ssh/common_test.go
  4. 28 19
      ssh/server.go
  5. 1 1
      ssh/session.go
  6. 1 1
      ssh/tcpip.go

+ 25 - 16
ssh/client.go

@@ -16,7 +16,7 @@ import (
 
 
 // ClientConn represents the client side of an SSH connection.
 // ClientConn represents the client side of an SSH connection.
 type ClientConn struct {
 type ClientConn struct {
-	*transport
+	transport   *transport
 	config      *ClientConfig
 	config      *ClientConfig
 	chanList    // channels associated with this connection
 	chanList    // channels associated with this connection
 	forwardList // forwarded tcpip connections from the remote side
 	forwardList // forwarded tcpip connections from the remote side
@@ -47,13 +47,22 @@ func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientCo
 	}
 	}
 
 
 	if err := conn.handshake(); err != nil {
 	if err := conn.handshake(); err != nil {
-		conn.Close()
+		conn.transport.Close()
 		return nil, fmt.Errorf("handshake failed: %v", err)
 		return nil, fmt.Errorf("handshake failed: %v", err)
 	}
 	}
 	go conn.mainLoop()
 	go conn.mainLoop()
 	return conn, nil
 	return conn, nil
 }
 }
 
 
+// Close closes the connection.
+func (c *ClientConn) Close() error { return c.transport.Close() }
+
+// LocalAddr returns the local network address.
+func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
+
 // handshake performs the client side key exchange. See RFC 4253 Section 7.
 // handshake performs the client side key exchange. See RFC 4253 Section 7.
 func (c *ClientConn) handshake() error {
 func (c *ClientConn) handshake() error {
 	clientVersion := []byte(packageVersion)
 	clientVersion := []byte(packageVersion)
@@ -78,10 +87,10 @@ func (c *ClientConn) handshake() error {
 		CompressionServerClient: supportedCompressions,
 		CompressionServerClient: supportedCompressions,
 	}
 	}
 	kexInitPacket := marshal(msgKexInit, clientKexInit)
 	kexInitPacket := marshal(msgKexInit, clientKexInit)
-	if err := c.writePacket(kexInitPacket); err != nil {
+	if err := c.transport.writePacket(kexInitPacket); err != nil {
 		return err
 		return err
 	}
 	}
-	packet, err := c.readPacket()
+	packet, err := c.transport.readPacket()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -99,7 +108,7 @@ func (c *ClientConn) handshake() error {
 	if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
 	if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
 		// The server sent a Kex message for the wrong algorithm,
 		// The server sent a Kex message for the wrong algorithm,
 		// which we have to ignore.
 		// which we have to ignore.
-		if _, err := c.readPacket(); err != nil {
+		if _, err := c.transport.readPacket(); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
@@ -115,7 +124,7 @@ func (c *ClientConn) handshake() error {
 		clientKexInit: kexInitPacket,
 		clientKexInit: kexInitPacket,
 		serverKexInit: packet,
 		serverKexInit: packet,
 	}
 	}
-	result, err := kex.Client(c, c.config.rand(), &magics)
+	result, err := kex.Client(c.transport, c.config.rand(), &magics)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -126,7 +135,7 @@ func (c *ClientConn) handshake() error {
 	}
 	}
 
 
 	if checker := c.config.HostKeyChecker; checker != nil {
 	if checker := c.config.HostKeyChecker; checker != nil {
-		err = checker.Check(c.dialAddress, c.RemoteAddr(), algs.hostKey, result.HostKey)
+		err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), algs.hostKey, result.HostKey)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -134,10 +143,10 @@ func (c *ClientConn) handshake() error {
 
 
 	c.transport.prepareKeyChange(algs, result)
 	c.transport.prepareKeyChange(algs, result)
 
 
-	if err = c.writePacket([]byte{msgNewKeys}); err != nil {
+	if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 		return err
 	}
 	}
-	if packet, err = c.readPacket(); err != nil {
+	if packet, err = c.transport.readPacket(); err != nil {
 		return err
 		return err
 	}
 	}
 	if packet[0] != msgNewKeys {
 	if packet[0] != msgNewKeys {
@@ -171,13 +180,13 @@ func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte
 // to their respective ClientChans.
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
 func (c *ClientConn) mainLoop() {
 	defer func() {
 	defer func() {
-		c.Close()
+		c.transport.Close()
 		c.chanList.closeAll()
 		c.chanList.closeAll()
 		c.forwardList.closeAll()
 		c.forwardList.closeAll()
 	}()
 	}()
 
 
 	for {
 	for {
-		packet, err := c.readPacket()
+		packet, err := c.transport.readPacket()
 		if err != nil {
 		if err != nil {
 			break
 			break
 		}
 		}
@@ -298,7 +307,7 @@ func (c *ClientConn) mainLoop() {
 				// This handles keepalive messages and matches
 				// This handles keepalive messages and matches
 				// the behaviour of OpenSSH.
 				// the behaviour of OpenSSH.
 				if msg.WantReply {
 				if msg.WantReply {
-					c.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
+					c.transport.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
 				}
 				}
 			case *globalRequestSuccessMsg, *globalRequestFailureMsg:
 			case *globalRequestSuccessMsg, *globalRequestFailureMsg:
 				c.globalRequest.response <- msg
 				c.globalRequest.response <- msg
@@ -355,7 +364,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 			MaxPacketSize: 1 << 15,
 			MaxPacketSize: 1 << 15,
 		}
 		}
 
 
-		c.writePacket(marshal(msgChannelOpenConfirm, m))
+		c.transport.writePacket(marshal(msgChannelOpenConfirm, m))
 		l <- forward{ch, raddr}
 		l <- forward{ch, raddr}
 	default:
 	default:
 		// unknown channel type
 		// unknown channel type
@@ -365,7 +374,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 			Message:  fmt.Sprintf("unknown channel type: %v", msg.ChanType),
 			Message:  fmt.Sprintf("unknown channel type: %v", msg.ChanType),
 			Language: "en_US.UTF-8",
 			Language: "en_US.UTF-8",
 		}
 		}
-		c.writePacket(marshal(msgChannelOpenFailure, m))
+		c.transport.writePacket(marshal(msgChannelOpenFailure, m))
 	}
 	}
 }
 }
 
 
@@ -375,7 +384,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
 func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
 	c.globalRequest.Lock()
 	c.globalRequest.Lock()
 	defer c.globalRequest.Unlock()
 	defer c.globalRequest.Unlock()
-	if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil {
+	if err := c.transport.writePacket(marshal(msgGlobalRequest, m)); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 	r := <-c.globalRequest.response
 	r := <-c.globalRequest.response
@@ -394,7 +403,7 @@ func (c *ClientConn) sendConnectionFailed(remoteId uint32) error {
 		Message:  "invalid request",
 		Message:  "invalid request",
 		Language: "en_US.UTF-8",
 		Language: "en_US.UTF-8",
 	}
 	}
-	return c.writePacket(marshal(msgChannelOpenFailure, m))
+	return c.transport.writePacket(marshal(msgChannelOpenFailure, m))
 }
 }
 
 
 // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
 // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.

+ 2 - 2
ssh/client_auth.go

@@ -14,10 +14,10 @@ import (
 // authenticate authenticates with the remote server. See RFC 4252.
 // authenticate authenticates with the remote server. See RFC 4252.
 func (c *ClientConn) authenticate(session []byte) error {
 func (c *ClientConn) authenticate(session []byte) error {
 	// initiate user auth session
 	// initiate user auth session
-	if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
+	if err := c.transport.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
 		return err
 		return err
 	}
 	}
-	packet, err := c.readPacket()
+	packet, err := c.transport.readPacket()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 31 - 0
ssh/common_test.go

@@ -5,6 +5,8 @@
 package ssh
 package ssh
 
 
 import (
 import (
+	"io"
+	"net"
 	"testing"
 	"testing"
 )
 )
 
 
@@ -24,3 +26,32 @@ func TestSafeString(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+// Make sure Read/Write are not exposed.
+func TestConnHideRWMethods(t *testing.T) {
+	for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
+		if _, ok := c.(io.Reader); ok {
+			t.Errorf("%T implements io.Reader", c)
+		}
+		if _, ok := c.(io.Writer); ok {
+			t.Errorf("%T implements io.Writer", c)
+		}
+	}
+}
+
+func TestConnSupportsLocalRemoteMethods(t *testing.T) {
+	type LocalAddr interface {
+		LocalAddr() net.Addr
+	}
+	type RemoteAddr interface {
+		RemoteAddr() net.Addr
+	}
+	for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
+		if _, ok := c.(LocalAddr); !ok {
+			t.Errorf("%T does not implement LocalAddr", c)
+		}
+		if _, ok := c.(RemoteAddr); !ok {
+			t.Errorf("%T does not implement RemoteAddr", c)
+		}
+	}
+}

+ 28 - 19
ssh/server.go

@@ -97,8 +97,8 @@ const maxCachedPubKeys = 16
 
 
 // A ServerConn represents an incoming connection.
 // A ServerConn represents an incoming connection.
 type ServerConn struct {
 type ServerConn struct {
-	*transport
-	config *ServerConfig
+	transport *transport
+	config    *ServerConfig
 
 
 	channels   map[uint32]*serverChan
 	channels   map[uint32]*serverChan
 	nextChanId uint32
 	nextChanId uint32
@@ -147,6 +147,15 @@ func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
 	return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
 	return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
 }
 }
 
 
+// Close closes the connection.
+func (s *ServerConn) Close() error { return s.transport.Close() }
+
+// LocalAddr returns the local network address.
+func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
+
 // Handshake performs an SSH transport and client authentication on the given ServerConn.
 // Handshake performs an SSH transport and client authentication on the given ServerConn.
 func (s *ServerConn) Handshake() error {
 func (s *ServerConn) Handshake() error {
 	var err error
 	var err error
@@ -160,7 +169,7 @@ func (s *ServerConn) Handshake() error {
 	}
 	}
 
 
 	var packet []byte
 	var packet []byte
-	if packet, err = s.readPacket(); err != nil {
+	if packet, err = s.transport.readPacket(); err != nil {
 		return err
 		return err
 	}
 	}
 	var serviceRequest serviceRequestMsg
 	var serviceRequest serviceRequestMsg
@@ -173,7 +182,7 @@ func (s *ServerConn) Handshake() error {
 	serviceAccept := serviceAcceptMsg{
 	serviceAccept := serviceAcceptMsg{
 		Service: serviceUserAuth,
 		Service: serviceUserAuth,
 	}
 	}
-	if err := s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
+	if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -199,13 +208,13 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	}
 	}
 
 
 	serverKexInitPacket := marshal(msgKexInit, serverKexInit)
 	serverKexInitPacket := marshal(msgKexInit, serverKexInit)
-	if err = s.writePacket(serverKexInitPacket); err != nil {
+	if err = s.transport.writePacket(serverKexInitPacket); err != nil {
 		return
 		return
 	}
 	}
 
 
 	if clientKexInitPacket == nil {
 	if clientKexInitPacket == nil {
 		clientKexInit = new(kexInitMsg)
 		clientKexInit = new(kexInitMsg)
-		if clientKexInitPacket, err = s.readPacket(); err != nil {
+		if clientKexInitPacket, err = s.transport.readPacket(); err != nil {
 			return
 			return
 		}
 		}
 		if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
 		if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
@@ -221,7 +230,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
 	if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
 		// The client sent a Kex message for the wrong algorithm,
 		// The client sent a Kex message for the wrong algorithm,
 		// which we have to ignore.
 		// which we have to ignore.
-		if _, err = s.readPacket(); err != nil {
+		if _, err = s.transport.readPacket(); err != nil {
 			return
 			return
 		}
 		}
 	}
 	}
@@ -244,7 +253,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 		serverKexInit: marshal(msgKexInit, serverKexInit),
 		serverKexInit: marshal(msgKexInit, serverKexInit),
 		clientKexInit: clientKexInitPacket,
 		clientKexInit: clientKexInitPacket,
 	}
 	}
-	result, err := kex.Server(s, s.config.rand(), &magics, hostKey)
+	result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -253,10 +262,10 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 		return err
 		return err
 	}
 	}
 
 
-	if err = s.writePacket([]byte{msgNewKeys}); err != nil {
+	if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
 		return
 		return
 	}
 	}
-	if packet, err := s.readPacket(); err != nil {
+	if packet, err := s.transport.readPacket(); err != nil {
 		return err
 		return err
 	} else if packet[0] != msgNewKeys {
 	} else if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
@@ -308,7 +317,7 @@ func (s *ServerConn) authenticate(H []byte) error {
 
 
 userAuthLoop:
 userAuthLoop:
 	for {
 	for {
-		if packet, err = s.readPacket(); err != nil {
+		if packet, err = s.transport.readPacket(); err != nil {
 			return err
 			return err
 		}
 		}
 		if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
 		if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
@@ -382,7 +391,7 @@ userAuthLoop:
 						Algo:   algo,
 						Algo:   algo,
 						PubKey: string(pubKey),
 						PubKey: string(pubKey),
 					}
 					}
-					if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
+					if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
 						return err
 						return err
 					}
 					}
 					continue userAuthLoop
 					continue userAuthLoop
@@ -432,13 +441,13 @@ userAuthLoop:
 			return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
 			return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
 		}
 		}
 
 
-		if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
+		if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
 
 
 	packet = []byte{msgUserAuthSuccess}
 	packet = []byte{msgUserAuthSuccess}
-	if err = s.writePacket(packet); err != nil {
+	if err = s.transport.writePacket(packet); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -462,7 +471,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
 		prompts = appendBool(prompts, echos[i])
 		prompts = appendBool(prompts, echos[i])
 	}
 	}
 
 
-	if err := c.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
+	if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
 		Instruction: instruction,
 		Instruction: instruction,
 		NumPrompts:  uint32(len(questions)),
 		NumPrompts:  uint32(len(questions)),
 		Prompts:     prompts,
 		Prompts:     prompts,
@@ -470,7 +479,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	packet, err := c.readPacket()
+	packet, err := c.transport.readPacket()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -511,7 +520,7 @@ func (s *ServerConn) Accept() (Channel, error) {
 	}
 	}
 
 
 	for {
 	for {
-		packet, err := s.readPacket()
+		packet, err := s.transport.readPacket()
 		if err != nil {
 		if err != nil {
 
 
 			s.lock.Lock()
 			s.lock.Lock()
@@ -557,7 +566,7 @@ func (s *ServerConn) Accept() (Channel, error) {
 				}
 				}
 				c := &serverChan{
 				c := &serverChan{
 					channel: channel{
 					channel: channel{
-						packetConn: s,
+						packetConn: s.transport,
 						remoteId:   msg.PeersId,
 						remoteId:   msg.PeersId,
 						remoteWin:  window{Cond: newCond()},
 						remoteWin:  window{Cond: newCond()},
 						maxPacket:  msg.MaxPacketSize,
 						maxPacket:  msg.MaxPacketSize,
@@ -619,7 +628,7 @@ func (s *ServerConn) Accept() (Channel, error) {
 
 
 			case *globalRequestMsg:
 			case *globalRequestMsg:
 				if msg.WantReply {
 				if msg.WantReply {
-					if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
+					if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil {
 						return nil, err
 						return nil, err
 					}
 					}
 				}
 				}

+ 1 - 1
ssh/session.go

@@ -564,7 +564,7 @@ func (s *Session) StderrPipe() (io.Reader, error) {
 // NewSession returns a new interactive session on the remote host.
 // NewSession returns a new interactive session on the remote host.
 func (c *ClientConn) NewSession() (*Session, error) {
 func (c *ClientConn) NewSession() (*Session, error) {
 	ch := c.newChan(c.transport)
 	ch := c.newChan(c.transport)
-	if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
+	if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{
 		ChanType:      "session",
 		ChanType:      "session",
 		PeersId:       ch.localId,
 		PeersId:       ch.localId,
 		PeersWindow:   1 << 14,
 		PeersWindow:   1 << 14,

+ 1 - 1
ssh/tcpip.go

@@ -296,7 +296,7 @@ type channelOpenDirectMsg struct {
 // strings and are expected to be resolvable at the remote end.
 // strings and are expected to be resolvable at the remote end.
 func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) {
 func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) {
 	ch := c.newChan(c.transport)
 	ch := c.newChan(c.transport)
-	if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
+	if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
 		ChanType:      "direct-tcpip",
 		ChanType:      "direct-tcpip",
 		PeersId:       ch.localId,
 		PeersId:       ch.localId,
 		PeersWindow:   1 << 14,
 		PeersWindow:   1 << 14,