Browse Source

http2: Respect peer's SETTINGS_MAX_HEADER_LIST_SIZE in ClientConn

Add a new peerMaxHeaderListSize member to ClientConn which records the
SETTINGS_MAX_HEADER_LIST_SIZE requested by the client's peer, and
respect this limit in (*ClientConn) encodeHeaders / encodeTrailers.

Attempting to send more than peerMaxHeaderListSize bytes of headers or
trailers will result in RoundTrip returning errRequestHeaderListSize.

Updates golang/go#13959

Change-Id: Ic707179782acdf8ae543429ea1af7f4f30e67e59
Reviewed-on: https://go-review.googlesource.com/29243
Run-TryBot: Tom Bergan <tombergan@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Tom Bergan <tombergan@google.com>
Mike Appleby 9 years ago
parent
commit
66aacef3dd
2 changed files with 426 additions and 69 deletions
  1. 114 68
      http2/transport.go
  2. 312 1
      http2/transport_test.go

+ 114 - 68
http2/transport.go

@@ -87,7 +87,7 @@ type Transport struct {
 
 	// MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to
 	// send in the initial settings frame. It is how many bytes
-	// of response headers are allow. Unlike the http2 spec, zero here
+	// of response headers are allowed. Unlike the http2 spec, zero here
 	// means to use a default limit (currently 10MB). If you actually
 	// want to advertise an ulimited value to the peer, Transport
 	// interprets the highest possible value here (0xffffffff or 1<<32-1)
@@ -172,9 +172,10 @@ type ClientConn struct {
 	fr              *Framer
 	lastActive      time.Time
 	// Settings from peer: (also guarded by mu)
-	maxFrameSize         uint32
-	maxConcurrentStreams uint32
-	initialWindowSize    uint32
+	maxFrameSize          uint32
+	maxConcurrentStreams  uint32
+	peerMaxHeaderListSize uint64
+	initialWindowSize     uint32
 
 	hbuf    bytes.Buffer // HPACK encoder writes into this
 	henc    *hpack.Encoder
@@ -519,17 +520,18 @@ func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
 
 func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
 	cc := &ClientConn{
-		t:                    t,
-		tconn:                c,
-		readerDone:           make(chan struct{}),
-		nextStreamID:         1,
-		maxFrameSize:         16 << 10, // spec default
-		initialWindowSize:    65535,    // spec default
-		maxConcurrentStreams: 1000,     // "infinite", per spec. 1000 seems good enough.
-		streams:              make(map[uint32]*clientStream),
-		singleUse:            singleUse,
-		wantSettingsAck:      true,
-		pings:                make(map[[8]byte]chan struct{}),
+		t:                     t,
+		tconn:                 c,
+		readerDone:            make(chan struct{}),
+		nextStreamID:          1,
+		maxFrameSize:          16 << 10,           // spec default
+		initialWindowSize:     65535,              // spec default
+		maxConcurrentStreams:  1000,               // "infinite", per spec. 1000 seems good enough.
+		peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
+		streams:               make(map[uint32]*clientStream),
+		singleUse:             singleUse,
+		wantSettingsAck:       true,
+		pings:                 make(map[[8]byte]chan struct{}),
 	}
 	if d := t.idleConnTimeout(); d != 0 {
 		cc.idleTimeout = d
@@ -1085,8 +1087,13 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
 	var trls []byte
 	if hasTrailers {
 		cc.mu.Lock()
-		defer cc.mu.Unlock()
-		trls = cc.encodeTrailers(req)
+		trls, err = cc.encodeTrailers(req)
+		cc.mu.Unlock()
+		if err != nil {
+			cc.writeStreamReset(cs.ID, ErrCodeInternal, err)
+			cc.forgetStreamID(cs.ID)
+			return err
+		}
 	}
 
 	cc.wmu.Lock()
@@ -1189,62 +1196,86 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
 		}
 	}
 
-	// 8.1.2.3 Request Pseudo-Header Fields
-	// The :path pseudo-header field includes the path and query parts of the
-	// target URI (the path-absolute production and optionally a '?' character
-	// followed by the query production (see Sections 3.3 and 3.4 of
-	// [RFC3986]).
-	cc.writeHeader(":authority", host)
-	cc.writeHeader(":method", req.Method)
-	if req.Method != "CONNECT" {
-		cc.writeHeader(":path", path)
-		cc.writeHeader(":scheme", req.URL.Scheme)
-	}
-	if trailers != "" {
-		cc.writeHeader("trailer", trailers)
-	}
+	enumerateHeaders := func(f func(name, value string)) {
+		// 8.1.2.3 Request Pseudo-Header Fields
+		// The :path pseudo-header field includes the path and query parts of the
+		// target URI (the path-absolute production and optionally a '?' character
+		// followed by the query production (see Sections 3.3 and 3.4 of
+		// [RFC3986]).
+		f(":authority", host)
+		f(":method", req.Method)
+		if req.Method != "CONNECT" {
+			f(":path", path)
+			f(":scheme", req.URL.Scheme)
+		}
+		if trailers != "" {
+			f("trailer", trailers)
+		}
 
-	var didUA bool
-	for k, vv := range req.Header {
-		lowKey := strings.ToLower(k)
-		switch lowKey {
-		case "host", "content-length":
-			// Host is :authority, already sent.
-			// Content-Length is automatic, set below.
-			continue
-		case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
-			// Per 8.1.2.2 Connection-Specific Header
-			// Fields, don't send connection-specific
-			// fields. We have already checked if any
-			// are error-worthy so just ignore the rest.
-			continue
-		case "user-agent":
-			// Match Go's http1 behavior: at most one
-			// User-Agent. If set to nil or empty string,
-			// then omit it. Otherwise if not mentioned,
-			// include the default (below).
-			didUA = true
-			if len(vv) < 1 {
+		var didUA bool
+		for k, vv := range req.Header {
+			if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
+				// Host is :authority, already sent.
+				// Content-Length is automatic, set below.
 				continue
-			}
-			vv = vv[:1]
-			if vv[0] == "" {
+			} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
+				strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
+				strings.EqualFold(k, "keep-alive") {
+				// Per 8.1.2.2 Connection-Specific Header
+				// Fields, don't send connection-specific
+				// fields. We have already checked if any
+				// are error-worthy so just ignore the rest.
 				continue
+			} else if strings.EqualFold(k, "user-agent") {
+				// Match Go's http1 behavior: at most one
+				// User-Agent. If set to nil or empty string,
+				// then omit it. Otherwise if not mentioned,
+				// include the default (below).
+				didUA = true
+				if len(vv) < 1 {
+					continue
+				}
+				vv = vv[:1]
+				if vv[0] == "" {
+					continue
+				}
+
+			}
+
+			for _, v := range vv {
+				f(k, v)
 			}
 		}
-		for _, v := range vv {
-			cc.writeHeader(lowKey, v)
+		if shouldSendReqContentLength(req.Method, contentLength) {
+			f("content-length", strconv.FormatInt(contentLength, 10))
+		}
+		if addGzipHeader {
+			f("accept-encoding", "gzip")
+		}
+		if !didUA {
+			f("user-agent", defaultUserAgent)
 		}
 	}
-	if shouldSendReqContentLength(req.Method, contentLength) {
-		cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
-	}
-	if addGzipHeader {
-		cc.writeHeader("accept-encoding", "gzip")
-	}
-	if !didUA {
-		cc.writeHeader("user-agent", defaultUserAgent)
+
+	// Do a first pass over the headers counting bytes to ensure
+	// we don't exceed cc.peerMaxHeaderListSize. This is done as a
+	// separate pass before encoding the headers to prevent
+	// modifying the hpack state.
+	hlSize := uint64(0)
+	enumerateHeaders(func(name, value string) {
+		hf := hpack.HeaderField{Name: name, Value: value}
+		hlSize += uint64(hf.Size())
+	})
+
+	if hlSize > cc.peerMaxHeaderListSize {
+		return nil, errRequestHeaderListSize
 	}
+
+	// Header list size is ok. Write the headers.
+	enumerateHeaders(func(name, value string) {
+		cc.writeHeader(strings.ToLower(name), value)
+	})
+
 	return cc.hbuf.Bytes(), nil
 }
 
@@ -1271,17 +1302,29 @@ func shouldSendReqContentLength(method string, contentLength int64) bool {
 }
 
 // requires cc.mu be held.
-func (cc *ClientConn) encodeTrailers(req *http.Request) []byte {
+func (cc *ClientConn) encodeTrailers(req *http.Request) ([]byte, error) {
 	cc.hbuf.Reset()
+
+	hlSize := uint64(0)
+	for k, vv := range req.Trailer {
+		for _, v := range vv {
+			hf := hpack.HeaderField{Name: k, Value: v}
+			hlSize += uint64(hf.Size())
+		}
+	}
+	if hlSize > cc.peerMaxHeaderListSize {
+		return nil, errRequestHeaderListSize
+	}
+
 	for k, vv := range req.Trailer {
-		// Transfer-Encoding, etc.. have already been filter at the
+		// Transfer-Encoding, etc.. have already been filtered at the
 		// start of RoundTrip
 		lowKey := strings.ToLower(k)
 		for _, v := range vv {
 			cc.writeHeader(lowKey, v)
 		}
 	}
-	return cc.hbuf.Bytes()
+	return cc.hbuf.Bytes(), nil
 }
 
 func (cc *ClientConn) writeHeader(name, value string) {
@@ -1911,6 +1954,8 @@ func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error {
 			cc.maxFrameSize = s.Val
 		case SettingMaxConcurrentStreams:
 			cc.maxConcurrentStreams = s.Val
+		case SettingMaxHeaderListSize:
+			cc.peerMaxHeaderListSize = uint64(s.Val)
 		case SettingInitialWindowSize:
 			// Values above the maximum flow-control
 			// window size of 2^31-1 MUST be treated as a
@@ -2077,6 +2122,7 @@ func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, err error)
 
 var (
 	errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
+	errRequestHeaderListSize  = errors.New("http2: request header list larger than peer's advertised limit")
 	errPseudoTrailers         = errors.New("http2: invalid pseudo header in trailers")
 )
 

+ 312 - 1
http2/transport_test.go

@@ -1371,6 +1371,269 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT
 	ct.run()
 }
 
+// headerListSize returns the HTTP2 header list size of h.
+//   http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
+//   http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock
+func headerListSize(h http.Header) (size uint32) {
+	for k, vv := range h {
+		for _, v := range vv {
+			hf := hpack.HeaderField{Name: k, Value: v}
+			size += hf.Size()
+		}
+	}
+	return size
+}
+
+// padHeaders adds data to an http.Header until headerListSize(h) ==
+// limit. Due to the way header list sizes are calculated, padHeaders
+// cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will
+// call t.Fatal if asked to do so. PadHeaders first reserves enough
+// space for an empty "Pad-Headers" key, then adds as many copies of
+// filler as possible. Any remaining bytes necessary to push the
+// header list size up to limit are added to h["Pad-Headers"].
+func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) {
+	if limit > 0xffffffff {
+		t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
+	}
+	hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
+	minPadding := uint64(hf.Size())
+	size := uint64(headerListSize(h))
+
+	minlimit := size + minPadding
+	if limit < minlimit {
+		t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
+	}
+
+	// Use a fixed-width format for name so that fieldSize
+	// remains constant.
+	nameFmt := "Pad-Headers-%06d"
+	hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
+	fieldSize := uint64(hf.Size())
+
+	// Add as many complete filler values as possible, leaving
+	// room for at least one empty "Pad-Headers" key.
+	limit = limit - minPadding
+	for i := 0; size+fieldSize < limit; i++ {
+		name := fmt.Sprintf(nameFmt, i)
+		h.Add(name, filler)
+		size += fieldSize
+	}
+
+	// Add enough bytes to reach limit.
+	remain := limit - size
+	lastValue := strings.Repeat("*", int(remain))
+	h.Add("Pad-Headers", lastValue)
+}
+
+func TestPadHeaders(t *testing.T) {
+	check := func(h http.Header, limit uint32, fillerLen int) {
+		if h == nil {
+			h = make(http.Header)
+		}
+		filler := strings.Repeat("f", fillerLen)
+		padHeaders(t, h, uint64(limit), filler)
+		gotSize := headerListSize(h)
+		if gotSize != limit {
+			t.Errorf("Got size = %v; want %v", gotSize, limit)
+		}
+	}
+	// Try all possible combinations for small fillerLen and limit.
+	hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
+	minLimit := hf.Size()
+	for limit := minLimit; limit <= 128; limit++ {
+		for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
+			check(nil, limit, fillerLen)
+		}
+	}
+
+	// Try a few tests with larger limits, plus cumulative
+	// tests. Since these tests are cumulative, tests[i+1].limit
+	// must be >= tests[i].limit + minLimit. See the comment on
+	// padHeaders for more info on why the limit arg has this
+	// restriction.
+	tests := []struct {
+		fillerLen int
+		limit     uint32
+	}{
+		{
+			fillerLen: 64,
+			limit:     1024,
+		},
+		{
+			fillerLen: 1024,
+			limit:     1286,
+		},
+		{
+			fillerLen: 256,
+			limit:     2048,
+		},
+		{
+			fillerLen: 1024,
+			limit:     10 * 1024,
+		},
+		{
+			fillerLen: 1023,
+			limit:     11 * 1024,
+		},
+	}
+	h := make(http.Header)
+	for _, tc := range tests {
+		check(nil, tc.limit, tc.fillerLen)
+		check(h, tc.limit, tc.fillerLen)
+	}
+}
+
+func TestTransportChecksRequestHeaderListSize(t *testing.T) {
+	st := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {
+			// Consume body & force client to send
+			// trailers before writing response.
+			// ioutil.ReadAll returns non-nil err for
+			// requests that attempt to send greater than
+			// maxHeaderListSize bytes of trailers, since
+			// those requests generate a stream reset.
+			ioutil.ReadAll(r.Body)
+			r.Body.Close()
+		},
+		func(ts *httptest.Server) {
+			ts.Config.MaxHeaderBytes = 16 << 10
+		},
+		optOnlyServer,
+		optQuiet,
+	)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+
+	checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
+		res, err := tr.RoundTrip(req)
+		if err != wantErr {
+			if res != nil {
+				res.Body.Close()
+			}
+			t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
+			return
+		}
+		if err == nil {
+			if res == nil {
+				t.Errorf("%v: response nil; want non-nil.", desc)
+				return
+			}
+			defer res.Body.Close()
+			if res.StatusCode != http.StatusOK {
+				t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK)
+			}
+			return
+		}
+		if res != nil {
+			t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err)
+		}
+	}
+	headerListSizeForRequest := func(req *http.Request) (size uint64) {
+		contentLen := actualContentLength(req)
+		trailers, err := commaSeparatedTrailers(req)
+		if err != nil {
+			t.Fatalf("headerListSizeForRequest: %v", err)
+		}
+		cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
+		cc.henc = hpack.NewEncoder(&cc.hbuf)
+		cc.mu.Lock()
+		hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
+		cc.mu.Unlock()
+		if err != nil {
+			t.Fatalf("headerListSizeForRequest: %v", err)
+		}
+		hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
+			size += uint64(hf.Size())
+		})
+		if len(hdrs) > 0 {
+			if _, err := hpackDec.Write(hdrs); err != nil {
+				t.Fatalf("headerListSizeForRequest: %v", err)
+			}
+		}
+		return size
+	}
+	// Create a new Request for each test, rather than reusing the
+	// same Request, to avoid a race when modifying req.Headers.
+	// See https://github.com/golang/go/issues/21316
+	newRequest := func() *http.Request {
+		// Body must be non-nil to enable writing trailers.
+		body := strings.NewReader("hello")
+		req, err := http.NewRequest("POST", st.ts.URL, body)
+		if err != nil {
+			t.Fatalf("newRequest: NewRequest: %v", err)
+		}
+		return req
+	}
+
+	// Make an arbitrary request to ensure we get the server's
+	// settings frame and initialize peerMaxHeaderListSize.
+	req := newRequest()
+	checkRoundTrip(req, nil, "Initial request")
+
+	// Get the ClientConn associated with the request and validate
+	// peerMaxHeaderListSize.
+	addr := authorityAddr(req.URL.Scheme, req.URL.Host)
+	cc, err := tr.connPool().GetClientConn(req, addr)
+	if err != nil {
+		t.Fatalf("GetClientConn: %v", err)
+	}
+	cc.mu.Lock()
+	peerSize := cc.peerMaxHeaderListSize
+	cc.mu.Unlock()
+	st.scMu.Lock()
+	wantSize := uint64(st.sc.maxHeaderListSize())
+	st.scMu.Unlock()
+	if peerSize != wantSize {
+		t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize)
+	}
+
+	// Sanity check peerSize. (*serverConn) maxHeaderListSize adds
+	// 320 bytes of padding.
+	wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320
+	if peerSize != wantHeaderBytes {
+		t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes)
+	}
+
+	// Pad headers & trailers, but stay under peerSize.
+	req = newRequest()
+	req.Header = make(http.Header)
+	req.Trailer = make(http.Header)
+	filler := strings.Repeat("*", 1024)
+	padHeaders(t, req.Trailer, peerSize, filler)
+	// cc.encodeHeaders adds some default headers to the request,
+	// so we need to leave room for those.
+	defaultBytes := headerListSizeForRequest(req)
+	padHeaders(t, req.Header, peerSize-defaultBytes, filler)
+	checkRoundTrip(req, nil, "Headers & Trailers under limit")
+
+	// Add enough header bytes to push us over peerSize.
+	req = newRequest()
+	req.Header = make(http.Header)
+	padHeaders(t, req.Header, peerSize, filler)
+	checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit")
+
+	// Push trailers over the limit.
+	req = newRequest()
+	req.Trailer = make(http.Header)
+	padHeaders(t, req.Trailer, peerSize+1, filler)
+	checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit")
+
+	// Send headers with a single large value.
+	req = newRequest()
+	filler = strings.Repeat("*", int(peerSize))
+	req.Header = make(http.Header)
+	req.Header.Set("Big", filler)
+	checkRoundTrip(req, errRequestHeaderListSize, "Single large header")
+
+	// Send trailers with a single large value.
+	req = newRequest()
+	req.Trailer = make(http.Header)
+	req.Trailer.Set("Big", filler)
+	checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer")
+}
+
 func TestTransportChecksResponseHeaderListSize(t *testing.T) {
 	ct := newClientTester(t)
 	ct.client = func() error {
@@ -2663,7 +2926,7 @@ func TestTransportRequestPathPseudo(t *testing.T) {
 		},
 	}
 	for i, tt := range tests {
-		cc := &ClientConn{}
+		cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
 		cc.henc = hpack.NewEncoder(&cc.hbuf)
 		cc.mu.Lock()
 		hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
@@ -3373,3 +3636,51 @@ func TestTransportNoBodyMeansNoDATA(t *testing.T) {
 	}
 	ct.run()
 }
+
+func benchSimpleRoundTrip(b *testing.B, nHeaders int) {
+	defer disableGoroutineTracking()()
+	b.ReportAllocs()
+	st := newServerTester(b,
+		func(w http.ResponseWriter, r *http.Request) {
+		},
+		optOnlyServer,
+		optQuiet,
+	)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+
+	req, err := http.NewRequest("GET", st.ts.URL, nil)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	for i := 0; i < nHeaders; i++ {
+		name := fmt.Sprint("A-", i)
+		req.Header.Set(name, "*")
+	}
+
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		res, err := tr.RoundTrip(req)
+		if err != nil {
+			if res != nil {
+				res.Body.Close()
+			}
+			b.Fatalf("RoundTrip err = %v; want nil", err)
+		}
+		res.Body.Close()
+		if res.StatusCode != http.StatusOK {
+			b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
+		}
+	}
+}
+
+func BenchmarkClientRequestHeaders(b *testing.B) {
+	b.Run("   0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0) })
+	b.Run("  10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10) })
+	b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100) })
+	b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000) })
+}