瀏覽代碼

go.net/spdy: disallow stream id 0

Per 2.3.2 of draft-mbelshe-httpbis-spdy-00.

R=mikioh.mikioh, bradfitz
CC=adg, golang-dev
https://golang.org/cl/6781053
Jeff Hodges 13 年之前
父節點
當前提交
6fefb5ef81
共有 4 個文件被更改,包括 107 次插入1 次删除
  1. 18 0
      spdy/read.go
  2. 70 1
      spdy/spdy_test.go
  3. 1 0
      spdy/types.go
  4. 18 0
      spdy/write.go

+ 18 - 0
spdy/read.go

@@ -28,6 +28,9 @@ func (frame *RstStreamFrame) read(h ControlFrameHeader, f *Framer) error {
 	if err := binary.Read(f.r, binary.BigEndian, &frame.Status); err != nil {
 		return err
 	}
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	return nil
 }
 
@@ -61,6 +64,9 @@ func (frame *PingFrame) read(h ControlFrameHeader, f *Framer) error {
 	if err := binary.Read(f.r, binary.BigEndian, &frame.Id); err != nil {
 		return err
 	}
+	if frame.Id == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	return nil
 }
 
@@ -222,6 +228,9 @@ func (f *Framer) readSynStreamFrame(h ControlFrameHeader, frame *SynStreamFrame)
 			}
 		}
 	}
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	return nil
 }
 
@@ -258,6 +267,9 @@ func (f *Framer) readSynReplyFrame(h ControlFrameHeader, frame *SynReplyFrame) e
 			}
 		}
 	}
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	return nil
 }
 
@@ -301,6 +313,9 @@ func (f *Framer) readHeadersFrame(h ControlFrameHeader, frame *HeadersFrame) err
 			}
 		}
 	}
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	return nil
 }
 
@@ -317,5 +332,8 @@ func (f *Framer) parseDataFrame(streamId uint32) (*DataFrame, error) {
 	if _, err := io.ReadFull(f.r, frame.Data); err != nil {
 		return nil, err
 	}
+	if frame.StreamId == 0 {
+		return nil, &Error{ZeroStreamId, 0}
+	}
 	return &frame, nil
 }

+ 70 - 1
spdy/spdy_test.go

