Jelajahi Sumber

ssh: reset buffered packets after sending

Since encryption messes up the packets, the wrongly retained packets
look like noise and cause application protocol errors or panics in the
SSH library.

This normally triggers very rarely: the mandatory key exchange doesn't
have parallel writes, so this failure condition would be setup on the
first key exchange, take effect only after the second key exchange.

Fortunately, the tests against openssh exercise this. This change adds
also adds a unittest.

Fixes #18850.

Change-Id: I656c8b94bfb265831daa118f4d614a2f0c65d2af
Reviewed-on: https://go-review.googlesource.com/36056
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Han-Wen Nienhuys 8 tahun lalu
induk
melakukan
641ab6b320
2 mengubah file dengan 45 tambahan dan 16 penghapusan
  1. 1 1
      ssh/handshake.go
  2. 44 15
      ssh/handshake_test.go

+ 1 - 1
ssh/handshake.go

@@ -314,7 +314,7 @@ write:
 				break
 			}
 		}
-		t.pendingPackets = t.pendingPackets[0:]
+		t.pendingPackets = t.pendingPackets[:0]
 		t.mu.Unlock()
 	}
 

+ 44 - 15
ssh/handshake_test.go

@@ -125,7 +125,12 @@ func TestHandshakeBasic(t *testing.T) {
 		t.Skip("see golang.org/issue/7237")
 	}
 
-	checker := &syncChecker{make(chan int, 10)}
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
 	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
 	if err != nil {
 		t.Fatalf("handshakePair: %v", err)
@@ -134,22 +139,25 @@ func TestHandshakeBasic(t *testing.T) {
 	defer trC.Close()
 	defer trS.Close()
 
+	// Let first kex complete normally.
 	<-checker.called
 
 	clientDone := make(chan int, 0)
 	gotHalf := make(chan int, 0)
+	const N = 20
 
 	go func() {
 		defer close(clientDone)
 		// Client writes a bunch of stuff, and does a key
 		// change in the middle. This should not confuse the
-		// handshake in progress
-		for i := 0; i < 10; i++ {
+		// handshake in progress. We do this twice, so we test
+		// that the packet buffer is reset correctly.
+		for i := 0; i < N; i++ {
 			p := []byte{msgRequestSuccess, byte(i)}
 			if err := trC.writePacket(p); err != nil {
 				t.Fatalf("sendPacket: %v", err)
 			}
-			if i == 5 {
+			if (i % 10) == 5 {
 				<-gotHalf
 				// halfway through, we request a key change.
 				trC.requestKeyExchange()
@@ -159,32 +167,38 @@ func TestHandshakeBasic(t *testing.T) {
 				// write more.
 				<-checker.called
 			}
+			if (i % 10) == 7 {
+				// write some packets until the kex
+				// completes, to test buffering of
+				// packets.
+				checker.waitCall <- 1
+			}
 		}
 	}()
 
 	// Server checks that client messages come in cleanly
 	i := 0
 	err = nil
-	for ; i < 10; i++ {
+	for ; i < N; i++ {
 		var p []byte
 		p, err = trS.readPacket()
 		if err != nil {
 			break
 		}
-		if i == 5 {
+		if (i % 10) == 5 {
 			gotHalf <- 1
 		}
 
 		want := []byte{msgRequestSuccess, byte(i)}
 		if bytes.Compare(p, want) != 0 {
-			t.Errorf("message %d: got %q, want %q", i, p, want)
+			t.Errorf("message %d: got %v, want %v", i, p, want)
 		}
 	}
 	<-clientDone
 	if err != nil && err != io.EOF {
 		t.Fatalf("server error: %v", err)
 	}
-	if i != 10 {
+	if i != N {
 		t.Errorf("received %d messages, want 10.", i)
 	}
 
@@ -239,7 +253,10 @@ func TestForceFirstKex(t *testing.T) {
 }
 
 func TestHandshakeAutoRekeyWrite(t *testing.T) {
-	checker := &syncChecker{make(chan int, 10)}
+	checker := &syncChecker{
+		called:   make(chan int, 10),
+		waitCall: nil,
+	}
 	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
 	clientConf.RekeyThreshold = 500
 	trC, trS, err := handshakePair(clientConf, "addr", false)
@@ -249,14 +266,19 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
 	defer trC.Close()
 	defer trS.Close()
 
+	input := make([]byte, 251)
+	input[0] = msgRequestSuccess
+
 	done := make(chan int, 1)
 	const numPacket = 5
 	go func() {
 		defer close(done)
 		j := 0
 		for ; j < numPacket; j++ {
-			if _, err := trS.readPacket(); err != nil {
+			if p, err := trS.readPacket(); err != nil {
 				break
+			} else if !bytes.Equal(input, p) {
+				t.Errorf("got packet type %d, want %d", p[0], input[0])
 			}
 		}
 
@@ -268,9 +290,9 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
 	<-checker.called
 
 	for i := 0; i < numPacket; i++ {
-		packet := make([]byte, 251)
-		packet[0] = msgRequestSuccess
-		if err := trC.writePacket(packet); err != nil {
+		p := make([]byte, len(input))
+		copy(p, input)
+		if err := trC.writePacket(p); err != nil {
 			t.Errorf("writePacket: %v", err)
 		}
 		if i == 2 {
@@ -283,16 +305,23 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
 }
 
 type syncChecker struct {
-	called chan int
+	waitCall chan int
+	called   chan int
 }
 
 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
 	c.called <- 1
+	if c.waitCall != nil {
+		<-c.waitCall
+	}
 	return nil
 }
 
 func TestHandshakeAutoRekeyRead(t *testing.T) {
-	sync := &syncChecker{make(chan int, 2)}
+	sync := &syncChecker{
+		called:   make(chan int, 2),
+		waitCall: nil,
+	}
 	clientConf := &ClientConfig{
 		HostKeyCallback: sync.Check,
 	}