Просмотр исходного кода

http2: fix server race

This changes makes sure we never write to *writeData in the ServeHTTP
goroutine until the serve goroutine is done with it.

Also, it makes sure we don't transition the stream to the closed state
on the final DATA frame concurrently with the write.

To fix both, the writeFrameAsync goroutine no longer replies directly back
to the ServeHTTP goroutine with the write result. It's now passed to
the serve goroutine instead, which looks at the frameWriteMsg to
decide how to advance the state machine, then signals the ServeHTTP
goroutine with the result, and then advances the state machine.

Because advancing the state machine could transition it to closed,
which the ServeHTTP goroutine might also be selecting on, make the
ServeHTTP goroutine prefer its frameWriteMsg response channel for errors
over the stream closure in its select.

Various code simplifications and robustness in the process.

Tests now pass reliably even with high -count values, -race on/off,
etc. I've been unable to make h2load be unhappy now either.

Thanks to Tatsuhiro Tsujikawa (Github user @tatsuhiro-t) for the bug
report and debugging clues.

Fixes golang/go#12998

Change-Id: I441c4c9ca928eaba89fd4728d213019606edd899
Reviewed-on: https://go-review.googlesource.com/16063
Reviewed-by: Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick 10 лет назад
Родитель
Сommit
564010564f
3 измененных файлов с 103 добавлено и 58 удалено
  1. 95 58
      http2/server.go
  2. 3 0
      http2/server_test.go
  3. 5 0
      http2/write.go

+ 95 - 58
http2/server.go

