Преглед на файлове

go.crypto/ssh: ensure {Server,Client}Conn do not expose io.ReadWriter

Transport should not be a ReadWriter. It can only write packets, i.e. no partial reads or writes. Furthermore, you can currently do ClientConn.Write() while the connection is live, which sends raw bytes over the connection. Doing so will confuse the transports because the data is not encrypted.

As a consequence, ClientConn and ServerConn stop being a net.Conn

Finally, ensure that {Server,Client}Conn implement LocalAddr and RemoteAddr methods that previously were exposed by an embedded net.Conn field.

R=hanwen
CC=golang-dev
https://golang.org/cl/16610043
Dave Cheney преди 12 години
родител
ревизия
c0d640c887
променени са 6 файла, в които са добавени 88 реда и са изтрити 39 реда
  1. 25 16
      ssh/client.go
  2. 2 2
      ssh/client_auth.go
  3. 31 0
      ssh/common_test.go
  4. 28 19
      ssh/server.go
  5. 1 1
      ssh/session.go
  6. 1 1
      ssh/tcpip.go

+ 25 - 16
ssh/client.go

@@ -16,7 +16,7 @@ import (
 
 // ClientConn represents the client side of an SSH connection.
 type ClientConn struct {
-	*transport
+	transport   *transport
 	config      *ClientConfig
 	chanList    // channels associated with this connection
 	forwardList // forwarded tcpip connections from the remote side
@@ -47,13 +47,22 @@ func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientCo
 	}
 
 	if err := conn.handshake(); err != nil {
-		conn.Close()
+		conn.transport.Close()
 		return nil, fmt.Errorf("handshake failed: %v", err)
 	}
 	go conn.mainLoop()
 	return conn, nil
 }
 
