Browse Source

ssh: fix protocol version exchange (for multi-line)

Fixes golang/go#23194

During SSH Protocol Version Exchange, a client may send metadata lines
prior to sending the SSH version string. To conform to the RFC, all SSH
implementations must support this (minimally, clients can ignore the
metadata lines).

For example, this is valid:
some-metadata
SSH-2.0-OpenSSH

The current Go implementation takes the first line it sees as
the version string (in this case, some-metadata). Then, it uses
the next line (SSH-2.0-OpenSSH) as part of key exchange, which
is guaranteed to fail.

Unfortunately, this SSH feature is used by some vendors and is part
of the official RFC: https://tools.ietf.org/html/rfc4253#section-4.2

Change-Id: I7be61700a07756353875bf43aad09a580ba533ff
Reviewed-on: https://go-review.googlesource.com/86675
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Justin Gracenin 8 năm trước cách đây
mục cha
commit
5f55bce93a
3 tập tin đã thay đổi với 82 bổ sung34 xóa
  1. 64 28
      ssh/client_test.go
  2. 9 1
      ssh/transport.go
  3. 9 5
      ssh/transport_test.go

+ 64 - 28
ssh/client_test.go

@@ -5,41 +5,77 @@
 package ssh
 
 import (
-	"net"
 	"strings"
 	"testing"
 )
 
-func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
-	clientConn, serverConn := net.Pipe()
-	defer clientConn.Close()
-	receivedVersion := make(chan string, 1)
-	config.HostKeyCallback = InsecureIgnoreHostKey()
-	go func() {
-		version, err := readVersion(serverConn)
-		if err != nil {
-			receivedVersion <- ""
-		} else {
-			receivedVersion <- string(version)
-		}
-		serverConn.Close()
-	}()
-	NewClientConn(clientConn, "", config)
-	actual := <-receivedVersion
-	if actual != expected {
-		t.Fatalf("got %s; want %s", actual, expected)
+func TestClientVersion(t *testing.T) {
+	for _, tt := range []struct {
+		name      string
+		version   string
+		multiLine string
+		wantErr   bool
+	}{
+		{
+			name:    "default version",
+			version: packageVersion,
+		},
+		{
+			name:    "custom version",
+			version: "SSH-2.0-CustomClientVersionString",
+		},
+		{
+			name:      "good multi line version",
+			version:   packageVersion,
+			multiLine: strings.Repeat("ignored\r\n", 20),
+		},
+		{
+			name:      "bad multi line version",
+			version:   packageVersion,
+			multiLine: "bad multi line version",
+			wantErr:   true,
+		},
+		{
+			name:      "long multi line version",
+			version:   packageVersion,
+			multiLine: strings.Repeat("long multi line version\r\n", 50)[:256],
+			wantErr:   true,
+		},
+	} {
+		t.Run(tt.name, func(t *testing.T) {
+			c1, c2, err := netPipe()
+			if err != nil {
+				t.Fatalf("netPipe: %v", err)
+			}
+			defer c1.Close()
+			defer c2.Close()
+			go func() {
+				if tt.multiLine != "" {
+					c1.Write([]byte(tt.multiLine))
+				}
+				NewClientConn(c1, "", &ClientConfig{
+					ClientVersion:   tt.version,
+					HostKeyCallback: InsecureIgnoreHostKey(),
+				})
+				c1.Close()
+			}()
+			conf := &ServerConfig{NoClientAuth: true}
+			conf.AddHostKey(testSigners["rsa"])
+			conn, _, _, err := NewServerConn(c2, conf)
+			if err == nil == tt.wantErr {
+				t.Fatalf("got err %v; wantErr %t", err, tt.wantErr)
+			}
+			if tt.wantErr {
+				// Don't verify the version on an expected error.
+				return
+			}
+			if got := string(conn.ClientVersion()); got != tt.version {
+				t.Fatalf("got %q; want %q", got, tt.version)
+			}
+		})
 	}
 }
 
-func TestCustomClientVersion(t *testing.T) {
-	version := "Test-Client-Version-0.0"
-	testClientVersion(t, &ClientConfig{ClientVersion: version}, version)
-}
-
-func TestDefaultClientVersion(t *testing.T) {
-	testClientVersion(t, &ClientConfig{}, packageVersion)
-}
-
 func TestHostKeyCheck(t *testing.T) {
 	for _, tt := range []struct {
 		name      string

+ 9 - 1
ssh/transport.go

@@ -6,6 +6,7 @@ package ssh
 
 import (
 	"bufio"
+	"bytes"
 	"errors"
 	"io"
 	"log"
@@ -342,7 +343,7 @@ func readVersion(r io.Reader) ([]byte, error) {
 	var ok bool
 	var buf [1]byte
 
-	for len(versionString) < maxVersionStringBytes {
+	for length := 0; length < maxVersionStringBytes; length++ {
 		_, err := io.ReadFull(r, buf[:])
 		if err != nil {
 			return nil, err
@@ -350,6 +351,13 @@ func readVersion(r io.Reader) ([]byte, error) {
 		// The RFC says that the version should be terminated with \r\n
 		// but several SSH servers actually only send a \n.
 		if buf[0] == '\n' {
+			if !bytes.HasPrefix(versionString, []byte("SSH-")) {
+				// RFC 4253 says we need to ignore all version string lines
+				// except the one containing the SSH version (provided that
+				// all the lines do not exceed 255 bytes in total).
+				versionString = versionString[:0]
+				continue
+			}
 			ok = true
 			break
 		}

+ 9 - 5
ssh/transport_test.go

@@ -13,11 +13,13 @@ import (
 )
 
 func TestReadVersion(t *testing.T) {
-	longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	multiLineVersion := strings.Repeat("ignored\r\n", 20) + "SSH-2.0-bla\r\n"
 	cases := map[string]string{
 		"SSH-2.0-bla\r\n":    "SSH-2.0-bla",
 		"SSH-2.0-bla\n":      "SSH-2.0-bla",
-		longversion + "\r\n": longversion,
+		multiLineVersion:     "SSH-2.0-bla",
+		longVersion + "\r\n": longVersion,
 	}
 
 	for in, want := range cases {
@@ -33,9 +35,11 @@ func TestReadVersion(t *testing.T) {
 }
 
 func TestReadVersionError(t *testing.T) {
-	longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	multiLineVersion := strings.Repeat("ignored\r\n", 50) + "SSH-2.0-bla\r\n"
 	cases := []string{
-		longversion + "too-long\r\n",
+		longVersion + "too-long\r\n",
+		multiLineVersion,
 	}
 	for _, in := range cases {
 		if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
@@ -60,7 +64,7 @@ func TestExchangeVersionsBasic(t *testing.T) {
 func TestExchangeVersions(t *testing.T) {
 	cases := []string{
 		"not\x000allowed",
-		"not allowed\n",
+		"not allowed\x01\r\n",
 	}
 	for _, c := range cases {
 		buf := bytes.NewBufferString("SSH-2.0-bla\r\n")