@@ -9,6 +9,7 @@ import (
 	"compress/zlib"
 	"encoding/base64"
 	"io"
+	"io/ioutil"
 	"net/http"
 	"reflect"
 	"testing"
@@ -47,6 +48,7 @@ func TestCreateParseSynStreamFrame(t *testing.T) {
 			version:   Version,
 			frameType: TypeSynStream,
 		},
+		StreamId: 2,
 		Headers: http.Header{
 			"Url":     []string{"http://www.google.com/"},
 			"Method":  []string{"get"},
@@ -103,6 +105,7 @@ func TestCreateParseSynReplyFrame(t *testing.T) {
 			version:   Version,
 			frameType: TypeSynReply,
 		},
+		StreamId: 2,
 		Headers: http.Header{
 			"Url":     []string{"http://www.google.com/"},
 			"Method":  []string{"get"},
@@ -307,6 +310,7 @@ func TestCreateParseHeadersFrame(t *testing.T) {
 			version:   Version,
 			frameType: TypeHeaders,
 		},
+		StreamId: 2,
 	}
 	headersFrame.Headers = http.Header{
 		"Url":     []string{"http://www.google.com/"},
@@ -384,6 +388,7 @@ func TestCompressionContextAcrossFrames(t *testing.T) {
 			version:   Version,
 			frameType: TypeHeaders,
 		},
+		StreamId: 2,
 		Headers: http.Header{
 			"Url":     []string{"http://www.google.com/"},
 			"Method":  []string{"get"},
@@ -393,7 +398,7 @@ func TestCompressionContextAcrossFrames(t *testing.T) {
 	if err := framer.WriteFrame(&headersFrame); err != nil {
 		t.Fatal("WriteFrame (HEADERS):", err)
 	}
-	synStreamFrame := SynStreamFrame{ControlFrameHeader{Version, TypeSynStream, 0, 0}, 0, 0, 0, nil}
+	synStreamFrame := SynStreamFrame{ControlFrameHeader{Version, TypeSynStream, 0, 0}, 2, 0, 0, nil}
 	synStreamFrame.Headers = http.Header{
 		"Url":     []string{"http://www.google.com/"},
 		"Method":  []string{"get"},
@@ -445,6 +450,7 @@ func TestMultipleSPDYFrames(t *testing.T) {
 			version:   Version,
 			frameType: TypeHeaders,
 		},
+		StreamId: 2,
 		Headers: http.Header{
 			"Url":     []string{"http://www.google.com/"},
 			"Method":  []string{"get"},
@@ -456,6 +462,7 @@ func TestMultipleSPDYFrames(t *testing.T) {
 			version:   Version,
 			frameType: TypeSynStream,
 		},
+		StreamId: 2,
 		Headers: http.Header{
 			"Url":     []string{"http://www.google.com/"},
 			"Method":  []string{"get"},
@@ -522,3 +529,65 @@ func TestReadMalformedZlibHeader(t *testing.T) {
 		}
 	}
 }
+
+type zeroStream struct {
+	frame   Frame
+	encoded string
+}
+
+var streamIdZeroFrames = map[string]zeroStream{
+	"SynStreamFrame": {
+		&SynStreamFrame{StreamId: 0},
+		"gAIAAQAAABgAAAAAAAAAAAAAePnfolGyYmAAAAAA//8=",
+	},
+	"SynReplyFrame": {
+		&SynReplyFrame{StreamId: 0},
+		"gAIAAgAAABQAAAAAAAB4+d+iUbJiYAAAAAD//w==",
+	},
+	"RstStreamFrame": {
+		&RstStreamFrame{StreamId: 0},
+		"gAIAAwAAAAgAAAAAAAAAAA==",
+	},
+	"HeadersFrame": {
+		&HeadersFrame{StreamId: 0},
+		"gAIACAAAABQAAAAAAAB4+d+iUbJiYAAAAAD//w==",
+	},
+	"DataFrame": {
+		&DataFrame{StreamId: 0},
+		"AAAAAAAAAAA=",
+	},
+	"PingFrame": {
+		&PingFrame{Id: 0},
+		"gAIABgAAAAQAAAAA",
+	},
+}
+
+func TestNoZeroStreamId(t *testing.T) {
+	for name, f := range streamIdZeroFrames {
+		b, err := base64.StdEncoding.DecodeString(f.encoded)
+		if err != nil {
+			t.Errorf("Unable to decode base64 encoded frame %s: %v", f, err)
+			continue
+		}
+		framer, err := NewFramer(ioutil.Discard, bytes.NewReader(b))
+		if err != nil {
+			t.Fatalf("NewFramer: %v", err)
+		}
+		err = framer.WriteFrame(f.frame)
+		checkZeroStreamId(t, name, "WriteFrame", err)
+
+		_, err = framer.ReadFrame()
+		checkZeroStreamId(t, name, "ReadFrame", err)
+	}
+}
+
+func checkZeroStreamId(t *testing.T, frame string, method string, err error) {
+	if err == nil {
+		t.Errorf("%s ZeroStreamId, no error on %s", method, frame)
+		return
+	}
+	eerr, ok := err.(*Error)
+	if !ok || eerr.Err != ZeroStreamId {
+		t.Errorf("%s ZeroStreamId, incorrect error %#v, frame %s", method, eerr, frame)
+	}
+}

+ 1 - 0
spdy/types.go

@@ -315,6 +315,7 @@ const (
 	InvalidControlFrame        ErrorCode = "invalid control frame"
 	InvalidDataFrame           ErrorCode = "invalid data frame"
 	InvalidHeaderPresent       ErrorCode = "frame contained invalid header"
+	ZeroStreamId               ErrorCode = "stream id zero is disallowed"
 )
 
 // Error contains both the type of error and additional values. StreamId is 0

+ 18 - 0
spdy/write.go

@@ -20,6 +20,9 @@ func (frame *SynReplyFrame) write(f *Framer) error {
 }
 
 func (frame *RstStreamFrame) write(f *Framer) (err error) {
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	frame.CFHeader.version = Version
 	frame.CFHeader.frameType = TypeRstStream
 	frame.CFHeader.length = 8
@@ -70,6 +73,9 @@ func (frame *NoopFrame) write(f *Framer) error {
 }
 
 func (frame *PingFrame) write(f *Framer) (err error) {
+	if frame.Id == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	frame.CFHeader.version = Version
 	frame.CFHeader.frameType = TypePing
 	frame.CFHeader.length = 4
@@ -100,10 +106,16 @@ func (frame *GoAwayFrame) write(f *Framer) (err error) {
 }
 
 func (frame *HeadersFrame) write(f *Framer) error {
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	return f.writeHeadersFrame(frame)
 }
 
 func (frame *DataFrame) write(f *Framer) error {
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	return f.writeDataFrame(frame)
 }
 
@@ -156,6 +168,9 @@ func writeHeaderValueBlock(w io.Writer, h http.Header) (n int, err error) {
 }
 
 func (f *Framer) writeSynStreamFrame(frame *SynStreamFrame) (err error) {
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	// Marshal the headers.
 	var writer io.Writer = f.headerBuf
 	if !f.headerCompressionDisabled {
@@ -194,6 +209,9 @@ func (f *Framer) writeSynStreamFrame(frame *SynStreamFrame) (err error) {
 }
 
 func (f *Framer) writeSynReplyFrame(frame *SynReplyFrame) (err error) {
+	if frame.StreamId == 0 {
+		return &Error{ZeroStreamId, 0}
+	}
 	// Marshal the headers.
 	var writer io.Writer = f.headerBuf
 	if !f.headerCompressionDisabled {