Browse Source

http2: fix server race

This changes makes sure we never write to *writeData in the ServeHTTP
goroutine until the serve goroutine is done with it.

Also, it makes sure we don't transition the stream to the closed state
on the final DATA frame concurrently with the write.

To fix both, the writeFrameAsync goroutine no longer replies directly back
to the ServeHTTP goroutine with the write result. It's now passed to
the serve goroutine instead, which looks at the frameWriteMsg to
decide how to advance the state machine, then signals the ServeHTTP
goroutine with the result, and then advances the state machine.

Because advancing the state machine could transition it to closed,
which the ServeHTTP goroutine might also be selecting on, make the
ServeHTTP goroutine prefer its frameWriteMsg response channel for errors
over the stream closure in its select.

Various code simplifications and robustness in the process.

Tests now pass reliably even with high -count values, -race on/off,
etc. I've been unable to make h2load be unhappy now either.

Thanks to Tatsuhiro Tsujikawa (Github user @tatsuhiro-t) for the bug
report and debugging clues.

Fixes golang/go#12998

Change-Id: I441c4c9ca928eaba89fd4728d213019606edd899
Reviewed-on: https://go-review.googlesource.com/16063
Reviewed-by: Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick 10 years ago
parent
commit
564010564f
3 changed files with 103 additions and 58 deletions
  1. 95 58
      http2/server.go
  2. 3 0
      http2/server_test.go
  3. 5 0
      http2/write.go

+ 95 - 58
http2/server.go

@@ -65,7 +65,7 @@ const (
 var (
 var (
 	errClientDisconnected = errors.New("client disconnected")
 	errClientDisconnected = errors.New("client disconnected")
 	errClosedBody         = errors.New("body closed by handler")
 	errClosedBody         = errors.New("body closed by handler")
-	errStreamBroken       = errors.New("http2: stream broken")
+	errStreamClosed       = errors.New("http2: stream closed")
 )
 )
 
 
 var responseWriterStatePool = sync.Pool{
 var responseWriterStatePool = sync.Pool{
@@ -207,8 +207,8 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 		streams:          make(map[uint32]*stream),
 		streams:          make(map[uint32]*stream),
 		readFrameCh:      make(chan readFrameResult),
 		readFrameCh:      make(chan readFrameResult),
 		wantWriteFrameCh: make(chan frameWriteMsg, 8),
 		wantWriteFrameCh: make(chan frameWriteMsg, 8),
-		wroteFrameCh:     make(chan struct{}, 1), // buffered; one send in reading goroutine
-		bodyReadCh:       make(chan bodyReadMsg), // buffering doesn't matter either way
+		wroteFrameCh:     make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync
+		bodyReadCh:       make(chan bodyReadMsg),         // buffering doesn't matter either way
 		doneServing:      make(chan struct{}),
 		doneServing:      make(chan struct{}),
 		advMaxStreams:    srv.maxConcurrentStreams(),
 		advMaxStreams:    srv.maxConcurrentStreams(),
 		writeSched: writeScheduler{
 		writeSched: writeScheduler{
@@ -322,15 +322,15 @@ type serverConn struct {
 	handler          http.Handler
 	handler          http.Handler
 	framer           *Framer
 	framer           *Framer
 	hpackDecoder     *hpack.Decoder
 	hpackDecoder     *hpack.Decoder
-	doneServing      chan struct{}        // closed when serverConn.serve ends
-	readFrameCh      chan readFrameResult // written by serverConn.readFrames
-	wantWriteFrameCh chan frameWriteMsg   // from handlers -> serve
-	wroteFrameCh     chan struct{}        // from writeFrameAsync -> serve, tickles more frame writes
-	bodyReadCh       chan bodyReadMsg     // from handlers -> serve
-	testHookCh       chan func(int)       // code to run on the serve loop
-	flow             flow                 // conn-wide (not stream-specific) outbound flow control
-	inflow           flow                 // conn-wide inbound flow control
-	tlsState         *tls.ConnectionState // shared by all handlers, like net/http
+	doneServing      chan struct{}         // closed when serverConn.serve ends
+	readFrameCh      chan readFrameResult  // written by serverConn.readFrames
+	wantWriteFrameCh chan frameWriteMsg    // from handlers -> serve
+	wroteFrameCh     chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes
+	bodyReadCh       chan bodyReadMsg      // from handlers -> serve
+	testHookCh       chan func(int)        // code to run on the serve loop
+	flow             flow                  // conn-wide (not stream-specific) outbound flow control
+	inflow           flow                  // conn-wide inbound flow control
+	tlsState         *tls.ConnectionState  // shared by all handlers, like net/http
 	remoteAddrStr    string
 	remoteAddrStr    string
 
 
 	// Everything following is owned by the serve loop; use serveG.check():
 	// Everything following is owned by the serve loop; use serveG.check():
@@ -584,20 +584,19 @@ func (sc *serverConn) readFrames() {
 	}
 	}
 }
 }
 
 
+// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine.
+type frameWriteResult struct {
+	wm  frameWriteMsg // what was written (or attempted)
+	err error         // result of the writeFrame call
+}
+
 // writeFrameAsync runs in its own goroutine and writes a single frame
 // writeFrameAsync runs in its own goroutine and writes a single frame
 // and then reports when it's done.
 // and then reports when it's done.
 // At most one goroutine can be running writeFrameAsync at a time per
 // At most one goroutine can be running writeFrameAsync at a time per
 // serverConn.
 // serverConn.
 func (sc *serverConn) writeFrameAsync(wm frameWriteMsg) {
 func (sc *serverConn) writeFrameAsync(wm frameWriteMsg) {
 	err := wm.write.writeFrame(sc)
 	err := wm.write.writeFrame(sc)
-	if ch := wm.done; ch != nil {
-		select {
-		case ch <- err:
-		default:
-			panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write))
-		}
-	}
-	sc.wroteFrameCh <- struct{}{} // tickle frame selection scheduler
+	sc.wroteFrameCh <- frameWriteResult{wm, err}
 }
 }
 
 
 func (sc *serverConn) closeAllStreamsOnConnClose() {
 func (sc *serverConn) closeAllStreamsOnConnClose() {
@@ -672,12 +671,8 @@ func (sc *serverConn) serve() {
 		select {
 		select {
 		case wm := <-sc.wantWriteFrameCh:
 		case wm := <-sc.wantWriteFrameCh:
 			sc.writeFrame(wm)
 			sc.writeFrame(wm)
-		case <-sc.wroteFrameCh:
-			if sc.writingFrame != true {
-				panic("internal error: expected to be already writing a frame")
-			}
-			sc.writingFrame = false
-			sc.scheduleFrameWrite()
+		case res := <-sc.wroteFrameCh:
+			sc.wroteFrame(res)
 		case res := <-sc.readFrameCh:
 		case res := <-sc.readFrameCh:
 			if !sc.processFrameFromReader(res) {
 			if !sc.processFrameFromReader(res) {
 				return
 				return
@@ -743,20 +738,34 @@ var errChanPool = sync.Pool{
 // scheduling decisions available.
 // scheduling decisions available.
 func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData) error {
 func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData) error {
 	ch := errChanPool.Get().(chan error)
 	ch := errChanPool.Get().(chan error)
-	sc.writeFrameFromHandler(frameWriteMsg{
+	err := sc.writeFrameFromHandler(frameWriteMsg{
 		write:  writeData,
 		write:  writeData,
 		stream: stream,
 		stream: stream,
 		done:   ch,
 		done:   ch,
 	})
 	})
-	select {
-	case err := <-ch:
-		errChanPool.Put(ch)
+	if err != nil {
 		return err
 		return err
+	}
+	select {
+	case err = <-ch:
 	case <-sc.doneServing:
 	case <-sc.doneServing:
 		return errClientDisconnected
 		return errClientDisconnected
 	case <-stream.cw:
 	case <-stream.cw:
-		return errStreamBroken
+		// If both ch and stream.cw were ready (as might
+		// happen on the final Write after an http.Handler
+		// ends), prefer the write result. Otherwise this
+		// might just be us successfully closing the stream.
+		// The writeFrameAsync and serve goroutines guarantee
+		// that the ch send will happen before the stream.cw
+		// close.
+		select {
+		case err = <-ch:
+		default:
+			return errStreamClosed
+		}
 	}
 	}
+	errChanPool.Put(ch)
+	return err
 }
 }
 
 
 // writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts
 // writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts
@@ -766,24 +775,15 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData)
 // deadlock writing to sc.wantWriteFrameCh (which is only mildly
 // deadlock writing to sc.wantWriteFrameCh (which is only mildly
 // buffered and is read by serve itself). If you're on the serve
 // buffered and is read by serve itself). If you're on the serve
 // goroutine, call writeFrame instead.
 // goroutine, call writeFrame instead.
-func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) {
+func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) error {
 	sc.serveG.checkNotOn() // NOT
 	sc.serveG.checkNotOn() // NOT
-	var scheduled bool
 	select {
 	select {
 	case sc.wantWriteFrameCh <- wm:
 	case sc.wantWriteFrameCh <- wm:
-		scheduled = true
+		return nil
 	case <-sc.doneServing:
 	case <-sc.doneServing:
+		// Serve loop is gone.
 		// Client has closed their connection to the server.
 		// Client has closed their connection to the server.
-	case <-wm.stream.cw:
-		// Stream closed.
-	}
-	// Don't block writers expecting a reply.
-	if !scheduled && wm.done != nil {
-		select {
-		case wm.done <- errStreamBroken:
-		default:
-			panic("expected buffered channel")
-		}
+		return errClientDisconnected
 	}
 	}
 }
 }
 
 
@@ -809,7 +809,6 @@ func (sc *serverConn) startFrameWrite(wm frameWriteMsg) {
 	if sc.writingFrame {
 	if sc.writingFrame {
 		panic("internal error: can only be writing one frame at a time")
 		panic("internal error: can only be writing one frame at a time")
 	}
 	}
-	sc.writingFrame = true
 
 
 	st := wm.stream
 	st := wm.stream
 	if st != nil {
 	if st != nil {
@@ -818,16 +817,44 @@ func (sc *serverConn) startFrameWrite(wm frameWriteMsg) {
 			panic("internal error: attempt to send frame on half-closed-local stream")
 			panic("internal error: attempt to send frame on half-closed-local stream")
 		case stateClosed:
 		case stateClosed:
 			if st.sentReset || st.gotReset {
 			if st.sentReset || st.gotReset {
-				// Skip this frame. But fake the frame write to reschedule:
-				sc.wroteFrameCh <- struct{}{}
+				// Skip this frame.
+				sc.scheduleFrameWrite()
 				return
 				return
 			}
 			}
 			panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm))
 			panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm))
 		}
 		}
 	}
 	}
 
 