+// Close closes the connection.
+func (c *ClientConn) Close() error { return c.transport.Close() }
+
+// LocalAddr returns the local network address.
+func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
+
 // handshake performs the client side key exchange. See RFC 4253 Section 7.
 func (c *ClientConn) handshake() error {
 	clientVersion := []byte(packageVersion)
@@ -78,10 +87,10 @@ func (c *ClientConn) handshake() error {
 		CompressionServerClient: supportedCompressions,
 	}
 	kexInitPacket := marshal(msgKexInit, clientKexInit)
-	if err := c.writePacket(kexInitPacket); err != nil {
+	if err := c.transport.writePacket(kexInitPacket); err != nil {
 		return err
 	}
-	packet, err := c.readPacket()
+	packet, err := c.transport.readPacket()
 	if err != nil {
 		return err
 	}
@@ -99,7 +108,7 @@ func (c *ClientConn) handshake() error {
 	if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
 		// The server sent a Kex message for the wrong algorithm,
 		// which we have to ignore.
-		if _, err := c.readPacket(); err != nil {
+		if _, err := c.transport.readPacket(); err != nil {
 			return err
 		}
 	}
@@ -115,7 +124,7 @@ func (c *ClientConn) handshake() error {
 		clientKexInit: kexInitPacket,
 		serverKexInit: packet,
 	}
-	result, err := kex.Client(c, c.config.rand(), &magics)
+	result, err := kex.Client(c.transport, c.config.rand(), &magics)
 	if err != nil {
 		return err
 	}
@@ -126,7 +135,7 @@ func (c *ClientConn) handshake() error {
 	}
 
 	if checker := c.config.HostKeyChecker; checker != nil {
-		err = checker.Check(c.dialAddress, c.RemoteAddr(), algs.hostKey, result.HostKey)
+		err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), algs.hostKey, result.HostKey)
 		if err != nil {
 			return err
 		}
@@ -134,10 +143,10 @@ func (c *ClientConn) handshake() error {
 
 	c.transport.prepareKeyChange(algs, result)
 
-	if err = c.writePacket([]byte{msgNewKeys}); err != nil {
+	if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 	}
-	if packet, err = c.readPacket(); err != nil {
+	if packet, err = c.transport.readPacket(); err != nil {
 		return err
 	}
 	if packet[0] != msgNewKeys {
@@ -171,13 +180,13 @@ func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
 	defer func() {
-		c.Close()
+		c.transport.Close()
 		c.chanList.closeAll()
 		c.forwardList.closeAll()
 	}()
 
 	for {
-		packet, err := c.readPacket()
+		packet, err := c.transport.readPacket()
 		if err != nil {
 			break
 		}
@@ -298,7 +307,7 @@ func (c *ClientConn) mainLoop() {
 				// This handles keepalive messages and matches
 				// the behaviour of OpenSSH.
 				if msg.WantReply {
-					c.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
+					c.transport.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
 				}
 			case *globalRequestSuccessMsg, *globalRequestFailureMsg:
 				c.globalRequest.response <- msg
@@ -355,7 +364,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 			MaxPacketSize: 1 << 15,
 		}
 
-		c.writePacket(marshal(msgChannelOpenConfirm, m))
+		c.transport.writePacket(marshal(msgChannelOpenConfirm, m))
 		l <- forward{ch, raddr}
 	default:
 		// unknown channel type
@@ -365,7 +374,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 			Message:  fmt.Sprintf("unknown channel type: %v", msg.ChanType),
 			Language: "en_US.UTF-8",
 		}
-		c.writePacket(marshal(msgChannelOpenFailure, m))
+		c.transport.writePacket(marshal(msgChannelOpenFailure, m))
 	}
 }
 
@@ -375,7 +384,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
 func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
 	c.globalRequest.Lock()
 	defer c.globalRequest.Unlock()
-	if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil {
+	if err := c.transport.writePacket(marshal(msgGlobalRequest, m)); err != nil {
 		return nil, err
 	}
 	r := <-c.globalRequest.response
@@ -394,7 +403,7 @@ func (c *ClientConn) sendConnectionFailed(remoteId uint32) error {
 		Message:  "invalid request",
 		Language: "en_US.UTF-8",
 	}
-	return c.writePacket(marshal(msgChannelOpenFailure, m))
+	return c.transport.writePacket(marshal(msgChannelOpenFailure, m))
 }
 
 // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.

+ 2 - 2
ssh/client_auth.go

@@ -14,10 +14,10 @@ import (
 // authenticate authenticates with the remote server. See RFC 4252.
 func (c *ClientConn) authenticate(session []byte) error {
 	// initiate user auth session
-	if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
+	if err := c.transport.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
 		return err
 	}
-	packet, err := c.readPacket()
+	packet, err := c.transport.readPacket()
 	if err != nil {
 		return err
 	}

+ 31 - 0
ssh/common_test.go

@@ -5,6 +5,8 @@
 package ssh
 
 import (
+	"io"
+	"net"
 	"testing"
 )
 
@@ -24,3 +26,32 @@ func TestSafeString(t *testing.T) {
 		}
 	}
 }
+
+// Make sure Read/Write are not exposed.
+func TestConnHideRWMethods(t *testing.T) {
+	for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
+		if _, ok := c.(io.Reader); ok {
+			t.Errorf("%T implements io.Reader", c)
+		}
+		if _, ok := c.(io.Writer); ok {
+			t.Errorf("%T implements io.Writer", c)
+		}
+	}
+}
+
+func TestConnSupportsLocalRemoteMethods(t *testing.T) {
+	type LocalAddr interface {
+		LocalAddr() net.Addr
+	}
+	type RemoteAddr interface {
+		RemoteAddr() net.Addr
+	}
+	for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
+		if _, ok := c.(LocalAddr); !ok {
+			t.Errorf("%T does not implement LocalAddr", c)
+		}
+		if _, ok := c.(RemoteAddr); !ok {
+			t.Errorf("%T does not implement RemoteAddr", c)
+		}
+	}
+}

+ 28 - 19
ssh/server.go

@@ -97,8 +97,8 @@ const maxCachedPubKeys = 16
 
 // A ServerConn represents an incoming connection.
 type ServerConn struct {
-	*transport
-	config *ServerConfig
+	transport *transport
+	config    *ServerConfig
 
 	channels   map[uint32]*serverChan
 	nextChanId uint32
@@ -147,6 +147,15 @@ func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
 	return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
 }
 
+// Close closes the connection.
+func (s *ServerConn) Close() error { return s.transport.Close() }
+
+// LocalAddr returns the local network address.
+func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
+
 // Handshake performs an SSH transport and client authentication on the given ServerConn.
 func (s *ServerConn) Handshake() error {
 	var err error
@@ -160,7 +169,7 @@ func (s *ServerConn) Handshake() error {
 	}
 
 	var packet []byte
-	if packet, err = s.readPacket(); err != nil {
+	if packet, err = s.transport.readPacket(); err != nil {
 		return err
 	}
 	var serviceRequest serviceRequestMsg
@@ -173,7 +182,7 @@ func (s *ServerConn) Handshake() error {
 	serviceAccept := serviceAcceptMsg{
 		Service: serviceUserAuth,
 	}
-	if err := s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
+	if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
 		return err
 	}
 
@@ -199,13 +208,13 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	}
 
 	serverKexInitPacket := marshal(msgKexInit, serverKexInit)
-	if err = s.writePacket(serverKexInitPacket); err != nil {
+	if err = s.transport.writePacket(serverKexInitPacket); err != nil {
 		return
 	}
 
 	if clientKexInitPacket == nil {
 		clientKexInit = new(kexInitMsg)
-		if clientKexInitPacket, err = s.readPacket(); err != nil {
+		if clientKexInitPacket, err = s.transport.readPacket(); err != nil {
 			return
 		}
 		if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
@@ -221,7 +230,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
 		// The client sent a Kex message for the wrong algorithm,
 		// which we have to ignore.
-		if _, err = s.readPacket(); err != nil {
+		if _, err = s.transport.readPacket(); err != nil {
 			return
 		}
 	}
@@ -244,7 +253,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 		serverKexInit: marshal(msgKexInit, serverKexInit),
 		clientKexInit: clientKexInitPacket,
 	}
-	result, err := kex.Server(s, s.config.rand(), &magics, hostKey)
+	result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey)
 	if err != nil {
 		return err
 	}
@@ -253,10 +262,10 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 		return err
 	}
 
-	if err = s.writePacket([]byte{msgNewKeys}); err != nil {
+	if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
 		return
 	}
-	if packet, err := s.readPacket(); err != nil {
+	if packet, err := s.transport.readPacket(); err != nil {
 		return err
 	} else if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
@@ -308,7 +317,7 @@ func (s *ServerConn) authenticate(H []byte) error {
 
 userAuthLoop:
 	for {
-		if packet, err = s.readPacket(); err != nil {
+		if packet, err = s.transport.readPacket(); err != nil {
 			return err
 		}
 		if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
@@ -382,7 +391,7 @@ userAuthLoop:
 						Algo:   algo,
 						PubKey: string(pubKey),
 					}
-					if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
+					if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
 						return err
 					}
 					continue userAuthLoop
@@ -432,13 +441,13 @@ userAuthLoop:
 			return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
 		}
 
-		if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
+		if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
 			return err
 		}
 	}
 
 	packet = []byte{msgUserAuthSuccess}
-	if err = s.writePacket(packet); err != nil {
+	if err = s.transport.writePacket(packet); err != nil {
 		return err
 	}
 
@@ -462,7 +471,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
 		prompts = appendBool(prompts, echos[i])
 	}
 
-	if err := c.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
+	if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
 		Instruction: instruction,
 		NumPrompts:  uint32(len(questions)),
 		Prompts:     prompts,
@@ -470,7 +479,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
 		return nil, err
 	}
 
-	packet, err := c.readPacket()
+	packet, err := c.transport.readPacket()
 	if err != nil {
 		return nil, err
 	}
@@ -511,7 +520,7 @@ func (s *ServerConn) Accept() (Channel, error) {
 	}
 
 	for {
-		packet, err := s.readPacket()
+		packet, err := s.transport.readPacket()
 		if err != nil {
 
 			s.lock.Lock()
@@ -557,7 +566,7 @@ func (s *ServerConn) Accept() (Channel, error) {
 				}
 				c := &serverChan{
 					channel: channel{
-						packetConn: s,
+						packetConn: s.transport,
 						remoteId:   msg.PeersId,
 						remoteWin:  window{Cond: newCond()},
 						maxPacket:  msg.MaxPacketSize,
@@ -619,7 +628,7 @@ func (s *ServerConn) Accept() (Channel, error) {
 
 			case *globalRequestMsg:
 				if msg.WantReply {
-					if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
+					if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil {
 						return nil, err
 					}
 				}

+ 1 - 1
ssh/session.go

@@ -564,7 +564,7 @@ func (s *Session) StderrPipe() (io.Reader, error) {
 // NewSession returns a new interactive session on the remote host.
 func (c *ClientConn) NewSession() (*Session, error) {
 	ch := c.newChan(c.transport)
-	if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
+	if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{
 		ChanType:      "session",
 		PeersId:       ch.localId,
 		PeersWindow:   1 << 14,

+ 1 - 1
ssh/tcpip.go

@@ -296,7 +296,7 @@ type channelOpenDirectMsg struct {
 // strings and are expected to be resolvable at the remote end.
 func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) {
 	ch := c.newChan(c.transport)
-	if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
+	if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
 		ChanType:      "direct-tcpip",
 		PeersId:       ch.localId,
 		PeersWindow:   1 << 14,