浏览代码

go.crypto/ssh: add a error return to decode(), and avoid casting decode() output.

R=dave, kardianos, agl
CC=gobot, golang-dev
https://golang.org/cl/9738053
Han-Wen Nienhuys 12 年之前
父节点
当前提交
36bf31eb71
共有 4 个文件被更改,包括 27 次插入13 次删除
  1. 9 2
      ssh/client.go
  2. 8 2
      ssh/client_auth.go
  3. 5 5
      ssh/messages.go
  4. 5 4
      ssh/server.go

+ 9 - 2
ssh/client.go

@@ -249,8 +249,15 @@ func (c *ClientConn) mainLoop() {
 				ch.stderr.write(packet)
 			}
 		default:
-			msg := decode(packet)
-			switch msg := msg.(type) {
+			decoded, err := decode(packet)
+			if err != nil {
+				if _, ok := err.(UnexpectedMessageError); ok {
+					fmt.Printf("mainLoop: unexpected message: %v\n", err)
+					continue
+				}
+				return
+			}
+			switch msg := decoded.(type) {
 			case *channelOpenMsg:
 				c.handleChanOpen(msg)
 			case *channelOpenConfirmMsg:

+ 8 - 2
ssh/client_auth.go

@@ -272,7 +272,10 @@ func (p *publickeyAuth) confirmKeyAck(key interface{}, t *transport) (bool, erro
 		case msgUserAuthBanner:
 			// TODO(gpaul): add callback to present the banner to the user
 		case msgUserAuthPubKeyOk:
-			msg := decode(packet).(*userAuthPubKeyOkMsg)
+			msg := userAuthPubKeyOkMsg{}
+			if err := unmarshal(&msg, packet, msgUserAuthPubKeyOk); err != nil {
+				return false, err
+			}
 			if msg.Algo != algoname || msg.PubKey != string(pubkey) {
 				return false, nil
 			}
@@ -309,7 +312,10 @@ func handleAuthResponse(t *transport) (bool, []string, error) {
 		case msgUserAuthBanner:
 			// TODO: add callback to present the banner to the user
 		case msgUserAuthFailure:
-			msg := decode(packet).(*userAuthFailureMsg)
+			msg := userAuthFailureMsg{}
+			if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil {
+				return false, nil, err
+			}
 			return false, msg.Methods, nil
 		case msgUserAuthSuccess:
 			return true, nil, nil

+ 5 - 5
ssh/messages.go

@@ -568,8 +568,8 @@ func marshalString(to []byte, s []byte) []byte {
 
 var bigIntType = reflect.TypeOf((*big.Int)(nil))
 
-// Decode a packet into it's corresponding message.
-func decode(packet []byte) interface{} {
+// Decode a packet into its corresponding message.
+func decode(packet []byte) (interface{}, error) {
 	var msg interface{}
 	switch packet[0] {
 	case msgDisconnect:
@@ -615,10 +615,10 @@ func decode(packet []byte) interface{} {
 	case msgChannelFailure:
 		msg = new(channelRequestFailureMsg)
 	default:
-		return UnexpectedMessageError{0, packet[0]}
+		return nil, UnexpectedMessageError{0, packet[0]}
 	}
 	if err := unmarshal(msg, packet, packet[0]); err != nil {
-		return err
+		return nil, err
 	}
-	return msg
+	return msg, nil
 }

+ 5 - 4
ssh/server.go

@@ -562,7 +562,11 @@ func (s *ServerConn) Accept() (Channel, error) {
 			}
 			s.lock.Unlock()
 		default:
-			switch msg := decode(packet).(type) {
+			decoded, err := decode(packet)
+			if err != nil {
+				return nil, err
+			}
+			switch msg := decoded.(type) {
 			case *channelOpenMsg:
 				if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
 					return nil, errors.New("ssh: invalid MaxPacketSize from peer")
@@ -643,9 +647,6 @@ func (s *ServerConn) Accept() (Channel, error) {
 					return nil, err
 				}
 				s.lock.Unlock()
-
-			case UnexpectedMessageError:
-				return nil, msg
 			case *disconnectMsg:
 				return nil, io.EOF
 			default: