浏览代码

crypto/ssh: fix deadlock during error condition.

Unblock writers if a read error occurs while writers are blocked on a
pending key change.

Add test to check for deadlocks in error paths in handshake.go

Fixes golang/go#11992.

Change-Id: Id098bd9fec3d4fe83daeb2b7f935e5647c19afd3
Reviewed-on: https://go-review.googlesource.com/13594
Reviewed-by: Adam Langley <agl@golang.org>
Han-Wen Nienhuys 10 年之前
父节点
当前提交
c169681727
共有 2 个文件被更改,包括 108 次插入1 次删除
  1. 9 1
      ssh/handshake.go
  2. 99 0
      ssh/handshake_test.go

+ 9 - 1
ssh/handshake.go

@@ -153,6 +153,14 @@ func (t *handshakeTransport) readLoop() {
 		}
 		t.incoming <- p
 	}
+
+	// If we can't read, declare the writing part dead too.
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.writeError == nil {
+		t.writeError = t.readError
+	}
+	t.cond.Broadcast()
 }
 
 func (t *handshakeTransport) readOnePacket() ([]byte, error) {
@@ -270,7 +278,7 @@ func (t *handshakeTransport) writePacket(p []byte) error {
 	if t.writtenSinceKex > t.config.RekeyThreshold {
 		t.sendKexInitLocked()
 	}
-	for t.sentInitMsg != nil {
+	for t.sentInitMsg != nil && t.writeError == nil {
 		t.cond.Wait()
 	}
 	if t.writeError != nil {

+ 99 - 0
ssh/handshake_test.go

@@ -7,9 +7,12 @@ package ssh
 import (
 	"bytes"
 	"crypto/rand"
+	"errors"
 	"fmt"
 	"net"
 	"runtime"
+	"strings"
+	"sync"
 	"testing"
 )
 
@@ -314,3 +317,99 @@ func TestHandshakeAutoRekeyRead(t *testing.T) {
 
 	<-sync.called
 }
+
+// errorKeyingTransport generates errors after a given number of
+// read/write operations.
+type errorKeyingTransport struct {
+	packetConn
+	readLeft, writeLeft int
+}
+
+func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
+	return nil
+}
+func (n *errorKeyingTransport) getSessionID() []byte {
+	return nil
+}
+
+func (n *errorKeyingTransport) writePacket(packet []byte) error {
+	if n.writeLeft == 0 {
+		n.Close()
+		return errors.New("barf")
+	}
+
+	n.writeLeft--
+	return n.packetConn.writePacket(packet)
+}
+
+func (n *errorKeyingTransport) readPacket() ([]byte, error) {
+	if n.readLeft == 0 {
+		n.Close()
+		return nil, errors.New("barf")
+	}
+
+	n.readLeft--
+	return n.packetConn.readPacket()
+}
+
+func TestHandshakeErrorHandlingRead(t *testing.T) {
+	for i := 0; i < 20; i++ {
+		testHandshakeErrorHandlingN(t, i, -1)
+	}
+}
+
+func TestHandshakeErrorHandlingWrite(t *testing.T) {
+	for i := 0; i < 20; i++ {
+		testHandshakeErrorHandlingN(t, -1, i)
+	}
+}
+
+// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
+// handshakeTransport deadlocks, the go runtime will detect it and
+// panic.
+func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
+	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
+
+	a, b := memPipe()
+	defer a.Close()
+	defer b.Close()
+
+	key := testSigners["ecdsa"]
+	serverConf := Config{RekeyThreshold: minRekeyThreshold}
+	serverConf.SetDefaults()
+	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
+	serverConn.hostKeys = []Signer{key}
+	go serverConn.readLoop()
+
+	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
+	clientConf.SetDefaults()
+	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
+	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
+	go clientConn.readLoop()
+
+	var wg sync.WaitGroup
+	wg.Add(4)
+
+	for _, hs := range []packetConn{serverConn, clientConn} {
+		go func(c packetConn) {
+			for {
+				err := c.writePacket(msg)
+				if err != nil {
+					break
+				}
+			}
+			wg.Done()
+		}(hs)
+		go func(c packetConn) {
+			for {
+				_, err := c.readPacket()
+				if err != nil {
+					break
+				}
+			}
+			wg.Done()
+		}(hs)
+	}
+
+	wg.Wait()
+}