Browse Source

Flow Control work.

This defines the counting semaphore type, processes WINDOW_UPDATE
frames, processes SETTINGS frames more, etc.

Untested, since we don't yet write DATA frames. That's next.
Brad Fitzpatrick 11 years ago
parent
commit
03abeab17b
7 changed files with 256 additions and 21 deletions
  1. 9 0
      errors.go
  2. 56 0
      flow.go
  3. 67 0
      flow_test.go
  4. 10 2
      frame.go
  5. 2 1
      frame_test.go
  6. 110 17
      http2.go
  7. 2 1
      http2_test.go

+ 9 - 0
errors.go

@@ -65,3 +65,12 @@ type StreamError struct {
 func (e StreamError) Error() string {
 func (e StreamError) Error() string {
 	return fmt.Sprintf("stream error: stream ID %d; %v", e.streamID, e.code)
 	return fmt.Sprintf("stream error: stream ID %d; %v", e.streamID, e.code)
 }
 }
+
+// 6.9.1 The Flow Control Window
+// "If a sender receives a WINDOW_UPDATE that causes a flow control
+// window to exceed this maximum it MUST terminate either the stream
+// or the connection, as appropriate. For streams, [...]; for the
+// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code."
+type goAwayFlowError struct{}
+
+func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" }

+ 56 - 0
flow.go

@@ -0,0 +1,56 @@
+// Copyright 2014 The Go Authors.
+// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
+// Licensed under the same terms as Go itself:
+// https://code.google.com/p/go/source/browse/LICENSE
+
+// Flow control
+
+package http2
+
+import "sync"
+
+// flow is the flow control window's counting semaphore.
+type flow struct {
+	c    *sync.Cond // protects size
+	size int32
+}
+
+func newFlow(n int32) *flow {
+	return &flow{
+		c:    sync.NewCond(new(sync.Mutex)),
+		size: n,
+	}
+}
+
+// acquire decrements the flow control window by n bytes, blocking
+// until they're available in the window.
+// The return value is only interesting for tests.
+func (f *flow) acquire(n int32) (waited int) {
+	if n < 0 {
+		panic("negative acquire")
+	}
+	f.c.L.Lock()
+	defer f.c.L.Unlock()
+	for {
+		if f.size >= n {
+			f.size -= n
+			return
+		}
+		waited++
+		f.c.Wait()
+	}
+}
+
+// add adds n bytes (positive or negative) to the flow control window.
+// It returns false if the sum would exceed 2^31-1.
+func (f *flow) add(n int32) bool {
+	f.c.L.Lock()
+	defer f.c.L.Unlock()
+	remain := (1<<31 - 1) - f.size
+	if n > remain {
+		return false
+	}
+	f.size += n
+	f.c.Broadcast()
+	return true
+}

+ 67 - 0
flow_test.go

@@ -0,0 +1,67 @@
+// Copyright 2014 The Go Authors.
+// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
+// Licensed under the same terms as Go itself:
+// https://code.google.com/p/go/source/browse/LICENSE
+
+package http2
+
+import (
+	"testing"
+	"time"
+)
+
+func (f *flow) cur() int32 {
+	f.c.L.Lock()
+	defer f.c.L.Unlock()
+	return f.size
+}
+
+func TestFlow(t *testing.T) {
+	f := newFlow(10)
+	if got, want := f.cur(), int32(10); got != want {
+		t.Fatalf("size = %d; want %d", got, want)
+	}
+	if waits := f.acquire(1); waits != 0 {
+		t.Errorf("waits = %d; want 0", waits)
+	}
+	if got, want := f.cur(), int32(9); got != want {
+		t.Fatalf("size = %d; want %d", got, want)
+	}
+
+	// Wait for 10, which should block, so start a background goroutine
+	// to refill it.
+	go func() {
+		time.Sleep(50 * time.Millisecond)
+		f.add(50)
+	}()
+	if waits := f.acquire(10); waits != 1 {
+		t.Errorf("waits for 50 = %d; want 0", waits)
+	}
+
+	if got, want := f.cur(), int32(49); got != want {
+		t.Fatalf("size = %d; want %d", got, want)
+	}
+}
+
+func TestFlowAdd(t *testing.T) {
+	f := newFlow(0)
+	if !f.add(1) {
+		t.Fatal("failed to add 1")
+	}
+	if !f.add(-1) {
+		t.Fatal("failed to add -1")
+	}
+	if got, want := f.cur(), int32(0); got != want {
+		t.Fatalf("size = %d; want %d", got, want)
+	}
+	if !f.add(1<<31 - 1) {
+		t.Fatal("failed to add 2^31-1")
+	}
+	if got, want := f.cur(), int32(1<<31-1); got != want {
+		t.Fatalf("size = %d; want %d", got, want)
+	}
+	if f.add(1) {
+		t.Fatal("adding 1 to max shouldn't be allowed")
+	}
+
+}

+ 10 - 2
frame.go

@@ -493,13 +493,21 @@ func (f *SettingsFrame) Value(s SettingID) (v uint32, ok bool) {
 	return 0, false
 	return 0, false
 }
 }
 
 
