Ver código fonte

go.crypto/ssh: put version exchange in function

R=golang-dev, dave, jpsugar, agl
CC=golang-dev
https://golang.org/cl/14641044
Han-Wen Nienhuys 12 anos atrás
pai
commit
f5f25bdad0
5 arquivos alterados com 107 adições e 77 exclusões
  1. 5 18
      ssh/client.go
  2. 1 1
      ssh/client_test.go
  3. 19 24
      ssh/server.go
  4. 38 8
      ssh/transport.go
  5. 44 26
      ssh/transport_test.go

+ 5 - 18
ssh/client.go

@@ -14,9 +14,6 @@ import (
 	"sync"
 )
 
-// clientVersion is the default identification string that the client will use.
-var clientVersion = []byte("SSH-2.0-Go")
-
 // ClientConn represents the client side of an SSH connection.
 type ClientConn struct {
 	*transport
@@ -59,22 +56,12 @@ func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientCo
 
 // handshake performs the client side key exchange. See RFC 4253 Section 7.
 func (c *ClientConn) handshake() error {
-	var myVersion []byte
-	if len(c.config.ClientVersion) > 0 {
-		myVersion = []byte(c.config.ClientVersion)
-	} else {
-		myVersion = clientVersion
-	}
-
-	if _, err := c.Write(append(myVersion, '\r', '\n')); err != nil {
-		return err
-	}
-	if err := c.Flush(); err != nil {
-		return err
+	clientVersion := []byte(packageVersion)
+	if c.config.ClientVersion != "" {
+		clientVersion = []byte(c.config.ClientVersion)
 	}
 
-	// read remote server version
-	serverVersion, err := readVersion(c)
+	serverVersion, err := exchangeVersions(c.transport.Conn, clientVersion)
 	if err != nil {
 		return err
 	}
@@ -123,7 +110,7 @@ func (c *ClientConn) handshake() error {
 	}
 
 	magics := handshakeMagics{
-		clientVersion: myVersion,
+		clientVersion: clientVersion,
 		serverVersion: serverVersion,
 		clientKexInit: kexInitPacket,
 		serverKexInit: packet,

+ 1 - 1
ssh/client_test.go

@@ -30,5 +30,5 @@ func TestCustomClientVersion(t *testing.T) {
 }
 
 func TestDefaultClientVersion(t *testing.T) {
-	testClientVersion(t, &ClientConfig{}, string(clientVersion))
+	testClientVersion(t, &ClientConfig{}, packageVersion)
 }

+ 19 - 24
ssh/server.go

@@ -121,6 +121,9 @@ type ServerConn struct {
 	// ClientVersion is the client's version, populated after
 	// Handshake is called. It should not be modified.
 	ClientVersion []byte
+
+	// Our version.
+	serverVersion []byte
 }
 
 // Server returns a new SSH server connection
@@ -144,33 +147,25 @@ func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
 	return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
 }
 
-// serverVersion is the fixed identification string that Server will use.
-var serverVersion = []byte("SSH-2.0-Go\r\n")
-
 // Handshake performs an SSH transport and client authentication on the given ServerConn.
-func (s *ServerConn) Handshake() (err error) {
-	if _, err = s.Write(serverVersion); err != nil {
-		return
-	}
-	if err := s.Flush(); err != nil {
-		return err
-	}
-
-	s.ClientVersion, err = readVersion(s)
+func (s *ServerConn) Handshake() error {
+	var err error
+	s.serverVersion = []byte(packageVersion)
+	s.ClientVersion, err = exchangeVersions(s.transport.Conn, s.serverVersion)
 	if err != nil {
-		return
+		return err
 	}
-	if err = s.clientInitHandshake(nil, nil); err != nil {
-		return
+	if err := s.clientInitHandshake(nil, nil); err != nil {
+		return err
 	}
 
 	var packet []byte
 	if packet, err = s.readPacket(); err != nil {
-		return
+		return err
 	}
 	var serviceRequest serviceRequestMsg
-	if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
-		return
+	if err := unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
+		return err
 	}
 	if serviceRequest.Service != serviceUserAuth {
 		return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
@@ -178,14 +173,14 @@ func (s *ServerConn) Handshake() (err error) {
 	serviceAccept := serviceAcceptMsg{
 		Service: serviceUserAuth,
 	}
-	if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
-		return
+	if err := s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
+		return err
 	}
 
-	if err = s.authenticate(s.transport.sessionID); err != nil {
-		return
+	if err := s.authenticate(s.transport.sessionID); err != nil {
+		return err
 	}
-	return
+	return err
 }
 
 func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
@@ -244,7 +239,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	}
 
 	magics := handshakeMagics{
-		serverVersion: serverVersion[:len(serverVersion)-2],
+		serverVersion: s.serverVersion,
 		clientVersion: s.ClientVersion,
 		serverKexInit: marshal(msgKexInit, serverKexInit),
 		clientKexInit: clientKexInitPacket,

+ 38 - 8
ssh/transport.go

@@ -358,18 +358,41 @@ func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
 	}
 }
 
-// maxVersionStringBytes is the maximum number of bytes that we'll accept as a
-// version string. In the event that the client is talking a different protocol
-// we need to set a limit otherwise we will keep using more and more memory
-// while searching for the end of the version handshake.
-const maxVersionStringBytes = 1024
+const packageVersion = "SSH-2.0-Go"
+
+// Sends and receives a version line.  The versionLine string should
+// be US ASCII, start with "SSH-2.0-", and should not include a
+// newline. exchangeVersions returns the other side's version line.
+func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
+	// Contrary to the RFC, we do not ignore lines that don't
+	// start with "SSH-2.0-" to make the library usable with
+	// nonconforming servers.
+	for _, c := range versionLine {
+		// The spec disallows non US-ASCII chars, and
+		// specifically forbids null chars.
+		if c < 32 {
+			return nil, errors.New("ssh: junk character in version line")
+		}
+	}
+	if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
+		return
+	}
+
+	them, err = readVersion(rw)
+	return them, err
+}
+
+// maxVersionStringBytes is the maximum number of bytes that we'll
+// accept as a version string. RFC 4253 section 4.2 limits this at 255
+// chars
+const maxVersionStringBytes = 255
 
 // Read version string as specified by RFC 4253, section 4.2.
 func readVersion(r io.Reader) ([]byte, error) {
 	versionString := make([]byte, 0, 64)
 	var ok bool
 	var buf [1]byte
-forEachByte:
+
 	for len(versionString) < maxVersionStringBytes {
 		_, err := io.ReadFull(r, buf[:])
 		if err != nil {
@@ -379,13 +402,20 @@ forEachByte:
 		// but several SSH servers actually only send a \n.
 		if buf[0] == '\n' {
 			ok = true
-			break forEachByte
+			break
 		}
+
+		// non ASCII chars are disallowed, but we are lenient,
+		// since Go doesn't use null-terminated strings.
+
+		// The RFC allows a comment after a space, however,
+		// all of it (version and comments) goes into the
+		// session hash.
 		versionString = append(versionString, buf[0])
 	}
 
 	if !ok {
-		return nil, errors.New("ssh: failed to read version string")
+		return nil, errors.New("ssh: overflow reading version string")
 	}
 
 	// There might be a '\r' on the end which we should remove.

+ 44 - 26
ssh/transport_test.go

@@ -5,47 +5,65 @@
 package ssh
 
 import (
-	"bufio"
 	"bytes"
+	"strings"
 	"testing"
 )
 
 func TestReadVersion(t *testing.T) {
-	buf := serverVersion
-	result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
-	if err != nil {
-		t.Errorf("readVersion didn't read version correctly: %s", err)
+	longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	cases := map[string]string{
+		"SSH-2.0-bla\r\n":    "SSH-2.0-bla",
+		"SSH-2.0-bla\n":      "SSH-2.0-bla",
+		longversion + "\r\n": longversion,
 	}
-	if !bytes.Equal(buf[:len(buf)-2], result) {
-		t.Error("version read did not match expected")
+
+	for in, want := range cases {
+		result, err := readVersion(bytes.NewBufferString(in))
+		if err != nil {
+			t.Errorf("readVersion(%q): %s", in, err)
+		}
+		got := string(result)
+		if got != want {
+			t.Errorf("got %q, want %q", got, want)
+		}
 	}
 }
 
-func TestReadVersionWithJustLF(t *testing.T) {
-	var buf []byte
-	buf = append(buf, serverVersion...)
-	buf = buf[:len(buf)-1]
-	buf[len(buf)-1] = '\n'
-	result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
-	if err != nil {
-		t.Error("readVersion failed to handle just a \n")
+func TestReadVersionError(t *testing.T) {
+	longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	cases := []string{
+		longversion + "too-long\r\n",
 	}
-	if !bytes.Equal(buf[:len(buf)-1], result) {
-		t.Errorf("version read did not match expected: got %x, want %x", result, buf[:len(buf)-1])
+	for _, in := range cases {
+		if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
+			t.Errorf("readVersion(%q) should have failed", in)
+		}
 	}
 }
 
-func TestReadVersionTooLong(t *testing.T) {
-	buf := make([]byte, maxVersionStringBytes+1)
-	if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
-		t.Errorf("readVersion consumed %d bytes without error", len(buf))
+func TestExchangeVersionsBasic(t *testing.T) {
+	v := "SSH-2.0-bla"
+	buf := bytes.NewBufferString(v + "\r\n")
+	them, err := exchangeVersions(buf, []byte("xyz"))
+	if err != nil {
+		t.Errorf("exchangeVersions: %v", err)
+	}
+
+	if want := "SSH-2.0-bla"; string(them) != want {
+		t.Errorf("got %q want %q for our version", them, want)
 	}
 }
 
-func TestReadVersionWithoutCRLF(t *testing.T) {
-	buf := serverVersion
-	buf = buf[:len(buf)-1]
-	if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
-		t.Error("readVersion did not notice \\n was missing")
+func TestExchangeVersions(t *testing.T) {
+	cases := []string{
+		"not\x000allowed",
+		"not allowed\n",
+	}
+	for _, c := range cases {
+		buf := bytes.NewBufferString("SSH-2.0-bla\r\n")
+		if _, err := exchangeVersions(buf, []byte(c)); err == nil {
+			t.Errorf("exchangeVersions(%q): should have failed", c)
+		}
 	}
 }