Ver Fonte

x/crypto/ssh: interpret disconnect message as error in the transport layer.

This ensures that higher level parts (e.g. the client authentication
loop) never have to deal with disconnect messages.

Fixes https://github.com/coreos/fleet/issues/565.

Change-Id: Ie164b6c4b0982c7ed9af6d3bf91697a78a911a20
Reviewed-on: https://go-review.googlesource.com/20801
Reviewed-by: Anton Khramov <anton@endocode.com>
Reviewed-by: Adam Langley <agl@golang.org>
Han-Wen Nienhuys há 9 anos atrás
pai
commit
9e7f5dc375
6 ficheiros alterados com 64 adições e 57 exclusões
  1. 0 2
      ssh/client_auth.go
  2. 43 0
      ssh/handshake_test.go
  3. 1 1
      ssh/messages.go
  4. 0 26
      ssh/mux.go
  5. 0 23
      ssh/mux_test.go
  6. 20 5
      ssh/transport.go

+ 0 - 2
ssh/client_auth.go

@@ -321,8 +321,6 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
 			return false, msg.Methods, nil
 			return false, msg.Methods, nil
 		case msgUserAuthSuccess:
 		case msgUserAuthSuccess:
 			return true, nil, nil
 			return true, nil, nil
-		case msgDisconnect:
-			return false, nil, io.EOF
 		default:
 		default:
 			return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
 			return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
 		}
 		}

+ 43 - 0
ssh/handshake_test.go

@@ -10,6 +10,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"reflect"
 	"runtime"
 	"runtime"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -413,3 +414,45 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
 
 
 	wg.Wait()
 	wg.Wait()
 }
 }
+
+func TestDisconnect(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+	checker := &testChecker{}
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+
+	defer trC.Close()
+	defer trS.Close()
+
+	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+	errMsg := &disconnectMsg{
+		Reason: 42,
+		Message: "such is life",
+	}
+	trC.writePacket(Marshal(errMsg))
+	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+
+	packet, err := trS.readPacket()
+	if err != nil {
+		t.Fatalf("readPacket 1: %v", err)
+	}
+	if packet[0] != msgRequestSuccess {
+		t.Errorf("got packet %v, want packet type %d", packet,  msgRequestSuccess)
+	}
+
+	_, err = trS.readPacket()
+	if err == nil {
+		t.Errorf("readPacket 2 succeeded")
+	} else if !reflect.DeepEqual(err, errMsg) {
+		t.Errorf("got error %#v, want %#v", err, errMsg)
+	}
+
+	_, err = trS.readPacket()
+	if err == nil {
+		t.Errorf("readPacket 3 succeeded")
+	}
+}

+ 1 - 1
ssh/messages.go

@@ -47,7 +47,7 @@ type disconnectMsg struct {
 }
 }
 
 
 func (d *disconnectMsg) Error() string {
 func (d *disconnectMsg) Error() string {
-	return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
+	return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
 }
 }
 
 
 // See RFC 4253, section 7.1.
 // See RFC 4253, section 7.1.

+ 0 - 26
ssh/mux.go

@@ -175,18 +175,6 @@ func (m *mux) ackRequest(ok bool, data []byte) error {
 	return m.sendMessage(globalRequestFailureMsg{Data: data})
 	return m.sendMessage(globalRequestFailureMsg{Data: data})
 }
 }
 
 
-// TODO(hanwen): Disconnect is a transport layer message. We should
-// probably send and receive Disconnect somewhere in the transport
-// code.
-
-// Disconnect sends a disconnect message.
-func (m *mux) Disconnect(reason uint32, message string) error {
-	return m.sendMessage(disconnectMsg{
-		Reason:  reason,
-		Message: message,
-	})
-}
-
 func (m *mux) Close() error {
 func (m *mux) Close() error {
 	return m.conn.Close()
 	return m.conn.Close()
 }
 }
@@ -239,8 +227,6 @@ func (m *mux) onePacket() error {
 	case msgNewKeys:
 	case msgNewKeys:
 		// Ignore notification of key change.
 		// Ignore notification of key change.
 		return nil
 		return nil
-	case msgDisconnect:
-		return m.handleDisconnect(packet)
 	case msgChannelOpen:
 	case msgChannelOpen:
 		return m.handleChannelOpen(packet)
 		return m.handleChannelOpen(packet)
 	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
 	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
@@ -260,18 +246,6 @@ func (m *mux) onePacket() error {
 	return ch.handlePacket(packet)
 	return ch.handlePacket(packet)
 }
 }
 
 
-func (m *mux) handleDisconnect(packet []byte) error {
-	var d disconnectMsg
-	if err := Unmarshal(packet, &d); err != nil {
-		return err
-	}
-
-	if debugMux {
-		log.Printf("caught disconnect: %v", d)
-	}
-	return &d
-}
-
 func (m *mux) handleGlobalPacket(packet []byte) error {
 func (m *mux) handleGlobalPacket(packet []byte) error {
 	msg, err := decode(packet)
 	msg, err := decode(packet)
 	if err != nil {
 	if err != nil {

+ 0 - 23
ssh/mux_test.go

@@ -331,7 +331,6 @@ func TestMuxGlobalRequest(t *testing.T) {
 			ok, data, err)
 			ok, data, err)
 	}
 	}
 
 
-	clientMux.Disconnect(0, "")
 	if !seen {
 	if !seen {
 		t.Errorf("never saw 'peek' request")
 		t.Errorf("never saw 'peek' request")
 	}
 	}
@@ -378,28 +377,6 @@ func TestMuxChannelRequestUnblock(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestMuxDisconnect(t *testing.T) {
-	a, b := muxPair()
-	defer a.Close()
-	defer b.Close()
-
-	go func() {
-		for r := range b.incomingRequests {
-			r.Reply(true, nil)
-		}
-	}()
-
-	a.Disconnect(42, "whatever")
-	ok, _, err := a.SendRequest("hello", true, nil)
-	if ok || err == nil {
-		t.Errorf("got reply after disconnecting")
-	}
-	err = b.Wait()
-	if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
-		t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
-	}
-}
-
 func TestMuxCloseChannel(t *testing.T) {
 func TestMuxCloseChannel(t *testing.T) {
 	r, w, mux := channelPair(t)
 	r, w, mux := channelPair(t)
 	defer mux.Close()
 	defer mux.Close()

+ 20 - 5
ssh/transport.go

@@ -114,12 +114,27 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
 		err = errors.New("ssh: zero length packet")
 		err = errors.New("ssh: zero length packet")
 	}
 	}
 
 
-	if len(packet) > 0 && packet[0] == msgNewKeys {
-		select {
-		case cipher := <-s.pendingKeyChange:
+	if len(packet) > 0 {
+		switch packet[0] {
+		case msgNewKeys:
+			select {
+			case cipher := <-s.pendingKeyChange:
 			s.packetCipher = cipher
 			s.packetCipher = cipher
-		default:
-			return nil, errors.New("ssh: got bogus newkeys message.")
+			default:
+				return nil, errors.New("ssh: got bogus newkeys message.")
+			}
+
+		case msgDisconnect:
+			// Transform a disconnect message into an
+			// error. Since this is lowest level at which
+			// we interpret message types, doing it here
+			// ensures that we don't have to handle it
+			// elsewhere.
+			var msg disconnectMsg
+			if err := Unmarshal(packet, &msg); err != nil {
+				return nil, err
+			}
+			return nil, &msg
 		}
 		}
 	}
 	}