Browse Source

Interrupt Request.Body.Read on RSTStream or connection close.

So we don't leak goroutines blocked forever in Read.
Brad Fitzpatrick 11 years ago
parent
commit
6d3aa4f311
5 changed files with 138 additions and 16 deletions
  1. 15 2
      flow.go
  2. 21 0
      flow_test.go
  3. 2 2
      frame.go
  4. 35 11
      server.go
  5. 65 1
      server_test.go

+ 15 - 2
flow.go

@@ -11,8 +11,9 @@ import "sync"
 
 // flow is the flow control window's counting semaphore.
 type flow struct {
-	c    *sync.Cond // protects size
-	size int32
+	c      *sync.Cond // protects size
+	size   int32
+	closed bool
 }
 
 func newFlow(n int32) *flow {
@@ -42,6 +43,9 @@ func (f *flow) acquire(n int32) (waited int) {
 	f.c.L.Lock()
 	defer f.c.L.Unlock()
 	for {
+		if f.closed {
+			return
+		}
 		if f.size >= n {
 			f.size -= n
 			return
@@ -64,3 +68,12 @@ func (f *flow) add(n int32) bool {
 	f.c.Broadcast()
 	return true
 }
+
+// close marks the flow as closed, meaning everybody gets all the
+// tokens they want, because everything else will fail anyway.
+func (f *flow) close() {
+	f.c.L.Lock()
+	defer f.c.L.Unlock()
+	f.closed = true
+	f.c.Broadcast()
+}

+ 21 - 0
flow_test.go

@@ -59,3 +59,24 @@ func TestFlowAdd(t *testing.T) {
 	}
 
 }
+
+func TestFlowClose(t *testing.T) {
+	f := newFlow(0)
+
+	// Wait for 10, which should block, so start a background goroutine
+	// to refill it.
+	go func() {
+		time.Sleep(50 * time.Millisecond)
+		f.close()
+	}()
+	donec := make(chan bool)
+	go func() {
+		defer close(donec)
+		f.acquire(10)
+	}()
+	select {
+	case <-donec:
+	case <-time.After(2 * time.Second):
+		t.Error("timeout")
+	}
+}

+ 2 - 2
frame.go

@@ -883,7 +883,7 @@ func (f *Framer) WritePriority(streamID uint32, p PriorityParam) error {
 // See http://http2.github.io/http2-spec/#rfc.section.6.4
 type RSTStreamFrame struct {
 	FrameHeader
-	ErrCode uint32
+	ErrCode ErrCode
 }
 
 func parseRSTStreamFrame(fh FrameHeader, p []byte) (Frame, error) {
@@ -893,7 +893,7 @@ func parseRSTStreamFrame(fh FrameHeader, p []byte) (Frame, error) {
 	if fh.StreamID == 0 {
 		return nil, ConnectionError(ErrCodeProtocol)
 	}
-	return &RSTStreamFrame{fh, binary.BigEndian.Uint32(p[:4])}, nil
+	return &RSTStreamFrame{fh, ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil
 }
 
 // WriteRSTStream writes a RST_STREAM frame.

+ 35 - 11
server.go

@@ -42,12 +42,6 @@ const (
 // be in-flight and then the frame scheduler in the serve goroutine
 // will be responsible for splitting things.
 
-// TODO: test/handle a client sending a POST with potential data, get
-// stuck in the handler in a Read, then client sends RST_STREAM, and
-// we should verify the Read then unblocks, rather than being stuck
-// forever and leaking a goroutine. and it should return an error from
-// the Read.
-
 // Server is an HTTP/2 server.
 type Server struct {
 	// MaxStreams optionally ...
@@ -175,7 +169,7 @@ type requestParam struct {
 
 type stream struct {
 	id    uint32
-	state streamState // owned by serverConn's processing loop
+	state streamState // owned by serverConn's serve loop
 	flow  *flow       // limits writing from Handler to client
 	body  *pipe       // non-nil if expecting DATA frames
 
@@ -330,10 +324,19 @@ func (sc *serverConn) writeFrames() {
 	}
 }
 
+func (sc *serverConn) stopServing() {
+	sc.serveG.check()
+	close(sc.writeFrameCh) // stop the writeFrames loop
+	err := errors.New("client disconnected")
+	for id := range sc.streams {
+		sc.closeStream(id, err)
+	}
+}
+
 func (sc *serverConn) serve() {
 	sc.serveG.check()
 	defer sc.conn.Close()
-	defer close(sc.doneServing)
+	defer sc.stopServing()
 
 	sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 
@@ -344,9 +347,8 @@ func (sc *serverConn) serve() {
 		return
 	}
 
-	go sc.readFrames() // closed by defer sc.conn.Close above
-	go sc.writeFrames()
-	defer close(sc.writeFrameCh) // shuts down writeFrames loop
+	go sc.readFrames()  // closed by defer sc.conn.Close above
+	go sc.writeFrames() // source closed in stopServing
 
 	settingsTimer := time.NewTimer(firstSettingsTimeout)
 
@@ -591,6 +593,8 @@ func (sc *serverConn) processFrame(f Frame) error {
 		return sc.processPing(f)
 	case *DataFrame:
 		return sc.processData(f)
+	case *RSTStreamFrame:
+		return sc.processResetStream(f)
 	default:
 		log.Printf("Ignoring unknown frame %#v", f)
 		return nil
@@ -649,6 +653,26 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
 	return nil
 }
 
+func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
+	sc.serveG.check()
+	sc.closeStream(f.StreamID, StreamError{f.StreamID, f.ErrCode})
+	return nil
+}
+
+func (sc *serverConn) closeStream(streamID uint32, err error) {
+	sc.serveG.check()
+	st, ok := sc.streams[streamID]
+	if !ok {
+		return
+	}
+	st.state = stateClosed // kinda useless
+	delete(sc.streams, streamID)
+	st.flow.close()
+	if p := st.body; p != nil {
+		p.Close(err)
+	}
+}
+
 func (sc *serverConn) processSettings(f *SettingsFrame) error {
 	sc.serveG.check()
 	if f.IsAck() {

+ 65 - 1
server_test.go

@@ -20,6 +20,7 @@ import (
 	"net/http/httptest"
 	"os"
 	"reflect"
+	"runtime"
 	"strconv"
 	"strings"
 	"sync/atomic"
@@ -220,7 +221,7 @@ func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
 	if rs.FrameHeader.StreamID != streamID {
 		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
 	}
-	if rs.ErrCode != uint32(errCode) {
+	if rs.ErrCode != errCode {
 		st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
 	}
 }
@@ -780,6 +781,66 @@ func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
 	st.wantRSTStream(1, ErrCodeFlowControl)
 }
 
+func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
+	inHandler := make(chan bool)
+	errc := make(chan error, 1)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		inHandler <- true
+		_, err := r.Body.Read(make([]byte, 1))
+		errc <- err
+	})
+	st.greet()
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1,
+		BlockFragment: encodeHeader(st.t, ":method", "POST"),
+		EndStream:     false, // keep it open
+		EndHeaders:    true,
+	})
+	<-inHandler
+	if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
+		t.Fatal(err)
+	}
+	select {
+	case err := <-errc:
+		if err == nil {
+			t.Fatal("unexpected nil error from Read")
+		}
+		t.Logf("Read = %v", err)
+		st.Close()
+	case <-time.After(5 * time.Second):
+		t.Fatal("timeout waiting for Handler's Body.Read to error out")
+	}
+}
+
+func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
+	inHandler := make(chan bool)
+	errc := make(chan error, 1)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		inHandler <- true
+		_, err := r.Body.Read(make([]byte, 1))
+		errc <- err
+	})
+	st.greet()
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1,
+		BlockFragment: encodeHeader(st.t, ":method", "POST"),
+		EndStream:     false, // keep it open
+		EndHeaders:    true,
+	})
+	<-inHandler
+	st.cc.Close() // hard-close the network connection
+	select {
+	case err := <-errc:
+		if err == nil {
+			t.Fatal("unexpected nil error from Read")
+		}
+		t.Logf("Read = %v", err)
+		st.Close()
+	case <-time.After(5 * time.Second):
+		t.Fatal("timeout waiting for Handler's Body.Read to error out")
+	}
+}
+
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
 
@@ -1221,6 +1282,9 @@ func testServerResponse(t *testing.T,
 }
 
 func TestServerWithCurl(t *testing.T) {
+	if runtime.GOOS == "darwin" {
+		t.Skip("skipping Docker test on Darwin; requires --net which won't work with boot2docker anyway")
+	}
 	requireCurl(t)
 	const msg = "Hello from curl!\n"
 	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {