浏览代码

http2: don't leaving hanging server goroutines after RST_STREAM from client

In general, clean up and simplify the handling of frame writing from
handler goroutines.  Always select on streams closing, and don't try
to pass around and re-use channels. It was too confusing. Instead,
reuse channels in a very local manner that's easy to reason about.

Thanks to Github user @pabbott0 (who has signed the Google CLA) for
the initial bug report and test cases.

Fixes bradfitz/http2#45

Change-Id: Ib72a87cb6e33a4bb118ae23d765ba594e9182ade
Reviewed-on: https://go-review.googlesource.com/15820
Reviewed-by: Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick 10 年之前
父节点
当前提交
c94bffa210
共有 2 个文件被更改,包括 109 次插入16 次删除
  1. 30 13
      http2/server.go
  2. 79 3
      http2/server_test.go

+ 30 - 13
http2/server.go

@@ -322,7 +322,7 @@ type serverConn struct {
 	wantWriteFrameCh chan frameWriteMsg   // from handlers -> serve
 	wroteFrameCh     chan struct{}        // from writeFrameAsync -> serve, tickles more frame writes
 	bodyReadCh       chan bodyReadMsg     // from handlers -> serve
-	testHookCh       chan func()          // code to run on the serve loop
+	testHookCh       chan func(int)       // code to run on the serve loop
 	flow             flow                 // conn-wide (not stream-specific) outbound flow control
 	inflow           flow                 // conn-wide inbound flow control
 	tlsState         *tls.ConnectionState // shared by all handlers, like net/http
@@ -636,7 +636,9 @@ func (sc *serverConn) serve() {
 	go sc.readFrames() // closed by defer sc.conn.Close above
 
 	settingsTimer := time.NewTimer(firstSettingsTimeout)
+	loopNum := 0
 	for {
+		loopNum++
 		select {
 		case wm := <-sc.wantWriteFrameCh:
 			sc.writeFrame(wm)
@@ -664,7 +666,7 @@ func (sc *serverConn) serve() {
 			sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
 			return
 		case fn := <-sc.testHookCh:
-			fn()
+			fn(loopNum)
 		}
 	}
 }
@@ -697,19 +699,20 @@ func (sc *serverConn) readPreface() error {
 	}
 }
 
+var errChanPool = sync.Pool{
+	New: func() interface{} { return make(chan error, 1) },
+}
+
 // writeDataFromHandler writes the data described in req to stream.id.
 //
-// The provided ch is used to avoid allocating new channels for each
-// write operation. It's expected that the caller reuses writeData and ch
-// over time.
-//
 // The flow control currently happens in the Handler where it waits
 // for 1 or more bytes to be available to then write here.  So at this
 // point we know that we have flow control. But this might have to
 // change when priority is implemented, so the serve goroutine knows
 // the total amount of bytes waiting to be sent and can can have more
 // scheduling decisions available.
-func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData, ch chan error) error {
+func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData) error {
+	ch := errChanPool.Get().(chan error)
 	sc.writeFrameFromHandler(frameWriteMsg{
 		write:  writeData,
 		stream: stream,
@@ -717,6 +720,7 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData,
 	})
 	select {
 	case err := <-ch:
+		errChanPool.Put(ch)
 		return err
 	case <-sc.doneServing:
 		return errClientDisconnected
@@ -734,10 +738,22 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData,
 // goroutine, call writeFrame instead.
 func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) {
 	sc.serveG.checkNotOn() // NOT
+	var scheduled bool
 	select {
 	case sc.wantWriteFrameCh <- wm:
+		scheduled = true
 	case <-sc.doneServing:
 		// Client has closed their connection to the server.
+	case <-wm.stream.cw:
+		// Stream closed.
+	}
+	// Don't block writers expecting a reply.
+	if !scheduled && wm.done != nil {
+		select {
+		case wm.done <- errStreamBroken:
+		default:
+			panic("expected buffered channel")
+		}
 	}
 }
 
@@ -1435,7 +1451,6 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 	rws.stream = rp.stream
 	rws.req = req
 	rws.body = body
-	rws.frameWriteCh = make(chan error, 1)
 
 	rw := &responseWriter{rws: rws}
 	return rw, req, nil
@@ -1460,7 +1475,7 @@ func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) {
 
 // called from handler goroutines.
 // h may be nil.
-func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders, tempCh chan error) {
+func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) {
 	sc.serveG.checkNotOn() // NOT on
 	var errc chan error
 	if headerData.h != nil {
@@ -1468,7 +1483,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders, temp
 		// waiting for this frame to be written, so an http.Flush mid-handler
 		// writes out the correct value of keys, before a handler later potentially
 		// mutates it.
-		errc = tempCh
+		errc = errChanPool.Get().(chan error)
 	}
 	sc.writeFrameFromHandler(frameWriteMsg{
 		write:  headerData,
@@ -1480,8 +1495,11 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders, temp
 		case <-errc:
 			// Ignore. Just for synchronization.
 			// Any error will be handled in the writing goroutine.
+			errChanPool.Put(errc)
 		case <-sc.doneServing:
 			// Client has closed the connection.
+		case <-st.cw:
+			// Client did RST_STREAM, etc. (but conn still alive)
 		}
 	}
 }
