Browse Source

Send WINDOW_UPDATE frames as Handlers read from their Request.Body.

Brad Fitzpatrick 11 years ago
parent
commit
6a48feb026
2 changed files with 133 additions and 4 deletions
  1. 41 3
      http2.go
  2. 92 1
      http2_test.go

+ 41 - 3
http2.go

@@ -69,12 +69,13 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 		hs:                hs,
 		hs:                hs,
 		conn:              c,
 		conn:              c,
 		handler:           h,
 		handler:           h,
-		framer:            NewFramer(c, c),
+		framer:            NewFramer(c, c), // TODO: write to a (custom?) buffered writer that can alternate when it's in buffered mode.
 		streams:           make(map[uint32]*stream),
 		streams:           make(map[uint32]*stream),
 		canonHeader:       make(map[string]string),
 		canonHeader:       make(map[string]string),
 		readFrameCh:       make(chan frameAndProcessed),
 		readFrameCh:       make(chan frameAndProcessed),
 		readFrameErrCh:    make(chan error, 1),
 		readFrameErrCh:    make(chan error, 1),
 		writeHeaderCh:     make(chan headerWriteReq), // must not be buffered
 		writeHeaderCh:     make(chan headerWriteReq), // must not be buffered
+		windowUpdateCh:    make(chan windowUpdateReq, 8),
 		flow:              newFlow(initialWindowSize),
 		flow:              newFlow(initialWindowSize),
 		doneServing:       make(chan struct{}),
 		doneServing:       make(chan struct{}),
 		maxWriteFrameSize: initialMaxFrameSize,
 		maxWriteFrameSize: initialMaxFrameSize,
@@ -107,8 +108,9 @@ type serverConn struct {
 	readFrameCh    chan frameAndProcessed // written by serverConn.readFrames
 	readFrameCh    chan frameAndProcessed // written by serverConn.readFrames
 	readFrameErrCh chan error
 	readFrameErrCh chan error
 	writeHeaderCh  chan headerWriteReq // must not be buffered
 	writeHeaderCh  chan headerWriteReq // must not be buffered
-	serveG         goroutineLock       // used to verify funcs are on serve()
-	flow           *flow               // the connection-wide one
+	windowUpdateCh chan windowUpdateReq
+	serveG         goroutineLock // used to verify funcs are on serve()
+	flow           *flow         // the connection-wide one
 
 
 	// Everything following is owned by the serve loop; use serveG.check()
 	// Everything following is owned by the serve loop; use serveG.check()
 	maxStreamID       uint32 // max ever seen
 	maxStreamID       uint32 // max ever seen
@@ -339,6 +341,11 @@ func (sc *serverConn) serve() {
 				sc.condlogf(err, "error writing response header: %v", err)
 				sc.condlogf(err, "error writing response header: %v", err)
 				return
 				return
 			}
 			}
+		case wu := <-sc.windowUpdateCh:
+			if err := sc.sendWindowUpdateInLoop(wu); err != nil {
+				sc.condlogf(err, "error writing window update: %v", err)
+				return
+			}
 		case fp, ok := <-sc.readFrameCh:
 		case fp, ok := <-sc.readFrameCh:
 			if !ok {
 			if !ok {
 				err := <-sc.readFrameErrCh
 				err := <-sc.readFrameErrCh
@@ -753,6 +760,36 @@ func (sc *serverConn) writeHeaderInLoop(req headerWriteReq) error {
 	})
 	})
 }
 }
 
 
+type windowUpdateReq struct {
+	streamID uint32
+	n        uint32
+}
+
+// called from handler goroutines
+func (sc *serverConn) sendWindowUpdate(streamID uint32, n int) {
+	const maxUint32 = 2147483647
+	for n >= maxUint32 {
+		sc.windowUpdateCh <- windowUpdateReq{streamID, maxUint32}
+		n -= maxUint32
+	}
+	if n > 0 {
+		sc.windowUpdateCh <- windowUpdateReq{streamID, uint32(n)}
+	}
+}
+
+func (sc *serverConn) sendWindowUpdateInLoop(wu windowUpdateReq) error {
+	sc.serveG.check()
+	// TODO: sc.bufferedOutput.StartBuffering()
+	if err := sc.framer.WriteWindowUpdate(0, wu.n); err != nil {
+		return err
+	}
+	if err := sc.framer.WriteWindowUpdate(wu.streamID, wu.n); err != nil {
+		return err
+	}
+	// TODO: return sc.bufferedOutput.Flush()
+	return nil
+}
+
 // ConfigureServer adds HTTP/2 support to a net/http Server.
 // ConfigureServer adds HTTP/2 support to a net/http Server.
 //
 //
 // The configuration conf may be nil.
 // The configuration conf may be nil.
