Browse Source

Handle 'cookie' headers' special case merging.

And start of rejecting headers with capitals as stream errors.

And more tests and test cleanup.
Brad Fitzpatrick 11 years ago
parent
commit
b90dfb042f
2 changed files with 224 additions and 66 deletions
  1. 44 5
      http2.go
  2. 180 61
      http2_test.go

+ 44 - 5
http2.go

@@ -107,6 +107,7 @@ type serverConn struct {
 	canonHeader       map[string]string // http2-lower-case -> Go-Canonical-Case
 	method, path      string
 	scheme, authority string
+	invalidHeader     bool
 
 	// State related to writing current headers:
 	hpackEncoder   *hpack.Encoder
@@ -161,8 +162,10 @@ func (sc *serverConn) logf(format string, args ...interface{}) {
 }
 
 func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
-	log.Printf("Header field: +%v", f)
-	if strings.HasPrefix(f.Name, ":") {
+	switch {
+	case !validHeader(f.Name):
+		sc.invalidHeader = true
+	case strings.HasPrefix(f.Name, ":"):
 		switch f.Name {
 		case ":method":
 			sc.method = f.Value
@@ -176,8 +179,15 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 			log.Printf("Ignoring unknown pseudo-header %q", f.Name)
 		}
 		return
+	case f.Name == "cookie":
+		if s, ok := sc.header["Cookie"]; ok && len(s) == 1 {
+			s[0] = s[0] + "; " + f.Value
+		} else {
+			sc.header.Add("Cookie", f.Value)
+		}
+	default:
+		sc.header.Add(sc.canonicalHeader(f.Name), f.Value)
 	}
-	sc.header.Add(sc.canonicalHeader(f.Name), f.Value)
 }
 
 func (sc *serverConn) canonicalHeader(v string) string {
@@ -208,7 +218,7 @@ func (sc *serverConn) serve() {
 	defer sc.conn.Close()
 	defer close(sc.doneServing)
 
-	log.Printf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
+	sc.logf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 
 	// Read the client preface
 	buf := make([]byte, len(ClientPreface))
@@ -283,7 +293,10 @@ func (sc *serverConn) serve() {
 					sc.logf("Disconnection; connection error: %v", err)
 					return
 				}
-				// TODO: stream errors, etc
+				if h2e.IsStreamError() {
+					// TODO: stream errors, etc
+					panic("TODO")
+				}
 			}
 			if err != nil {
 				sc.logf("Disconnection due to other error: %v", err)
@@ -349,6 +362,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	}
 	sc.streams[id] = st
 	sc.header = make(http.Header)
+	sc.invalidHeader = false
 	sc.curHeaderStreamID = id
 	sc.curStream = st
 	return sc.processHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
@@ -370,6 +384,14 @@ func (sc *serverConn) processHeaderBlockFragment(frag []byte, end bool) error {
 		// TODO: convert to stream error I assume?
 		return err
 	}
+	if sc.invalidHeader {
+		// See 8.1.2.6 Malformed Requests and Responses:
+		//
+		// Malformed requests or responses that are detected
+		// MUST be treated as a stream error (Section 5.4.2)
+		// of type PROTOCOL_ERROR."
+		return StreamError(ErrCodeProtocol)
+	}
 	curStream := sc.curStream
 	sc.curHeaderStreamID = 0
 	sc.curStream = nil
@@ -579,3 +601,20 @@ func (w *responseWriter) handlerDone() {
 }
 
 var testHookOnConn func() // for testing
+
+func validHeader(v string) bool {
+	if len(v) == 0 {
+		return false
+	}
+	for _, r := range v {
+		// "Just as in HTTP/1.x, header field names are
+		// strings of ASCII characters that are compared in a
+		// case-insensitive fashion. However, header field
+		// names MUST be converted to lowercase prior to their
+		// encoding in HTTP/2. "
+		if r >= 127 || ('A' <= r && r <= 'Z') {
+			return false
+		}
+	}
+	return true
+}

+ 180 - 61
http2_test.go

@@ -30,16 +30,19 @@ import (
 )
 
 type serverTester struct {
-	cc net.Conn // client conn
-	t  *testing.T
-	ts *httptest.Server
-	fr *Framer
+	cc     net.Conn // client conn
+	t      *testing.T
+	ts     *httptest.Server
+	fr     *Framer
+	logBuf *bytes.Buffer
 }
 
 func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
+	logBuf := new(bytes.Buffer)
 	ts := httptest.NewUnstartedServer(handler)
 	ConfigureServer(ts.Config, &Server{})
 	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
+	ts.Config.ErrorLog = log.New(io.MultiWriter(twriter{t: t}, logBuf), "", log.LstdFlags)
 	ts.StartTLS()
 
 	t.Logf("Running test server at: %s", ts.URL)
@@ -52,10 +55,11 @@ func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 	}
 	log.SetOutput(twriter{t})
 	return &serverTester{
-		t:  t,
-		ts: ts,
-		cc: cc,
-		fr: NewFramer(cc, cc),
+		t:      t,
+		ts:     ts,
+		cc:     cc,
+		fr:     NewFramer(cc, cc),
+		logBuf: logBuf,
 	}
 }
 
@@ -65,6 +69,16 @@ func (st *serverTester) Close() {
 	log.SetOutput(os.Stderr)
 }
 
+// greet initiates the client's HTTP/2 connection into a state where
+// frames may be sent.
+func (st *serverTester) greet() {
+	st.writePreface()
+	st.writeInitialSettings()
+	st.wantSettings()
+	st.writeSettingsAck()
+	st.wantSettingsAck()
+}
+
 func (st *serverTester) writePreface() {
 	n, err := st.cc.Write(clientPreface)
 	if err != nil {
@@ -93,10 +107,26 @@ func (st *serverTester) writeHeaders(p HeadersFrameParam) {
 	}
 }
 
+// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set.
+func (st *serverTester) bodylessReq1(headers ...string) {
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1, // clients send odd numbers
+		BlockFragment: encodeHeader(st.t, headers...),
+		EndStream:     true,
+		EndHeaders:    true,
+	})
+}
+
+func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
+	if err := st.fr.WriteData(streamID, endStream, data); err != nil {
+		st.t.Fatalf("Error writing DATA: %v", err)
+	}
+}
+
 func (st *serverTester) wantSettings() *SettingsFrame {
 	f, err := st.fr.ReadFrame()
 	if err != nil {
-		st.t.Fatal(err)
+		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
 	}
 	sf, ok := f.(*SettingsFrame)
 	if !ok {
@@ -105,6 +135,23 @@ func (st *serverTester) wantSettings() *SettingsFrame {
 	return sf
 }
 
+func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
+	f, err := st.fr.ReadFrame()
+	if err != nil {
+		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
+	}
+	rs, ok := f.(*RSTStreamFrame)
+	if !ok {
+		st.t.Fatalf("got a %T; want *RSTStream", f)
+	}
+	if rs.FrameHeader.StreamID != streamID {
+		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
+	}
+	if rs.ErrCode != uint32(errCode) {
+		st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
+	}
+}
+
 func (st *serverTester) wantSettingsAck() {
 	f, err := st.fr.ReadFrame()
 	if err != nil {
@@ -144,14 +191,10 @@ func TestServer(t *testing.T) {
 	st.wantSettingsAck()
 
 	st.writeHeaders(HeadersFrameParam{
-		StreamID: 1, // clients send odd numbers
-		BlockFragment: encodeHeader(t,
-			":method", "GET",
-			":path", "/",
-			":scheme", "https",
-		),
-		EndStream:  true, // no DATA frames
-		EndHeaders: true,
+		StreamID:      1, // clients send odd numbers
+		BlockFragment: encodeHeader(t),
+		EndStream:     true, // no DATA frames
+		EndHeaders:    true,
 	})
 
 	select {
@@ -164,18 +207,12 @@ func TestServer(t *testing.T) {
 func TestServer_Request_Get(t *testing.T) {
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				"foo-bar", "some-value",
-			),
-			EndStream:  true, // no DATA frames
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, "foo-bar", "some-value"),
+			EndStream:     true, // no DATA frames
+			EndHeaders:    true,
 		})
 	}, func(r *http.Request) {
-		t.Logf("GOT %#v", r)
 		if r.Method != "GET" {
 			t.Errorf("Method = %q; want GET", r.Method)
 		}
@@ -203,20 +240,63 @@ func TestServer_Request_Get(t *testing.T) {
 	})
 }
 
+// TODO: add a test with EndStream=true on the HEADERS but setting a
+// Content-Length anyway.  Should we just omit it and force it to
+// zero?
+
+func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":method", "POST"),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+	}, func(r *http.Request) {
+		if r.Method != "POST" {
+			t.Errorf("Method = %q; want POST", r.Method)
+		}
+		if r.ContentLength != 0 {
+			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
+		}
+		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
+			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
+		}
+	})
+}
+
+func TestServer_Request_Post_Body(t *testing.T) {
+	t.Skip("TODO: post bodies not yet implemented")
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":method", "POST"),
+			EndStream:     false, // migth be DATA frames
+			EndHeaders:    true,
+		})
+		st.writeData(1, true, nil)
+	}, func(r *http.Request) {
+		if r.Method != "POST" {
+			t.Errorf("Method = %q; want POST", r.Method)
+		}
+		if r.ContentLength != -1 {
+			t.Errorf("ContentLength = %v; want -1", r.ContentLength)
+		}
+		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
+			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
+		}
+	})
+}
+
 // Using a Host header, instead of :authority
 func TestServer_Request_Get_Host(t *testing.T) {
 	const host = "example.com"
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				"host", host,
-			),
-			EndStream:  true,
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, "host", host),
+			EndStream:     true,
+			EndHeaders:    true,
 		})
 	}, func(r *http.Request) {
 		if r.Host != host {
@@ -230,15 +310,10 @@ func TestServer_Request_Get_Authority(t *testing.T) {
 	const host = "example.com"
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				":authority", host,
-			),
-			EndStream:  true,
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":authority", host),
+			EndStream:     true,
+			EndHeaders:    true,
 		})
 	}, func(r *http.Request) {
 		if r.Host != host {
@@ -255,9 +330,6 @@ func TestServer_Request_WithContinuation(t *testing.T) {
 	}
 	testServerRequest(t, func(st *serverTester) {
 		fullHeaders := encodeHeader(t,
-			":method", "GET",
-			":path", "/",
-			":scheme", "https",
 			"foo-one", "value-one",
 			"foo-two", "value-two",
 			"foo-three", "value-three",
@@ -297,6 +369,36 @@ func TestServer_Request_WithContinuation(t *testing.T) {
 	})
 }
 
+// Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field")
+func TestServer_Request_CookieConcat(t *testing.T) {
+	const host = "example.com"
+	testServerRequest(t, func(st *serverTester) {
+		st.bodylessReq1(
+			":authority", host,
+			"cookie", "a=b",
+			"cookie", "c=d",
+			"cookie", "e=f",
+		)
+	}, func(r *http.Request) {
+		const want = "a=b; c=d; e=f"
+		if got := r.Header.Get("Cookie"); got != want {
+			t.Errorf("Cookie = %q; want %q", got, want)
+		}
+	})
+}
+
+func TestServer_Request_RejectCapitalHeader(t *testing.T) {
+	t.Skip("TODO: not handling stream errors properly yet in http2.go: if h2e.IsStreamError stuff")
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		t.Fatal("server request made it to handler; should've been rejected")
+	})
+	defer st.Close()
+
+	st.greet()
+	st.bodylessReq1("UPPER", "v")
+	st.wantRSTStream(1, ErrCodeProtocol)
+}
+
 // testServerRequest sets up an idle HTTP/2 connection and lets you
 // write a single request with writeReq, and then verify that the
 // *http.Request is built correctly in checkReq.
@@ -311,12 +413,7 @@ func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func
 	})
 	defer st.Close()
 
