Browse Source

Handle 'cookie' headers' special case merging.

And start of rejecting headers with capitals as stream errors.

And more tests and test cleanup.
Brad Fitzpatrick 11 years ago
parent
commit
b90dfb042f
2 changed files with 224 additions and 66 deletions
  1. 44 5
      http2.go
  2. 180 61
      http2_test.go

+ 44 - 5
http2.go

@@ -107,6 +107,7 @@ type serverConn struct {
 	canonHeader       map[string]string // http2-lower-case -> Go-Canonical-Case
 	canonHeader       map[string]string // http2-lower-case -> Go-Canonical-Case
 	method, path      string
 	method, path      string
 	scheme, authority string
 	scheme, authority string
+	invalidHeader     bool
 
 
 	// State related to writing current headers:
 	// State related to writing current headers:
 	hpackEncoder   *hpack.Encoder
 	hpackEncoder   *hpack.Encoder
@@ -161,8 +162,10 @@ func (sc *serverConn) logf(format string, args ...interface{}) {
 }
 }
 
 
 func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
-	log.Printf("Header field: +%v", f)
-	if strings.HasPrefix(f.Name, ":") {
+	switch {
+	case !validHeader(f.Name):
+		sc.invalidHeader = true
+	case strings.HasPrefix(f.Name, ":"):
 		switch f.Name {
 		switch f.Name {
 		case ":method":
 		case ":method":
 			sc.method = f.Value
 			sc.method = f.Value
@@ -176,8 +179,15 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 			log.Printf("Ignoring unknown pseudo-header %q", f.Name)
 			log.Printf("Ignoring unknown pseudo-header %q", f.Name)
 		}
 		}
 		return
 		return
+	case f.Name == "cookie":
+		if s, ok := sc.header["Cookie"]; ok && len(s) == 1 {
+			s[0] = s[0] + "; " + f.Value
+		} else {
+			sc.header.Add("Cookie", f.Value)
+		}
+	default:
+		sc.header.Add(sc.canonicalHeader(f.Name), f.Value)
 	}
 	}
-	sc.header.Add(sc.canonicalHeader(f.Name), f.Value)
 }
 }
 
 
 func (sc *serverConn) canonicalHeader(v string) string {
 func (sc *serverConn) canonicalHeader(v string) string {
@@ -208,7 +218,7 @@ func (sc *serverConn) serve() {
 	defer sc.conn.Close()
 	defer sc.conn.Close()
 	defer close(sc.doneServing)
 	defer close(sc.doneServing)
 
 
-	log.Printf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
+	sc.logf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 
 
 	// Read the client preface
 	// Read the client preface
 	buf := make([]byte, len(ClientPreface))
 	buf := make([]byte, len(ClientPreface))
@@ -283,7 +293,10 @@ func (sc *serverConn) serve() {
 					sc.logf("Disconnection; connection error: %v", err)
 					sc.logf("Disconnection; connection error: %v", err)
 					return
 					return
 				}
 				}
-				// TODO: stream errors, etc
+				if h2e.IsStreamError() {
+					// TODO: stream errors, etc
+					panic("TODO")
+				}
 			}
 			}
 			if err != nil {
 			if err != nil {
 				sc.logf("Disconnection due to other error: %v", err)
 				sc.logf("Disconnection due to other error: %v", err)
@@ -349,6 +362,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	}
 	}
 	sc.streams[id] = st
 	sc.streams[id] = st
 	sc.header = make(http.Header)
 	sc.header = make(http.Header)
+	sc.invalidHeader = false
 	sc.curHeaderStreamID = id
 	sc.curHeaderStreamID = id
 	sc.curStream = st
 	sc.curStream = st
 	return sc.processHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
 	return sc.processHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
@@ -370,6 +384,14 @@ func (sc *serverConn) processHeaderBlockFragment(frag []byte, end bool) error {
 		// TODO: convert to stream error I assume?
 		// TODO: convert to stream error I assume?
 		return err
 		return err
 	}
 	}
+	if sc.invalidHeader {
+		// See 8.1.2.6 Malformed Requests and Responses:
+		//
+		// Malformed requests or responses that are detected
+		// MUST be treated as a stream error (Section 5.4.2)
+		// of type PROTOCOL_ERROR."
+		return StreamError(ErrCodeProtocol)
+	}
 	curStream := sc.curStream
 	curStream := sc.curStream
 	sc.curHeaderStreamID = 0
 	sc.curHeaderStreamID = 0
 	sc.curStream = nil
 	sc.curStream = nil
