Bläddra i källkod

go.crypto/ssh: (un)marshal data without type byte prefix.

This helps manipulating data in global and channel request
payloads.

R=agl, dave, jpsugar
CC=golang-dev
https://golang.org/cl/14438068
Han-Wen Nienhuys 12 år sedan
förälder
incheckning
49702c17cc
2 ändrade filer med 47 tillägg och 9 borttagningar
  1. 15 9
      ssh/messages.go
  2. 32 0
      ssh/messages_test.go

+ 15 - 9
ssh/messages.go

@@ -227,17 +227,20 @@ type userAuthPubKeyOkMsg struct {
 	PubKey string
 	PubKey string
 }
 }
 
 
-// unmarshal parses the SSH wire data in packet into out using reflection.
-// expectedType is the expected SSH message type. It either returns nil on
-// success, or a ParseError or UnexpectedMessageError on error.
+// unmarshal parses the SSH wire data in packet into out using
+// reflection. expectedType, if non-zero, is the SSH message type that
+// the packet is expected to start with.  unmarshal either returns nil
+// on success, or a ParseError or UnexpectedMessageError on error.
 func unmarshal(out interface{}, packet []byte, expectedType uint8) error {
 func unmarshal(out interface{}, packet []byte, expectedType uint8) error {
 	if len(packet) == 0 {
 	if len(packet) == 0 {
 		return ParseError{expectedType}
 		return ParseError{expectedType}
 	}
 	}
-	if packet[0] != expectedType {
-		return UnexpectedMessageError{expectedType, packet[0]}
+	if expectedType > 0 {
+		if packet[0] != expectedType {
+			return UnexpectedMessageError{expectedType, packet[0]}
+		}
+		packet = packet[1:]
 	}
 	}
-	packet = packet[1:]
 
 
 	v := reflect.ValueOf(out).Elem()
 	v := reflect.ValueOf(out).Elem()
 	structType := v.Type()
 	structType := v.Type()
@@ -319,10 +322,13 @@ func unmarshal(out interface{}, packet []byte, expectedType uint8) error {
 	return nil
 	return nil
 }
 }
 
 
-// marshal serializes the message in msg, using the given message type.
+// marshal serializes the message in msg. The given message type is
+// prepended if it is non-zero.
 func marshal(msgType uint8, msg interface{}) []byte {
 func marshal(msgType uint8, msg interface{}) []byte {
-	out := make([]byte, 1, 64)
-	out[0] = msgType
+	out := make([]byte, 0, 64)
+	if msgType > 0 {
+		out = append(out, msgType)
+	}
 
 
 	v := reflect.ValueOf(msg)
 	v := reflect.ValueOf(msg)
 	for i, n := 0, v.NumField(); i < n; i++ {
 	for i, n := 0, v.NumField(); i < n; i++ {

+ 32 - 0
ssh/messages_test.go

@@ -78,6 +78,38 @@ func TestMarshalUnmarshal(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestBareMarshalUnmarshal(t *testing.T) {
+	type S struct {
+		I uint32
+		S string
+		B bool
+	}
+
+	s := S{42, "hello", true}
+	packet := marshal(0, s)
+	roundtrip := S{}
+	unmarshal(&roundtrip, packet, 0)
+
+	if !reflect.DeepEqual(s, roundtrip) {
+		t.Errorf("got %#v, want %#v", roundtrip, s)
+	}
+}
+
+func TestBareMarshal(t *testing.T) {
+	type S2 struct {
+		I uint32
+	}
+	s := S2{42}
+	packet := marshal(0, s)
+	i, rest, ok := parseUint32(packet)
+	if len(rest) > 0 || !ok {
+		t.Errorf("parseInt(%q): parse error", packet)
+	}
+	if i != s.I {
+		t.Errorf("got %d, want %d", i, s.I)
+	}
+}
+
 func randomBytes(out []byte, rand *rand.Rand) {
 func randomBytes(out []byte, rand *rand.Rand) {
 	for i := 0; i < len(out); i++ {
 	for i := 0; i < len(out); i++ {
 		out[i] = byte(rand.Int31())
 		out[i] = byte(rand.Int31())