+	sc.writingFrame = true
 	sc.needsFrameFlush = true
 	sc.needsFrameFlush = true
-	if endsStream(wm.write) {
+	go sc.writeFrameAsync(wm)
+}
+
+// wroteFrame is called on the serve goroutine with the result of
+// whatever happened on writeFrameAsync.
+func (sc *serverConn) wroteFrame(res frameWriteResult) {
+	sc.serveG.check()
+	if !sc.writingFrame {
+		panic("internal error: expected to be already writing a frame")
+	}
+	sc.writingFrame = false
+
+	wm := res.wm
+	st := wm.stream
+
+	closeStream := endsStream(wm.write)
+
+	// Reply (if requested) to the blocked ServeHTTP goroutine.
+	if ch := wm.done; ch != nil {
+		select {
+		case ch <- res.err:
+		default:
+			panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write))
+		}
+	}
+	wm.write = nil // prevent use (assume it's tainted after wm.done send)
+
+	if closeStream {
 		if st == nil {
 		if st == nil {
 			panic("internal error: expecting non-nil stream")
 			panic("internal error: expecting non-nil stream")
 		}
 		}
@@ -848,7 +875,8 @@ func (sc *serverConn) startFrameWrite(wm frameWriteMsg) {
 			sc.closeStream(st, nil)
 			sc.closeStream(st, nil)
 		}
 		}
 	}
 	}
-	go sc.writeFrameAsync(wm)
+
+	sc.scheduleFrameWrite()
 }
 }
 
 
 // scheduleFrameWrite tickles the frame writing scheduler.
 // scheduleFrameWrite tickles the frame writing scheduler.
