Browse Source

Write flow control token overflow errors on the writeFrames loop, add tests.

Brad Fitzpatrick 11 years ago
parent
commit
55815ec7b5
4 changed files with 91 additions and 14 deletions
  1. 4 4
      frame.go
  2. 1 1
      frame_test.go
  3. 34 9
      server.go
  4. 52 0
      server_test.go

+ 4 - 4
frame.go

@@ -587,7 +587,7 @@ func (f *Framer) WritePing(ack bool, data [8]byte) error {
 type GoAwayFrame struct {
 type GoAwayFrame struct {
 	FrameHeader
 	FrameHeader
 	LastStreamID uint32
 	LastStreamID uint32
-	ErrCode      uint32
+	ErrCode      ErrCode
 	debugData    []byte
 	debugData    []byte
 }
 }
 
 
@@ -610,7 +610,7 @@ func parseGoAwayFrame(fh FrameHeader, p []byte) (Frame, error) {
 	return &GoAwayFrame{
 	return &GoAwayFrame{
 		FrameHeader:  fh,
 		FrameHeader:  fh,
 		LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1),
 		LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1),
-		ErrCode:      binary.BigEndian.Uint32(p[4:8]),
+		ErrCode:      ErrCode(binary.BigEndian.Uint32(p[4:8])),
 		debugData:    p[8:],
 		debugData:    p[8:],
 	}, nil
 	}, nil
 }
 }