-	st.writePreface()
-	st.writeInitialSettings()
-	st.wantSettings()
-	st.writeSettingsAck()
-	st.wantSettingsAck()
-
+	st.greet()
 	writeReq(st)
 
 	select {
@@ -433,17 +530,39 @@ func (w twriter) Write(p []byte) (n int, err error) {
 	return len(p), nil
 }
 
-func encodeHeader(t *testing.T, kv ...string) []byte {
-	if len(kv)%2 == 1 {
+// encodeHeader encodes headers and returns their HPACK bytes. headers
+// must contain an even number of key/value pairs.  There may be
+// multiple pairs for keys (e.g. "cookie").  The :method, :path, and
+// :scheme headers default to GET, / and https.
+func encodeHeader(t *testing.T, headers ...string) []byte {
+	if len(headers)%2 == 1 {
 		panic("odd number of kv args")
 	}
+	keys := []string{":method", ":path", ":scheme"}
+	vals := map[string][]string{
+		":method": {"GET"},
+		":path":   {"/"},
+		":scheme": {"https"},
+	}
+	for len(headers) > 0 {
+		k, v := headers[0], headers[1]
+		headers = headers[2:]
+		if _, ok := vals[k]; !ok {
+			keys = append(keys, k)
+		}
+		if strings.HasPrefix(k, ":") {
+			vals[k] = []string{v}
+		} else {
+			vals[k] = append(vals[k], v)
+		}
+	}
 	var buf bytes.Buffer
 	enc := hpack.NewEncoder(&buf)
-	for len(kv) > 0 {
-		k, v := kv[0], kv[1]
-		kv = kv[2:]
-		if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
-			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+	for _, k := range keys {
+		for _, v := range vals[k] {
+			if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
+				t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+			}
 		}
 	}
 	return buf.Bytes()