Browse Source

Move enforcement of SETTINGS_MAX_CONCURRENT_STREAMS later.

Before it was enforced when processing the FRAME header, but as pointed
out on http2/http2-spec#649 , that means we were dropping any HPACK decoding
context in subsequent CONTINUATION frames.

Instead, move the check later to processHeaderBlockFragment (which
runs at END_HEADERS, whether in HEADERS or CONTINUATION).

And add a test using a CONTINUATION with decoder state to preserve.

This change also then is able to revert resetStream to its earlier
more paranoid behavior.
Brad Fitzpatrick 11 năm trước cách đây
mục cha
commit
b3e0a87fae
2 tập tin đã thay đổi với 46 bổ sung20 xóa
  1. 23 15
      server.go
  2. 23 5
      server_test.go

+ 23 - 15
server.go

@@ -761,17 +761,16 @@ func (sc *serverConn) writeGoAwayFrame(_ uint32, v interface{}) error {
 
 func (sc *serverConn) resetStream(se StreamError) {
 	sc.serveG.check()
+	st, ok := sc.streams[se.StreamID]
+	if !ok {
+		panic("internal package error; resetStream called on non-existent stream")
+	}
 	sc.writeFrame(frameWriteMsg{
 		write: (*serverConn).writeRSTStreamFrame,
 		v:     &se,
 	})
-	// Close the stream if it was open.
-	// It might not even be open or known (e.g. in the case of a HEADERS frame
-	// arriving and violating the max concurrent streams limit)
-	if st, ok := sc.streams[se.StreamID]; ok {
-		st.sentReset = true
-		sc.closeStream(st, se)
-	}
+	st.sentReset = true
+	sc.closeStream(st, se)
 }
 
 func (sc *serverConn) writeRSTStreamFrame(streamID uint32, v interface{}) error {
@@ -1104,13 +1103,6 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	if id > sc.maxStreamID {
 		sc.maxStreamID = id
 	}
-	if sc.curOpenStreams == sc.advMaxStreams {
-		// Too many open streams.
-		// TODO: which error code here? Using ErrCodeProtocol for now.
-		// https://github.com/http2/http2-spec/issues/649
-		return StreamError{id, ErrCodeProtocol}
-	}
-	sc.curOpenStreams++
 	st := &stream{
 		conn:  sc,
 		id:    id,
@@ -1122,6 +1114,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		st.state = stateHalfClosedRemote
 	}
 	sc.streams[id] = st
+	sc.curOpenStreams++
 	sc.req = requestParam{
 		stream: st,
 		header: make(http.Header),
@@ -1151,8 +1144,15 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 		// TODO: convert to stream error I assume?
 		return err
 	}
+	defer sc.resetPendingRequest()
+	if sc.curOpenStreams > sc.advMaxStreams {
+		// Too many open streams.
+		// TODO: which error code here? Using ErrCodeProtocol for now.
+		// https://github.com/http2/http2-spec/issues/649
+		return StreamError{st.id, ErrCodeProtocol}
+	}
+
 	rw, req, err := sc.newWriterAndRequest()
-	sc.req = requestParam{}
 	if err != nil {
 		return err
 	}
@@ -1162,6 +1162,14 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 	return nil
 }
 
+// resetPendingRequest zeros out all state related to a HEADERS frame
+// and its zero or more CONTINUATION frames sent to start a new
+// request.
+func (sc *serverConn) resetPendingRequest() {
+	sc.serveG.check()
+	sc.req = requestParam{}
+}
+
 func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) {
 	sc.serveG.check()
 	rp := &sc.req

+ 23 - 5
server_test.go

@@ -1460,10 +1460,16 @@ func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
 }
 
 func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
+	const testPath = "/some/path"
+
 	inHandler := make(chan uint32)
 	leaveHandler := make(chan bool)
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
-		inHandler <- w.(*responseWriter).rws.stream.id
+		id := w.(*responseWriter).rws.stream.id
+		inHandler <- id
+		if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath {
+			t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath)
+		}
 		<-leaveHandler
 	})
 	defer st.Close()
@@ -1473,10 +1479,10 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
 		defer func() { nextStreamID += 2 }()
 		return nextStreamID
 	}
-	sendReq := func(id uint32) {
+	sendReq := func(id uint32, headers ...string) {
 		st.writeHeaders(HeadersFrameParam{
 			StreamID:      id,
-			BlockFragment: encodeHeader(st.t),
+			BlockFragment: encodeHeader(st.t, headers...),
 			EndStream:     true,
 			EndHeaders:    true,
 		})
@@ -1492,8 +1498,20 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
 	}()
 
 	// And this one should cross the limit:
+	// (It's also sent as a CONTINUATION, to verify we still track the decoder context,
+	// even if we're rejecting it)
 	rejectID := streamID()
-	sendReq(rejectID)
+	headerBlock := encodeHeader(st.t, ":path", testPath)
+	frag1, frag2 := headerBlock[:3], headerBlock[3:]
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      rejectID,
+		BlockFragment: frag1,
+		EndStream:     true,
+		EndHeaders:    false, // CONTINUATION coming
+	})
+	if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
+		t.Fatal(err)
+	}
 	st.wantRSTStream(rejectID, ErrCodeProtocol)
 
 	// But let a handler finish:
@@ -1502,7 +1520,7 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
 
 	// And now another stream should be able to start:
 	goodID := streamID()
-	sendReq(goodID)
+	sendReq(goodID, ":path", testPath)
 	select {
 	case got := <-inHandler:
 		if got != goodID {