فهرست منبع

Interrupt Request.Body.Read on RSTStream or connection close.

So we don't leak goroutines blocked forever in Read.
Brad Fitzpatrick 11 سال پیش
والد
کامیت
6d3aa4f311
5فایلهای تغییر یافته به همراه138 افزوده شده و 16 حذف شده
  1. 15 2
      flow.go
  2. 21 0
      flow_test.go
  3. 2 2
      frame.go
  4. 35 11
      server.go
  5. 65 1
      server_test.go

+ 15 - 2
flow.go

@@ -11,8 +11,9 @@ import "sync"
 
 // flow is the flow control window's counting semaphore.
 type flow struct {
-	c    *sync.Cond // protects size
-	size int32
+	c      *sync.Cond // protects size
+	size   int32
+	closed bool
 }
 
 func newFlow(n int32) *flow {
@@ -42,6 +43,9 @@ func (f *flow) acquire(n int32) (waited int) {
 	f.c.L.Lock()
 	defer f.c.L.Unlock()
 	for {
+		if f.closed {
+			return
+		}
 		if f.size >= n {
 			f.size -= n
 			return
@@ -64,3 +68,12 @@ func (f *flow) add(n int32) bool {
 	f.c.Broadcast()
 	return true
 }
+
+// close marks the flow as closed, meaning everybody gets all the
+// tokens they want, because everything else will fail anyway.
+func (f *flow) close() {
+	f.c.L.Lock()
+	defer f.c.L.Unlock()
+	f.closed = true
+	f.c.Broadcast()
+}

+ 21 - 0
flow_test.go

@@ -59,3 +59,24 @@ func TestFlowAdd(t *testing.T) {
 	}
 
 }
+
+func TestFlowClose(t *testing.T) {
+	f := newFlow(0)
+
+	// Wait for 10, which should block, so start a background goroutine
+	// to refill it.
+	go func() {
+		time.Sleep(50 * time.Millisecond)
+		f.close()
+	}()
+	donec := make(chan bool)
+	go func() {
+		defer close(donec)
+		f.acquire(10)
+	}()
+	select {
+	case <-donec:
+	case <-time.After(2 * time.Second):
+		t.Error("timeout")
+	}
+}

+ 2 - 2
frame.go

@@ -883,7 +883,7 @@ func (f *Framer) WritePriority(streamID uint32, p PriorityParam) error {
 // See http://http2.github.io/http2-spec/#rfc.section.6.4
 type RSTStreamFrame struct {
 	FrameHeader
-	ErrCode uint32
+	ErrCode ErrCode
 }
 
 func parseRSTStreamFrame(fh FrameHeader, p []byte) (Frame, error) {
@@ -893,7 +893,7 @@ func parseRSTStreamFrame(fh FrameHeader, p []byte) (Frame, error) {
 	if fh.StreamID == 0 {
 		return nil, ConnectionError(ErrCodeProtocol)
 	}
-	return &RSTStreamFrame{fh, binary.BigEndian.Uint32(p[:4])}, nil
+	return &RSTStreamFrame{fh, ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil
 }
 
 // WriteRSTStream writes a RST_STREAM frame.

+ 35 - 11
server.go

@@ -42,12 +42,6 @@ const (
 // be in-flight and then the frame scheduler in the serve goroutine
 // will be responsible for splitting things.
 
-// TODO: test/handle a client sending a POST with potential data, get
-// stuck in the handler in a Read, then client sends RST_STREAM, and
-// we should verify the Read then unblocks, rather than being stuck
-// forever and leaking a goroutine. and it should return an error from
-// the Read.
-
 // Server is an HTTP/2 server.
 type Server struct {
 	// MaxStreams optionally ...
@@ -175,7 +169,7 @@ type requestParam struct {
 
 type stream struct {
 	id    uint32
-	state streamState // owned by serverConn's processing loop
+	state streamState // owned by serverConn's serve loop
 	flow  *flow       // limits writing from Handler to client
 	body  *pipe       // non-nil if expecting DATA frames
 
@@ -330,10 +324,19 @@ func (sc *serverConn) writeFrames() {
 	}
 }
 
+func (sc *serverConn) stopServing() {
+	sc.serveG.check()
+	close(sc.writeFrameCh) // stop the writeFrames loop
+	err := errors.New("client disconnected")
+	for id := range sc.streams {
+		sc.closeStream(id, err)
+	}
+}
+
 func (sc *serverConn) serve() {
 	sc.serveG.check()
 	defer sc.conn.Close()
-	defer close(sc.doneServing)
+	defer sc.stopServing()
 
 	sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 
@@ -344,9 +347,8 @@ func (sc *serverConn) serve() {
 		return
 	}
 
-	go sc.readFrames() // closed by defer sc.conn.Close above
-	go sc.writeFrames()
-	defer close(sc.writeFrameCh) // shuts down writeFrames loop
+	go sc.readFrames()  // closed by defer sc.conn.Close above
+	go sc.writeFrames() // source closed in stopServing
 
 	settingsTimer := time.NewTimer(firstSettingsTimeout)
 
@@ -591,6 +593,8 @@ func (sc *serverConn) processFrame(f Frame) error {
 		return sc.processPing(f)
 	case *DataFrame:
 		return sc.processData(f)
+	case *RSTStreamFrame:
+		return sc.processResetStream(f)
 	default:
 		log.Printf("Ignoring unknown frame %#v", f)
 		return nil
@@ -649,6 +653,26 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
 	return nil
 }
 
+func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
+	sc.serveG.check()
+	sc.closeStream(f.StreamID, StreamError{f.StreamID, f.ErrCode})
+	return nil
+}
+
+func (sc *serverConn) closeStream(streamID uint32, err error) {
+	sc.serveG.check()
+	st, ok := sc.streams[streamID]
+	if !ok {
+		return
+	}
+	st.state = stateClosed // kinda useless
+	delete(sc.streams, streamID)
+	st.flow.close()
+	if p := st.body; p != nil {
+		p.Close(err)
+	}
+}
+
 func (sc *serverConn) processSettings(f *SettingsFrame) error {
 	sc.serveG.check()
 	if f.IsAck() {

+ 65 - 1
server_test.go

@@ -20,6 +20,7 @@ import (
 	"net/http/httptest"
 	"os"
 	"reflect"
+	"runtime"
 	"strconv"
 	"strings"
 	"sync/atomic"
@@ -220,7 +221,7 @@ func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
 	if rs.FrameHeader.StreamID != streamID {
 		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
 	}
-	if rs.ErrCode != uint32(errCode) {
+	if rs.ErrCode != errCode {
 		st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
 	}
 }
@@ -780,6 +781,66 @@ func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
 	st.wantRSTStream(1, ErrCodeFlowControl)
 }
 
+func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
+	inHandler := make(chan bool)
+	errc := make(chan error, 1)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		inHandler <- true
+		_, err := r.Body.Read(make([]byte, 1))
+		errc <- err
+	})
+	st.greet()
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1,
+		BlockFragment: encodeHeader(st.t, ":method", "POST"),
+		EndStream:     false, // keep it open
+		EndHeaders:    true,
+	})
+	<-inHandler
+	if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
+		t.Fatal(err)
+	}
+	select {
+	case err := <-errc:
+		if err == nil {
+			t.Fatal("unexpected nil error from Read")
+		}
+		t.Logf("Read = %v", err)
+		st.Close()
+	case <-time.After(5 * time.Second):
+		t.Fatal("timeout waiting for Handler's Body.Read to error out")
+	}
+}
+
+func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
+	inHandler := make(chan bool)
+	errc := make(chan error, 1)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		inHandler <- true
+		_, err := r.Body.Read(make([]byte, 1))
+		errc <- err
+	})
+	st.greet()
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1,
+		BlockFragment: encodeHeader(st.t, ":method", "POST"),
+		EndStream:     false, // keep it open
+		EndHeaders:    true,
+	})
+	<-inHandler
+	st.cc.Close() // hard-close the network connection
+	select {
+	case err := <-errc:
+		if err == nil {
+			t.Fatal("unexpected nil error from Read")
+		}
+		t.Logf("Read = %v", err)
+		st.Close()
+	case <-time.After(5 * time.Second):
+		t.Fatal("timeout waiting for Handler's Body.Read to error out")
+	}
+}
+
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
 
@@ -1221,6 +1282,9 @@ func testServerResponse(t *testing.T,
 }
 
 func TestServerWithCurl(t *testing.T) {
+	if runtime.GOOS == "darwin" {
+		t.Skip("skipping Docker test on Darwin; requires --net which won't work with boot2docker anyway")
+	}
 	requireCurl(t)
 	const msg = "Hello from curl!\n"
 	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {