Parcourir la source

Handle 'cookie' headers' special case merging.

And start of rejecting headers with capitals as stream errors.

And more tests and test cleanup.
Brad Fitzpatrick il y a 11 ans
Parent
commit
b90dfb042f
2 fichiers modifiés avec 224 ajouts et 66 suppressions
  1. 44 5
      http2.go
  2. 180 61
      http2_test.go

+ 44 - 5
http2.go

@@ -107,6 +107,7 @@ type serverConn struct {
 	canonHeader       map[string]string // http2-lower-case -> Go-Canonical-Case
 	method, path      string
 	scheme, authority string
+	invalidHeader     bool
 
 	// State related to writing current headers:
 	hpackEncoder   *hpack.Encoder
@@ -161,8 +162,10 @@ func (sc *serverConn) logf(format string, args ...interface{}) {
 }
 
 func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
-	log.Printf("Header field: +%v", f)
-	if strings.HasPrefix(f.Name, ":") {
+	switch {
+	case !validHeader(f.Name):
+		sc.invalidHeader = true
+	case strings.HasPrefix(f.Name, ":"):
 		switch f.Name {
 		case ":method":
 			sc.method = f.Value
@@ -176,8 +179,15 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 			log.Printf("Ignoring unknown pseudo-header %q", f.Name)
 		}
 		return
+	case f.Name == "cookie":
+		if s, ok := sc.header["Cookie"]; ok && len(s) == 1 {
+			s[0] = s[0] + "; " + f.Value
+		} else {
+			sc.header.Add("Cookie", f.Value)
+		}
+	default:
+		sc.header.Add(sc.canonicalHeader(f.Name), f.Value)
 	}
-	sc.header.Add(sc.canonicalHeader(f.Name), f.Value)
 }
 
 func (sc *serverConn) canonicalHeader(v string) string {
@@ -208,7 +218,7 @@ func (sc *serverConn) serve() {
 	defer sc.conn.Close()
 	defer close(sc.doneServing)
 
-	log.Printf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
+	sc.logf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
 
 	// Read the client preface
 	buf := make([]byte, len(ClientPreface))
@@ -283,7 +293,10 @@ func (sc *serverConn) serve() {
 					sc.logf("Disconnection; connection error: %v", err)
 					return
 				}
-				// TODO: stream errors, etc
+				if h2e.IsStreamError() {
+					// TODO: stream errors, etc
+					panic("TODO")
+				}
 			}
 			if err != nil {
 				sc.logf("Disconnection due to other error: %v", err)
@@ -349,6 +362,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	}
 	sc.streams[id] = st
 	sc.header = make(http.Header)
+	sc.invalidHeader = false
 	sc.curHeaderStreamID = id
 	sc.curStream = st
 	return sc.processHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
@@ -370,6 +384,14 @@ func (sc *serverConn) processHeaderBlockFragment(frag []byte, end bool) error {
 		// TODO: convert to stream error I assume?
 		return err
 	}
+	if sc.invalidHeader {
+		// See 8.1.2.6 Malformed Requests and Responses:
+		//
+		// Malformed requests or responses that are detected
+		// MUST be treated as a stream error (Section 5.4.2)
+		// of type PROTOCOL_ERROR."
+		return StreamError(ErrCodeProtocol)
+	}
 	curStream := sc.curStream
 	sc.curHeaderStreamID = 0
 	sc.curStream = nil
@@ -579,3 +601,20 @@ func (w *responseWriter) handlerDone() {
 }
 
 var testHookOnConn func() // for testing
+
+func validHeader(v string) bool {
+	if len(v) == 0 {
+		return false
+	}
+	for _, r := range v {
+		// "Just as in HTTP/1.x, header field names are
+		// strings of ASCII characters that are compared in a
+		// case-insensitive fashion. However, header field
+		// names MUST be converted to lowercase prior to their
+		// encoding in HTTP/2. "
+		if r >= 127 || ('A' <= r && r <= 'Z') {
+			return false
+		}
+	}
+	return true
+}

+ 180 - 61
http2_test.go

@@ -30,16 +30,19 @@ import (
 )
 
 type serverTester struct {
-	cc net.Conn // client conn
-	t  *testing.T
-	ts *httptest.Server
-	fr *Framer
+	cc     net.Conn // client conn
+	t      *testing.T
+	ts     *httptest.Server
+	fr     *Framer
+	logBuf *bytes.Buffer
 }
 
 func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
+	logBuf := new(bytes.Buffer)
 	ts := httptest.NewUnstartedServer(handler)
 	ConfigureServer(ts.Config, &Server{})
 	ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
+	ts.Config.ErrorLog = log.New(io.MultiWriter(twriter{t: t}, logBuf), "", log.LstdFlags)
 	ts.StartTLS()
 
 	t.Logf("Running test server at: %s", ts.URL)
@@ -52,10 +55,11 @@ func newServerTester(t *testing.T, handler http.HandlerFunc) *serverTester {
 	}
 	log.SetOutput(twriter{t})
 	return &serverTester{
-		t:  t,
-		ts: ts,
-		cc: cc,
-		fr: NewFramer(cc, cc),
+		t:      t,
+		ts:     ts,
+		cc:     cc,
+		fr:     NewFramer(cc, cc),
+		logBuf: logBuf,
 	}
 }
 
@@ -65,6 +69,16 @@ func (st *serverTester) Close() {
 	log.SetOutput(os.Stderr)
 }
 
+// greet initiates the client's HTTP/2 connection into a state where
+// frames may be sent.
+func (st *serverTester) greet() {
+	st.writePreface()
+	st.writeInitialSettings()
+	st.wantSettings()
+	st.writeSettingsAck()
+	st.wantSettingsAck()
+}
+
 func (st *serverTester) writePreface() {
 	n, err := st.cc.Write(clientPreface)
 	if err != nil {
@@ -93,10 +107,26 @@ func (st *serverTester) writeHeaders(p HeadersFrameParam) {
 	}
 }
 
+// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set.
+func (st *serverTester) bodylessReq1(headers ...string) {
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1, // clients send odd numbers
+		BlockFragment: encodeHeader(st.t, headers...),
+		EndStream:     true,
+		EndHeaders:    true,
+	})
+}
+
+func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
+	if err := st.fr.WriteData(streamID, endStream, data); err != nil {
+		st.t.Fatalf("Error writing DATA: %v", err)
+	}
+}
+
 func (st *serverTester) wantSettings() *SettingsFrame {
 	f, err := st.fr.ReadFrame()
 	if err != nil {
-		st.t.Fatal(err)
+		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
 	}
 	sf, ok := f.(*SettingsFrame)
 	if !ok {
@@ -105,6 +135,23 @@ func (st *serverTester) wantSettings() *SettingsFrame {
 	return sf
 }
 
+func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
+	f, err := st.fr.ReadFrame()
+	if err != nil {
+		st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
+	}
+	rs, ok := f.(*RSTStreamFrame)
+	if !ok {
+		st.t.Fatalf("got a %T; want *RSTStream", f)
+	}
+	if rs.FrameHeader.StreamID != streamID {
+		st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
+	}
+	if rs.ErrCode != uint32(errCode) {
+		st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
+	}
+}
+
 func (st *serverTester) wantSettingsAck() {
 	f, err := st.fr.ReadFrame()
 	if err != nil {
@@ -144,14 +191,10 @@ func TestServer(t *testing.T) {
 	st.wantSettingsAck()
 
 	st.writeHeaders(HeadersFrameParam{
-		StreamID: 1, // clients send odd numbers
-		BlockFragment: encodeHeader(t,
-			":method", "GET",
-			":path", "/",
-			":scheme", "https",
-		),
-		EndStream:  true, // no DATA frames
-		EndHeaders: true,
+		StreamID:      1, // clients send odd numbers
+		BlockFragment: encodeHeader(t),
+		EndStream:     true, // no DATA frames
+		EndHeaders:    true,
 	})
 
 	select {
@@ -164,18 +207,12 @@ func TestServer(t *testing.T) {
 func TestServer_Request_Get(t *testing.T) {
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				"foo-bar", "some-value",
-			),
-			EndStream:  true, // no DATA frames
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, "foo-bar", "some-value"),
+			EndStream:     true, // no DATA frames
+			EndHeaders:    true,
 		})
 	}, func(r *http.Request) {
-		t.Logf("GOT %#v", r)
 		if r.Method != "GET" {
 			t.Errorf("Method = %q; want GET", r.Method)
 		}
@@ -203,20 +240,63 @@ func TestServer_Request_Get(t *testing.T) {
 	})
 }
 
+// TODO: add a test with EndStream=true on the HEADERS but setting a
+// Content-Length anyway.  Should we just omit it and force it to
+// zero?
+
+func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":method", "POST"),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+	}, func(r *http.Request) {
+		if r.Method != "POST" {
+			t.Errorf("Method = %q; want POST", r.Method)
+		}
+		if r.ContentLength != 0 {
+			t.Errorf("ContentLength = %v; want 0", r.ContentLength)
+		}
+		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
+			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
+		}
+	})
+}
+
+func TestServer_Request_Post_Body(t *testing.T) {
+	t.Skip("TODO: post bodies not yet implemented")
+	testServerRequest(t, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":method", "POST"),
+			EndStream:     false, // migth be DATA frames
+			EndHeaders:    true,
+		})
+		st.writeData(1, true, nil)
+	}, func(r *http.Request) {
+		if r.Method != "POST" {
+			t.Errorf("Method = %q; want POST", r.Method)
+		}
+		if r.ContentLength != -1 {
+			t.Errorf("ContentLength = %v; want -1", r.ContentLength)
+		}
+		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
+			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
+		}
+	})
+}
+
 // Using a Host header, instead of :authority
 func TestServer_Request_Get_Host(t *testing.T) {
 	const host = "example.com"
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				"host", host,
-			),
-			EndStream:  true,
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, "host", host),
+			EndStream:     true,
+			EndHeaders:    true,
 		})
 	}, func(r *http.Request) {
 		if r.Host != host {
@@ -230,15 +310,10 @@ func TestServer_Request_Get_Authority(t *testing.T) {
 	const host = "example.com"
 	testServerRequest(t, func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
-			StreamID: 1, // clients send odd numbers
-			BlockFragment: encodeHeader(t,
-				":method", "GET",
-				":path", "/",
-				":scheme", "https",
-				":authority", host,
-			),
-			EndStream:  true,
-			EndHeaders: true,
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":authority", host),
+			EndStream:     true,
+			EndHeaders:    true,
 		})
 	}, func(r *http.Request) {
 		if r.Host != host {
@@ -255,9 +330,6 @@ func TestServer_Request_WithContinuation(t *testing.T) {
 	}
 	testServerRequest(t, func(st *serverTester) {
 		fullHeaders := encodeHeader(t,
-			":method", "GET",
-			":path", "/",
-			":scheme", "https",
 			"foo-one", "value-one",
 			"foo-two", "value-two",
 			"foo-three", "value-three",
@@ -297,6 +369,36 @@ func TestServer_Request_WithContinuation(t *testing.T) {
 	})
 }
 
+// Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field")
+func TestServer_Request_CookieConcat(t *testing.T) {
+	const host = "example.com"
+	testServerRequest(t, func(st *serverTester) {
+		st.bodylessReq1(
+			":authority", host,
+			"cookie", "a=b",
+			"cookie", "c=d",
+			"cookie", "e=f",
+		)
+	}, func(r *http.Request) {
+		const want = "a=b; c=d; e=f"
+		if got := r.Header.Get("Cookie"); got != want {
+			t.Errorf("Cookie = %q; want %q", got, want)
+		}
+	})
+}
+
+func TestServer_Request_RejectCapitalHeader(t *testing.T) {
+	t.Skip("TODO: not handling stream errors properly yet in http2.go: if h2e.IsStreamError stuff")
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		t.Fatal("server request made it to handler; should've been rejected")
+	})
+	defer st.Close()
+
+	st.greet()
+	st.bodylessReq1("UPPER", "v")
+	st.wantRSTStream(1, ErrCodeProtocol)
+}
+
 // testServerRequest sets up an idle HTTP/2 connection and lets you
 // write a single request with writeReq, and then verify that the
 // *http.Request is built correctly in checkReq.
@@ -311,12 +413,7 @@ func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func
 	})
 	defer st.Close()
 