@@ -810,6 +847,7 @@ func (b *requestBody) Read(p []byte) (n int, err error) {
 	}
 	}
 	n, err = b.pipe.Read(p)
 	n, err = b.pipe.Read(p)
 	if n > 0 {
 	if n > 0 {
+		b.sc.sendWindowUpdate(b.streamID, n)
 		// TODO: tell b.sc to send back 'n' flow control quota credits to the sender
 		// TODO: tell b.sc to send back 'n' flow control quota credits to the sender
 	}
 	}
 	return
 	return

+ 92 - 1
http2_test.go

@@ -186,7 +186,7 @@ func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
 	}
 	}
 	rs, ok := f.(*RSTStreamFrame)
 	rs, ok := f.(*RSTStreamFrame)
 	if !ok {
 	if !ok {
-		st.t.Fatalf("got a %T; want *RSTStream", f)
+		st.t.Fatalf("got a %T; want *RSTStreamFrame", f)
 	}
 	}
 	if rs.FrameHeader.StreamID != streamID {
 	if rs.FrameHeader.StreamID != streamID {
 		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
 		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
@@ -196,6 +196,23 @@ func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
 	}
 	}
 }
 }
 
 
+func (st *serverTester) wantWindowUpdate(streamID, incr uint32) {
+	f, err := st.readFrame()
+	if err != nil {
+		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
+	}
+	wu, ok := f.(*WindowUpdateFrame)
+	if !ok {
+		st.t.Fatalf("got a %T; want *WindowUpdateFrame", f)
+	}
+	if wu.FrameHeader.StreamID != streamID {
+		st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
+	}
+	if wu.Increment != incr {
+		st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
+	}
+}
+
 func (st *serverTester) wantSettingsAck() {
 func (st *serverTester) wantSettingsAck() {
 	f, err := st.readFrame()
 	f, err := st.readFrame()
 	if err != nil {
 	if err != nil {
@@ -628,6 +645,51 @@ func TestServer_Ping(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
+	puppet := newHandlerPuppet()
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		puppet.act(w, r)
+	})
+	defer st.Close()
+	defer puppet.done()
+
+	st.greet()
+
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1, // clients send odd numbers
+		BlockFragment: encodeHeader(t, ":method", "POST"),
+		EndStream:     false, // data coming
+		EndHeaders:    true,
+	})
+	st.writeData(1, true, []byte("abcdef"))
+	puppet.do(func(w http.ResponseWriter, r *http.Request) {
+		buf := make([]byte, 3)
+		_, err := io.ReadFull(r.Body, buf)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if string(buf) != "abc" {
+			t.Errorf("read %q; want abc", buf)
+		}
+	})
+	st.wantWindowUpdate(0, 3)
+	st.wantWindowUpdate(1, 3)
+	puppet.do(func(w http.ResponseWriter, r *http.Request) {
+		buf := make([]byte, 3)
+		_, err := io.ReadFull(r.Body, buf)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		if string(buf) != "def" {
+			t.Errorf("read %q; want abc", buf)
+		}
+	})
+	st.wantWindowUpdate(0, 3)
+	st.wantWindowUpdate(1, 3)
+}
+
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
 
 
@@ -806,3 +868,32 @@ func encodeHeader(t *testing.T, headers ...string) []byte {
 	}
 	}
 	return buf.Bytes()
 	return buf.Bytes()
 }
 }
+
+type puppetCommand struct {
+	fn   func(w http.ResponseWriter, r *http.Request)
+	done chan<- bool
+}
+
+type handlerPuppet struct {
+	ch chan puppetCommand
+}
+
+func newHandlerPuppet() *handlerPuppet {
+	return &handlerPuppet{
+		ch: make(chan puppetCommand),
+	}
+}
+
+func (p *handlerPuppet) act(w http.ResponseWriter, r *http.Request) {
+	for cmd := range p.ch {
+		cmd.fn(w, r)
+		cmd.done <- true
+	}
+}
+
+func (p *handlerPuppet) done() { close(p.ch) }
+func (p *handlerPuppet) do(fn func(http.ResponseWriter, *http.Request)) {
+	done := make(chan bool)
+	p.ch <- puppetCommand{fn, done}
+	<-done
+}