@@ -579,3 +601,20 @@ func (w *responseWriter) handlerDone() {
 }
 }
 
 
 var testHookOnConn func() // for testing
 var testHookOnConn func() // for testing
+
+func validHeader(v string) bool {
+	if len(v) == 0 {
+		return false
+	}
+	for _, r := range v {
+		// "Just as in HTTP/1.x, header field names are
+		// strings of ASCII characters that are compared in a
+		// case-insensitive fashion. However, header field
+		// names MUST be converted to lowercase prior to their
+		// encoding in HTTP/2. "
+		if r >= 127 || ('A' <= r && r <= 'Z') {
+			return false
+		}
+	}
+	return true
+}

+ 180 - 61
http2_test.go

@@ -30,16 +30,19 @@ import (
 )
 )
 
 
 type serverTester struct {
 type serverTester struct {
-	cc net.Conn // client conn
-	t  *testing.T
-	ts *httptest.Server
-	fr *Framer
+	cc     net.Conn // client conn
+	t      *testing.T
+	ts     *httptest.Server
+	fr     *Framer
+	logBuf *bytes.Buffer
 }
 }
 
 
 func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
+	logBuf := new(bytes.Buffer)
 	ts := httptest.NewUnstartedServer(handler)
 	ts := httptest.NewUnstartedServer(handler)
 	ConfigureServer(ts.Config, &Server{})
 	ConfigureServer(ts.Config, &Server{})
 	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
 	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
+	ts.Config.ErrorLog = log.New(io.MultiWriter(twriter{t: t}, logBuf), "", log.LstdFlags)
 	ts.StartTLS()
 	ts.StartTLS()
 
 
 	t.Logf("Running test server at: %s", ts.URL)
 	t.Logf("Running test server at: %s", ts.URL)
@@ -52,10 +55,11 @@ func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 	}
 	}
 	log.SetOutput(twriter{t})
 	log.SetOutput(twriter{t})
 	return &serverTester{
 	return &serverTester{
-		t:  t,
-		ts: ts,
-		cc: cc,
-		fr: NewFramer(cc, cc),
+		t:      t,
+		ts:     ts,
+		cc:     cc,
+		fr:     NewFramer(cc, cc),
+		logBuf: logBuf,
 	}
 	}
 }
 }
 
 
@@ -65,6 +69,16 @@ func (st *serverTester) Close() {
 	log.SetOutput(os.Stderr)
 	log.SetOutput(os.Stderr)
 }
 }
 
 
