Browse Source

Buffer the writing of frames.

Previously each written frame went in its own network packet.

Now we buffer until the frame writer has nothing else to write, at
which point it flushes and discards its write buffers, to minimize
the memory usage of idle connections.
Brad Fitzpatrick 11 years ago
parent
commit
5e4e2dc4f7
2 changed files with 76 additions and 11 deletions
  1. 43 0
      http2.go
  2. 33 11
      server.go

+ 43 - 0
http2.go

@@ -17,7 +17,9 @@
 package http2
 package http2
 
 
 import (
 import (
+	"bufio"
 	"fmt"
 	"fmt"
+	"io"
 	"net/http"
 	"net/http"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
@@ -213,3 +215,44 @@ func (cw *closeWaiter) Wait() {
 		cw.c.Wait()
 		cw.c.Wait()
 	}
 	}
 }
 }
+
+// bufferedWriter is a buffered writer that writes to w.
+// Its buffered writer is lazily allocated as needed, to minimize
+// idle memory usage with many connections.
+type bufferedWriter struct {
+	w  io.Writer     // immutable
+	bw *bufio.Writer // non-nil when data is buffered
+}
+
+func newBufferedWriter(w io.Writer) *bufferedWriter {
+	return &bufferedWriter{w: w}
+}
+
+var bufWriterPool = sync.Pool{
+	New: func() interface{} {
+		// TODO: pick something better? this is a bit under
+		// (3 x typical 1500 byte MTU) at least.
+		return bufio.NewWriterSize(nil, 4<<10)
+	},
+}
+
+func (w *bufferedWriter) Write(p []byte) (n int, err error) {
+	if w.bw == nil {
+		bw := bufWriterPool.Get().(*bufio.Writer)
+		bw.Reset(w.w)
+		w.bw = bw
+	}
+	return w.bw.Write(p)
+}
+
+func (w *bufferedWriter) Flush() error {
+	bw := w.bw
+	if bw == nil {
+		return nil
+	}
+	err := bw.Flush()
+	bw.Reset(nil)
+	bufWriterPool.Put(bw)
+	w.bw = nil
+	return err
+}

+ 33 - 11
server.go

