Переглянути джерело

Add a failing state transition test.

We're not cleaning up old streams either.
Brad Fitzpatrick 11 роки тому
батько
коміт
0db6d6557b
3 змінених файлів з 101 додано та 0 видалено
  1. 14 0
      http2.go
  2. 8 0
      server.go
  3. 79 0
      server_test.go

+ 14 - 0
http2.go

@@ -56,6 +56,20 @@ const (
 	stateClosed
 )
 
+var stateName = [...]string{
+	stateIdle:             "Idle",
+	stateOpen:             "Open",
+	stateHalfClosedLocal:  "HalfClosedLocal",
+	stateHalfClosedRemote: "HalfClosedRemote",
+	stateResvLocal:        "ResvLocal",
+	stateResvRemote:       "ResvRemote",
+	stateClosed:           "Closed",
+}
+
+func (st streamState) String() string {
+	return stateName[st]
+}
+
 func validHeader(v string) bool {
 	if len(v) == 0 {
 		return false

+ 8 - 0
server.go

@@ -84,6 +84,8 @@ func ConfigureServer(s *http.Server, conf *Server) {
 	}
 }
 
+var testHookGetServerConn func(*serverConn)
+
 func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 	sc := &serverConn{
 		hs:                hs,
@@ -104,6 +106,9 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 	}
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
+	if hook := testHookGetServerConn; hook != nil {
+		hook(sc)
+	}
 	sc.serve()
 }
 
@@ -130,6 +135,7 @@ type serverConn struct {
 	wantWriteFrameCh chan frameWriteMsg // from handlers -> serve
 	writeFrameCh     chan frameWriteMsg // from serve -> writeFrames
 	wroteFrameCh     chan struct{}      // from writeFrames -> serve, tickles more sends on writeFrameCh
+	testHookCh       chan func()        // code to run on the serve loop
 
 	serveG goroutineLock // used to verify funcs are on serve()
 	writeG goroutineLock // used to verify things running on writeLoop
@@ -370,6 +376,8 @@ func (sc *serverConn) serve() {
 		case <-settingsTimer.C:
 			sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
 			return
+		case fn := <-sc.testHookCh:
+			fn()
 		}
 	}
 }

+ 79 - 0
server_test.go

@@ -36,6 +36,7 @@ type serverTester struct {
 	ts     *httptest.Server
 	fr     *Framer
 	logBuf *bytes.Buffer
+	sc     *serverConn
 }
 
 func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
@@ -49,6 +50,11 @@ func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 	if VerboseLogs {
 		t.Logf("Running test server at: %s", ts.URL)
 	}
+	var sc *serverConn
+	testHookGetServerConn = func(v *serverConn) {
+		sc = v
+		sc.testHookCh = make(chan func())
+	}
 	cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), &tls.Config{
 		InsecureSkipVerify: true,
 		NextProtos:         []string{npnProto},
@@ -63,9 +69,26 @@ func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 		cc:     cc,
 		fr:     NewFramer(cc, cc),
 		logBuf: logBuf,
+		sc:     sc,
 	}
 }
 
+func (st *serverTester) stream(id uint32) *stream {
+	ch := make(chan *stream, 1)
+	st.sc.testHookCh <- func() {
+		ch <- st.sc.streams[id]
+	}
+	return <-ch
+}
+
+func (st *serverTester) streamState(id uint32) streamState {
+	ch := make(chan streamState, 1)
+	st.sc.testHookCh <- func() {
+		ch <- st.sc.state(id)
+	}
+	return <-ch
+}
+
 func (st *serverTester) Close() {
 	st.ts.Close()
 	st.cc.Close()
@@ -841,6 +864,62 @@ func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
 	}
 }
 
+func TestServer_StateTransitions(t *testing.T) {
+	t.Skip("TODO: failing test. fix")
+	var st *serverTester
+	inHandler := make(chan bool)
+	writeData := make(chan bool)
+	leaveHandler := make(chan bool)
+	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		inHandler <- true
+		if st.stream(1) == nil {
+			t.Errorf("nil stream 1 in handler")
+		}
+		if got := st.streamState(1); got != stateOpen {
+			t.Errorf("in handler, state is %v; want OPEN", got)
+		}
+		writeData <- true
+		if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
+			t.Errorf("body read = %d, %v; want 0, EOF", n, err)
+		}
+		if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
+			t.Errorf("in handler, state is %v; want %v", got, want)
+		}
+
+		<-leaveHandler
+	})
+	st.greet()
+	if st.stream(1) != nil {
+		t.Fatal("stream 1 should be empty")
+	}
+	if got := st.streamState(1); got != stateIdle {
+		t.Fatalf("stream 1 should be idle; got %v", got)
+	}
+
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1,
+		BlockFragment: encodeHeader(st.t, ":method", "POST"),
+		EndStream:     false, // keep it open
+		EndHeaders:    true,
+	})
+	<-inHandler
+	<-writeData
+	st.writeData(1, true, nil)
+
+	leaveHandler <- true
+	hf := st.wantHeaders()
+	if !hf.StreamEnded() {
+		t.Fatal("expected END_STREAM flag")
+	}
+
+	if got, want := st.streamState(1), stateClosed; got != want {
+		t.Errorf("at end, state is %v; want %v", got, want)
+	}
+	if st.stream(1) != nil {
+		t.Fatal("at end, stream 1 should be gone")
+	}
+}
+
 // TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
 // TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)