Browse Source

Refactor frame writing in prep for the write scheduler and client support.

Brad Fitzpatrick 11 years ago
parent
commit
23564bf81c
4 changed files with 225 additions and 177 deletions
  1. 0 10
      http2.go
  2. 31 167
      server.go
  3. 172 0
      write.go
  4. 22 0
      writesched.go

+ 0 - 10
http2.go

@@ -258,13 +258,3 @@ func (w *bufferedWriter) Flush() error {
 	w.bw = nil
 	return err
 }
-
-type goAwayParams struct {
-	maxStreamID uint32
-	code        ErrCode
-}
-
-type dataWriteParams struct {
-	p   []byte
-	end bool
-}

+ 31 - 167
server.go

@@ -236,8 +236,7 @@ type serverConn struct {
 	shutdownTimerCh       <-chan time.Time // nil until used
 	shutdownTimer         *time.Timer      // nil until used
 
-	// Owned by the writeFrameAsync goroutine; use writeG.check():
-	writeG         goroutineLock // used to verify things running on writeFrameAsync
+	// Owned by the writeFrameAsync goroutine:
 	headerWriteBuf bytes.Buffer
 	hpackEncoder   *hpack.Encoder
 }
@@ -279,6 +278,13 @@ type stream struct {
 	gotReset      bool  // only true once detacted from streams map
 }
 