@@ -900,12 +900,12 @@ func parseRSTStreamFrame(fh FrameHeader, p []byte) (Frame, error) {
 //
 //
 // It will perform exactly one Write to the underlying Writer.
 // It will perform exactly one Write to the underlying Writer.
 // It is the caller's responsibility to not call other Write methods concurrently.
 // It is the caller's responsibility to not call other Write methods concurrently.
-func (f *Framer) WriteRSTStream(streamID, errCode uint32) error {
+func (f *Framer) WriteRSTStream(streamID uint32, code ErrCode) error {
 	if !validStreamID(streamID) && !f.AllowIllegalWrites {
 	if !validStreamID(streamID) && !f.AllowIllegalWrites {
 		return errStreamID
 		return errStreamID
 	}
 	}
 	f.startWrite(FrameRSTStream, 0, streamID)
 	f.startWrite(FrameRSTStream, 0, streamID)
-	f.writeUint32(errCode)
+	f.writeUint32(uint32(code))
 	return f.endWrite()
 	return f.endWrite()
 }
 }
 
 

+ 1 - 1
frame_test.go

@@ -19,7 +19,7 @@ func TestWriteRST(t *testing.T) {
 	fr, buf := testFramer()
 	fr, buf := testFramer()
 	var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
 	var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
 	var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4
 	var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4
-	fr.WriteRSTStream(streamID, errCode)
+	fr.WriteRSTStream(streamID, ErrCode(errCode))
 	const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04"
 	const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04"
 	if buf.String() != wantEnc {
 	if buf.String() != wantEnc {
 		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
 		t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)

+ 34 - 9
server.go

@@ -457,21 +457,49 @@ func (sc *serverConn) scheduleFrameWrite() {
 	sc.writeFrameCh <- wm
 	sc.writeFrameCh <- wm
 }
 }
 
 
-func (sc *serverConn) goAway(code ErrCode) error {
+func (sc *serverConn) goAway(code ErrCode) {
 	sc.serveG.check()
 	sc.serveG.check()
+	if sc.sentGoAway {
+		return
+	}
 	sc.sentGoAway = true
 	sc.sentGoAway = true
-	return sc.framer.WriteGoAway(sc.maxStreamID, code, nil)
+	// TODO: set a timer to see if they're gone at some point?
+	sc.enqueueFrameWrite(frameWriteMsg{
+		write: (*serverConn).writeGoAwayFrame,
+		v: &goAwayParams{
+			maxStreamID: sc.maxStreamID,
+			code:        code,
+		},
+	})
+}
+
+type goAwayParams struct {
+	maxStreamID uint32
+	code        ErrCode
+}
+
+func (sc *serverConn) writeGoAwayFrame(v interface{}) error {
+	sc.writeG.check()
+	p := v.(*goAwayParams)
+	return sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
 }
 }
 
 
 func (sc *serverConn) resetStreamInLoop(se StreamError) error {
 func (sc *serverConn) resetStreamInLoop(se StreamError) error {
 	sc.serveG.check()
 	sc.serveG.check()
-	if err := sc.framer.WriteRSTStream(se.streamID, uint32(se.code)); err != nil {
-		return err
-	}
 	delete(sc.streams, se.streamID)
 	delete(sc.streams, se.streamID)
+	sc.enqueueFrameWrite(frameWriteMsg{
+		write: (*serverConn).writeRSTStreamFrame,
+		v:     &se,
+	})
 	return nil
 	return nil
 }
 }
 
 
+func (sc *serverConn) writeRSTStreamFrame(v interface{}) error {
+	sc.writeG.check()
+	se := v.(*StreamError)
+	return sc.framer.WriteRSTStream(se.streamID, se.code)
+}
+
 func (sc *serverConn) curHeaderStreamID() uint32 {
 func (sc *serverConn) curHeaderStreamID() uint32 {
 	sc.serveG.check()
 	sc.serveG.check()
 	st := sc.req.stream
 	st := sc.req.stream
@@ -515,10 +543,7 @@ func (sc *serverConn) processFrameFromReader(fg frameAndGate, fgValid bool) bool
 		}
 		}
 		return true
 		return true
 	case goAwayFlowError:
 	case goAwayFlowError:
-		if err := sc.goAway(ErrCodeFlowControl); err != nil {
-			sc.condlogf(err, "failed to GOAWAY: %v", err)
-			return false
-		}
+		sc.goAway(ErrCodeFlowControl)
 		return true
 		return true
 	case ConnectionError:
 	case ConnectionError:
 		sc.logf("disconnecting; %v", ev)
 		sc.logf("disconnecting; %v", ev)

+ 52 - 0
server_test.go

@@ -196,6 +196,18 @@ func (st *serverTester) wantPing() *PingFrame {
 	return pf
 	return pf
 }
 }
 
 
+func (st *serverTester) wantGoAway() *GoAwayFrame {
+	f, err := st.readFrame()
+	if err != nil {
+		st.t.Fatalf("Error while expecting a PING frame: %v", err)
+	}
+	gf, ok := f.(*GoAwayFrame)
+	if !ok {
+		st.t.Fatalf("got a %T; want *GoAwayFrame", f)
+	}
+	return gf
+}
+
 func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
 func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
 	f, err := st.readFrame()
 	f, err := st.readFrame()
 	if err != nil {
 	if err != nil {
@@ -728,6 +740,46 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
 	st.wantWindowUpdate(1, 3)
 	st.wantWindowUpdate(1, 3)
 }
 }
 
 
+func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
+	st := newServerTester(t, nil)
+	defer st.Close()
+	st.greet()
+	if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
+		t.Fatal(err)
+	}
+	gf := st.wantGoAway()
+	if gf.ErrCode != ErrCodeFlowControl {
+		t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl)
+	}
+	if gf.LastStreamID != 0 {
+		t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0)
+	}
+}
+
+func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
+	inHandler := make(chan bool)
+	blockHandler := make(chan bool)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		inHandler <- true
+		<-blockHandler
+	})
+	defer st.Close()
+	defer close(blockHandler)
+	st.greet()
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1,
+		BlockFragment: encodeHeader(st.t, ":method", "POST"),
+		EndStream:     false, // keep it open
+		EndHeaders:    true,
+	})
+	<-inHandler
+	// Send a bogus window update:
+	if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil {
+		t.Fatal(err)
+	}
+	st.wantRSTStream(1, ErrCodeFlowControl)
+}
+
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)