Browse Source

ssh: reset buffered packets after sending

Since encryption messes up the packets, the wrongly retained packets
look like noise and cause application protocol errors or panics in the
SSH library.

This normally triggers very rarely: the mandatory key exchange doesn't
have parallel writes, so this failure condition would be setup on the
first key exchange, take effect only after the second key exchange.

Fortunately, the tests against openssh exercise this. This change adds
also adds a unittest.

Fixes #18850.

Change-Id: I656c8b94bfb265831daa118f4d614a2f0c65d2af
Reviewed-on: https://go-review.googlesource.com/36056
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Han-Wen Nienhuys 8 years ago
parent
commit
641ab6b320
2 changed files with 45 additions and 16 deletions
  1. 1 1
      ssh/handshake.go
  2. 44 15
      ssh/handshake_test.go

+ 1 - 1
ssh/handshake.go

@@ -314,7 +314,7 @@ write:
 				break
 				break
 			}
 			}
 		}
 		}
-		t.pendingPackets = t.pendingPackets[0:]
+		t.pendingPackets = t.pendingPackets[:0]
 		t.mu.Unlock()
 		t.mu.Unlock()
 	}
 	}
 
 

+ 44 - 15
ssh/handshake_test.go

@@ -125,7 +125,12 @@ func TestHandshakeBasic(t *testing.T) {
 		t.Skip("see golang.org/issue/7237")
 		t.Skip("see golang.org/issue/7237")
 	}
 	}
 
 
-	checker := &syncChecker{make(chan int, 10)}
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
 	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
 	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("handshakePair: %v", err)
 		t.Fatalf("handshakePair: %v", err)
