|
|
@@ -7,9 +7,12 @@ package ssh
|
|
|
import (
|
|
|
"bytes"
|
|
|
"crypto/rand"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"runtime"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
"testing"
|
|
|
)
|
|
|
|
|
|
@@ -314,3 +317,99 @@ func TestHandshakeAutoRekeyRead(t *testing.T) {
|
|
|
|
|
|
<-sync.called
|
|
|
}
|
|
|
+
|
|
|
+// errorKeyingTransport generates errors after a given number of
|
|
|
+// read/write operations.
|
|
|
+type errorKeyingTransport struct {
|
|
|
+ packetConn
|
|
|
+ readLeft, writeLeft int
|
|
|
+}
|
|
|
+
|
|
|
+func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
|
|
|
+ return nil
|
|
|
+}
|
|
|
+func (n *errorKeyingTransport) getSessionID() []byte {
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (n *errorKeyingTransport) writePacket(packet []byte) error {
|
|
|
+ if n.writeLeft == 0 {
|
|
|
+ n.Close()
|
|
|
+ return errors.New("barf")
|
|
|
+ }
|
|
|
+
|
|
|
+ n.writeLeft--
|
|
|
+ return n.packetConn.writePacket(packet)
|
|
|
+}
|
|
|
+
|
|
|
+func (n *errorKeyingTransport) readPacket() ([]byte, error) {
|
|
|
+ if n.readLeft == 0 {
|
|
|
+ n.Close()
|
|
|
+ return nil, errors.New("barf")
|
|
|
+ }
|
|
|
+
|
|
|
+ n.readLeft--
|
|
|
+ return n.packetConn.readPacket()
|
|
|
+}
|
|
|
+
|
|
|
+func TestHandshakeErrorHandlingRead(t *testing.T) {
|
|
|
+ for i := 0; i < 20; i++ {
|
|
|
+ testHandshakeErrorHandlingN(t, i, -1)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestHandshakeErrorHandlingWrite(t *testing.T) {
|
|
|
+ for i := 0; i < 20; i++ {
|
|
|
+ testHandshakeErrorHandlingN(t, -1, i)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
|
|
|
+// handshakeTransport deadlocks, the go runtime will detect it and
|
|
|
+// panic.
|
|
|
+func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
|
|
|
+ msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
|
|
|
+
|
|
|
+ a, b := memPipe()
|
|
|
+ defer a.Close()
|
|
|
+ defer b.Close()
|
|
|
+
|
|
|
+ key := testSigners["ecdsa"]
|
|
|
+ serverConf := Config{RekeyThreshold: minRekeyThreshold}
|
|
|
+ serverConf.SetDefaults()
|
|
|
+ serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
|
|
|
+ serverConn.hostKeys = []Signer{key}
|
|
|
+ go serverConn.readLoop()
|
|
|
+
|
|
|
+ clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
|
|
|
+ clientConf.SetDefaults()
|
|
|
+ clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
|
|
|
+ clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
|
|
|
+ go clientConn.readLoop()
|
|
|
+
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ wg.Add(4)
|
|
|
+
|
|
|
+ for _, hs := range []packetConn{serverConn, clientConn} {
|
|
|
+ go func(c packetConn) {
|
|
|
+ for {
|
|
|
+ err := c.writePacket(msg)
|
|
|
+ if err != nil {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ wg.Done()
|
|
|
+ }(hs)
|
|
|
+ go func(c packetConn) {
|
|
|
+ for {
|
|
|
+ _, err := c.readPacket()
|
|
|
+ if err != nil {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ wg.Done()
|
|
|
+ }(hs)
|
|
|
+ }
|
|
|
+
|
|
|
+ wg.Wait()
|
|
|
+}
|