+func (sc *serverConn) Framer() *Framer  { return sc.framer }
+func (sc *serverConn) CloseConn() error { return sc.conn.Close() }
+func (sc *serverConn) Flush() error     { return sc.bw.Flush() }
+func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) {
+	return sc.hpackEncoder, &sc.headerWriteBuf
+}
+
 func (sc *serverConn) state(streamID uint32) streamState {
 	sc.serveG.check()
 	// http://http2.github.io/http2-spec/#rfc.section.5.1
@@ -418,12 +424,7 @@ func (sc *serverConn) readFrames() {
 // At most one goroutine can be running writeFrameAsync at a time per
 // serverConn.
 func (sc *serverConn) writeFrameAsync(wm frameWriteMsg) {
-	sc.writeG = newGoroutineLock()
-	var streamID uint32
-	if wm.stream != nil {
-		streamID = wm.stream.id
-	}
-	err := wm.write(sc, streamID, wm.v)
+	err := wm.write(sc, wm.v)
 	if ch := wm.done; ch != nil {
 		select {
 		case ch <- err:
@@ -434,11 +435,6 @@ func (sc *serverConn) writeFrameAsync(wm frameWriteMsg) {
 	sc.wroteFrameCh <- struct{}{} // tickle frame selection scheduler
 }
 
-func (sc *serverConn) flushFrameWriter(uint32, interface{}) error {
-	sc.writeG.check()
-	return sc.bw.Flush() // may block on the network
-}
-
 func (sc *serverConn) closeAllStreamsOnConnClose() {
 	sc.serveG.check()
 	for _, st := range sc.streams {
@@ -462,7 +458,14 @@ func (sc *serverConn) serve() {
 
 	sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 
-	sc.writeFrame(frameWriteMsg{write: (*serverConn).sendInitialSettings})
+	sc.writeFrame(frameWriteMsg{
+		write: writeSettings,
+		v: []Setting{
+			{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
+			{SettingMaxConcurrentStreams, sc.advMaxStreams},
+			/* TODO: more actual settings */
+		},
+	})
 
 	if err := sc.readPreface(); err != nil {
 		sc.condlogf(err, "error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
@@ -502,15 +505,6 @@ func (sc *serverConn) serve() {
 	}
 }
 
-func (sc *serverConn) sendInitialSettings(uint32, interface{}) error {
-	sc.writeG.check()
-	return sc.framer.WriteSettings(
-		Setting{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
-		Setting{SettingMaxConcurrentStreams, sc.advMaxStreams},
-		/* TODO: more actual settings */
-	)
-}
-
 // readPreface reads the ClientPreface greeting from the peer
 // or returns an error on timeout or an invalid greeting.
 func (sc *serverConn) readPreface() error {
@@ -554,7 +548,7 @@ func (sc *serverConn) readPreface() error {
 func (sc *serverConn) writeData(stream *stream, data *dataWriteParams, ch chan error) error {
 	sc.serveG.checkNotOn() // NOT on; otherwise could deadlock in sc.writeFrame
 	sc.writeFrameFromHandler(frameWriteMsg{
-		write:     (*serverConn).writeDataFrame,
+		write:     writeDataFrame,
 		cost:      uint32(len(data.p)),
 		stream:    stream,
 		endStream: data.end,
@@ -661,7 +655,7 @@ func (sc *serverConn) scheduleFrameWrite() {
 	if sc.needToSendGoAway {
 		sc.needToSendGoAway = false
 		sc.startFrameWrite(frameWriteMsg{
-			write: (*serverConn).writeGoAwayFrame,
+			write: writeGoAwayFrame,
 			v: &goAwayParams{
 				maxStreamID: sc.maxStreamID,
 				code:        sc.goAwayCode,
@@ -670,7 +664,7 @@ func (sc *serverConn) scheduleFrameWrite() {
 		return
 	}
 	if sc.writeSched.empty() && sc.needsFrameFlush {
-		sc.startFrameWrite(frameWriteMsg{write: (*serverConn).flushFrameWriter})
+		sc.startFrameWrite(frameWriteMsg{write: flushFrameWriter})
 		sc.needsFrameFlush = false // after startFrameWrite, since it sets this true
 		return
 	}
@@ -680,7 +674,7 @@ func (sc *serverConn) scheduleFrameWrite() {
 	}
 	if sc.needToSendSettingsAck {
 		sc.needToSendSettingsAck = false
-		sc.startFrameWrite(frameWriteMsg{write: (*serverConn).writeSettingsAck})
+		sc.startFrameWrite(frameWriteMsg{write: writeSettingsAck})
 		return
 	}
 	if sc.writeSched.empty() {
@@ -716,18 +710,6 @@ func (sc *serverConn) shutDownIn(d time.Duration) {
 	sc.shutdownTimerCh = sc.shutdownTimer.C
 }
 
-func (sc *serverConn) writeGoAwayFrame(_ uint32, v interface{}) error {
-	sc.writeG.check()
-	p := v.(*goAwayParams)
-	err := sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
-	if p.code != 0 {
-		sc.bw.Flush() // ignore error: we're hanging up on them anyway
-		time.Sleep(50 * time.Millisecond)
-		sc.conn.Close()
-	}
-	return err
-}
-
 func (sc *serverConn) resetStream(se StreamError) {
 	sc.serveG.check()
 	st, ok := sc.streams[se.StreamID]
@@ -735,19 +717,13 @@ func (sc *serverConn) resetStream(se StreamError) {
 		panic("internal package error; resetStream called on non-existent stream")
 	}
 	sc.writeFrame(frameWriteMsg{
-		write: (*serverConn).writeRSTStreamFrame,
+		write: writeRSTStreamFrame,
 		v:     &se,
 	})
 	st.sentReset = true
 	sc.closeStream(st, se)
 }
 
-func (sc *serverConn) writeRSTStreamFrame(streamID uint32, v interface{}) error {
-	sc.writeG.check()
-	se := v.(*StreamError)
-	return sc.framer.WriteRSTStream(se.StreamID, se.Code)
-}
-
 // curHeaderStreamID returns the stream ID of the header block we're
 // currently in the middle of reading. If this returns non-zero, the
 // next frame must be a CONTINUATION with this stream id.
@@ -871,18 +847,12 @@ func (sc *serverConn) processPing(f *PingFrame) error {
 		return ConnectionError(ErrCodeProtocol)
 	}
 	sc.writeFrame(frameWriteMsg{
-		write: (*serverConn).writePingAck,
+		write: writePingAck,
 		v:     f,
 	})
 	return nil
 }
 
-func (sc *serverConn) writePingAck(_ uint32, v interface{}) error {
-	sc.writeG.check()
-	pf := v.(*PingFrame) // contains the data we need to write back
-	return sc.framer.WritePing(true, pf.Data)
-}
-
 func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
 	sc.serveG.check()
 	switch {
@@ -961,11 +931,6 @@ func (sc *serverConn) processSettings(f *SettingsFrame) error {
 	return nil
 }
 
-func (sc *serverConn) writeSettingsAck(uint32, interface{}) error {
-	sc.writeG.check()
-	return sc.framer.WriteSettingsAck()
-}
-
 func (sc *serverConn) processSetting(s Setting) error {
 	sc.serveG.check()
 	if err := s.Valid(); err != nil {
@@ -1263,21 +1228,6 @@ func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request) {
 	sc.handler.ServeHTTP(rw, req)
 }
 
-type frameWriteMsg struct {
-	// write runs on the writeFrameAsync goroutine.
-	write func(sc *serverConn, streamID uint32, v interface{}) error
-
-	v         interface{} // passed to write
-	cost      uint32      // number of flow control bytes required
-	stream    *stream     // used for prioritization
-	endStream bool        // streamID is being closed locally
-
-	// done, if non-nil, must be a buffered channel with space for
-	// 1 message and is sent the return value from write (or an
-	// earlier error) when the frame has been written.
-	done chan error
-}
-
 // headerWriteReq is a request to write an HTTP response header from a server Handler.
 type headerWriteReq struct {
 	stream      *stream
@@ -1302,7 +1252,7 @@ func (sc *serverConn) writeHeaders(req headerWriteReq, tempCh chan error) {
 		errc = tempCh
 	}
 	sc.writeFrameFromHandler(frameWriteMsg{
-		write:     (*serverConn).writeHeadersFrame,
+		write:     writeHeadersFrame,
 		v:         req,
 		stream:    req.stream,
 		done:      errc,
@@ -1319,91 +1269,16 @@ func (sc *serverConn) writeHeaders(req headerWriteReq, tempCh chan error) {
 	}
 }
 
-func (sc *serverConn) writeHeadersFrame(streamID uint32, v interface{}) error {
-	sc.writeG.check()
-	req := v.(headerWriteReq)
-
-	sc.headerWriteBuf.Reset()
-	sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(req.httpResCode)})
-	for k, vv := range req.h {
-		k = lowerHeader(k)
-		for _, v := range vv {
-			// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
-			if k == "transfer-encoding" && v != "trailers" {
-				continue
-			}
-			sc.hpackEncoder.WriteField(hpack.HeaderField{Name: k, Value: v})
-		}
-	}
-	if req.contentType != "" {
-		sc.hpackEncoder.WriteField(hpack.HeaderField{Name: "content-type", Value: req.contentType})
-	}
-	if req.contentLength != "" {
-		sc.hpackEncoder.WriteField(hpack.HeaderField{Name: "content-length", Value: req.contentLength})
-	}
-
-	headerBlock := sc.headerWriteBuf.Bytes()
-	if len(headerBlock) == 0 {
-		panic("unexpected empty hpack")
-	}
-	first := true
-	for len(headerBlock) > 0 {
-		frag := headerBlock
-		if len(frag) > int(sc.maxWriteFrameSize) {
-			frag = frag[:sc.maxWriteFrameSize]
-		}
-		headerBlock = headerBlock[len(frag):]
-		endHeaders := len(headerBlock) == 0
-		var err error
-		if first {
-			first = false
-			err = sc.framer.WriteHeaders(HeadersFrameParam{
-				StreamID:      req.stream.id,
-				BlockFragment: frag,
-				EndStream:     req.endStream,
-				EndHeaders:    endHeaders,
-			})
-		} else {
-			err = sc.framer.WriteContinuation(req.stream.id, endHeaders, frag)
-		}
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
 // called from handler goroutines.
 func (sc *serverConn) write100ContinueHeaders(st *stream) {
 	sc.serveG.checkNotOn() // NOT
 	sc.writeFrameFromHandler(frameWriteMsg{
-		write:  (*serverConn).write100ContinueHeadersFrame,
+		write:  write100ContinueHeadersFrame,
+		v:      st,
 		stream: st,
 	})
 }
 
-func (sc *serverConn) write100ContinueHeadersFrame(streamID uint32, _ interface{}) error {
-	sc.writeG.check()
-	sc.headerWriteBuf.Reset()
-	sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
-	return sc.framer.WriteHeaders(HeadersFrameParam{
-		StreamID:      streamID,
-		BlockFragment: sc.headerWriteBuf.Bytes(),
-		EndStream:     false,
-		EndHeaders:    true,
-	})
-}
-
-func (sc *serverConn) writeDataFrame(streamID uint32, v interface{}) error {
-	sc.writeG.check()
-	req := v.(*dataWriteParams)
-	return sc.framer.WriteData(streamID, req.end, req.p)
-}
-
-type windowUpdateReq struct {
-	n uint32
-}
-
 // called from handler goroutines
 func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
 	sc.serveG.checkNotOn() // NOT
@@ -1413,33 +1288,21 @@ func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
 	const maxUint32 = 2147483647
 	for n >= maxUint32 {
 		sc.writeFrameFromHandler(frameWriteMsg{
-			write:  (*serverConn).sendWindowUpdateInLoop,
-			v:      windowUpdateReq{maxUint32},
+			write:  writeWindowUpdate,
+			v:      windowUpdateReq{streamID: st.id, n: maxUint32},
 			stream: st,
 		})
 		n -= maxUint32
 	}
 	if n > 0 {
 		sc.writeFrameFromHandler(frameWriteMsg{
-			write:  (*serverConn).sendWindowUpdateInLoop,
-			v:      windowUpdateReq{uint32(n)},
+			write:  writeWindowUpdate,
+			v:      windowUpdateReq{streamID: st.id, n: uint32(n)},
 			stream: st,
 		})
 	}
 }
 
-func (sc *serverConn) sendWindowUpdateInLoop(streamID uint32, v interface{}) error {
-	sc.writeG.check()
-	wu := v.(windowUpdateReq)
-	if err := sc.framer.WriteWindowUpdate(0, wu.n); err != nil {
-		return err
-	}
-	if err := sc.framer.WriteWindowUpdate(streamID, wu.n); err != nil {
-		return err
-	}
-	return nil
-}
-
 type requestBody struct {
 	stream        *stream
 	closed        bool
@@ -1511,6 +1374,7 @@ type responseWriterState struct {
 }
 
 func (rws *responseWriterState) writeData(p []byte, end bool) error {
+	rws.curWrite.streamID = rws.stream.id
 	rws.curWrite.p = p
 	rws.curWrite.end = end
 	return rws.stream.conn.writeData(rws.stream, &rws.curWrite, rws.frameWriteCh)

+ 172 - 0
write.go

@@ -0,0 +1,172 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
+// Licensed under the same terms as Go itself:
+// https://code.google.com/p/go/source/browse/LICENSE
+
+package http2
+
+import (
+	"bytes"
+	"time"
+
+	"github.com/bradfitz/http2/hpack"
+)
+
+// writeContext is the interface needed by the various frame writing
+// functions below. All the functions below are scheduled via the
+// frame writing scheduler (see writeScheduler in writesched.go).
+//
+// This interface is implemented by *serverConn.
+// TODO: use it from the client code, once it exists.
+type writeContext interface {
+	Framer() *Framer
+	Flush() error
+	CloseConn() error
+	// HeaderEncoder returns an HPACK encoder that writes to the
+	// returned buffer.
+	HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
+}
+
+func flushFrameWriter(ctx writeContext, _ interface{}) error {
+	return ctx.Flush()
+}
+
+func writeSettings(ctx writeContext, v interface{}) error {
+	settings := v.([]Setting)
+	return ctx.Framer().WriteSettings(settings...)
+}
+
+type goAwayParams struct {
+	maxStreamID uint32
+	code        ErrCode
+}
+
+func writeGoAwayFrame(ctx writeContext, v interface{}) error {
+	p := v.(*goAwayParams)
+	err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
+	if p.code != 0 {
+		ctx.Flush() // ignore error: we're hanging up on them anyway
+		time.Sleep(50 * time.Millisecond)
+		ctx.CloseConn()
+	}
+	return err
+}
+
+type dataWriteParams struct {
+	streamID uint32
+	p        []byte
+	end      bool
+}
+
+func writeRSTStreamFrame(ctx writeContext, v interface{}) error {
+	se := v.(*StreamError)
+	return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
+}
+
+func writePingAck(ctx writeContext, v interface{}) error {
+	pf := v.(*PingFrame) // contains the data we need to write back
+	return ctx.Framer().WritePing(true, pf.Data)
+}
+
+func writeSettingsAck(ctx writeContext, _ interface{}) error {
+	return ctx.Framer().WriteSettingsAck()
+}
+
+func writeHeadersFrame(ctx writeContext, v interface{}) error {
+	req := v.(headerWriteReq)
+	enc, buf := ctx.HeaderEncoder()
+	buf.Reset()
+	enc.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(req.httpResCode)})
+	for k, vv := range req.h {
+		k = lowerHeader(k)
+		for _, v := range vv {
+			// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
+			if k == "transfer-encoding" && v != "trailers" {
+				continue
+			}
+			enc.WriteField(hpack.HeaderField{Name: k, Value: v})
+		}
+	}
+	if req.contentType != "" {
+		enc.WriteField(hpack.HeaderField{Name: "content-type", Value: req.contentType})
+	}
+	if req.contentLength != "" {
+		enc.WriteField(hpack.HeaderField{Name: "content-length", Value: req.contentLength})
+	}
+
+	headerBlock := buf.Bytes()
+	if len(headerBlock) == 0 {
+		panic("unexpected empty hpack")
+	}
+
+	// For now we're lazy and just pick the minimum MAX_FRAME_SIZE
+	// that all peers must support (16KB). Later we could care
+	// more and send larger frames if the peer advertised it, but
+	// there's little point. Most headers are small anyway (so we
+	// generally won't have CONTINUATION frames), and extra frames
+	// only waste 9 bytes anyway.
+	const maxFrameSize = 16384
+
+	first := true
+	for len(headerBlock) > 0 {
+		frag := headerBlock
+		if len(frag) > maxFrameSize {
+			frag = frag[:maxFrameSize]
+		}
+		headerBlock = headerBlock[len(frag):]
+		endHeaders := len(headerBlock) == 0
+		var err error
+		if first {
+			first = false
+			err = ctx.Framer().WriteHeaders(HeadersFrameParam{
+				StreamID:      req.stream.id,
+				BlockFragment: frag,
+				EndStream:     req.endStream,
+				EndHeaders:    endHeaders,
+			})
+		} else {
+			err = ctx.Framer().WriteContinuation(req.stream.id, endHeaders, frag)
+		}
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func write100ContinueHeadersFrame(ctx writeContext, v interface{}) error {
+	st := v.(*stream)
+	enc, buf := ctx.HeaderEncoder()
+	buf.Reset()
+	enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
+	return ctx.Framer().WriteHeaders(HeadersFrameParam{
+		StreamID:      st.id,
+		BlockFragment: buf.Bytes(),
+		EndStream:     false,
+		EndHeaders:    true,
+	})
+}
+
+func writeDataFrame(ctx writeContext, v interface{}) error {
+	req := v.(*dataWriteParams)
+	return ctx.Framer().WriteData(req.streamID, req.end, req.p)
+}
+
+type windowUpdateReq struct {
+	streamID uint32
+	n        uint32
+}
+
+func writeWindowUpdate(ctx writeContext, v interface{}) error {
+	wu := v.(windowUpdateReq)
+	fr := ctx.Framer()
+	if err := fr.WriteWindowUpdate(0, wu.n); err != nil {
+		return err
+	}
+	if err := fr.WriteWindowUpdate(wu.streamID, wu.n); err != nil {
+		return err
+	}
+	return nil
+}

+ 22 - 0
writesched.go

@@ -7,6 +7,28 @@
 
 package http2
 
+// frameWriteMsg is a request to write a frame.
+type frameWriteMsg struct {
+	// write is the function that does the writing, once the
+	// writeScheduler (below) has decided to select this frame
+	// to write. The write functions are all defined in write.go.
+	write func(ctx writeContext, v interface{}) error
+
+	// v is the argument passed to the write function. See each
+	// function in write.go to see which type they should be,
+	// depending on what write is.
+	v interface{}
+
+	cost      uint32  // if DATA, number of flow control bytes required
+	stream    *stream // used for prioritization
+	endStream bool    // stream is being closed locally
+
+	// done, if non-nil, must be a buffered channel with space for
+	// 1 message and is sent the return value from write (or an
+	// earlier error) when the frame has been written.
+	done chan error
+}
+
 // writeScheduler tracks pending frames to write, priorities, and decides
 // the next one to use. It is not thread-safe.
 type writeScheduler struct {