소스 검색

ssh: invert algorithm choices on the server

At the protocol level, SSH lets client and server specify different
algorithms for the read and write half of the connection. This has
never worked correctly, as Client-to-Server was always interpreted as
the "write" side, even if we were the server.

This has never been a problem because, apparently, there are no
clients that insist on different algorithm choices running against Go
SSH servers.

Since the SSH package does not expose a mechanism to specify
algorithms for read/write separately, there is end-to-end for this
change, so add a unittest instead.

Change-Id: Ie3aa781630a3bb7a3b0e3754cb67b3ce12581544
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/172538
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Han-Wen Nienhuys 6 년 전
부모
커밋
df01cb2cc4
3개의 변경된 파일192개의 추가작업 그리고 9개의 파일을 삭제
  1. 13 7
      ssh/common.go
  2. 176 0
      ssh/common_test.go
  3. 3 2
      ssh/handshake.go

+ 13 - 7
ssh/common.go

@@ -109,6 +109,7 @@ func findCommon(what string, client []string, server []string) (common string, e
 	return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
 }
 
+// directionAlgorithms records algorithm choices in one direction (either read or write)
 type directionAlgorithms struct {
 	Cipher      string
 	MAC         string
@@ -137,7 +138,7 @@ type algorithms struct {
 	r       directionAlgorithms
 }
 
-func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
+func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
 	result := &algorithms{}
 
 	result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
@@ -150,32 +151,37 @@ func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algor
 		return
 	}
 
-	result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
+	stoc, ctos := &result.w, &result.r
+	if isClient {
+		ctos, stoc = stoc, ctos
+	}
+
+	ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
+	stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
 	if err != nil {
 		return
 	}
 
-	result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+	ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+	stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
 	if err != nil {
 		return
 	}
 
-	result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
+	ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
 	if err != nil {
 		return
 	}
 
-	result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
+	stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
 	if err != nil {
 		return
 	}

+ 176 - 0
ssh/common_test.go

@@ -0,0 +1,176 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestFindAgreedAlgorithms(t *testing.T) {
+	initKex := func(k *kexInitMsg) {
+		if k.KexAlgos == nil {
+			k.KexAlgos = []string{"kex1"}
+		}
+		if k.ServerHostKeyAlgos == nil {
+			k.ServerHostKeyAlgos = []string{"hostkey1"}
+		}
+		if k.CiphersClientServer == nil {
+			k.CiphersClientServer = []string{"cipher1"}
+
+		}
+		if k.CiphersServerClient == nil {
+			k.CiphersServerClient = []string{"cipher1"}
+
+		}
+		if k.MACsClientServer == nil {
+			k.MACsClientServer = []string{"mac1"}
+
+		}
+		if k.MACsServerClient == nil {
+			k.MACsServerClient = []string{"mac1"}
+
+		}
+		if k.CompressionClientServer == nil {
+			k.CompressionClientServer = []string{"compression1"}
+
+		}
+		if k.CompressionServerClient == nil {
+			k.CompressionServerClient = []string{"compression1"}
+
+		}
+		if k.LanguagesClientServer == nil {
+			k.LanguagesClientServer = []string{"language1"}
+
+		}
+		if k.LanguagesServerClient == nil {
+			k.LanguagesServerClient = []string{"language1"}
+
+		}
+	}
+
+	initDirAlgs := func(a *directionAlgorithms) {
+		if a.Cipher == "" {
+			a.Cipher = "cipher1"
+		}
+		if a.MAC == "" {
+			a.MAC = "mac1"
+		}
+		if a.Compression == "" {
+			a.Compression = "compression1"
+		}
+	}
+
+	initAlgs := func(a *algorithms) {
+		if a.kex == "" {
+			a.kex = "kex1"
+		}
+		if a.hostKey == "" {
+			a.hostKey = "hostkey1"
+		}
+		initDirAlgs(&a.r)
+		initDirAlgs(&a.w)
+	}
+
+	type testcase struct {
+		name                   string
+		clientIn, serverIn     kexInitMsg
+		wantClient, wantServer algorithms
+		wantErr                bool
+	}
+
+	cases := []testcase{
+		testcase{
+			name: "standard",
+		},
+
+		testcase{
+			name: "no common hostkey",
+			serverIn: kexInitMsg{
+				ServerHostKeyAlgos: []string{"hostkey2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "no common kex",
+			serverIn: kexInitMsg{
+				KexAlgos: []string{"kex2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "no common cipher",
+			serverIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher2"},
+			},
+			wantErr: true,
+		},
+
+		testcase{
+			name: "client decides cipher",
+			serverIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher1", "cipher2"},
+				CiphersServerClient: []string{"cipher2", "cipher3"},
+			},
+			clientIn: kexInitMsg{
+				CiphersClientServer: []string{"cipher2", "cipher1"},
+				CiphersServerClient: []string{"cipher3", "cipher2"},
+			},
+			wantClient: algorithms{
+				r: directionAlgorithms{
+					Cipher: "cipher3",
+				},
+				w: directionAlgorithms{
+					Cipher: "cipher2",
+				},
+			},
+			wantServer: algorithms{
+				w: directionAlgorithms{
+					Cipher: "cipher3",
+				},
+				r: directionAlgorithms{
+					Cipher: "cipher2",
+				},
+			},
+		},
+
+		// TODO(hanwen): fix and add tests for AEAD ignoring
+		// the MACs field
+	}
+
+	for i := range cases {
+		initKex(&cases[i].clientIn)
+		initKex(&cases[i].serverIn)
+		initAlgs(&cases[i].wantClient)
+		initAlgs(&cases[i].wantServer)
+	}
+
+	for _, c := range cases {
+		t.Run(c.name, func(t *testing.T) {
+			serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
+			clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)
+
+			serverHasErr := serverErr != nil
+			clientHasErr := clientErr != nil
+			if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
+				t.Fatalf("got client/server error (%v, %v), want hasError %v",
+					clientErr, serverErr, c.wantErr)
+
+			}
+			if c.wantErr {
+				return
+			}
+
+			if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
+				t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
+			}
+			if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
+				t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
+			}
+		})
+	}
+}

+ 3 - 2
ssh/handshake.go

@@ -543,7 +543,8 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 
 	clientInit := otherInit
 	serverInit := t.sentInitMsg
-	if len(t.hostKeys) == 0 {
+	isClient := len(t.hostKeys) == 0
+	if isClient {
 		clientInit, serverInit = serverInit, clientInit
 
 		magics.clientKexInit = t.sentInitPacket
@@ -551,7 +552,7 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 	}
 
 	var err error
-	t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
+	t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
 	if err != nil {
 		return err
 	}