|
|
@@ -10,6 +10,7 @@ import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
+ "reflect"
|
|
|
"runtime"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
@@ -413,3 +414,45 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
|
|
|
|
|
|
wg.Wait()
|
|
|
}
|
|
|
+
|
|
|
+func TestDisconnect(t *testing.T) {
|
|
|
+ if runtime.GOOS == "plan9" {
|
|
|
+ t.Skip("see golang.org/issue/7237")
|
|
|
+ }
|
|
|
+ checker := &testChecker{}
|
|
|
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("handshakePair: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ defer trC.Close()
|
|
|
+ defer trS.Close()
|
|
|
+
|
|
|
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0})
|
|
|
+ errMsg := &disconnectMsg{
|
|
|
+ Reason: 42,
|
|
|
+ Message: "such is life",
|
|
|
+ }
|
|
|
+ trC.writePacket(Marshal(errMsg))
|
|
|
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0})
|
|
|
+
|
|
|
+ packet, err := trS.readPacket()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("readPacket 1: %v", err)
|
|
|
+ }
|
|
|
+ if packet[0] != msgRequestSuccess {
|
|
|
+ t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err = trS.readPacket()
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("readPacket 2 succeeded")
|
|
|
+ } else if !reflect.DeepEqual(err, errMsg) {
|
|
|
+ t.Errorf("got error %#v, want %#v", err, errMsg)
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err = trS.readPacket()
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("readPacket 3 succeeded")
|
|
|
+ }
|
|
|
+}
|