Browse Source

Move enforcement of SETTINGS_MAX_CONCURRENT_STREAMS later.

Before it was enforced when processing the FRAME header, but as pointed
out on http2/http2-spec#649 , that means we were dropping any HPACK decoding
context in subsequent CONTINUATION frames.

Instead, move the check later to processHeaderBlockFragment (which
runs at END_HEADERS, whether in HEADERS or CONTINUATION).

And add a test using a CONTINUATION with decoder state to preserve.

This change also then is able to revert resetStream to its earlier
more paranoid behavior.
Brad Fitzpatrick 11 years ago
parent
commit
b3e0a87fae
2 changed files with 46 additions and 20 deletions
  1. 23 15
      server.go
  2. 23 5
      server_test.go

+ 23 - 15
server.go

@@ -761,17 +761,16 @@ func (sc *serverConn) writeGoAwayFrame(_ uint32, v interface{}) error {
 
 func (sc *serverConn) resetStream(se StreamError) {
 	sc.serveG.check()
+	st, ok := sc.streams[se.StreamID]
+	if !ok {
+		panic("internal package error; resetStream called on non-existent stream")
+	}
 	sc.writeFrame(frameWriteMsg{
 		write: (*serverConn).writeRSTStreamFrame,
 		v:     &se,
 	})
-	// Close the stream if it was open.
-	// It might not even be open or known (e.g. in the case of a HEADERS frame
-	// arriving and violating the max concurrent streams limit)
-	if st, ok := sc.streams[se.StreamID]; ok {
-		st.sentReset = true
-		sc.closeStream(st, se)
-	}
+	st.sentReset = true
+	sc.closeStream(st, se)
 }
 
 func (sc *serverConn) writeRSTStreamFrame(streamID uint32, v interface{}) error {
@@ -1104,13 +1103,6 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	if id > sc.maxStreamID {
 		sc.maxStreamID = id
 	}
-	if sc.curOpenStreams == sc.advMaxStreams {
-		// Too many open streams.
-		// TODO: which error code here? Using ErrCodeProtocol for now.
-		// https://github.com/http2/http2-spec/issues/649
-		return StreamError{id, ErrCodeProtocol}
-	}
-	sc.curOpenStreams++
 	st := &stream{
 		conn:  sc,
 		id:    id,
@@ -1122,6 +1114,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		st.state = stateHalfClosedRemote
 	}
 	sc.streams[id] = st
+	sc.curOpenStreams++
 	sc.req = requestParam{
 		stream: st,
 		header: make(http.Header),
@@ -1151,8 +1144,15 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 		// TODO: convert to stream error I assume?
 		return err
 	}
+	defer sc.resetPendingRequest()
+	if sc.curOpenStreams > sc.advMaxStreams {
+		// Too many open streams.
+		// TODO: which error code here? Using ErrCodeProtocol for now.
+		// https://github.com/http2/http2-spec/issues/649
+		return StreamError{st.id, ErrCodeProtocol}
+	}
+
 	rw, req, err := sc.newWriterAndRequest()
-	sc.req = requestParam{}
 	if err != nil {
 		return err
 	}
@@ -1162,6 +1162,14 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 	return nil
 }
 
+// resetPendingRequest zeros out all state related to a HEADERS frame
+// and its zero or more CONTINUATION frames sent to start a new
+// request.
+func (sc *serverConn) resetPendingRequest() {
+	sc.serveG.check()
+	sc.req = requestParam{}
+}
+
 func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) {
 	sc.serveG.check()
 	rp := &sc.req

+ 23 - 5
server_test.go

@@ -1460,10 +1460,16 @@ func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
 }
 
 func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
+	const testPath = "/some/path"
+
 	inHandler := make(chan uint32)
 	leaveHandler := make(chan bool)
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
-		inHandler <- w.(*responseWriter).rws.stream.id
+		id := w.(*responseWriter).rws.stream.id
+		inHandler <- id
+		if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath {
+			t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath)
+		}
 		<-leaveHandler
 	})
 	defer st.Close()
@@ -1473,10 +1479,10 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
 		defer func() { nextStreamID += 2 }()
 		return nextStreamID
 	}
-	sendReq := func(id uint32) {
+	sendReq := func(id uint32, headers ...string) {
 		st.writeHeaders(HeadersFrameParam{
 			StreamID:      id,
-			BlockFragment: encodeHeader(st.t),
+			BlockFragment: encodeHeader(st.t, headers...),
 			EndStream:     true,
 			EndHeaders:    true,
 		})
@@ -1492,8 +1498,20 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
 	}()
 
 	// And this one should cross the limit:
+	// (It's also sent as a CONTINUATION, to verify we still track the decoder context,
+	// even if we're rejecting it)
 	rejectID := streamID()
-	sendReq(rejectID)
+	headerBlock := encodeHeader(st.t, ":path", testPath)
+	frag1, frag2 := headerBlock[:3], headerBlock[3:]
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      rejectID,
+		BlockFragment: frag1,
+		EndStream:     true,
+		EndHeaders:    false, // CONTINUATION coming
+	})
+	if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
+		t.Fatal(err)
+	}
 	st.wantRSTStream(rejectID, ErrCodeProtocol)
 
 	// But let a handler finish:
@@ -1502,7 +1520,7 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
 
 	// And now another stream should be able to start:
 	goodID := streamID()
-	sendReq(goodID)
+	sendReq(goodID, ":path", testPath)
 	select {
 	case got := <-inHandler:
 		if got != goodID {