@@ -1629,7 +1647,6 @@ type responseWriterState struct {
 	sentHeader    bool        // have we sent the header frame?
 	handlerDone   bool        // handler has finished
 	curWrite      writeData
-	frameWriteCh  chan error // re-used whenever we need to block on a frame being written
 
 	closeNotifierMu sync.Mutex // guards closeNotifierCh
 	closeNotifierCh chan bool  // nil until first used
@@ -1666,7 +1683,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			endStream:     endStream,
 			contentType:   ctype,
 			contentLength: clen,
-		}, rws.frameWriteCh)
+		})
 		if endStream {
 			return 0, nil
 		}
@@ -1678,7 +1695,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 	curWrite.streamID = rws.stream.id
 	curWrite.p = p
 	curWrite.endStream = rws.handlerDone
-	if err := rws.conn.writeDataFromHandler(rws.stream, curWrite, rws.frameWriteCh); err != nil {
+	if err := rws.conn.writeDataFromHandler(rws.stream, curWrite); err != nil {
 		return 0, err
 	}
 	return len(p), nil

+ 79 - 3
http2/server_test.go

@@ -125,7 +125,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
 		st.scMu.Lock()
 		defer st.scMu.Unlock()
 		st.sc = v
-		st.sc.testHookCh = make(chan func())
+		st.sc.testHookCh = make(chan func(int))
 	}
 	log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
 	if !onlyServer {
@@ -152,7 +152,7 @@ func (st *serverTester) addLogFilter(phrase string) {
 
 func (st *serverTester) stream(id uint32) *stream {
 	ch := make(chan *stream, 1)
-	st.sc.testHookCh <- func() {
+	st.sc.testHookCh <- func(int) {
 		ch <- st.sc.streams[id]
 	}
 	return <-ch
@@ -160,13 +160,39 @@ func (st *serverTester) stream(id uint32) *stream {
 
 func (st *serverTester) streamState(id uint32) streamState {
 	ch := make(chan streamState, 1)
-	st.sc.testHookCh <- func() {
+	st.sc.testHookCh <- func(int) {
 		state, _ := st.sc.state(id)
 		ch <- state
 	}
 	return <-ch
 }
 
+// loopNum reports how many times this conn's select loop has gone around.
+func (st *serverTester) loopNum() int {
+	lastc := make(chan int, 1)
+	st.sc.testHookCh <- func(loopNum int) {
+		lastc <- loopNum
+	}
+	return <-lastc
+}
+
+// awaitIdle heuristically awaits for the server conn's select loop to be idle.
+// The heuristic is that the server connection's serve loop must schedule
+// 50 times in a row without any channel sends or receives occuring.
+func (st *serverTester) awaitIdle() {
+	remain := 50
+	last := st.loopNum()
+	for remain > 0 {
+		n := st.loopNum()
+		if n == last+1 {
+			remain--
+		} else {
+			remain = 50
+		}
+		last = n
+	}
+}
+
 func (st *serverTester) Close() {
 	st.ts.Close()
 	if st.cc != nil {
@@ -1028,6 +1054,56 @@ func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
 	)
 }
 
+func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
+	// Run this test a bunch, because it doesn't always
+	// deadlock. But with a bunch, it did.
+	n := 50
+	if testing.Short() {
+		n = 5
+	}
+	for i := 0; i < n; i++ {
+		testServer_RSTStream_Unblocks_Header_Write(t)
+	}
+}
+
+func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
+	inHandler := make(chan bool, 1)
+	unblockHandler := make(chan bool, 1)
+	headerWritten := make(chan bool, 1)
+	wroteRST := make(chan bool, 1)
+
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		inHandler <- true
+		<-wroteRST
+		w.Header().Set("foo", "bar")
+		w.WriteHeader(200)
+		w.(http.Flusher).Flush()
+		headerWritten <- true
+		<-unblockHandler
+	})
+	defer st.Close()
+
+	st.greet()
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1,
+		BlockFragment: st.encodeHeader(":method", "POST"),
+		EndStream:     false, // keep it open
+		EndHeaders:    true,
+	})
+	<-inHandler
+	if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
+		t.Fatal(err)
+	}
+	wroteRST <- true
+	st.awaitIdle()
+	select {
+	case <-headerWritten:
+	case <-time.After(2 * time.Second):
+		t.Error("timeout waiting for header write")
+	}
+	unblockHandler <- true
+}
+
 func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
 	testServerPostUnblock(t,
 		func(w http.ResponseWriter, r *http.Request) (err error) {