+// greet initiates the client's HTTP/2 connection into a state where
+// frames may be sent.
+func (st *serverTester) greet() {
+	st.writePreface()
+	st.writeInitialSettings()
+	st.wantSettings()
+	st.writeSettingsAck()
+	st.wantSettingsAck()
+}
+
 func (st *serverTester) writePreface() {
 func (st *serverTester) writePreface() {
 	n, err := st.cc.Write(clientPreface)
 	n, err := st.cc.Write(clientPreface)
 	if err != nil {
 	if err != nil {
@@ -93,10 +107,26 @@ func (st *serverTester) writeHeaders(p HeadersFrameParam) {
 	}
 	}
 }
 }
 
 
+// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set.
+func (st *serverTester) bodylessReq1(headers ...string) {
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1, // clients send odd numbers
+		BlockFragment: encodeHeader(st.t, headers...),
+		EndStream:     true,
+		EndHeaders:    true,
+	})
+}
+
+func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
+	if err := st.fr.WriteData(streamID, endStream, data); err != nil {
+		st.t.Fatalf("Error writing DATA: %v", err)
+	}
+}
+
 func (st *serverTester) wantSettings() *SettingsFrame {
 func (st *serverTester) wantSettings() *SettingsFrame {
 	f, err := st.fr.ReadFrame()
 	f, err := st.fr.ReadFrame()
 	if err != nil {
 	if err != nil {
-		st.t.Fatal(err)
+		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
 	}
 	}
 	sf, ok := f.(*SettingsFrame)
 	sf, ok := f.(*SettingsFrame)
 	if !ok {
 	if !ok {
@@ -105,6 +135,23 @@ func (st *serverTester) wantSettings() *SettingsFrame {
 	return sf
 	return sf
 }
 }
 
 
+func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
+	f, err := st.fr.ReadFrame()
+	if err != nil {
+		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
+	}
+	rs, ok := f.(*RSTStreamFrame)
+	if !ok {
+		st.t.Fatalf("got a %T; want *RSTStream", f)
+	}
+	if rs.FrameHeader.StreamID != streamID {
+		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
+	}
+	if rs.ErrCode != uint32(errCode) {
+		st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
+	}
+}
+
 func (st *serverTester) wantSettingsAck() {
 func (st *serverTester) wantSettingsAck() {
 	f, err := st.fr.ReadFrame()
 	f, err := st.fr.ReadFrame()
 	if err != nil {
 	if err != nil {
@@ -144,14 +191,10 @@ func TestServer(t *testing.T) {
 	st.wantSettingsAck()
 	st.wantSettingsAck()
 
 
 	st.writeHeaders(HeadersFrameParam{
 	st.writeHeaders(HeadersFrameParam{
-		StreamID: 1, // clients send odd numbers
-		BlockFragment: encodeHeader(t,
-			":method", "GET",
-			":path", "/",
-			":scheme", "https",
-		),
-		EndStream:  true, // no DATA frames
-		EndHeaders: true,
+		StreamID:      1, // clients send odd numbers
+		BlockFragment: encodeHeader(t),
+		EndStream:     true, // no DATA frames
+		EndHeaders:    true,
 	})
 	})
 
 
 	select {
 	select {
@@ -164,18 +207,12 @@ func TestServer(t *testing.T) {
 func TestServer_Request_Get(t *testing.T) {
 func TestServer_Request_Get(t *testing.T) {
 	testServerRequest(t, func(st *serverTester) {
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				"foo-bar", "some-value",
-			),
-			EndStream:  true, // no DATA frames
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, "foo-bar", "some-value"),
+			EndStream:     true, // no DATA frames
+			EndHeaders:    true,
 		})
 		})
 	}, func(r *http.Request) {
 	}, func(r *http.Request) {
-		t.Logf("GOT %#v", r)
 		if r.Method != "GET" {
 		if r.Method != "GET" {
 			t.Errorf("Method = %q; want GET", r.Method)
 			t.Errorf("Method = %q; want GET", r.Method)
 		}
 		}
@@ -203,20 +240,63 @@ func TestServer_Request_Get(t *testing.T) {
 	})
 	})
 }
 }
 
 
+// TODO: add a test with EndStream=true on the HEADERS but setting a
+// Content-Length anyway.  Should we just omit it and force it to
+// zero?
+
+func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":method", "POST"),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+	}, func(r *http.Request) {
+		if r.Method != "POST" {
+			t.Errorf("Method = %q; want POST", r.Method)
+		}
+		if r.ContentLength != 0 {
+			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
+		}
+		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
+			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
+		}
+	})
+}
+
+func TestServer_Request_Post_Body(t *testing.T) {
+	t.Skip("TODO: post bodies not yet implemented")
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":method", "POST"),
+			EndStream:     false, // migth be DATA frames
+			EndHeaders:    true,
+		})
+		st.writeData(1, true, nil)
+	}, func(r *http.Request) {
+		if r.Method != "POST" {
+			t.Errorf("Method = %q; want POST", r.Method)
+		}
+		if r.ContentLength != -1 {
+			t.Errorf("ContentLength = %v; want -1", r.ContentLength)
+		}
+		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
+			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
+		}
+	})
+}
+
 // Using a Host header, instead of :authority
 // Using a Host header, instead of :authority
 func TestServer_Request_Get_Host(t *testing.T) {
 func TestServer_Request_Get_Host(t *testing.T) {
 	const host = "example.com"
 	const host = "example.com"
 	testServerRequest(t, func(st *serverTester) {
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				"host", host,
-			),
-			EndStream:  true,
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, "host", host),
+			EndStream:     true,
+			EndHeaders:    true,
 		})
 		})
 	}, func(r *http.Request) {
 	}, func(r *http.Request) {
 		if r.Host != host {
 		if r.Host != host {
@@ -230,15 +310,10 @@ func TestServer_Request_Get_Authority(t *testing.T) {
 	const host = "example.com"
 	const host = "example.com"
 	testServerRequest(t, func(st *serverTester) {
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				":authority", host,
-			),
-			EndStream:  true,
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":authority", host),
+			EndStream:     true,
+			EndHeaders:    true,
 		})
 		})
 	}, func(r *http.Request) {
 	}, func(r *http.Request) {
 		if r.Host != host {
 		if r.Host != host {
@@ -255,9 +330,6 @@ func TestServer_Request_WithContinuation(t *testing.T) {
 	}
 	}
 	testServerRequest(t, func(st *serverTester) {
 	testServerRequest(t, func(st *serverTester) {
 		fullHeaders := encodeHeader(t,
 		fullHeaders := encodeHeader(t,
-			":method", "GET",
-			":path", "/",
-			":scheme", "https",
 			"foo-one", "value-one",
 			"foo-one", "value-one",
 			"foo-two", "value-two",
 			"foo-two", "value-two",
 			"foo-three", "value-three",
 			"foo-three", "value-three",
@@ -297,6 +369,36 @@ func TestServer_Request_WithContinuation(t *testing.T) {
 	})
 	})
 }
 }
 
 
+// Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field")
+func TestServer_Request_CookieConcat(t *testing.T) {
+	const host = "example.com"
+	testServerRequest(t, func(st *serverTester) {
+		st.bodylessReq1(
+			":authority", host,
+			"cookie", "a=b",
+			"cookie", "c=d",
+			"cookie", "e=f",
+		)
+	}, func(r *http.Request) {
+		const want = "a=b; c=d; e=f"
+		if got := r.Header.Get("Cookie"); got != want {
+			t.Errorf("Cookie = %q; want %q", got, want)
+		}
+	})
+}
+
+func TestServer_Request_RejectCapitalHeader(t *testing.T) {
+	t.Skip("TODO: not handling stream errors properly yet in http2.go: if h2e.IsStreamError stuff")
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		t.Fatal("server request made it to handler; should've been rejected")
+	})
+	defer st.Close()
+
+	st.greet()
+	st.bodylessReq1("UPPER", "v")
+	st.wantRSTStream(1, ErrCodeProtocol)
+}
+
 // testServerRequest sets up an idle HTTP/2 connection and lets you
 // testServerRequest sets up an idle HTTP/2 connection and lets you
 // write a single request with writeReq, and then verify that the
 // write a single request with writeReq, and then verify that the
 // *http.Request is built correctly in checkReq.
 // *http.Request is built correctly in checkReq.
@@ -311,12 +413,7 @@ func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func
 	})
 	})
 	defer st.Close()
 	defer st.Close()
 
 
