Browse Source

Add MaxConcurrentStreams streams tunable and enforce it.

But open question is which stream error code to use http2/http2-spec#649
Brad Fitzpatrick 11 years ago
parent
commit
6ec1731884
2 changed files with 118 additions and 33 deletions
  1. 64 33
      server.go
  2. 54 0
      server_test.go

+ 64 - 33
server.go

@@ -30,6 +30,7 @@ const (
 	prefaceTimeout        = 5 * time.Second
 	firstSettingsTimeout  = 2 * time.Second // should be in-flight with preface anyway
 	handlerChunkWriteSize = 4 << 10
+	defaultMaxStreams     = 250
 )
 
 var (
@@ -83,12 +84,24 @@ var (
 
 // Server is an HTTP/2 server.
 type Server struct {
-	// MaxStreams optionally ...
-	MaxStreams int
+	// MaxHandlers limits the number of http.Handler ServeHTTP goroutines
+	// which may run at a time over all connections.
+	// Negative or zero no limit.
+	// TODO: implement
+	MaxHandlers int
+
+	// MaxConcurrentStreams optionally specifies the number of
+	// concurrent streams that each client may have open at a
+	// time. This is unrelated to the number of http.Handler goroutines
+	// which may be active globally, which is MaxHandlers.
+	// If zero, MaxConcurrentStreams defaults to at least 100, per
+	// the HTTP/2 spec's recommendations.
+	MaxConcurrentStreams uint32
 
 	// MaxReadFrameSize optionally specifies the largest frame
 	// this server is willing to read. A valid value is between
-	// 16k and 16M, inclusive.
+	// 16k and 16M, inclusive. If zero or otherwise invalid, a
+	// default value is used.
 	MaxReadFrameSize uint32
 }
 
@@ -99,6 +112,13 @@ func (s *Server) maxReadFrameSize() uint32 {
 	return defaultMaxReadFrameSize
 }
 
+func (s *Server) maxConcurrentStreams() uint32 {
+	if v := s.MaxConcurrentStreams; v > 0 {
+		return v
+	}
+	return defaultMaxStreams
+}
+
 // ConfigureServer adds HTTP/2 support to a net/http Server.
 //
 // The configuration conf may be nil.
@@ -135,25 +155,25 @@ func ConfigureServer(s *http.Server, conf *Server) {
 
 func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 	sc := &serverConn{
-		srv:                  srv,
-		hs:                   hs,
-		conn:                 c,
-		bw:                   newBufferedWriter(c),
-		handler:              h,
-		streams:              make(map[uint32]*stream),
-		readFrameCh:          make(chan frameAndGate),
-		readFrameErrCh:       make(chan error, 1), // must be buffered for 1
-		wantWriteFrameCh:     make(chan frameWriteMsg, 8),
-		writeFrameCh:         make(chan frameWriteMsg, 1), // may be 0 or 1, but more is useless. (max 1 in flight)
-		wroteFrameCh:         make(chan struct{}, 1),      // TODO: consider 0. will deadlock currently in sendFrameWrite in sentReset case
-		flow:                 newFlow(initialWindowSize),
-		doneServing:          make(chan struct{}),
-		maxWriteFrameSize:    initialMaxFrameSize,
-		initialWindowSize:    initialWindowSize,
-		headerTableSize:      initialHeaderTableSize,
-		maxConcurrentStreams: -1, // no limit
-		serveG:               newGoroutineLock(),
-		pushEnabled:          true,
+		srv:               srv,
+		hs:                hs,
+		conn:              c,
+		bw:                newBufferedWriter(c),
+		handler:           h,
+		streams:           make(map[uint32]*stream),
+		readFrameCh:       make(chan frameAndGate),
+		readFrameErrCh:    make(chan error, 1), // must be buffered for 1
+		wantWriteFrameCh:  make(chan frameWriteMsg, 8),
+		writeFrameCh:      make(chan frameWriteMsg, 1), // may be 0 or 1, but more is useless. (max 1 in flight)
+		wroteFrameCh:      make(chan struct{}, 1),      // TODO: consider 0. will deadlock currently in sendFrameWrite in sentReset case
+		flow:              newFlow(initialWindowSize),
+		doneServing:       make(chan struct{}),
+		advMaxStreams:     srv.maxConcurrentStreams(),
+		maxWriteFrameSize: initialMaxFrameSize,
+		initialWindowSize: initialWindowSize,
+		headerTableSize:   initialHeaderTableSize,
+		serveG:            newGoroutineLock(),
+		pushEnabled:       true,
 	}
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
@@ -203,13 +223,15 @@ type serverConn struct {
 	pushEnabled           bool
 	sawFirstSettings      bool // got the initial SETTINGS frame after the preface
 	needToSendSettingsAck bool
+	clientMaxStreams      uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit)
+	advMaxStreams         uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
+	curOpenStreams        uint32 // client's number of open streams
 	maxStreamID           uint32 // max ever seen
 	streams               map[uint32]*stream
 	maxWriteFrameSize     uint32
 	initialWindowSize     int32
 	headerTableSize       uint32
 	maxHeaderListSize     uint32            // zero means unknown (default)
-	maxConcurrentStreams  int64             // negative means no limit.
 	canonHeader           map[string]string // http2-lower-case -> Go-Canonical-Case
 	req                   requestParam      // non-zero while reading request headers
 	writingFrame          bool              // sent on writeFrameCh but haven't heard back on wroteFrameCh yet
@@ -488,7 +510,7 @@ func (sc *serverConn) sendInitialSettings(uint32, interface{}) error {
 	sc.writeG.check()
 	return sc.framer.WriteSettings(
 		Setting{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
-		Setting{SettingMaxConcurrentStreams, 250}, // TODO: tunable?
+		Setting{SettingMaxConcurrentStreams, sc.advMaxStreams},
 		/* TODO: more actual settings */
 	)
 }
@@ -737,18 +759,19 @@ func (sc *serverConn) writeGoAwayFrame(_ uint32, v interface{}) error {
 	return err
 }
 
-func (sc *serverConn) resetStreamInLoop(se StreamError) {
+func (sc *serverConn) resetStream(se StreamError) {
 	sc.serveG.check()
-	st, ok := sc.streams[se.StreamID]
-	if !ok {
-		panic(fmt.Sprintf("invariant. closing non-open stream %d", se.StreamID))
-	}
 	sc.writeFrame(frameWriteMsg{
 		write: (*serverConn).writeRSTStreamFrame,
 		v:     &se,
 	})
-	st.sentReset = true
-	sc.closeStream(st, se)
+	// Close the stream if it was open.
+	// It might not even be open or known (e.g. in the case of a HEADERS frame
+	// arriving and violating the max concurrent streams limit)
+	if st, ok := sc.streams[se.StreamID]; ok {
+		st.sentReset = true
+		sc.closeStream(st, se)
+	}
 }
 
 func (sc *serverConn) writeRSTStreamFrame(streamID uint32, v interface{}) error {
@@ -798,7 +821,7 @@ func (sc *serverConn) processFrameFromReader(fg frameAndGate, fgValid bool) bool
 
 	switch ev := err.(type) {
 	case StreamError:
-		sc.resetStreamInLoop(ev)
+		sc.resetStream(ev)
 		return true
 	case goAwayFlowError:
 		sc.goAway(ErrCodeFlowControl)
@@ -929,6 +952,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
 		panic("invariant")
 	}
 	st.state = stateClosed
+	sc.curOpenStreams--
 	delete(sc.streams, st.id)
 	st.flow.close()
 	if p := st.body; p != nil {
@@ -968,7 +992,7 @@ func (sc *serverConn) processSetting(s Setting) error {
 	case SettingEnablePush:
 		sc.pushEnabled = s.Val != 0
 	case SettingMaxConcurrentStreams:
-		sc.maxConcurrentStreams = int64(s.Val)
+		sc.clientMaxStreams = s.Val
 	case SettingInitialWindowSize:
 		return sc.processSettingInitialWindowSize(s.Val)
 	case SettingMaxFrameSize:
@@ -1080,6 +1104,13 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	if id > sc.maxStreamID {
 		sc.maxStreamID = id
 	}
+	if sc.curOpenStreams == sc.advMaxStreams {
+		// Too many open streams.
+		// TODO: which error code here? Using ErrCodeProtocol for now.
+		// https://github.com/http2/http2-spec/issues/649
+		return StreamError{id, ErrCodeProtocol}
+	}
+	sc.curOpenStreams++
 	st := &stream{
 		conn:  sc,
 		id:    id,

+ 54 - 0
server_test.go

@@ -1459,6 +1459,60 @@ func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
 	})
 }
 
+func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
+	inHandler := make(chan uint32)
+	leaveHandler := make(chan bool)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		inHandler <- w.(*responseWriter).rws.stream.id
+		<-leaveHandler
+	})
+	defer st.Close()
+	st.greet()
+	nextStreamID := uint32(1)
+	streamID := func() uint32 {
+		defer func() { nextStreamID += 2 }()
+		return nextStreamID
+	}
+	sendReq := func(id uint32) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      id,
+			BlockFragment: encodeHeader(st.t),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+	}
+	for i := 0; i < defaultMaxStreams; i++ {
+		sendReq(streamID())
+		<-inHandler
+	}
+	defer func() {
+		for i := 0; i < defaultMaxStreams; i++ {
+			leaveHandler <- true
+		}
+	}()
+
+	// And this one should cross the limit:
+	rejectID := streamID()
+	sendReq(rejectID)
+	st.wantRSTStream(rejectID, ErrCodeProtocol)
+
+	// But let a handler finish:
+	leaveHandler <- true
+	st.wantHeaders()
+
+	// And now another stream should be able to start:
+	goodID := streamID()
+	sendReq(goodID)
+	select {
+	case got := <-inHandler:
+		if got != goodID {
+			t.Errorf("Got stream %d; want %d", got, goodID)
+		}
+	case <-time.After(3 * time.Second):
+		t.Error("timeout waiting for handler")
+	}
+}
+
 func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) {
 	d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
 		pairs = append(pairs, [2]string{f.Name, f.Value})