Browse Source

Send RST_STREAM on stream errors. Reject capital headers with stream errors.

Brad Fitzpatrick 11 years ago
parent
commit
d43f8f3b33
4 changed files with 50 additions and 20 deletions
  1. 8 3
      errors.go
  2. 2 2
      frame.go
  3. 35 14
      http2.go
  4. 5 1
      http2_test.go

+ 8 - 3
errors.go

@@ -67,10 +67,15 @@ func (e ConnectionError) Error() string           { return fmt.Sprintf("connecti
 
 // StreamError is an error that only affects one stream within an
 // HTTP/2 connection.
-type StreamError uint32
+type StreamError struct {
+	streamID uint32
+	code     ErrCode
+}
 
-var _ Error = StreamError(0)
+var _ Error = StreamError{}
 
 func (e StreamError) IsStreamError() bool     { return true }
 func (e StreamError) IsConnectionError() bool { return false }
-func (e StreamError) Error() string           { return fmt.Sprintf("stream error: stream ID = %d", uint32(e)) }
+func (e StreamError) Error() string {
+	return fmt.Sprintf("stream error: stream ID %d; %v", e.streamID, e.code)
+}

+ 2 - 2
frame.go

@@ -630,7 +630,7 @@ func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) {
 		if fh.StreamID == 0 {
 			return nil, ConnectionError(ErrCodeProtocol)
 		}
-		return nil, StreamError(ErrCodeProtocol)
+		return nil, StreamError{fh.StreamID, ErrCodeProtocol}
 	}
 	return &WindowUpdateFrame{
 		FrameHeader: fh,
@@ -703,7 +703,7 @@ func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) {
 		}
 	}
 	if len(p)-int(padLength) <= 0 {
-		return nil, StreamError(fh.StreamID)
+		return nil, StreamError{fh.StreamID, ErrCodeProtocol}
 	}
 	hf.headerFragBuf = p[:len(p)-int(padLength)]
 	return hf, nil

+ 35 - 14
http2.go

@@ -31,6 +31,8 @@ import (
 	"github.com/bradfitz/http2/hpack"
 )
 
+var VerboseLogs = false
+
 const (
 	// ClientPreface is the string that must be sent by new
 	// connections from clients.
@@ -153,6 +155,12 @@ func (sc *serverConn) state(streamID uint32) streamState {
 	return stateIdle
 }
 
+func (sc *serverConn) vlogf(format string, args ...interface{}) {
+	if VerboseLogs {
+		sc.logf(format, args...)
+	}
+}
+
 func (sc *serverConn) logf(format string, args ...interface{}) {
 	if lg := sc.hs.ErrorLog; lg != nil {
 		lg.Printf(format, args...)
@@ -285,20 +293,21 @@ func (sc *serverConn) serve() {
 				return
 			}
 			f := fp.f
-			log.Printf("got %v: %#v", f.Header(), f)
+			sc.vlogf("got %v: %#v", f.Header(), f)
 			err := sc.processFrame(f)
 			fp.processed <- struct{}{} // let readFrames proceed
-			if h2e, ok := err.(Error); ok {
-				if h2e.IsConnectionError() {
-					sc.logf("Disconnection; connection error: %v", err)
+			switch ev := err.(type) {
+			case nil:
+				// nothing.
+			case StreamError:
+				if err := sc.resetStreamInLoop(ev); err != nil {
+					sc.logf("Error writing RSTSTream: %v", err)
 					return
 				}
-				if h2e.IsStreamError() {
-					// TODO: stream errors, etc
-					panic("TODO")
-				}
-			}
-			if err != nil {
+			case ConnectionError:
+				sc.logf("Disconnecting; %v", ev)
+				return
+			default:
 				sc.logf("Disconnection due to other error: %v", err)
 				return
 			}
@@ -306,6 +315,14 @@ func (sc *serverConn) serve() {
 	}
 }
 
+func (sc *serverConn) resetStreamInLoop(se StreamError) error {
+	if err := sc.framer.WriteRSTStream(se.streamID, uint32(se.code)); err != nil {
+		return err
+	}
+	delete(sc.streams, se.streamID)
+	return nil
+}
+
 func (sc *serverConn) processFrame(f Frame) error {
 	if s := sc.curHeaderStreamID; s != 0 {
 		if cf, ok := f.(*ContinuationFrame); !ok {
@@ -365,14 +382,18 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	sc.invalidHeader = false
 	sc.curHeaderStreamID = id
 	sc.curStream = st
-	return sc.processHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+	return sc.processHeaderBlockFragment(id, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 
 func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
-	return sc.processHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+	id := f.Header().StreamID
+	if sc.curHeaderStreamID != id {
+		return ConnectionError(ErrCodeProtocol)
+	}
+	return sc.processHeaderBlockFragment(id, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 
-func (sc *serverConn) processHeaderBlockFragment(frag []byte, end bool) error {
+func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, end bool) error {
 	if _, err := sc.hpackDecoder.Write(frag); err != nil {
 		// TODO: convert to stream error I assume?
 		return err
@@ -390,7 +411,7 @@ func (sc *serverConn) processHeaderBlockFragment(frag []byte, end bool) error {
 		// Malformed requests or responses that are detected
 		// MUST be treated as a stream error (Section 5.4.2)
 		// of type PROTOCOL_ERROR."
-		return StreamError(ErrCodeProtocol)
+		return StreamError{streamID, ErrCodeProtocol}
 	}
 	curStream := sc.curStream
 	sc.curHeaderStreamID = 0

+ 5 - 1
http2_test.go

@@ -29,6 +29,8 @@ import (
 	"github.com/bradfitz/http2/hpack"
 )
 
+func init() { VerboseLogs = true }
+
 type serverTester struct {
 	cc     net.Conn // client conn
 	t      *testing.T
@@ -388,7 +390,6 @@ func TestServer_Request_CookieConcat(t *testing.T) {
 }
 
 func TestServer_Request_RejectCapitalHeader(t *testing.T) {
-	t.Skip("TODO: not handling stream errors properly yet in http2.go: if h2e.IsStreamError stuff")
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 		t.Fatal("server request made it to handler; should've been rejected")
 	})
@@ -399,6 +400,9 @@ func TestServer_Request_RejectCapitalHeader(t *testing.T) {
 	st.wantRSTStream(1, ErrCodeProtocol)
 }
 
+// TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
+// TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
+
 // testServerRequest sets up an idle HTTP/2 connection and lets you
 // write a single request with writeReq, and then verify that the
 // *http.Request is built correctly in checkReq.