@@ -1509,7 +1537,7 @@ func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) {
 
 
 // called from handler goroutines.
 // called from handler goroutines.
 // h may be nil.
 // h may be nil.
-func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) {
+func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) error {
 	sc.serveG.checkNotOn() // NOT on
 	sc.serveG.checkNotOn() // NOT on
 	var errc chan error
 	var errc chan error
 	if headerData.h != nil {
 	if headerData.h != nil {
@@ -1519,23 +1547,25 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) {
 		// mutates it.
 		// mutates it.
 		errc = errChanPool.Get().(chan error)
 		errc = errChanPool.Get().(chan error)
 	}
 	}
-	sc.writeFrameFromHandler(frameWriteMsg{
+	if err := sc.writeFrameFromHandler(frameWriteMsg{
 		write:  headerData,
 		write:  headerData,
 		stream: st,
 		stream: st,
 		done:   errc,
 		done:   errc,
-	})
+	}); err != nil {
+		return err
+	}
 	if errc != nil {
 	if errc != nil {
 		select {
 		select {
-		case <-errc:
-			// Ignore. Just for synchronization.
-			// Any error will be handled in the writing goroutine.
+		case err := <-errc:
 			errChanPool.Put(errc)
 			errChanPool.Put(errc)
+			return err
 		case <-sc.doneServing:
 		case <-sc.doneServing:
-			// Client has closed the connection.
+			return errClientDisconnected
 		case <-st.cw:
 		case <-st.cw:
-			// Client did RST_STREAM, etc. (but conn still alive)
+			return errStreamClosed
 		}
 		}
 	}
 	}
+	return nil
 }
 }
 
 
 // called from handler goroutines.
 // called from handler goroutines.
@@ -1710,7 +1740,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			ctype = http.DetectContentType(p)
 			ctype = http.DetectContentType(p)
 		}
 		}
 		endStream := rws.handlerDone && len(p) == 0
 		endStream := rws.handlerDone && len(p) == 0
-		rws.conn.writeHeaders(rws.stream, &writeResHeaders{
+		err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
 			streamID:      rws.stream.id,
 			streamID:      rws.stream.id,
 			httpResCode:   rws.status,
 			httpResCode:   rws.status,
 			h:             rws.snapHeader,
 			h:             rws.snapHeader,
@@ -1718,6 +1748,9 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			contentType:   ctype,
 			contentType:   ctype,
 			contentLength: clen,
 			contentLength: clen,
 		})
 		})
+		if err != nil {
+			return 0, err
+		}
 		if endStream {
 		if endStream {
 			return 0, nil
 			return 0, nil
 		}
 		}
@@ -1725,6 +1758,10 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 	if len(p) == 0 && !rws.handlerDone {
 	if len(p) == 0 && !rws.handlerDone {
 		return 0, nil
 		return 0, nil
 	}
 	}
+
+	// Reuse curWrite (which as a pointer fits into the
+	// 'writeFramer' interface value) for each write to avoid an
+	// allocation per write.
 	curWrite := &rws.curWrite
 	curWrite := &rws.curWrite
 	curWrite.streamID = rws.stream.id
 	curWrite.streamID = rws.stream.id
 	curWrite.p = p
 	curWrite.p = p

+ 3 - 0
http2/server_test.go

@@ -2207,6 +2207,9 @@ func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) {
 	if runtime.GOOS != "linux" {
 	if runtime.GOOS != "linux" {
 		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
 		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
 	}
 	}
+	if testing.Short() {
+		t.Skip("skipping curl test in short mode")
+	}
 	requireCurl(t)
 	requireCurl(t)
 	const msg = "Hello from curl!\n"
 	const msg = "Hello from curl!\n"
 	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

+ 5 - 0
http2/write.go

@@ -41,6 +41,11 @@ func endsStream(w writeFramer) bool {
 		return v.endStream
 		return v.endStream
 	case *writeResHeaders:
 	case *writeResHeaders:
 		return v.endStream
 		return v.endStream
+	case nil:
+		// This can only happen if the caller reuses w after it's
+		// been intentionally nil'ed out to prevent use. Keep this
+		// here to catch future refactoring breaking it.
+		panic("endsStream called on nil writeFramer")
 	}
 	}
 	return false
 	return false
 }
 }