Browse Source

More tests, clean up test log noise, fix a GOAWAY bug.

The bug, found via new tests: if the ReadFrames goroutine read a frame
from the Framer resulting in an error of type ConnectionError, the
code specific to handling ConnectionError wasn't getting to handle it.
We always assumed that all errors coming from the Framer were related
to connection errors (io.EOF, etc). So now fall through past the
normal frame-processing logic and handle all errors the same way. That
means we now respond with a GOAWAY frame to clients who sent us bogus
frames rejected by the framer level too.
Brad Fitzpatrick 11 years ago
parent
commit
8ec321e392
4 changed files with 206 additions and 44 deletions
  1. 3 0
      frame.go
  2. 24 1
      http2_test.go
  3. 27 19
      server.go
  4. 152 24
      server_test.go

+ 3 - 0
frame.go

@@ -779,6 +779,9 @@ type HeadersFrameParam struct {
 // It will perform exactly one Write to the underlying Writer.
 // It is the caller's responsibility to not call other Write methods concurrently.
 func (f *Framer) WriteHeaders(p HeadersFrameParam) error {
+	if !validStreamID(p.StreamID) && !f.AllowIllegalWrites {
+		return errStreamID
+	}
 	var flags Flags
 	if p.PadLength != 0 {
 		flags |= FlagHeadersPadded

+ 24 - 1
http2_test.go

@@ -42,10 +42,19 @@ func TestSettingString(t *testing.T) {
 }
 
 type twriter struct {
-	t testing.TB
+	t  testing.TB
+	st *serverTester // optional
 }
 
 func (w twriter) Write(p []byte) (n int, err error) {
+	if w.st != nil {
+		ps := string(p)
+		for _, phrase := range w.st.logFilter {
+			if strings.Contains(ps, phrase) {
+				return len(p), nil // no logging
+			}
+		}
+	}
 	w.t.Logf("%s", p)
 	return len(p), nil
 }
@@ -95,6 +104,20 @@ func encodeHeader(t *testing.T, headers ...string) []byte {
 	return buf.Bytes()
 }
 
+// like encodeHeader, but don't add implicit psuedo headers.
+func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte {
+	var buf bytes.Buffer
+	enc := hpack.NewEncoder(&buf)
+	for len(headers) > 0 {
+		k, v := headers[0], headers[1]
+		headers = headers[2:]
+		if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
+			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+		}
+	}
+	return buf.Bytes()
+}
+
 // Verify that curl has http2.
 func requireCurl(t *testing.T) {
 	out, err := dockerLogs(curl(t, "--version"))

+ 27 - 19
server.go

@@ -403,7 +403,7 @@ func (sc *serverConn) readFrames() {
 	for {
 		f, err := sc.framer.ReadFrame()
 		if err != nil {
-			sc.readFrameErrCh <- err // BEFORE the close
+			sc.readFrameErrCh <- err
 			close(sc.readFrameCh)
 			return
 		}
@@ -788,29 +788,33 @@ func (sc *serverConn) curHeaderStreamID() uint32 {
 // processFrameFromReader returns whether the connection should be kept open.
 func (sc *serverConn) processFrameFromReader(fg frameAndGate, fgValid bool) bool {
 	sc.serveG.check()
+	var clientGone bool
+	var err error
 	if !fgValid {
-		err := <-sc.readFrameErrCh
+		err = <-sc.readFrameErrCh
 		if err == ErrFrameTooLarge {
 			sc.goAway(ErrCodeFrameSize)
 			return true // goAway will close the loop
 		}
-		if err != io.EOF {
-			errstr := err.Error()
-			if !strings.Contains(errstr, "use of closed network connection") {
-				sc.logf("client %s stopped sending frames: %v", sc.conn.RemoteAddr(), errstr)
-			}
+		clientGone = err == io.EOF || strings.Contains(err.Error(), "use of closed network connection")
+		if clientGone {
+			// TODO: could we also get into this state if
+			// the peer does a half close
+			// (e.g. CloseWrite) because they're done
+			// sending frames but they're still wanting
+			// our open replies?  Investigate.
+			return false
+		}
+	}
+
+	if fgValid {
+		f := fg.f
+		sc.vlogf("got %v: %#v", f.Header(), f)
+		err = sc.processFrame(f)
+		fg.g.Done() // unblock the readFrames goroutine
+		if err == nil {
+			return true
 		}
-		// TODO: could we also get into this state if the peer does a half close (e.g. CloseWrite)
-		// because they're done sending frames but they're still wanting our open replies?
-		// Investigate.
-		return false
-	}
-	f := fg.f
-	sc.vlogf("got %v: %#v", f.Header(), f)
-	err := sc.processFrame(f)
-	fg.g.Done() // unblock the readFrames goroutine
-	if err == nil {
-		return true
 	}
 
 	switch ev := err.(type) {
@@ -825,7 +829,11 @@ func (sc *serverConn) processFrameFromReader(fg frameAndGate, fgValid bool) bool
 		sc.goAway(ErrCode(ev))
 		return true // goAway will handle shutdown
 	default:
-		sc.logf("disconnection due to other error: %v", err)
+		if !fgValid {
+			sc.logf("disconnecting; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
+		} else {
+			sc.logf("disconnection due to other error: %v", err)
+		}
 	}
 	return false
 }

+ 152 - 24
server_test.go

@@ -32,20 +32,28 @@ import (
 )
 
 type serverTester struct {
-	cc     net.Conn // client conn
-	t      *testing.T
-	ts     *httptest.Server
-	fr     *Framer
-	logBuf *bytes.Buffer
-	sc     *serverConn
+	cc        net.Conn // client conn
+	t         *testing.T
+	ts        *httptest.Server
+	fr        *Framer
+	logBuf    *bytes.Buffer
+	sc        *serverConn
+	logFilter []string // substrings to filter out
 }
 
 func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 	logBuf := new(bytes.Buffer)
 	ts := httptest.NewUnstartedServer(handler)
 	ConfigureServer(ts.Config, &Server{})
+
+	st := &serverTester{
+		t:      t,
+		ts:     ts,
+		logBuf: logBuf,
+	}
+
 	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
-	ts.Config.ErrorLog = log.New(io.MultiWriter(twriter{t: t}, logBuf), "", log.LstdFlags)
+	ts.Config.ErrorLog = log.New(io.MultiWriter(twriter{t: t, st: st}, logBuf), "", log.LstdFlags)
 	ts.StartTLS()
 
 	if VerboseLogs {
@@ -68,16 +76,19 @@ func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 	if err != nil {
 		t.Fatal(err)
 	}
-	log.SetOutput(twriter{t})
+	log.SetOutput(twriter{t: t, st: st})
+
+	st.cc = cc
+	st.fr = NewFramer(cc, cc)
+
 	mu.Lock()
-	return &serverTester{
-		t:      t,
-		ts:     ts,
-		cc:     cc,
-		fr:     NewFramer(cc, cc),
-		logBuf: logBuf,
-		sc:     sc,
-	}
+	st.sc = sc
+	mu.Unlock() // unnecessary, but looks weird without.
+	return st
+}
+
+func (st *serverTester) addLogFilter(phrase string) {
+	st.logFilter = append(st.logFilter, phrase)
 }
 
 func (st *serverTester) stream(id uint32) *stream {
@@ -304,10 +315,7 @@ func TestServer(t *testing.T) {
 
 	st.writePreface()
 	st.writeInitialSettings()
-	st.wantSettings().ForeachSetting(func(s Setting) error {
-		t.Logf("Server sent setting %v = %v", s.ID, s.Val)
-		return nil
-	})
+	st.wantSettings()
 	st.writeSettingsAck()
 	st.wantSettingsAck()
 
@@ -647,7 +655,10 @@ func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
 func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
 	// 8.1.2.3 Request Pseudo-Header Fields
 	// "All HTTP/2 requests MUST include exactly one valid value" ...
-	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "GET", ":method", "POST") })
+	testRejectRequest(t, func(st *serverTester) {
+		st.addLogFilter("duplicate pseudo-header")
+		st.bodylessReq1(":method", "GET", ":method", "POST")
+	})
 }
 
 func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
@@ -658,6 +669,7 @@ func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
 	// block after a regular header field MUST be treated as
 	// malformed (Section 8.1.2.6)."
 	testRejectRequest(t, func(st *serverTester) {
+		st.addLogFilter("pseudo-header after regular header")
 		var buf bytes.Buffer
 		enc := hpack.NewEncoder(&buf)
 		enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
@@ -686,7 +698,10 @@ func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
 }
 
 func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
-	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":unknown_thing", "") })
+	testRejectRequest(t, func(st *serverTester) {
+		st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
+		st.bodylessReq1(":unknown_thing", "")
+	})
 }
 
 func testRejectRequest(t *testing.T, send func(*serverTester)) {
@@ -977,8 +992,121 @@ func TestServer_StateTransitions(t *testing.T) {
 	}
 }
 
-// TODO: test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
-// TODO: test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
+// test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
+func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
+	testServerRejects(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1,
+			BlockFragment: encodeHeader(st.t),
+			EndStream:     true,
+			EndHeaders:    false,
+		})
+		st.writeHeaders(HeadersFrameParam{ // Not a continuation.
+			StreamID:      3, // different stream.
+			BlockFragment: encodeHeader(st.t),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+	})
+}
+
+// test HEADERS w/o EndHeaders + PING (should get rejected)
+func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
+	testServerRejects(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1,
+			BlockFragment: encodeHeader(st.t),
+			EndStream:     true,
+			EndHeaders:    false,
+		})
+		if err := st.fr.WritePing(false, [8]byte{}); err != nil {
+			t.Fatal(err)
+		}
+	})
+}
+
+// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
+func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
+	testServerRejects(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1,
+			BlockFragment: encodeHeader(st.t),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+		st.wantHeaders()
+		if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
+			t.Fatal(err)
+		}
+	})
+}
+
+// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
+func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
+	testServerRejects(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1,
+			BlockFragment: encodeHeader(st.t),
+			EndStream:     true,
+			EndHeaders:    false,
+		})
+		if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
+			t.Fatal(err)
+		}
+	})
+}
+
+// No HEADERS on stream 0.
+func TestServer_Rejects_Headers0(t *testing.T) {
+	testServerRejects(t, func(st *serverTester) {
+		st.fr.AllowIllegalWrites = true
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      0,
+			BlockFragment: encodeHeader(st.t),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+	})
+}
+
+// No CONTINUATION on stream 0.
+func TestServer_Rejects_Continuation0(t *testing.T) {
+	testServerRejects(t, func(st *serverTester) {
+		st.fr.AllowIllegalWrites = true
+		if err := st.fr.WriteContinuation(0, true, encodeHeader(t)); err != nil {
+			t.Fatal(err)
+		}
+	})
+}
+
+// testServerRejects tests that the server hangs up with a GOAWAY
+// frame and a server close after the client does something
+// deserving a CONNECTION_ERROR.
+func testServerRejects(t *testing.T, writeReq func(*serverTester)) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
+	st.addLogFilter("connection error: PROTOCOL_ERROR")
+	defer st.Close()
+	st.greet()
+	writeReq(st)
+
+	st.wantGoAway()
+	errc := make(chan error, 1)
+	go func() {
+		fr, err := st.fr.ReadFrame()
+		if err == nil {
+			err = fmt.Errorf("got frame of type %T", fr)
+		}
+		errc <- err
+	}()
+	select {
+	case err := <-errc:
+		if err != io.EOF {
+			t.Errorf("ReadFrame = %v; want io.EOF", err)
+		}
+	case <-time.After(2 * time.Second):
+		t.Error("timeout waiting for disconnect")
+	}
+}
 
 // testServerRequest sets up an idle HTTP/2 connection and lets you
 // write a single request with writeReq, and then verify that the