Browse Source

http2: fix enforcement of max header list size

In the first attempt to enforce the SETTINGS_MAX_HEADER_LIST_SIZE
(https://go-review.googlesource.com/15751), the enforcement happened
in the hpack decoder and the hpack decoder returned errors on Write
and Close if the limit was violated. This was incorrect because the
decoder is used over the life of the connection and all subsequent
requests and could therefore get out of sync.

Instead, this moves the counting of the limit up to the http2 package
in the serverConn type, and replaces the hpack counting mechanism with
a simple on/off switch. When SetEmitEnabled is set false, the header
field emit callbacks will be suppressed and the hpack Decoder will do
less work (less CPU and garbage) if possible, but will still return
nil from Write and Close on valid input, and will still stay in sync
it the stream.

The http2 Server then returns a 431 error if emits were disabled while
processing the HEADER or any CONTINUATION frames.

Fixes golang/go#12843

Change-Id: I3b41aaefc6c6ee6218225f8dc62bba6ae5fe8f2d
Reviewed-on: https://go-review.googlesource.com/15733
Reviewed-by: Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick 10 years ago
parent
commit
d8f3c68ddd
4 changed files with 107 additions and 80 deletions
  1. 43 36
      http2/hpack/hpack.go
  2. 22 34
      http2/hpack/hpack_test.go
  3. 30 7
      http2/server.go
  4. 12 3
      http2/server_test.go

+ 43 - 36
http2/hpack/hpack.go

@@ -64,9 +64,7 @@ type Decoder struct {
 	dynTab dynamicTable
 	emit   func(f HeaderField)
 
-	headerListSize    int64
-	maxHeaderListSize uint32 // 0 means unlimited
-	hitLimit          bool
+	emitEnabled bool // whether calls to emit are enabled
 
 	// buf is the unparsed buffer. It's only written to
 	// saveBuf if it was truncated in the middle of a header
@@ -78,23 +76,29 @@ type Decoder struct {
 
 // NewDecoder returns a new decoder with the provided maximum dynamic
 // table size. The emitFunc will be called for each valid field
-// parsed.
+// parsed, in the same goroutine as calls to Write, before Write returns.
 func NewDecoder(maxDynamicTableSize uint32, emitFunc func(f HeaderField)) *Decoder {
 	d := &Decoder{
-		emit: emitFunc,
+		emit:        emitFunc,
+		emitEnabled: true,
 	}
 	d.dynTab.allowedMaxSize = maxDynamicTableSize
 	d.dynTab.setMaxSize(maxDynamicTableSize)
 	return d
 }
 
-// SetMaxHeaderListSize sets the decoder's SETTINGS_MAX_HEADER_LIST_SIZE.
-// It should be set before any call to Write.
-// The default, 0, means unlimited.
-// If the limit is passed, calls to Write and Close will return ErrMaxHeaderListSize.
-func (d *Decoder) SetMaxHeaderListSize(v uint32) {
-	d.maxHeaderListSize = v
-}
+// SetEmitEnabled controls whether the emitFunc provided to NewDecoder
+// should be called. The default is true.
+//
+// This facility exists to let servers enforce MAX_HEADER_LIST_SIZE
+// while still decoding and keeping in-sync with decoder state, but
+// without doing unnecessary decompression or generating unnecessary
+// garbage for header fields past the limit.
+func (d *Decoder) SetEmitEnabled(v bool) { d.emitEnabled = v }
+
+// EmitEnabled reports whether calls to the emitFunc provided to NewDecoder
+// are currently enabled. The default is true.
+func (d *Decoder) EmitEnabled() bool { return d.emitEnabled }
 
 // TODO: add method *Decoder.Reset(maxSize, emitFunc) to let callers re-use Decoders and their
 // underlying buffers for garbage reasons.
@@ -235,16 +239,11 @@ func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
 	return hf, nil
 }
 
-var ErrMaxHeaderListSize = errors.New("hpack: max header list size exceeded")
-
 func (d *Decoder) Close() error {
 	if d.saveBuf.Len() > 0 {
 		d.saveBuf.Reset()
 		return DecodingError{errors.New("truncated headers")}
 	}
-	if d.hitLimit {
-		return ErrMaxHeaderListSize
-	}
 	return nil
 }
 
@@ -265,7 +264,7 @@ func (d *Decoder) Write(p []byte) (n int, err error) {
 		d.saveBuf.Reset()
 	}
 
-	for len(d.buf) > 0 && !d.hitLimit {
+	for len(d.buf) > 0 {
 		err = d.parseHeaderFieldRepr()
 		if err != nil {
 			if err == errNeedMore {
@@ -275,9 +274,6 @@ func (d *Decoder) Write(p []byte) (n int, err error) {
 			break
 		}
 	}
-	if err == nil && d.hitLimit {
-		err = ErrMaxHeaderListSize
-	}
 	return len(p), err
 }
 
@@ -359,6 +355,7 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
 	}
 
 	var hf HeaderField
+	wantStr := d.emitEnabled || it.indexed()
 	if nameIdx > 0 {
 		ihf, ok := d.at(nameIdx)
 		if !ok {
@@ -366,12 +363,12 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
 		}
 		hf.Name = ihf.Name
 	} else {
-		hf.Name, buf, err = readString(buf)
+		hf.Name, buf, err = readString(buf, wantStr)
 		if err != nil {
 			return err
 		}
 	}
-	hf.Value, buf, err = readString(buf)
+	hf.Value, buf, err = readString(buf, wantStr)
 	if err != nil {
 		return err
 	}
@@ -385,13 +382,9 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
 }
 
 func (d *Decoder) callEmit(hf HeaderField) {
-	const overheadPerField = 32 // per http2 section 6.5.2, etc
-	d.headerListSize += int64(len(hf.Name)+len(hf.Value)) + overheadPerField
-	if d.maxHeaderListSize != 0 && d.headerListSize > int64(d.maxHeaderListSize) {
-		d.hitLimit = true
-		return
+	if d.emitEnabled {
+		d.emit(hf)
 	}
-	d.emit(hf)
 }
 
 // (same invariants and behavior as parseHeaderFieldRepr)
@@ -452,7 +445,15 @@ func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
 	return 0, origP, errNeedMore
 }
 
-func readString(p []byte) (s string, remain []byte, err error) {
+// readString decodes an hpack string from p.
+//
+// wantStr is whether s will be used. If false, decompression and
+// []byte->string garbage are skipped if s will be ignored
+// anyway. This does mean that huffman decoding errors for non-indexed
+// strings past the MAX_HEADER_LIST_SIZE are ignored, but the server
+// is returning an error anyway, and because they're not indexed, the error
+// won't affect the decoding state.
+func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
 	if len(p) == 0 {
 		return "", p, errNeedMore
 	}
@@ -465,13 +466,19 @@ func readString(p []byte) (s string, remain []byte, err error) {
 		return "", p, errNeedMore
 	}
 	if !isHuff {
-		return string(p[:strLen]), p[strLen:], nil
+		if wantStr {
+			s = string(p[:strLen])
+		}
+		return s, p[strLen:], nil
 	}
 
-	// TODO: optimize this garbage:
-	var buf bytes.Buffer
-	if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil {
-		return "", nil, err
+	if wantStr {
+		// TODO: optimize this garbage:
+		var buf bytes.Buffer
+		if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil {
+			return "", nil, err
+		}
+		s = buf.String()
 	}
-	return buf.String(), p[strLen:], nil
+	return s, p[strLen:], nil
 }

+ 22 - 34
http2/hpack/hpack_test.go

@@ -647,40 +647,28 @@ func dehex(s string) []byte {
 	return b
 }
 
-func TestMaxHeaderListSize(t *testing.T) {
-	tests := []struct {
-		fields  []HeaderField
-		max     int
-		wantErr bool
-	}{
-		// Plenty of space.
-		{
-			fields: []HeaderField{{Name: "foo", Value: "bar"}},
-			max:    500,
-		},
-		// Exactly right limit.
-		{
-			fields: []HeaderField{{Name: "foo", Value: "bar"}},
-			max:    len("foo") + len("bar") + 32,
-		},
-		// One byte too short.
-		{
-			fields:  []HeaderField{{Name: "foo", Value: "bar"}},
-			max:     len("foo") + len("bar") + 32 - 1,
-			wantErr: true,
-		},
+func TestEmitEnabled(t *testing.T) {
+	var buf bytes.Buffer
+	enc := NewEncoder(&buf)
+	enc.WriteField(HeaderField{Name: "foo", Value: "bar"})
+	enc.WriteField(HeaderField{Name: "foo", Value: "bar"})
+
+	numCallback := 0
+	var dec *Decoder
+	dec = NewDecoder(8<<20, func(HeaderField) {
+		numCallback++
+		dec.SetEmitEnabled(false)
+	})
+	if !dec.EmitEnabled() {
+		t.Errorf("initial emit enabled = false; want true")
 	}
-	for i, tt := range tests {
-		var buf bytes.Buffer
-		enc := NewEncoder(&buf)
-		for _, hf := range tt.fields {
-			enc.WriteField(hf)
-		}
-		dec := NewDecoder(8<<20, func(HeaderField) {})
-		dec.SetMaxHeaderListSize(uint32(tt.max))
-		_, err := dec.Write(buf.Bytes())
-		if (err != nil) != tt.wantErr {
-			t.Errorf("%d. err = %v; want err = %v", i, err, tt.wantErr)
-		}
+	if _, err := dec.Write(buf.Bytes()); err != nil {
+		t.Error(err)
+	}
+	if numCallback != 1 {
+		t.Errorf("num callbacks = %d; want 1", numCallback)
+	}
+	if dec.EmitEnabled() {
+		t.Errorf("emit enabled = true; want false")
 	}
 }

+ 30 - 7
http2/server.go

@@ -220,7 +220,6 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 	sc.inflow.add(initialWindowSize)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
-	sc.hpackDecoder.SetMaxHeaderListSize(sc.maxHeaderListSize())
 
 	fr := NewFramer(sc.bw, c)
 	fr.SetMaxReadFrameSize(srv.maxReadFrameSize())
@@ -373,7 +372,7 @@ type serverConn struct {
 
 func (sc *serverConn) maxHeaderListSize() uint32 {
 	n := sc.hs.MaxHeaderBytes
-	if n == 0 {
+	if n <= 0 {
 		n = http.DefaultMaxHeaderBytes
 	}
 	// http2's count is in a slightly different unit and includes 32 bytes per pair.
@@ -393,8 +392,9 @@ type requestParam struct {
 	header            http.Header
 	method, path      string
 	scheme, authority string
-	sawRegularHeader  bool // saw a non-pseudo header already
-	invalidHeader     bool // an invalid header was seen
+	sawRegularHeader  bool  // saw a non-pseudo header already
+	invalidHeader     bool  // an invalid header was seen
+	headerListSize    int64 // actually uint32, but easier math this way
 }
 
 // stream represents a stream. This is the minimal metadata needed by
@@ -515,6 +515,11 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 	default:
 		sc.req.sawRegularHeader = true
 		sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value)
+		const headerFieldOverhead = 32 // per spec
+		sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
+		if sc.req.headerListSize > int64(sc.maxHeaderListSize()) {
+			sc.hpackDecoder.SetEmitEnabled(false)
+		}
 	}
 }
 
@@ -1247,6 +1252,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		stream: st,
 		header: make(http.Header),
 	}
+	sc.hpackDecoder.SetEmitEnabled(true)
 	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 
@@ -1298,7 +1304,14 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 	}
 	st.body = req.Body.(*requestBody).pipe // may be nil
 	st.declBodyBytes = req.ContentLength
-	go sc.runHandler(rw, req)
+
+	handler := sc.handler.ServeHTTP
+	if !sc.hpackDecoder.EmitEnabled() {
+		// Their header list was too long. Send a 431 error.
+		handler = handleHeaderListTooLong
+	}
+
+	go sc.runHandler(rw, req, handler)
 	return nil
 }
 
@@ -1438,10 +1451,20 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 }
 
 // Run on its own goroutine.
-func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request) {
+func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
 	defer rw.handlerDone()
 	// TODO: catch panics like net/http.Server
-	sc.handler.ServeHTTP(rw, req)
+	handler(rw, req)
+}
+
+func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) {
+	// 10.5.1 Limits on Header Block Size:
+	// .. "A server that receives a larger header block than it is
+	// willing to handle can send an HTTP 431 (Request Header Fields Too
+	// Large) status code"
+	const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+
+	w.WriteHeader(statusRequestHeaderFieldsTooLarge)
+	io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>")
 }
 
 // called from handler goroutines.

+ 12 - 3
http2/server_test.go

@@ -2251,9 +2251,18 @@ func TestServerDoS_MaxHeaderListSize(t *testing.T) {
 		st.fr.WriteContinuation(1, len(b) == 0, chunk)
 	}
 
-	fr, err := st.fr.ReadFrame()
-	if err == nil {
-		t.Fatalf("want error; got unexpected frame: %#v", fr)
+	h := st.wantHeaders()
+	if !h.HeadersEnded() {
+		t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
+	}
+	headers := decodeHeader(t, h.HeaderBlockFragment())
+	want := [][2]string{
+		{":status", "431"},
+		{"content-type", "text/html; charset=utf-8"},
+		{"content-length", "63"},
+	}
+	if !reflect.DeepEqual(headers, want) {
+		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
 	}
 }