@@ -65,7 +65,7 @@ const (
 var (
 	errClientDisconnected = errors.New("client disconnected")
 	errClosedBody         = errors.New("body closed by handler")
-	errStreamBroken       = errors.New("http2: stream broken")
+	errStreamClosed       = errors.New("http2: stream closed")
 )
 
 var responseWriterStatePool = sync.Pool{
@@ -207,8 +207,8 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 		streams:          make(map[uint32]*stream),
 		readFrameCh:      make(chan readFrameResult),
 		wantWriteFrameCh: make(chan frameWriteMsg, 8),
-		wroteFrameCh:     make(chan struct{}, 1), // buffered; one send in reading goroutine
-		bodyReadCh:       make(chan bodyReadMsg), // buffering doesn't matter either way
+		wroteFrameCh:     make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync
+		bodyReadCh:       make(chan bodyReadMsg),         // buffering doesn't matter either way
 		doneServing:      make(chan struct{}),
 		advMaxStreams:    srv.maxConcurrentStreams(),
 		writeSched: writeScheduler{
@@ -322,15 +322,15 @@ type serverConn struct {
 	handler          http.Handler
 	framer           *Framer
 	hpackDecoder     *hpack.Decoder
-	doneServing      chan struct{}        // closed when serverConn.serve ends
-	readFrameCh      chan readFrameResult // written by serverConn.readFrames
-	wantWriteFrameCh chan frameWriteMsg   // from handlers -> serve
-	wroteFrameCh     chan struct{}        // from writeFrameAsync -> serve, tickles more frame writes
-	bodyReadCh       chan bodyReadMsg     // from handlers -> serve
-	testHookCh       chan func(int)       // code to run on the serve loop
-	flow             flow                 // conn-wide (not stream-specific) outbound flow control
-	inflow           flow                 // conn-wide inbound flow control
-	tlsState         *tls.ConnectionState // shared by all handlers, like net/http
+	doneServing      chan struct{}         // closed when serverConn.serve ends
+	readFrameCh      chan readFrameResult  // written by serverConn.readFrames
+	wantWriteFrameCh chan frameWriteMsg    // from handlers -> serve
+	wroteFrameCh     chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes
+	bodyReadCh       chan bodyReadMsg      // from handlers -> serve
+	testHookCh       chan func(int)        // code to run on the serve loop
+	flow             flow                  // conn-wide (not stream-specific) outbound flow control
+	inflow           flow                  // conn-wide inbound flow control
+	tlsState         *tls.ConnectionState  // shared by all handlers, like net/http
 	remoteAddrStr    string
 
 	// Everything following is owned by the serve loop; use serveG.check():
@@ -584,20 +584,19 @@ func (sc *serverConn) readFrames() {
 	}
 }
 
+// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine.
+type frameWriteResult struct {
+	wm  frameWriteMsg // what was written (or attempted)
+	err error         // result of the writeFrame call
+}
+
 // writeFrameAsync runs in its own goroutine and writes a single frame
 // and then reports when it's done.
 // At most one goroutine can be running writeFrameAsync at a time per
 // serverConn.
 func (sc *serverConn) writeFrameAsync(wm frameWriteMsg) {
 	err := wm.write.writeFrame(sc)
-	if ch := wm.done; ch != nil {
-		select {
-		case ch <- err:
-		default:
-			panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write))
-		}
-	}
-	sc.wroteFrameCh <- struct{}{} // tickle frame selection scheduler
+	sc.wroteFrameCh <- frameWriteResult{wm, err}
 }
 
 func (sc *serverConn) closeAllStreamsOnConnClose() {
@@ -672,12 +671,8 @@ func (sc *serverConn) serve() {
 		select {
 		case wm := <-sc.wantWriteFrameCh:
 			sc.writeFrame(wm)
-		case <-sc.wroteFrameCh:
-			if sc.writingFrame != true {
-				panic("internal error: expected to be already writing a frame")
-			}
-			sc.writingFrame = false
-			sc.scheduleFrameWrite()
+		case res := <-sc.wroteFrameCh:
+			sc.wroteFrame(res)
 		case res := <-sc.readFrameCh:
 			if !sc.processFrameFromReader(res) {
 				return
@@ -743,20 +738,34 @@ var errChanPool = sync.Pool{
 // scheduling decisions available.
 func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData) error {
 	ch := errChanPool.Get().(chan error)
-	sc.writeFrameFromHandler(frameWriteMsg{
+	err := sc.writeFrameFromHandler(frameWriteMsg{
 		write:  writeData,
 		stream: stream,
 		done:   ch,
 	})
-	select {
-	case err := <-ch:
-		errChanPool.Put(ch)
+	if err != nil {
 		return err
+	}
+	select {
+	case err = <-ch:
 	case <-sc.doneServing:
 		return errClientDisconnected
 	case <-stream.cw:
-		return errStreamBroken
+		// If both ch and stream.cw were ready (as might
+		// happen on the final Write after an http.Handler
+		// ends), prefer the write result. Otherwise this
+		// might just be us successfully closing the stream.
+		// The writeFrameAsync and serve goroutines guarantee
+		// that the ch send will happen before the stream.cw
+		// close.
+		select {
+		case err = <-ch:
+		default:
+			return errStreamClosed
+		}
 	}
+	errChanPool.Put(ch)
+	return err
 }
 
 // writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts
@@ -766,24 +775,15 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData)
 // deadlock writing to sc.wantWriteFrameCh (which is only mildly
 // buffered and is read by serve itself). If you're on the serve
 // goroutine, call writeFrame instead.
-func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) {
+func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) error {
 	sc.serveG.checkNotOn() // NOT
-	var scheduled bool
 	select {
 	case sc.wantWriteFrameCh <- wm:
-		scheduled = true
+		return nil
 	case <-sc.doneServing:
+		// Serve loop is gone.
 		// Client has closed their connection to the server.
-	case <-wm.stream.cw:
-		// Stream closed.
-	}
-	// Don't block writers expecting a reply.
-	if !scheduled && wm.done != nil {
-		select {
-		case wm.done <- errStreamBroken:
-		default:
-			panic("expected buffered channel")
-		}
+		return errClientDisconnected
 	}
 }
 
@@ -809,7 +809,6 @@ func (sc *serverConn) startFrameWrite(wm frameWriteMsg) {
 	if sc.writingFrame {
 		panic("internal error: can only be writing one frame at a time")
 	}
-	sc.writingFrame = true
 
 	st := wm.stream
 	if st != nil {
@@ -818,16 +817,44 @@ func (sc *serverConn) startFrameWrite(wm frameWriteMsg) {
 			panic("internal error: attempt to send frame on half-closed-local stream")
 		case stateClosed:
 			if st.sentReset || st.gotReset {
-				// Skip this frame. But fake the frame write to reschedule:
-				sc.wroteFrameCh <- struct{}{}
+				// Skip this frame.
+				sc.scheduleFrameWrite()
 				return
 			}
 			panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm))
 		}
 	}
 
+	sc.writingFrame = true
 	sc.needsFrameFlush = true
-	if endsStream(wm.write) {
+	go sc.writeFrameAsync(wm)
+}
+
+// wroteFrame is called on the serve goroutine with the result of
+// whatever happened on writeFrameAsync.
+func (sc *serverConn) wroteFrame(res frameWriteResult) {
+	sc.serveG.check()
+	if !sc.writingFrame {
+		panic("internal error: expected to be already writing a frame")
+	}
+	sc.writingFrame = false
+
+	wm := res.wm
+	st := wm.stream
+
+	closeStream := endsStream(wm.write)
+
+	// Reply (if requested) to the blocked ServeHTTP goroutine.
+	if ch := wm.done; ch != nil {
+		select {
+		case ch <- res.err:
+		default:
+			panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write))
+		}
+	}
+	wm.write = nil // prevent use (assume it's tainted after wm.done send)
+
+	if closeStream {
 		if st == nil {
 			panic("internal error: expecting non-nil stream")
 		}
@@ -848,7 +875,8 @@ func (sc *serverConn) startFrameWrite(wm frameWriteMsg) {
 			sc.closeStream(st, nil)
 		}
 	}
-	go sc.writeFrameAsync(wm)
+
+	sc.scheduleFrameWrite()
 }
 
 // scheduleFrameWrite tickles the frame writing scheduler.
@@ -1509,7 +1537,7 @@ func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) {
 
 // called from handler goroutines.
 // h may be nil.
-func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) {
+func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) error {
 	sc.serveG.checkNotOn() // NOT on
 	var errc chan error
 	if headerData.h != nil {
@@ -1519,23 +1547,25 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) {
 		// mutates it.
 		errc = errChanPool.Get().(chan error)
 	}
-	sc.writeFrameFromHandler(frameWriteMsg{
+	if err := sc.writeFrameFromHandler(frameWriteMsg{
 		write:  headerData,
 		stream: st,
 		done:   errc,
-	})
+	}); err != nil {
+		return err
+	}
 	if errc != nil {
 		select {
-		case <-errc:
-			// Ignore. Just for synchronization.
-			// Any error will be handled in the writing goroutine.
+		case err := <-errc:
 			errChanPool.Put(errc)
+			return err
 		case <-sc.doneServing:
-			// Client has closed the connection.
+			return errClientDisconnected
 		case <-st.cw:
-			// Client did RST_STREAM, etc. (but conn still alive)
+			return errStreamClosed
 		}
 	}
+	return nil
 }
 
 // called from handler goroutines.
@@ -1710,7 +1740,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			ctype = http.DetectContentType(p)
 		}
 		endStream := rws.handlerDone && len(p) == 0
-		rws.conn.writeHeaders(rws.stream, &writeResHeaders{
+		err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
 			streamID:      rws.stream.id,
 			httpResCode:   rws.status,
 			h:             rws.snapHeader,
@@ -1718,6 +1748,9 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			contentType:   ctype,
 			contentLength: clen,
 		})
+		if err != nil {
+			return 0, err
+		}
 		if endStream {
 			return 0, nil
 		}
@@ -1725,6 +1758,10 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 	if len(p) == 0 && !rws.handlerDone {
 		return 0, nil
 	}
+
+	// Reuse curWrite (which as a pointer fits into the
+	// 'writeFramer' interface value) for each write to avoid an
+	// allocation per write.
 	curWrite := &rws.curWrite
 	curWrite.streamID = rws.stream.id
 	curWrite.p = p

+ 3 - 0
http2/server_test.go

@@ -2207,6 +2207,9 @@ func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) {
 	if runtime.GOOS != "linux" {
 		t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
 	}
+	if testing.Short() {
+		t.Skip("skipping curl test in short mode")
+	}
 	requireCurl(t)
 	const msg = "Hello from curl!\n"
 	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

+ 5 - 0
http2/write.go

@@ -41,6 +41,11 @@ func endsStream(w writeFramer) bool {
 		return v.endStream
 	case *writeResHeaders:
 		return v.endStream
+	case nil:
+		// This can only happen if the caller reuses w after it's
+		// been intentionally nil'ed out to prevent use. Keep this
+		// here to catch future refactoring breaking it.
+		panic("endsStream called on nil writeFramer")
 	}
 	return false
 }