-	st.writePreface()
-	st.writeInitialSettings()
-	st.wantSettings()
-	st.writeSettingsAck()
-	st.wantSettingsAck()
-
+	st.greet()
 	writeReq(st)
 
 	select {
@@ -433,17 +530,39 @@ func (w twriter) Write(p []byte) (n int, err error) {
 	return len(p), nil
 }
 
-func encodeHeader(t *testing.T, kv ...string) []byte {
-	if len(kv)%2 == 1 {
+// encodeHeader encodes headers and returns their HPACK bytes. headers
+// must contain an even number of key/value pairs.  There may be
+// multiple pairs for keys (e.g. "cookie").  The :method, :path, and
+// :scheme headers default to GET, / and https.
+func encodeHeader(t *testing.T, headers ...string) []byte {
+	if len(headers)%2 == 1 {
 		panic("odd number of kv args")
 	}
+	keys := []string{":method", ":path", ":scheme"}
+	vals := map[string][]string{
+		":method": {"GET"},
+		":path":   {"/"},
+		":scheme": {"https"},
+	}
+	for len(headers) > 0 {
+		k, v := headers[0], headers[1]
+		headers = headers[2:]
+		if _, ok := vals[k]; !ok {
+			keys = append(keys, k)
+		}
+		if strings.HasPrefix(k, ":") {
+			vals[k] = []string{v}
+		} else {
+			vals[k] = append(vals[k], v)
+		}
+	}
 	var buf bytes.Buffer
 	enc := hpack.NewEncoder(&buf)
-	for len(kv) > 0 {
-		k, v := kv[0], kv[1]
-		kv = kv[2:]
-		if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
-			t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+	for _, k := range keys {
+		for _, v := range vals[k] {
+			if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
+				t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
+			}
 		}
 	}
 	return buf.Bytes()