Browse Source

Require pseudo headers in requests. Clear state between requests.

Brad Fitzpatrick 11 years ago
parent
commit
c4d60a2b87
2 changed files with 68 additions and 18 deletions
  1. 41 13
      http2.go
  2. 27 5
      http2_test.go

+ 41 - 13
http2.go

@@ -95,6 +95,8 @@ type serverConn struct {
 	conn           net.Conn
 	handler        http.Handler
 	framer         *Framer
+	hpackDecoder   *hpack.Decoder
+	hpackEncoder   *hpack.Encoder
 	doneServing    chan struct{}          // closed when serverConn.serve ends
 	readFrameCh    chan frameAndProcessed // written by serverConn.readFrames
 	readFrameErrCh chan error
@@ -109,21 +111,19 @@ type serverConn struct {
 	maxWriteFrameSize uint32 // TODO: update this when settings come in
 
 	// State related to parsing current headers:
-	hpackDecoder      *hpack.Decoder
 	header            http.Header
 	canonHeader       map[string]string // http2-lower-case -> Go-Canonical-Case
 	method, path      string
 	scheme, authority string
 	invalidHeader     bool
 
-	// State related to writing current headers:
-	hpackEncoder   *hpack.Encoder
-	headerWriteBuf bytes.Buffer
-
-	// curHeaderStreamID is non-zero if we're in the middle
-	// of parsing headers that span multiple frames.
+	// curHeaderStreamID and curStream are non-zero if we're in
+	// the middle of parsing headers that span multiple frames.
 	curHeaderStreamID uint32
 	curStream         *stream
+
+	// State related to writing current headers:
+	headerWriteBuf bytes.Buffer // written/accessed from serve goroutine
 }
 
 type streamState int
@@ -175,6 +175,19 @@ func (sc *serverConn) logf(format string, args ...interface{}) {
 	}
 }
 
+func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
+	if err == nil {
+		return
+	}
+	str := err.Error()
+	if strings.Contains(str, "use of closed network connection") {
+		// Boring, expected errors.
+		sc.vlogf(format, args...)
+	} else {
+		sc.logf(format, args...)
+	}
+}
+
 func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 	sc.serveG.check()
 	switch {
@@ -191,7 +204,13 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 		case ":authority":
 			sc.authority = f.Value
 		default:
-			log.Printf("Ignoring unknown pseudo-header %q", f.Name)
+			// 8.1.2.1 Pseudo-Header Fields
+			// "Endpoints MUST treat a request or response
+			// that contains undefined or invalid
+			// pseudo-header fields as malformed (Section
+			// 8.1.2.6)."
+			sc.logf("invalid pseudo-header %q", f.Name)
+			sc.invalidHeader = true
 		}
 		return
 	case f.Name == "cookie":
@@ -237,7 +256,7 @@ func (sc *serverConn) serve() {
 	defer sc.conn.Close()
 	defer close(sc.doneServing)
 
-	sc.logf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
+	sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 
 	// Read the client preface
 	buf := make([]byte, len(ClientPreface))
@@ -250,7 +269,7 @@ func (sc *serverConn) serve() {
 		sc.logf("bogus greeting from client: %q", buf)
 		return
 	}
-	sc.logf("client %v said hello", sc.conn.RemoteAddr())
+	sc.vlogf("client %v said hello", sc.conn.RemoteAddr())
 
 	f, err := sc.framer.ReadFrame()
 	if err != nil {
@@ -288,8 +307,7 @@ func (sc *serverConn) serve() {
 		select {
 		case hr := <-sc.writeHeaderCh:
 			if err := sc.writeHeaderInLoop(hr); err != nil {
-				// TODO: diff error handling?
-				sc.logf("error writing response header: %v", err)
+				sc.condlogf(err, "error writing response header: %v", err)
 				return
 			}
 		case fp, ok := <-sc.readFrameCh:
@@ -394,7 +412,12 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		st.state = stateHalfClosedRemote
 	}
 	sc.streams[id] = st
+
 	sc.header = make(http.Header)
+	sc.method = ""
+	sc.path = ""
+	sc.scheme = ""
+	sc.authority = ""
 	sc.invalidHeader = false
 	sc.curHeaderStreamID = id
 	sc.curStream = st
@@ -423,12 +446,17 @@ func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, e
 		// TODO: convert to stream error I assume?
 		return err
 	}
-	if sc.invalidHeader {
+	if sc.invalidHeader || sc.method == "" || sc.path == "" || sc.scheme == "" {
 		// See 8.1.2.6 Malformed Requests and Responses:
 		//
 		// Malformed requests or responses that are detected
 		// MUST be treated as a stream error (Section 5.4.2)
 		// of type PROTOCOL_ERROR."
+		//
+		// 8.1.2.3 Request Pseudo-Header Fields
+		// "All HTTP/2 requests MUST include exactly one valid
+		// value for the :method, :scheme, and :path
+		// pseudo-header fields"
 		return StreamError{streamID, ErrCodeProtocol}
 	}
 	curStream := sc.curStream

+ 27 - 5
http2_test.go

@@ -11,6 +11,7 @@ import (
 	"bytes"
 	"crypto/tls"
 	"errors"
+	"flag"
 	"fmt"
 	"io"
 	"log"
@@ -30,8 +31,8 @@ import (
 )
 
 func init() {
-	VerboseLogs = true
 	DebugGoroutines = true
+	flag.BoolVar(&VerboseLogs, "verboseh2", false, "Verbose HTTP/2 debug logging")
 }
 
 type serverTester struct {
@@ -50,7 +51,9 @@ func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 	ts.Config.ErrorLog = log.New(io.MultiWriter(twriter{t: t}, logBuf), "", log.LstdFlags)
 	ts.StartTLS()
 
-	t.Logf("Running test server at: %s", ts.URL)
+	if VerboseLogs {
+		t.Logf("Running test server at: %s", ts.URL)
+	}
 	cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), &tls.Config{
 		InsecureSkipVerify: true,
 		NextProtos:         []string{npnProto},
@@ -176,7 +179,6 @@ func TestServer(t *testing.T) {
 	gotReq := make(chan bool, 1)
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("Foo", "Bar")
-		t.Logf("GOT REQUEST %#v", r)
 		gotReq <- true
 	})
 	defer st.Close()
@@ -392,14 +394,34 @@ func TestServer_Request_CookieConcat(t *testing.T) {
 	})
 }
 
-func TestServer_Request_RejectCapitalHeader(t *testing.T) {
+func TestServer_Request_Reject_CapitalHeader(t *testing.T) {
+	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") })
+}
+
+func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
+	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") })
+}
+
+func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) {
+	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") })
+}
+
+func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) {
+	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") })
+}
+
+func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
+	testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":unknown_thing", "") })
+}
+
+func testRejectRequest(t *testing.T, send func(*serverTester)) {
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 		t.Fatal("server request made it to handler; should've been rejected")
 	})
 	defer st.Close()
 
 	st.greet()
-	st.bodylessReq1("UPPER", "v")
+	send(st)
 	st.wantRSTStream(1, ErrCodeProtocol)
 }