-	st.writePreface()
-	st.writeInitialSettings()
-	st.wantSettings()
-	st.writeSettingsAck()
-	st.wantSettingsAck()
-
+	st.greet()
 	writeReq(st)
 	writeReq(st)
 
 
 	select {
 	select {
@@ -433,17 +530,39 @@ func (w twriter) Write(p []byte) (n int, err error) {
 	return len(p), nil
 	return len(p), nil
 }
 }
 
 
-func encodeHeader(t *testing.T, kv ...string) []byte {
-	if len(kv)%2 == 1 {
+// encodeHeader encodes headers and returns their HPACK bytes. headers
+// must contain an even number of key/value pairs.  There may be
+// multiple pairs for keys (e.g. "cookie").  The :method, :path, and
+// :scheme headers default to GET, / and https.
+func encodeHeader(t *testing.T, headers ...string) []byte {
+	if len(headers)%2 == 1 {
 		panic("odd number of kv args")
 		panic("odd number of kv args")
 	}
 	}
+	keys := []string{":method", ":path", ":scheme"}
+	vals := map[string][]string{
+		":method": {"GET"},
+		":path":   {"/"},
+		":scheme": {"https"},
+	}
+	for len(headers) > 0 {
+		k, v := headers[0], headers[1]
+		headers = headers[2:]
+		if _, ok := vals[k]; !ok {
+			keys = append(keys, k)
+		}
+		if strings.HasPrefix(k, ":") {
+			vals[k] = []string{v}
+		} else {
+			vals[k] = append(vals[k], v)
+		}
+	}
 	var buf bytes.Buffer
 	var buf bytes.Buffer
 	enc := hpack.NewEncoder(&buf)
 	enc := hpack.NewEncoder(&buf)
-	for len(kv) > 0 {
-		k, v := kv[0], kv[1]
-		kv = kv[2:]
-		if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
-			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+	for _, k := range keys {
+		for _, v := range vals[k] {
+			if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
+				t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+			}
 		}
 		}
 	}
 	}
 	return buf.Bytes()
 	return buf.Bytes()