|
|
@@ -125,7 +125,12 @@ func TestHandshakeBasic(t *testing.T) {
|
|
|
t.Skip("see golang.org/issue/7237")
|
|
|
}
|
|
|
|
|
|
- checker := &syncChecker{make(chan int, 10)}
|
|
|
+ checker := &syncChecker{
|
|
|
+ waitCall: make(chan int, 10),
|
|
|
+ called: make(chan int, 10),
|
|
|
+ }
|
|
|
+
|
|
|
+ checker.waitCall <- 1
|
|
|
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
|
|
|
if err != nil {
|
|
|
t.Fatalf("handshakePair: %v", err)
|
|
|
@@ -134,22 +139,25 @@ func TestHandshakeBasic(t *testing.T) {
|
|
|
defer trC.Close()
|
|
|
defer trS.Close()
|
|
|
|
|
|
+ // Let first kex complete normally.
|
|
|
<-checker.called
|
|
|
|
|
|
clientDone := make(chan int, 0)
|
|
|
gotHalf := make(chan int, 0)
|
|
|
+ const N = 20
|
|
|
|
|
|
go func() {
|
|
|
defer close(clientDone)
|
|
|
// Client writes a bunch of stuff, and does a key
|
|
|
// change in the middle. This should not confuse the
|
|
|
- // handshake in progress
|
|
|
- for i := 0; i < 10; i++ {
|
|
|
+ // handshake in progress. We do this twice, so we test
|
|
|
+ // that the packet buffer is reset correctly.
|
|
|
+ for i := 0; i < N; i++ {
|
|
|
p := []byte{msgRequestSuccess, byte(i)}
|
|
|
if err := trC.writePacket(p); err != nil {
|
|
|
t.Fatalf("sendPacket: %v", err)
|
|
|
}
|
|
|
- if i == 5 {
|
|
|
+ if (i % 10) == 5 {
|
|
|
<-gotHalf
|
|
|
// halfway through, we request a key change.
|
|
|
trC.requestKeyExchange()
|
|
|
@@ -159,32 +167,38 @@ func TestHandshakeBasic(t *testing.T) {
|
|
|
// write more.
|
|
|
<-checker.called
|
|
|
}
|
|
|
+ if (i % 10) == 7 {
|
|
|
+ // write some packets until the kex
|
|
|
+ // completes, to test buffering of
|
|
|
+ // packets.
|
|
|
+ checker.waitCall <- 1
|
|
|
+ }
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
// Server checks that client messages come in cleanly
|
|
|
i := 0
|
|
|
err = nil
|
|
|
- for ; i < 10; i++ {
|
|
|
+ for ; i < N; i++ {
|
|
|
var p []byte
|
|
|
p, err = trS.readPacket()
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
|
- if i == 5 {
|
|
|
+ if (i % 10) == 5 {
|
|
|
gotHalf <- 1
|
|
|
}
|
|
|
|
|
|
want := []byte{msgRequestSuccess, byte(i)}
|
|
|
if bytes.Compare(p, want) != 0 {
|
|
|
- t.Errorf("message %d: got %q, want %q", i, p, want)
|
|
|
+ t.Errorf("message %d: got %v, want %v", i, p, want)
|
|
|
}
|
|
|
}
|
|
|
<-clientDone
|
|
|
if err != nil && err != io.EOF {
|
|
|
t.Fatalf("server error: %v", err)
|
|
|
}
|
|
|
- if i != 10 {
|
|
|
+ if i != N {
|
|
|
t.Errorf("received %d messages, want 10.", i)
|
|
|
}
|
|
|
|
|
|
@@ -239,7 +253,10 @@ func TestForceFirstKex(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
func TestHandshakeAutoRekeyWrite(t *testing.T) {
|
|
|
- checker := &syncChecker{make(chan int, 10)}
|
|
|
+ checker := &syncChecker{
|
|
|
+ called: make(chan int, 10),
|
|
|
+ waitCall: nil,
|
|
|
+ }
|
|
|
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
|
|
|
clientConf.RekeyThreshold = 500
|
|
|
trC, trS, err := handshakePair(clientConf, "addr", false)
|
|
|
@@ -249,14 +266,19 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
|
|
|
defer trC.Close()
|
|
|
defer trS.Close()
|
|
|
|
|
|
+ input := make([]byte, 251)
|
|
|
+ input[0] = msgRequestSuccess
|
|
|
+
|
|
|
done := make(chan int, 1)
|
|
|
const numPacket = 5
|
|
|
go func() {
|
|
|
defer close(done)
|
|
|
j := 0
|
|
|
for ; j < numPacket; j++ {
|
|
|
- if _, err := trS.readPacket(); err != nil {
|
|
|
+ if p, err := trS.readPacket(); err != nil {
|
|
|
break
|
|
|
+ } else if !bytes.Equal(input, p) {
|
|
|
+ t.Errorf("got packet type %d, want %d", p[0], input[0])
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -268,9 +290,9 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
|
|
|
<-checker.called
|
|
|
|
|
|
for i := 0; i < numPacket; i++ {
|
|
|
- packet := make([]byte, 251)
|
|
|
- packet[0] = msgRequestSuccess
|
|
|
- if err := trC.writePacket(packet); err != nil {
|
|
|
+ p := make([]byte, len(input))
|
|
|
+ copy(p, input)
|
|
|
+ if err := trC.writePacket(p); err != nil {
|
|
|
t.Errorf("writePacket: %v", err)
|
|
|
}
|
|
|
if i == 2 {
|
|
|
@@ -283,16 +305,23 @@ func TestHandshakeAutoRekeyWrite(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
type syncChecker struct {
|
|
|
- called chan int
|
|
|
+ waitCall chan int
|
|
|
+ called chan int
|
|
|
}
|
|
|
|
|
|
func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
|
|
|
c.called <- 1
|
|
|
+ if c.waitCall != nil {
|
|
|
+ <-c.waitCall
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func TestHandshakeAutoRekeyRead(t *testing.T) {
|
|
|
- sync := &syncChecker{make(chan int, 2)}
|
|
|
+ sync := &syncChecker{
|
|
|
+ called: make(chan int, 2),
|
|
|
+ waitCall: nil,
|
|
|
+ }
|
|
|
clientConf := &ClientConfig{
|
|
|
HostKeyCallback: sync.Check,
|
|
|
}
|