瀏覽代碼

ssh: invert algorithm choices on the server

At the protocol level, SSH lets client and server specify different
algorithms for the read and write half of the connection. This has
never worked correctly, as Client-to-Server was always interpreted as
the "write" side, even if we were the server.

This has never been a problem because, apparently, there are no
clients that insist on different algorithm choices running against Go
SSH servers.

Since the SSH package does not expose a mechanism to specify
algorithms for read/write separately, there is end-to-end for this
change, so add a unittest instead.

Change-Id: Ie3aa781630a3bb7a3b0e3754cb67b3ce12581544
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/172538
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Han-Wen Nienhuys 6 年之前
父節點
當前提交
df01cb2cc4
共有 3 個文件被更改,包括 192 次插入9 次删除
  1. 13 7
      ssh/common.go
  2. 176 0
      ssh/common_test.go
  3. 3 2
      ssh/handshake.go

+ 13 - 7
ssh/common.go

@@ -109,6 +109,7 @@ func findCommon(what string, client []string, server []string) (common string, e
 	return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
 }
 
+// directionAlgorithms records algorithm choices in one direction (either read or write)
 type directionAlgorithms struct {
 	Cipher      string
 	MAC         string
@@ -137,7 +138,7 @@ type algorithms struct {
 	r       directionAlgorithms
 }
 
-func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
+func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
 	result := &algorithms{}
 
 	result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
@@ -150,32 +151,37 @@ func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algor
 		return
 	}
 
-	result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
+	stoc, ctos := &result.w, &result.r
+	if isClient {
+		ctos, stoc = stoc, ctos
+	}
+
+	ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
+	stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
 	if err != nil {
 		return
 	}
 
-	result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+	ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+	stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
 	if err != nil {
 		return
 	}
 
-	result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
+	ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
+	stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
 	if err != nil {
 		return
 	}

+ 176 - 0
ssh/common_test.go

@@ -0,0 +1,176 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestFindAgreedAlgorithms(t *testing.T) {
+	initKex := func(k *kexInitMsg) {
+		if k.KexAlgos == nil {
+			k.KexAlgos = []string{"kex1"}
+		}
+		if k.ServerHostKeyAlgos == nil {
+			k.ServerHostKeyAlgos = []string{"hostkey1"}
+		}
+		if k.CiphersClientServer == nil {
+			k.CiphersClientServer = []string{"cipher1"}
+
+		}
+		if k.CiphersServerClient == nil {
+			k.CiphersServerClient = []string{"cipher1"}
+
+		}
+		if k.MACsClientServer == nil {
+			k.MACsClientServer = []string{"mac1"}
+
+		}
+		if k.MACsServerClient == nil {
+			k.MACsServerClient = []string{"mac1"}
+
+		}
+		if k.CompressionClientServer == nil {
+			k.CompressionClientServer = []string{"compression1"}
+
+		}
+		if k.CompressionServerClient == nil {
+			k.CompressionServerClient = []string{"compression1"}
+
+		}
+		if k.LanguagesClientServer == nil {
+			k.LanguagesClientServer = []string{"language1"}
+
+		}
+		if k.LanguagesServerClient == nil {
+			k.LanguagesServerClient = []string{"language1"}
+
+		}
+	}
+
+	initDirAlgs := func(a *directionAlgorithms) {
+		if a.Cipher == "" {
+			a.Cipher = "cipher1"
+		}
+		if a.MAC == "" {
+			a.MAC = "mac1"
+		}
+		if a.Compression == "" {
+			a.Compression = "compression1"
+		}
+	}
+
+	initAlgs := func(a *algorithms) {
+		if a.kex == "" {
+			a.kex = "kex1"
+		}
+		if a.hostKey == "" {
+			a.hostKey = "hostkey1"
+		}
+		initDirAlgs(&a.r)
+		initDirAlgs(&a.w)
+	}
+
+	type testcase struct {
+		name                   string
+		clientIn, serverIn     kexInitMsg
+		wantClient, wantServer algorithms
+		wantErr                bool
+	}
+
+	cases := []testcase{
+		testcase{
+			name: "standard",
+		},
+
+		testcase{
+			name: "no common hostkey",
+			serverIn: kexInitMsg{
+				ServerHostKeyAlgos: []string{"hostkey2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "no common kex",
+			serverIn: kexInitMsg{
+				KexAlgos: []string{"kex2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "no common cipher",
+			serverIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "client decides cipher",
+			serverIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher1", "cipher2"},
+				CiphersServerClient: []string{"cipher2", "cipher3"},
+			},
+			clientIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher2", "cipher1"},
+				CiphersServerClient: []string{"cipher3", "cipher2"},
+			},
+			wantClient: algorithms{
+				r: directionAlgorithms{
+					Cipher: "cipher3",
+				},
+				w: directionAlgorithms{
+					Cipher: "cipher2",
+				},
+			},
+			wantServer: algorithms{
+				w: directionAlgorithms{
+					Cipher: "cipher3",
+				},
+				r: directionAlgorithms{
+					Cipher: "cipher2",
+				},
+			},
+		},
+
+		// TODO(hanwen): fix and add tests for AEAD ignoring
+		// the MACs field
+	}
+
+	for i := range cases {
+		initKex(&cases[i].clientIn)
+		initKex(&cases[i].serverIn)
+		initAlgs(&cases[i].wantClient)
+		initAlgs(&cases[i].wantServer)
+	}
+
+	for _, c := range cases {
+		t.Run(c.name, func(t *testing.T) {
+			serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
+			clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)
+
+			serverHasErr := serverErr != nil
+			clientHasErr := clientErr != nil
+			if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
+				t.Fatalf("got client/server error (%v, %v), want hasError %v",
+					clientErr, serverErr, c.wantErr)
+
+			}
+			if c.wantErr {
+				return
+			}
+
+			if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
+				t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
+			}
+			if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
+				t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
+			}
+		})
+	}
+}

+ 3 - 2
ssh/handshake.go

@@ -543,7 +543,8 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 
 	clientInit := otherInit
 	serverInit := t.sentInitMsg
-	if len(t.hostKeys) == 0 {
+	isClient := len(t.hostKeys) == 0
+	if isClient {
 		clientInit, serverInit = serverInit, clientInit
 
 		magics.clientKexInit = t.sentInitPacket
@@ -551,7 +552,7 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 	}
 
 	var err error
-	t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
+	t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
 	if err != nil {
 		return err
 	}