-func (f *SettingsFrame) ForeachSetting(fn func(Setting)) {
+// ForeachSetting runs fn for each setting.
+// It stops and returns the first error.
+func (f *SettingsFrame) ForeachSetting(fn func(Setting) error) error {
 	f.checkValid()
 	f.checkValid()
 	buf := f.p
 	buf := f.p
 	for len(buf) > 0 {
 	for len(buf) > 0 {
-		fn(Setting{SettingID(binary.BigEndian.Uint16(buf[:2])), binary.BigEndian.Uint32(buf[2:6])})
+		if err := fn(Setting{
+			SettingID(binary.BigEndian.Uint16(buf[:2])),
+			binary.BigEndian.Uint32(buf[2:6]),
+		}); err != nil {
+			return err
+		}
 		buf = buf[6:]
 		buf = buf[6:]
 	}
 	}
+	return nil
 }
 }
 
 
 // Setting is a setting parameter: which setting it is, and its value.
 // Setting is a setting parameter: which setting it is, and its value.

+ 2 - 1
frame_test.go

@@ -332,12 +332,13 @@ func TestWriteSettings(t *testing.T) {
 		t.Fatalf("Got a %T; want a SettingsFrame", f)
 		t.Fatalf("Got a %T; want a SettingsFrame", f)
 	}
 	}
 	var got []Setting
 	var got []Setting
