|
|
@@ -9,6 +9,7 @@ import (
|
|
|
"crypto/rand"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "io"
|
|
|
"net"
|
|
|
"reflect"
|
|
|
"runtime"
|
|
|
@@ -58,14 +59,46 @@ func netPipe() (net.Conn, net.Conn, error) {
|
|
|
return c1, c2, nil
|
|
|
}
|
|
|
|
|
|
-func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
|
|
|
+// noiseTransport inserts ignore messages to check that the read loop
|
|
|
+// and the key exchange filters out these messages.
|
|
|
+type noiseTransport struct {
|
|
|
+ keyingTransport
|
|
|
+}
|
|
|
+
|
|
|
+func (t *noiseTransport) writePacket(p []byte) error {
|
|
|
+ ignore := []byte{msgIgnore}
|
|
|
+ if err := t.keyingTransport.writePacket(ignore); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ debug := []byte{msgDebug, 1, 2, 3}
|
|
|
+ if err := t.keyingTransport.writePacket(debug); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return t.keyingTransport.writePacket(p)
|
|
|
+}
|
|
|
+
|
|
|
+func addNoiseTransport(t keyingTransport) keyingTransport {
|
|
|
+ return &noiseTransport{t}
|
|
|
+}
|
|
|
+
|
|
|
+// handshakePair creates two handshakeTransports connected with each
|
|
|
+// other. If the noise argument is true, both transports will try to
|
|
|
+// confuse the other side by sending ignore and debug messages.
|
|
|
+func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
|
|
|
a, b, err := netPipe()
|
|
|
if err != nil {
|
|
|
return nil, nil, err
|
|
|
}
|
|
|
|
|
|
- trC := newTransport(a, rand.Reader, true)
|
|
|
- trS := newTransport(b, rand.Reader, false)
|
|
|
+ var trC, trS keyingTransport
|
|
|
+
|
|
|
+ trC = newTransport(a, rand.Reader, true)
|
|
|
+ trS = newTransport(b, rand.Reader, false)
|
|
|
+ if noise {
|
|
|
+ trC = addNoiseTransport(trC)
|
|
|
+ trS = addNoiseTransport(trS)
|
|
|
+ }
|
|
|
clientConf.SetDefaults()
|
|
|
|
|
|
v := []byte("version")
|
|
|
@@ -85,7 +118,7 @@ func TestHandshakeBasic(t *testing.T) {
|
|
|
t.Skip("see golang.org/issue/7237")
|
|
|
}
|
|
|
checker := &testChecker{}
|
|
|
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
|
|
|
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", true)
|
|
|
if err != nil {
|
|
|
t.Fatalf("handshakePair: %v", err)
|
|
|
}
|
|
|
@@ -93,7 +126,9 @@ func TestHandshakeBasic(t *testing.T) {
|
|
|
defer trC.Close()
|
|
|
defer trS.Close()
|
|
|
|
|
|
+ clientDone := make(chan int, 0)
|
|
|
go func() {
|
|
|
+ defer close(clientDone)
|
|
|
// Client writes a bunch of stuff, and does a key
|
|
|
// change in the middle. This should not confuse the
|
|
|
// handshake in progress
|
|
|
@@ -115,8 +150,10 @@ func TestHandshakeBasic(t *testing.T) {
|
|
|
|
|
|
// Server checks that client messages come in cleanly
|
|
|
i := 0
|
|
|
+ err = nil
|
|
|
for {
|
|
|
- p, err := trS.readPacket()
|
|
|
+ var p []byte
|
|
|
+ p, err = trS.readPacket()
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
|
@@ -129,6 +166,10 @@ func TestHandshakeBasic(t *testing.T) {
|
|
|
}
|
|
|
i++
|
|
|
}
|
|
|
+ <-clientDone
|
|
|
+ if err != nil && err != io.EOF {
|
|
|
+ t.Fatalf("server error: %v", err)
|
|
|
+ }
|
|
|
if i != 10 {
|
|
|
t.Errorf("received %d messages, want 10.", i)
|
|
|
}
|
|
|
@@ -143,11 +184,12 @@ func TestHandshakeBasic(t *testing.T) {
|
|
|
if want != checker.calls[0] {
|
|
|
t.Errorf("got %q want %q for host key check", checker.calls[0], want)
|
|
|
}
|
|
|
+
|
|
|
}
|
|
|
|
|
|
func TestHandshakeError(t *testing.T) {
|
|
|
checker := &testChecker{}
|
|
|
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
|
|
|
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad", false)
|
|
|
if err != nil {
|
|
|
t.Fatalf("handshakePair: %v", err)
|
|
|
}
|
|
|
@@ -186,7 +228,7 @@ func TestHandshakeError(t *testing.T) {
|
|
|
|
|
|
func TestForceFirstKex(t *testing.T) {
|
|
|
checker := &testChecker{}
|
|
|
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
|
|
|
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
|
|
|
if err != nil {
|
|
|
t.Fatalf("handshakePair: %v", err)
|
|
|
}
|
|
|
@@ -208,7 +250,7 @@ func TestForceFirstKex(t *testing.T) {
|
|
|
|
|
|
func TestHandshakeTwice(t *testing.T) {
|
|
|
checker := &testChecker{}
|
|
|
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
|
|
|
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
|
|
|
if err != nil {
|
|
|
t.Fatalf("handshakePair: %v", err)
|
|
|
}
|
|
|
@@ -278,7 +320,7 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
|
|
|
checker := &testChecker{}
|
|
|
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
|
|
|
clientConf.RekeyThreshold = 500
|
|
|
- trC, trS, err := handshakePair(clientConf, "addr")
|
|
|
+ trC, trS, err := handshakePair(clientConf, "addr", false)
|
|
|
if err != nil {
|
|
|
t.Fatalf("handshakePair: %v", err)
|
|
|
}
|
|
|
@@ -326,7 +368,7 @@ func TestHandshakeAutoRekeyRead(t *testing.T) {
|
|
|
}
|
|
|
clientConf.RekeyThreshold = 500
|
|
|
|
|
|
- trC, trS, err := handshakePair(clientConf, "addr")
|
|
|
+ trC, trS, err := handshakePair(clientConf, "addr", false)
|
|
|
if err != nil {
|
|
|
t.Fatalf("handshakePair: %v", err)
|
|
|
}
|
|
|
@@ -448,7 +490,7 @@ func TestDisconnect(t *testing.T) {
|
|
|
t.Skip("see golang.org/issue/7237")
|
|
|
}
|
|
|
checker := &testChecker{}
|
|
|
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
|
|
|
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
|
|
|
if err != nil {
|
|
|
t.Fatalf("handshakePair: %v", err)
|
|
|
}
|