@@ -134,22 +139,25 @@ func TestHandshakeBasic(t *testing.T) {
 	defer trC.Close()
 	defer trC.Close()
 	defer trS.Close()
 	defer trS.Close()
 
 
+	// Let first kex complete normally.
 	<-checker.called
 	<-checker.called
 
 
 	clientDone := make(chan int, 0)
 	clientDone := make(chan int, 0)
 	gotHalf := make(chan int, 0)
 	gotHalf := make(chan int, 0)
+	const N = 20
 
 
 	go func() {
 	go func() {
 		defer close(clientDone)
 		defer close(clientDone)
 		// Client writes a bunch of stuff, and does a key
 		// Client writes a bunch of stuff, and does a key
 		// change in the middle. This should not confuse the
 		// change in the middle. This should not confuse the
-		// handshake in progress
-		for i := 0; i < 10; i++ {
+		// handshake in progress. We do this twice, so we test
+		// that the packet buffer is reset correctly.
+		for i := 0; i < N; i++ {
 			p := []byte{msgRequestSuccess, byte(i)}
 			p := []byte{msgRequestSuccess, byte(i)}
 			if err := trC.writePacket(p); err != nil {
 			if err := trC.writePacket(p); err != nil {
 				t.Fatalf("sendPacket: %v", err)
 				t.Fatalf("sendPacket: %v", err)
 			}
 			}
-			if i == 5 {
+			if (i % 10) == 5 {
 				<-gotHalf
 				<-gotHalf
 				// halfway through, we request a key change.
 				// halfway through, we request a key change.
 				trC.requestKeyExchange()
 				trC.requestKeyExchange()
@@ -159,32 +167,38 @@ func TestHandshakeBasic(t *testing.T) {
 				// write more.
 				// write more.
 				<-checker.called
 				<-checker.called
 			}
 			}
+			if (i % 10) == 7 {
+				// write some packets until the kex
+				// completes, to test buffering of
+				// packets.
+				checker.waitCall <- 1
+			}
 		}
 		}
 	}()
 	}()
 
 
 	// Server checks that client messages come in cleanly
 	// Server checks that client messages come in cleanly
 	i := 0
 	i := 0
 	err = nil
 	err = nil
-	for ; i < 10; i++ {
+	for ; i < N; i++ {
 		var p []byte
 		var p []byte
 		p, err = trS.readPacket()
 		p, err = trS.readPacket()
 		if err != nil {
 		if err != nil {
 			break
 			break
 		}
 		}
-		if i == 5 {
+		if (i % 10) == 5 {
 			gotHalf <- 1
 			gotHalf <- 1
 		}
 		}
 
 
 		want := []byte{msgRequestSuccess, byte(i)}
 		want := []byte{msgRequestSuccess, byte(i)}
 		if bytes.Compare(p, want) != 0 {
 		if bytes.Compare(p, want) != 0 {
-			t.Errorf("message %d: got %q, want %q", i, p, want)
+			t.Errorf("message %d: got %v, want %v", i, p, want)
 		}
 		}
 	}
 	}
 	<-clientDone
 	<-clientDone
 	if err != nil && err != io.EOF {
 	if err != nil && err != io.EOF {
 		t.Fatalf("server error: %v", err)
 		t.Fatalf("server error: %v", err)
 	}
 	}
-	if i != 10 {
+	if i != N {
 		t.Errorf("received %d messages, want 10.", i)
 		t.Errorf("received %d messages, want 10.", i)
 	}
 	}
 
 
@@ -239,7 +253,10 @@ func TestForceFirstKex(t *testing.T) {
 }
 }
 
 
 func TestHandshakeAutoRekeyWrite(t *testing.T) {
 func TestHandshakeAutoRekeyWrite(t *testing.T) {
-	checker := &syncChecker{make(chan int, 10)}
+	checker := &syncChecker{
+		called:   make(chan int, 10),
+		waitCall: nil,
+	}
 	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
 	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
 	clientConf.RekeyThreshold = 500
 	clientConf.RekeyThreshold = 500
 	trC, trS, err := handshakePair(clientConf, "addr", false)
 	trC, trS, err := handshakePair(clientConf, "addr", false)
@@ -249,14 +266,19 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
 	defer trC.Close()
 	defer trC.Close()
 	defer trS.Close()
 	defer trS.Close()
 
 
+	input := make([]byte, 251)
+	input[0] = msgRequestSuccess
+
 	done := make(chan int, 1)
 	done := make(chan int, 1)
 	const numPacket = 5
 	const numPacket = 5
 	go func() {
 	go func() {
 		defer close(done)
 		defer close(done)
 		j := 0
 		j := 0
 		for ; j < numPacket; j++ {
 		for ; j < numPacket; j++ {
-			if _, err := trS.readPacket(); err != nil {
+			if p, err := trS.readPacket(); err != nil {
 				break
 				break
+			} else if !bytes.Equal(input, p) {
+				t.Errorf("got packet type %d, want %d", p[0], input[0])
 			}
 			}
 		}
 		}
 
 
@@ -268,9 +290,9 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
 	<-checker.called
 	<-checker.called
 
 
 	for i := 0; i < numPacket; i++ {
 	for i := 0; i < numPacket; i++ {
-		packet := make([]byte, 251)
-		packet[0] = msgRequestSuccess
-		if err := trC.writePacket(packet); err != nil {
+		p := make([]byte, len(input))
+		copy(p, input)
+		if err := trC.writePacket(p); err != nil {
 			t.Errorf("writePacket: %v", err)
 			t.Errorf("writePacket: %v", err)
 		}
 		}
 		if i == 2 {
 		if i == 2 {
@@ -283,16 +305,23 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
 }
 }
 
 
 type syncChecker struct {
 type syncChecker struct {
-	called chan int
+	waitCall chan int
+	called   chan int
 }
 }
 
 
 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
 	c.called <- 1
 	c.called <- 1
+	if c.waitCall != nil {
+		<-c.waitCall
+	}
 	return nil
 	return nil
 }
 }
 
 
 func TestHandshakeAutoRekeyRead(t *testing.T) {
 func TestHandshakeAutoRekeyRead(t *testing.T) {
-	sync := &syncChecker{make(chan int, 2)}
+	sync := &syncChecker{
+		called:   make(chan int, 2),
+		waitCall: nil,
+	}
 	clientConf := &ClientConfig{
 	clientConf := &ClientConfig{
 		HostKeyCallback: sync.Check,
 		HostKeyCallback: sync.Check,
 	}
 	}