@@ -45,6 +45,9 @@ const (
 // TODO: send PING frames to idle clients and disconnect them if no
 // TODO: send PING frames to idle clients and disconnect them if no
 // reply
 // reply
 
 
+// TODO: don't keep the writeFrames goroutine active. turn it off when no frames
+// are enqueued.
+
 // Server is an HTTP/2 server.
 // Server is an HTTP/2 server.
 type Server struct {
 type Server struct {
 	// MaxStreams optionally ...
 	// MaxStreams optionally ...
@@ -102,17 +105,12 @@ func ConfigureServer(s *http.Server, conf *Server) {
 var testHookGetServerConn func(*serverConn)
 var testHookGetServerConn func(*serverConn)
 
 
 func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
-	// TODO: write to a (custom?) buffered writer that can
-	// alternate when it's in buffered mode.
-	fr := NewFramer(c, c)
-	fr.SetMaxReadFrameSize(srv.maxReadFrameSize())
-
 	sc := &serverConn{
 	sc := &serverConn{
 		srv:                  srv,
 		srv:                  srv,
 		hs:                   hs,
 		hs:                   hs,
 		conn:                 c,
 		conn:                 c,
+		bw:                   newBufferedWriter(c),
 		handler:              h,
 		handler:              h,
-		framer:               fr,
 		streams:              make(map[uint32]*stream),
 		streams:              make(map[uint32]*stream),
 		readFrameCh:          make(chan frameAndGate),
 		readFrameCh:          make(chan frameAndGate),
 		readFrameErrCh:       make(chan error, 1), // must be buffered for 1
 		readFrameErrCh:       make(chan error, 1), // must be buffered for 1
@@ -130,6 +128,11 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 	}
 	}
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
 	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
+
+	fr := NewFramer(sc.bw, c)
+	fr.SetMaxReadFrameSize(srv.maxReadFrameSize())
+	sc.framer = fr
+
 	if hook := testHookGetServerConn; hook != nil {
 	if hook := testHookGetServerConn; hook != nil {
 		hook(sc)
 		hook(sc)
 	}
 	}
@@ -151,6 +154,7 @@ type serverConn struct {
 	srv              *Server
 	srv              *Server
 	hs               *http.Server
 	hs               *http.Server
 	conn             net.Conn
 	conn             net.Conn
+	bw               *bufferedWriter // writing to conn
 	handler          http.Handler
 	handler          http.Handler
 	framer           *Framer
 	framer           *Framer
 	hpackDecoder     *hpack.Decoder
 	hpackDecoder     *hpack.Decoder
@@ -180,6 +184,7 @@ type serverConn struct {
 	canonHeader           map[string]string // http2-lower-case -> Go-Canonical-Case
 	canonHeader           map[string]string // http2-lower-case -> Go-Canonical-Case
 	req                   requestParam      // non-zero while reading request headers
 	req                   requestParam      // non-zero while reading request headers
 	writingFrame          bool              // sent on writeFrameCh but haven't heard back on wroteFrameCh yet
 	writingFrame          bool              // sent on writeFrameCh but haven't heard back on wroteFrameCh yet
+	needsFrameFlush       bool              // last frame to writeFrameCh wasn't a flush
 	writeQueue            []frameWriteMsg   // TODO: proper scheduler, not a queue
 	writeQueue            []frameWriteMsg   // TODO: proper scheduler, not a queue
 	inGoAway              bool              // we've started to or sent GOAWAY
 	inGoAway              bool              // we've started to or sent GOAWAY
 	needToSendGoAway      bool              // we need to schedule a GOAWAY frame write
 	needToSendGoAway      bool              // we need to schedule a GOAWAY frame write
@@ -376,6 +381,11 @@ func (sc *serverConn) writeFrames() {
 	}
 	}
 }
 }
 
 
+func (sc *serverConn) flushFrameWriter(_ interface{}) error {
+	sc.writeG.check()
+	return sc.bw.Flush() // may block on the network
+}
+
 var errClientDisconnected = errors.New("client disconnected")
 var errClientDisconnected = errors.New("client disconnected")
 
 
 func (sc *serverConn) closeAllStreamsOnConnClose() {
 func (sc *serverConn) closeAllStreamsOnConnClose() {
@@ -528,6 +538,7 @@ func (sc *serverConn) sendFrameWrite(wm frameWriteMsg) {
 	}
 	}
 
 
 	sc.writingFrame = true
 	sc.writingFrame = true
+	sc.needsFrameFlush = true
 	if wm.endStream {
 	if wm.endStream {
 		if st == nil {
 		if st == nil {
 			panic("nil stream with endStream set")
 			panic("nil stream with endStream set")
@@ -542,11 +553,21 @@ func (sc *serverConn) sendFrameWrite(wm frameWriteMsg) {
 	sc.writeFrameCh <- wm
 	sc.writeFrameCh <- wm
 }
 }
 
 
+func (sc *serverConn) sendFrameWriteFlush() {
+	sc.serveG.check()
+	if sc.writingFrame {
+		panic("invariant")
+	}
+	sc.writingFrame = true
+	sc.needsFrameFlush = false
+	sc.writeFrameCh <- frameWriteMsg{write: (*serverConn).flushFrameWriter}
+}
+
 func (sc *serverConn) enqueueSettingsAck() {
 func (sc *serverConn) enqueueSettingsAck() {
 	sc.serveG.check()
 	sc.serveG.check()
 	if !sc.writingFrame {
 	if !sc.writingFrame {
 		sc.needToSendSettingsAck = false
 		sc.needToSendSettingsAck = false
-		sc.writeFrameCh <- frameWriteMsg{write: (*serverConn).writeSettingsAck}
+		sc.sendFrameWrite(frameWriteMsg{write: (*serverConn).writeSettingsAck})
 		return
 		return
 	}
 	}
 	sc.needToSendSettingsAck = true
 	sc.needToSendSettingsAck = true
@@ -568,6 +589,10 @@ func (sc *serverConn) scheduleFrameWrite() {
 		})
 		})
 		return
 		return
 	}
 	}
+	if len(sc.writeQueue) == 0 && sc.needsFrameFlush {
+		sc.sendFrameWriteFlush()
+		return
+	}
 	if sc.inGoAway {
 	if sc.inGoAway {
 		// No more frames after we've sent GOAWAY.
 		// No more frames after we've sent GOAWAY.
 		return
 		return
@@ -577,7 +602,6 @@ func (sc *serverConn) scheduleFrameWrite() {
 		return
 		return
 	}
 	}
 	if len(sc.writeQueue) == 0 {
 	if len(sc.writeQueue) == 0 {
-		// TODO: flush Framer's underlying buffered writer, once that's added
 		return
 		return
 	}
 	}
 
 
@@ -627,9 +651,7 @@ func (sc *serverConn) writeGoAwayFrame(v interface{}) error {
 	p := v.(*goAwayParams)
 	p := v.(*goAwayParams)
 	err := sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
 	err := sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
 	if p.code != 0 {
 	if p.code != 0 {
-		// TODO: flush any buffer, if we add a buffering writing
-		// Sleep a bit to give the peer a bit of time to read the
-		// GOAWAY before potentially getting a TCP RST packet:
+		sc.bw.Flush() // ignore error: we're hanging up on them anyway
 		time.Sleep(50 * time.Millisecond)
 		time.Sleep(50 * time.Millisecond)
 		sc.conn.Close()
 		sc.conn.Close()
 	}
 	}