Browse Source

From handler goroutines, don't assume serve loop goroutine is still active.

From running the production demo site (http2.golang.org), I noticed
three places where the handler goroutines would be forever blocked on
a channel operation because the client had disconnected, thus leaking
the goroutine.
Brad Fitzpatrick 11 years ago
parent
commit
9b41faf85f
2 changed files with 62 additions and 7 deletions
  1. 27 7
      server.go
  2. 35 0
      server_test.go

+ 27 - 7
server.go

@@ -42,6 +42,9 @@ const (
 // be in-flight and then the frame scheduler in the serve goroutine
 // be in-flight and then the frame scheduler in the serve goroutine
 // will be responsible for splitting things.
 // will be responsible for splitting things.
 
 
+// TODO: send PING frames to idle clients and disconnect them if no
+// reply
+
 // Server is an HTTP/2 server.
 // Server is an HTTP/2 server.
 type Server struct {
 type Server struct {
 	// MaxStreams optionally ...
 	// MaxStreams optionally ...
@@ -352,9 +355,8 @@ func (sc *serverConn) writeFrames() {
 
 
 var errClientDisconnected = errors.New("client disconnected")
 var errClientDisconnected = errors.New("client disconnected")
 
 
-func (sc *serverConn) stopServing() {
+func (sc *serverConn) closeAllStreamsOnConnClose() {
 	sc.serveG.check()
 	sc.serveG.check()
-	close(sc.writeFrameCh) // stop the writeFrames loop
 	for _, st := range sc.streams {
 	for _, st := range sc.streams {
 		sc.closeStream(st, errClientDisconnected)
 		sc.closeStream(st, errClientDisconnected)
 	}
 	}
@@ -363,7 +365,9 @@ func (sc *serverConn) stopServing() {
 func (sc *serverConn) serve() {
 func (sc *serverConn) serve() {
 	sc.serveG.check()
 	sc.serveG.check()
 	defer sc.conn.Close()
 	defer sc.conn.Close()
-	defer sc.stopServing()
+	defer sc.closeAllStreamsOnConnClose()
+	defer close(sc.doneServing)  // unblocks handlers trying to send
+	defer close(sc.writeFrameCh) // stop the writeFrames loop
 
 
 	sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 	sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 
 
@@ -444,8 +448,12 @@ func (sc *serverConn) readPreface() error {
 // should be called from non-serve() goroutines, otherwise the ends may deadlock
 // should be called from non-serve() goroutines, otherwise the ends may deadlock
 // the serve loop. (it's only buffered a little bit).
 // the serve loop. (it's only buffered a little bit).
 func (sc *serverConn) writeFrame(wm frameWriteMsg) {
 func (sc *serverConn) writeFrame(wm frameWriteMsg) {
-	sc.serveG.checkNotOn() // note the "NOT"
-	sc.wantWriteFrameCh <- wm
+	sc.serveG.checkNotOn() // NOT
+	select {
+	case sc.wantWriteFrameCh <- wm:
+	case <-sc.doneServing:
+		// Client has closed their connection to the server.
+	}
 }
 }
 
 
 func (sc *serverConn) enqueueFrameWrite(wm frameWriteMsg) {
 func (sc *serverConn) enqueueFrameWrite(wm frameWriteMsg) {
@@ -1113,7 +1121,13 @@ func (sc *serverConn) writeHeaders(req headerWriteReq) {
 		endStream: req.endStream,
 		endStream: req.endStream,
 	})
 	})
 	if errc != nil {
 	if errc != nil {
-		<-errc
+		select {
+		case <-errc:
+			// Ignore. Just for synchronization.
+			// Any error will be handled in the writing goroutine.
+		case <-sc.doneServing:
+			// Client has closed the connection.
+		}
 	}
 	}
 }
 }
 
 
@@ -1367,7 +1381,13 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			done:      rws.chunkWrittenCh,
 			done:      rws.chunkWrittenCh,
 			v:         rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
 			v:         rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
 		})
 		})
-		err = <-rws.chunkWrittenCh // block until it's written
+		// Block until it's written, or if the client disconnects.
+		select {
+		case err = <-rws.chunkWrittenCh:
+		case <-rws.stream.conn.doneServing:
+			// Client disconnected.
+			err = errClientDisconnected
+		}
 		if err != nil {
 		if err != nil {
 			break
 			break
 		}
 		}

+ 35 - 0
server_test.go

@@ -1348,6 +1348,41 @@ func TestServer_Response_Automatic100Continue(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
+	errc := make(chan error, 1)
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		p := []byte("some data.\n")
+		for {
+			_, err := w.Write(p)
+			if err != nil {
+				errc <- err
+				return nil
+			}
+		}
+	}, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1,
+			BlockFragment: encodeHeader(st.t),
+			EndStream:     false,
+			EndHeaders:    true,
+		})
+		hf := st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("unexpected END_STREAM flag")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		// Close the connection and wait for the handler to (hopefully) notice.
+		st.cc.Close()
+		select {
+		case <-errc:
+		case <-time.After(5 * time.Second):
+			t.Error("timeout")
+		}
+	})
+}
+
 func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) {
 func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) {
 	d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
 	d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
 		pairs = append(pairs, [2]string{f.Name, f.Value})
 		pairs = append(pairs, [2]string{f.Name, f.Value})