-	sf.ForeachSetting(func(s Setting) {
+	sf.ForeachSetting(func(s Setting) error {
 		got = append(got, s)
 		got = append(got, s)
 		valBack, ok := sf.Value(s.ID)
 		valBack, ok := sf.Value(s.ID)
 		if !ok || valBack != s.Val {
 		if !ok || valBack != s.Val {
 			t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok)
 			t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok)
 		}
 		}
+		return nil
 	})
 	})
 	if !reflect.DeepEqual(settings, got) {
 	if !reflect.DeepEqual(settings, got) {
 		t.Errorf("Read settings %+v != written settings %+v", got, settings)
 		t.Errorf("Read settings %+v != written settings %+v", got, settings)

+ 110 - 17
http2.go

@@ -16,6 +16,9 @@
 // This package currently targets draft-14. See http://http2.github.io/
 // This package currently targets draft-14. See http://http2.github.io/
 package http2
 package http2
 
 
+// TODO: finish GOAWAY support. Consider each incoming frame type and whether
+// it should be ignored during a shutdown race.
+
 import (
 import (
 	"bytes"
 	"bytes"
 	"crypto/tls"
 	"crypto/tls"
@@ -41,17 +44,17 @@ const (
 	// SETTINGS_MAX_FRAME_SIZE default
 	// SETTINGS_MAX_FRAME_SIZE default
 	// http://http2.github.io/http2-spec/#rfc.section.6.5.2
 	// http://http2.github.io/http2-spec/#rfc.section.6.5.2
 	initialMaxFrameSize = 16384
 	initialMaxFrameSize = 16384
-)
-
-var (
-	clientPreface = []byte(ClientPreface)
-)
 
 
-const (
 	npnProto = "h2-14"
 	npnProto = "h2-14"
 
 
 	// http://http2.github.io/http2-spec/#SettingValues
 	// http://http2.github.io/http2-spec/#SettingValues
 	initialHeaderTableSize = 4096
 	initialHeaderTableSize = 4096
+
+	initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
+)
+
+var (
+	clientPreface = []byte(ClientPreface)
 )
 )
 
 
 // Server is an HTTP2 server.
 // Server is an HTTP2 server.
@@ -71,8 +74,10 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 		readFrameCh:       make(chan frameAndProcessed),
 		readFrameCh:       make(chan frameAndProcessed),
 		readFrameErrCh:    make(chan error, 1),
 		readFrameErrCh:    make(chan error, 1),
 		writeHeaderCh:     make(chan headerWriteReq), // must not be buffered
 		writeHeaderCh:     make(chan headerWriteReq), // must not be buffered
+		flow:              newFlow(initialWindowSize),
 		doneServing:       make(chan struct{}),
 		doneServing:       make(chan struct{}),
 		maxWriteFrameSize: initialMaxFrameSize,
 		maxWriteFrameSize: initialMaxFrameSize,
+		initialWindowSize: initialWindowSize,
 		serveG:            newGoroutineLock(),
 		serveG:            newGoroutineLock(),
 	}
 	}
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
@@ -102,6 +107,7 @@ type serverConn struct {
 	readFrameErrCh chan error
 	readFrameErrCh chan error
 	writeHeaderCh  chan headerWriteReq // must not be buffered
 	writeHeaderCh  chan headerWriteReq // must not be buffered
 	serveG         goroutineLock       // used to verify funcs are on serve()
 	serveG         goroutineLock       // used to verify funcs are on serve()
+	flow           *flow               // the connection-wide one
 
 
 	// Everything following is owned by the serve loop; use serveG.check()
 	// Everything following is owned by the serve loop; use serveG.check()
 
 
@@ -109,6 +115,8 @@ type serverConn struct {
 	streams     map[uint32]*stream
 	streams     map[uint32]*stream
 
 
 	maxWriteFrameSize uint32 // TODO: update this when settings come in
 	maxWriteFrameSize uint32 // TODO: update this when settings come in
+	initialWindowSize int32
+	sentGoAway        bool
 
 
 	// State related to parsing current headers:
 	// State related to parsing current headers:
 	header            http.Header
 	header            http.Header
@@ -142,6 +150,7 @@ const (
 type stream struct {
 type stream struct {
 	id    uint32
 	id    uint32
 	state streamState // owned by serverConn's processing loop
 	state streamState // owned by serverConn's processing loop
+	flow  *flow
 }
 }
 
 
 func (sc *serverConn) state(streamID uint32) streamState {
 func (sc *serverConn) state(streamID uint32) streamState {
@@ -296,9 +305,10 @@ func (sc *serverConn) serve() {
 		sc.logf("invalid initial frame type %T received from client", f)
 		sc.logf("invalid initial frame type %T received from client", f)
 		return
 		return
 	}
 	}
-	sf.ForeachSetting(func(s Setting) {
-		// TODO: process, record
-	})
+	if err := sf.ForeachSetting(sc.processSetting); err != nil {
+		sc.logf("initial settings error: %v", err)
+		return
+	}
 
 
 	// TODO: don't send two network packets for our SETTINGS + our
 	// TODO: don't send two network packets for our SETTINGS + our
 	// ACK of their settings.  But if we make framer write to a
 	// ACK of their settings.  But if we make framer write to a
@@ -351,6 +361,11 @@ func (sc *serverConn) serve() {
 			case ConnectionError:
 			case ConnectionError:
 				sc.logf("Disconnecting; %v", ev)
 				sc.logf("Disconnecting; %v", ev)
 				return
 				return
+			case goAwayFlowError:
+				if err := sc.goAway(ErrCodeFlowControl); err != nil {
+					sc.condlogf(err, "failed to GOAWAY: %v", err)
+					return
+				}
 			default:
 			default:
 				sc.logf("Disconnection due to other error: %v", err)
 				sc.logf("Disconnection due to other error: %v", err)
 				return
 				return
@@ -359,6 +374,12 @@ func (sc *serverConn) serve() {
 	}
 	}
 }
 }
 
 
+func (sc *serverConn) goAway(code ErrCode) error {
+	sc.serveG.check()
+	sc.sentGoAway = true
+	return sc.framer.WriteGoAway(sc.maxStreamID, code, nil)
+}
+
 func (sc *serverConn) resetStreamInLoop(se StreamError) error {
 func (sc *serverConn) resetStreamInLoop(se StreamError) error {
 	sc.serveG.check()
 	sc.serveG.check()
 	if err := sc.framer.WriteRSTStream(se.streamID, uint32(se.code)); err != nil {
 	if err := sc.framer.WriteRSTStream(se.streamID, uint32(se.code)); err != nil {
@@ -386,6 +407,8 @@ func (sc *serverConn) processFrame(f Frame) error {
 		return sc.processHeaders(f)
 		return sc.processHeaders(f)
 	case *ContinuationFrame:
 	case *ContinuationFrame:
 		return sc.processContinuation(f)
 		return sc.processContinuation(f)
+	case *WindowUpdateFrame:
+		return sc.processWindowUpdate(f)
 	case *PingFrame:
 	case *PingFrame:
 		return sc.processPing(f)
 		return sc.processPing(f)
 	default:
 	default:
@@ -397,32 +420,101 @@ func (sc *serverConn) processFrame(f Frame) error {
 func (sc *serverConn) processPing(f *PingFrame) error {
 func (sc *serverConn) processPing(f *PingFrame) error {
 	sc.serveG.check()
 	sc.serveG.check()
 	if f.Flags.Has(FlagSettingsAck) {
 	if f.Flags.Has(FlagSettingsAck) {
-		// 6.7 PING: " An endpoint MUST NOT respond to PING frames containing this flag."
+		// 6.7 PING: " An endpoint MUST NOT respond to PING frames
+		// containing this flag."
 		return nil
 		return nil
 	}
 	}
 	if f.StreamID != 0 {
 	if f.StreamID != 0 {
 		// "PING frames are not associated with any individual
 		// "PING frames are not associated with any individual
 		// stream. If a PING frame is received with a stream
 		// stream. If a PING frame is received with a stream
-		// identifier field value other than 0x0, the
-		// recipient MUST respond with a connection error
-		// (Section 5.4.1) of type PROTOCOL_ERROR."
+		// identifier field value other than 0x0, the recipient MUST
+		// respond with a connection error (Section 5.4.1) of type
+		// PROTOCOL_ERROR."
 		return ConnectionError(ErrCodeProtocol)
 		return ConnectionError(ErrCodeProtocol)
 	}
 	}
 	return sc.framer.WritePing(true, f.Data)
 	return sc.framer.WritePing(true, f.Data)
 }
 }
 
 
+func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
+	sc.serveG.check()
+	switch {
+	case f.StreamID != 0: // stream-level flow control
+		st := sc.streams[f.StreamID]
+		if st == nil {
+			// "WINDOW_UPDATE can be sent by a peer that has sent a
+			// frame bearing the END_STREAM flag. This means that a
+			// receiver could receive a WINDOW_UPDATE frame on a "half
+			// closed (remote)" or "closed" stream. A receiver MUST
+			// NOT treat this as an error, see Section 5.1."
+			return nil
+		}
+		if !st.flow.add(int32(f.Increment)) {
+			return StreamError{f.StreamID, ErrCodeFlowControl}
+		}
+	default: // connection-level flow control
+		if !sc.flow.add(int32(f.Increment)) {
+			return goAwayFlowError{}
+		}
+	}
+	return nil
+}
+
 func (sc *serverConn) processSettings(f *SettingsFrame) error {
 func (sc *serverConn) processSettings(f *SettingsFrame) error {
 	sc.serveG.check()
 	sc.serveG.check()
-	f.ForeachSetting(func(s Setting) {
-		log.Printf("  setting %s = %v", s.ID, s.Val)
-	})
+	return f.ForeachSetting(sc.processSetting)
+}
+
+func (sc *serverConn) processSetting(s Setting) error {
+	sc.serveG.check()
+	sc.vlogf("processing setting %v", s)
+	switch s.ID {
+	case SettingInitialWindowSize:
+		return sc.processSettingInitialWindowSize(s.Val)
+	}
+	log.Printf("TODO: handle %v", s)
+	return nil
+}
+
+func (sc *serverConn) processSettingInitialWindowSize(val uint32) error {
+	sc.serveG.check()
+	if val > (1<<31 - 1) {
+		// 6.5.2 Defined SETTINGS Parameters
+		// "Values above the maximum flow control window size of
+		// 231-1 MUST be treated as a connection error (Section
+		// 5.4.1) of type FLOW_CONTROL_ERROR."
+		return ConnectionError(ErrCodeFlowControl)
+	}
+
+	// "A SETTINGS frame can alter the initial flow control window
+	// size for all current streams. When the value of
+	// SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST
+	// adjust the size of all stream flow control windows that it
+	// maintains by the difference between the new value and the
+	// old value."
+	old := sc.initialWindowSize
+	sc.initialWindowSize = int32(val)
+	growth := sc.initialWindowSize - old // may be negative
+	for _, st := range sc.streams {
+		if !st.flow.add(growth) {
+			// 6.9.2 Initial Flow Control Window Size
+			// "An endpoint MUST treat a change to
+			// SETTINGS_INITIAL_WINDOW_SIZE that causes any flow
+			// control window to exceed the maximum size as a
+			// connection error (Section 5.4.1) of type
+			// FLOW_CONTROL_ERROR."
+			return ConnectionError(ErrCodeFlowControl)
+		}
+	}
 	return nil
 	return nil
 }
 }
 
 
 func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	sc.serveG.check()
 	sc.serveG.check()
 	id := f.Header().StreamID
 	id := f.Header().StreamID
-
+	if sc.sentGoAway {
+		// Ignore.
+		return nil
+	}
 	// http://http2.github.io/http2-spec/#rfc.section.5.1.1
 	// http://http2.github.io/http2-spec/#rfc.section.5.1.1
 	if id%2 != 1 || id <= sc.maxStreamID {
 	if id%2 != 1 || id <= sc.maxStreamID {
 		// Streams initiated by a client MUST use odd-numbered
 		// Streams initiated by a client MUST use odd-numbered
@@ -441,6 +533,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	st := &stream{
 	st := &stream{
 		id:    id,
 		id:    id,
 		state: stateOpen,
 		state: stateOpen,
+		flow:  newFlow(sc.initialWindowSize),
 	}
 	}
 	if f.Header().Flags.Has(FlagHeadersEndStream) {
 	if f.Header().Flags.Has(FlagHeadersEndStream) {
 		st.state = stateHalfClosedRemote
 		st.state = stateHalfClosedRemote

+ 2 - 1
http2_test.go

@@ -226,8 +226,9 @@ func TestServer(t *testing.T) {
 
 
 	st.writePreface()
 	st.writePreface()
 	st.writeInitialSettings()
 	st.writeInitialSettings()
-	st.wantSettings().ForeachSetting(func(s Setting) {
+	st.wantSettings().ForeachSetting(func(s Setting) error {
 		t.Logf("Server sent setting %v = %v", s.ID, s.Val)
 		t.Logf("Server sent setting %v = %v", s.ID, s.Val)
+		return nil
 	})
 	})
 	st.writeSettingsAck()
 	st.writeSettingsAck()
 	st.wantSettingsAck()
 	st.wantSettingsAck()