Browse Source

http2: move merging of HEADERS and CONTINUATION into Framer

HEADERS and CONTINUATION frames are special in that they must appear
contiguous on the wire and there are lots of annoying details to
verify while working through its state machine, including the handling
of hpack header list size limits and DoS vectors.

We now have three implementations of this merging (Server, Transport,
and grpc), and grpc's is not complete. The Transport's was also
partially incomplete.

Move this up to the Framer (opt-in, for compatibility) and remove the
support from the Server and Transport. I can fix grpc later to use
this.

Recommended reviewing order:

* hpack.go exports the HeaderField.Size method and adds an IsPseudo
  method.

* errors.go adds some new unexported error types, for testing.

* frame.go adds the new type MetaHeadersFrame.

* frame.go adds new fields on Framer for controlling how ReadFrame
  behaves

* frame.go Framer.ReadFrame now calls the new Framer.readMetaFrame
  method

* frame_test.go adds a bunch of tests. these are largely redundant
  with the existing tests which were in server and transport
  before. They really belong with frame_test.go, but I also don't want
  to delete tests in a CL like this. I probably won't remove them
  later either.

* server.go and transport.go can be reviewed in either order at this
  point. Both are the fun part of this change: deleting lots of hairy
  state machine code (which was redundant in at least 6 ways: server
  headers, server trailers, client headers, client trailers, grpc
  headers, grpc trailers...). Both server and transport.go have the
  general following form:

  - set Framer.ReadMetaHeaders
  - stop handling *HeadersFrame and *ContinuationFrame; handle
    *MetaHeadersFrame instead.
  - delete all the state machine + hpack parsing callback hell

The diffstat numbers look like a wash once you exclude the new tests,
but this pays for itself by far when you consider the grpc savings as
well, and the increased simplicity.

Change-Id: If348cf585165b528b7d3ab2e5f86b49a03fbb0d2
Reviewed-on: https://go-review.googlesource.com/19726
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Brad Fitzpatrick 10 years ago
parent
commit
9e1fb3c175
9 changed files with 723 additions and 453 deletions
  1. 33 1
      http2/errors.go
  2. 218 2
      http2/frame.go
  3. 242 2
      http2/frame_test.go
  4. 1 1
      http2/hpack/encode.go
  5. 12 3
      http2/hpack/hpack.go
  6. 60 219
      http2/server.go
  7. 24 4
      http2/server_test.go
  8. 127 215
      http2/transport.go
  9. 6 6
      http2/transport_test.go

+ 33 - 1
http2/errors.go

@@ -4,7 +4,10 @@
 
 package http2
 
-import "fmt"
+import (
+	"errors"
+	"fmt"
+)
 
 // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
 type ErrCode uint32
@@ -88,3 +91,32 @@ type connError struct {
 func (e connError) Error() string {
 	return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
 }
+
+type pseudoHeaderError string
+
+func (e pseudoHeaderError) Error() string {
+	return fmt.Sprintf("invalid pseudo-header %q", string(e))
+}
+
+type duplicatePseudoHeaderError string
+
+func (e duplicatePseudoHeaderError) Error() string {
+	return fmt.Sprintf("duplicate pseudo-header %q", string(e))
+}
+
+type headerFieldNameError string
+
+func (e headerFieldNameError) Error() string {
+	return fmt.Sprintf("invalid header field name %q", string(e))
+}
+
+type headerFieldValueError string
+
+func (e headerFieldValueError) Error() string {
+	return fmt.Sprintf("invalid header field value %q", string(e))
+}
+
+var (
+	errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers")
+	errPseudoAfterRegular   = errors.New("pseudo header field after regular")
+)

+ 218 - 2
http2/frame.go

@@ -11,7 +11,10 @@ import (
 	"fmt"
 	"io"
 	"log"
+	"strings"
 	"sync"
+
+	"golang.org/x/net/http2/hpack"
 )
 
 const frameHeaderLen = 9
@@ -261,7 +264,7 @@ type Frame interface {
 type Framer struct {
 	r         io.Reader
 	lastFrame Frame
-	errReason string
+	errDetail error
 
 	// lastHeaderStream is non-zero if the last frame was an
 	// unfinished HEADERS/CONTINUATION.
@@ -293,8 +296,20 @@ type Framer struct {
 	// to return non-compliant frames or frame orders.
 	// This is for testing and permits using the Framer to test
 	// other HTTP/2 implementations' conformance to the spec.
+	// It is not compatible with ReadMetaHeaders.
 	AllowIllegalReads bool
 
+	// ReadMetaHeaders if non-nil causes ReadFrame to merge
+	// HEADERS and CONTINUATION frames together and return
+	// MetaHeadersFrame instead.
+	ReadMetaHeaders *hpack.Decoder
+
+	// MaxHeaderListSize is the http2 MAX_HEADER_LIST_SIZE.
+	// It's used only if ReadMetaHeaders is set; 0 means a sane default
+	// (currently 16MB)
+	// If the limit is hit, MetaHeadersFrame.Truncated is set true.
+	MaxHeaderListSize uint32
+
 	// TODO: track which type of frame & with which flags was sent
 	// last.  Then return an error (unless AllowIllegalWrites) if
 	// we're in the middle of a header block and a
@@ -307,6 +322,13 @@ type Framer struct {
 	debugFramerBuf *bytes.Buffer
 }
 
+func (fr *Framer) maxHeaderListSize() uint32 {
+	if fr.MaxHeaderListSize == 0 {
+		return 16 << 20 // sane default, per docs
+	}
+	return fr.MaxHeaderListSize
+}
+
 func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) {
 	// Write the FrameHeader.
 	f.wbuf = append(f.wbuf[:0],
@@ -423,6 +445,7 @@ func terminalReadFrameError(err error) bool {
 // ConnectionError, StreamError, or anything else from from the underlying
 // reader.
 func (fr *Framer) ReadFrame() (Frame, error) {
+	fr.errDetail = nil
 	if fr.lastFrame != nil {
 		fr.lastFrame.invalidate()
 	}
@@ -450,6 +473,9 @@ func (fr *Framer) ReadFrame() (Frame, error) {
 	if fr.logReads {
 		log.Printf("http2: Framer %p: read %v", fr, summarizeFrame(f))
 	}
+	if fh.Type == FrameHeaders && fr.ReadMetaHeaders != nil {
+		return fr.readMetaFrame(f.(*HeadersFrame))
+	}
 	return f, nil
 }
 
@@ -458,7 +484,7 @@ func (fr *Framer) ReadFrame() (Frame, error) {
 // to the peer before hanging up on them. This might help others debug
 // their implementations.
 func (fr *Framer) connError(code ErrCode, reason string) error {
-	fr.errReason = reason
+	fr.errDetail = errors.New(reason)
 	return ConnectionError(code)
 }
 
@@ -1225,6 +1251,196 @@ type headersEnder interface {
 	HeadersEnded() bool
 }
 
+type headersOrContinuation interface {
+	headersEnder
+	HeaderBlockFragment() []byte
+}
+
+// A MetaHeadersFrame is the representation of one HEADERS frame and
+// zero or more contiguous CONTINUATION frames and the decoding of
+// their HPACK-encoded contents.
+//
+// This type of frame does not appear on the wire and is only returned
+// by the Framer when Framer.ReadMetaHeaders is set.
+type MetaHeadersFrame struct {
+	*HeadersFrame
+
+	// Fields are the fields contained in the HEADERS and
+	// CONTINUATION frames. The underlying slice is owned by the
+	// Framer and must not be retained after the next call to
+	// ReadFrame.
+	//
+	// Fields are guaranteed to be in the correct http2 order and
+	// not have unknown pseudo header fields or invalid header
+	// field names or values. Required pseudo header fields may be
+	// missing, however. Use the MetaHeadersFrame.Pseudo accessor
+	// method access pseudo headers.
+	Fields []hpack.HeaderField
+
+	// Truncated is whether the max header list size limit was hit
+	// and Fields is incomplete. The hpack decoder state is still
+	// valid, however.
+	Truncated bool
+}
+
+// PseudoValue returns the given pseudo header field's value.
+// The provided pseudo field should not contain the leading colon.
+func (mh *MetaHeadersFrame) PseudoValue(pseudo string) string {
+	for _, hf := range mh.Fields {
+		if !hf.IsPseudo() {
+			return ""
+		}
+		if hf.Name[1:] == pseudo {
+			return hf.Value
+		}
+	}
+	return ""
+}
+
+// RegularFields returns the regular (non-pseudo) header fields of mh.
+// The caller does not own the returned slice.
+func (mh *MetaHeadersFrame) RegularFields() []hpack.HeaderField {
+	for i, hf := range mh.Fields {
+		if !hf.IsPseudo() {
+			return mh.Fields[i:]
+		}
+	}
+	return nil
+}
+
+// PseudoFields returns the pseudo header fields of mh.
+// The caller does not own the returned slice.
+func (mh *MetaHeadersFrame) PseudoFields() []hpack.HeaderField {
+	for i, hf := range mh.Fields {
+		if !hf.IsPseudo() {
+			return mh.Fields[:i]
+		}
+	}
+	return mh.Fields
+}
+
+func (mh *MetaHeadersFrame) checkPseudos() error {
+	var isRequest, isResponse bool
+	pf := mh.PseudoFields()
+	for i, hf := range pf {
+		switch hf.Name {
+		case ":method", ":path", ":scheme", ":authority":
+			isRequest = true
+		case ":status":
+			isResponse = true
+		default:
+			return pseudoHeaderError(hf.Name)
+		}
+		// Check for duplicates.
+		// This would be a bad algorithm, but N is 4.
+		// And this doesn't allocate.
+		for _, hf2 := range pf[:i] {
+			if hf.Name == hf2.Name {
+				return duplicatePseudoHeaderError(hf.Name)
+			}
+		}
+	}
+	if isRequest && isResponse {
+		return errMixPseudoHeaderTypes
+	}
+	return nil
+}
+
+func (fr *Framer) maxHeaderStringLen() int {
+	v := fr.maxHeaderListSize()
+	if uint32(int(v)) == v {
+		return int(v)
+	}
+	// They had a crazy big number for MaxHeaderBytes anyway,
+	// so give them unlimited header lengths:
+	return 0
+}
+
+// readMetaFrame returns 0 or more CONTINUATION frames from fr and
+// merge them into into the provided hf and returns a MetaHeadersFrame
+// with the decoded hpack values.
+func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
+	if fr.AllowIllegalReads {
+		return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders")
+	}
+	mh := &MetaHeadersFrame{
+		HeadersFrame: hf,
+	}
+	var remainSize = fr.maxHeaderListSize()
+	var sawRegular bool
+
+	var invalid error // pseudo header field errors
+	hdec := fr.ReadMetaHeaders
+	hdec.SetEmitEnabled(true)
+	hdec.SetMaxStringLength(fr.maxHeaderStringLen())
+	hdec.SetEmitFunc(func(hf hpack.HeaderField) {
+		if !validHeaderFieldValue(hf.Value) {
+			invalid = headerFieldValueError(hf.Value)
+		}
+		isPseudo := strings.HasPrefix(hf.Name, ":")
+		if isPseudo {
+			if sawRegular {
+				invalid = errPseudoAfterRegular
+			}
+		} else {
+			sawRegular = true
+			if !validHeaderFieldName(hf.Name) {
+				invalid = headerFieldNameError(hf.Name)
+			}
+		}
+
+		if invalid != nil {
+			hdec.SetEmitEnabled(false)
+			return
+		}
+
+		size := hf.Size()
+		if size > remainSize {
+			hdec.SetEmitEnabled(false)
+			mh.Truncated = true
+			return
+		}
+		remainSize -= size
+
+		mh.Fields = append(mh.Fields, hf)
+	})
+	// Lose reference to MetaHeadersFrame:
+	defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {})
+
+	var hc headersOrContinuation = hf
+	for {
+		frag := hc.HeaderBlockFragment()
+		if _, err := hdec.Write(frag); err != nil {
+			return nil, ConnectionError(ErrCodeCompression)
+		}
+
+		if hc.HeadersEnded() {
+			break
+		}
+		if f, err := fr.ReadFrame(); err != nil {
+			return nil, err
+		} else {
+			hc = f.(*ContinuationFrame) // guaranteed by checkFrameOrder
+		}
+	}
+
+	mh.HeadersFrame.headerFragBuf = nil
+	mh.HeadersFrame.invalidate()
+
+	if err := hdec.Close(); err != nil {
+		return nil, ConnectionError(ErrCodeCompression)
+	}
+	if invalid != nil {
+		fr.errDetail = invalid
+		return nil, StreamError{mh.StreamID, ErrCodeProtocol}
+	}
+	if err := mh.checkPseudos(); err != nil {
+		fr.errDetail = err
+		return nil, StreamError{mh.StreamID, ErrCodeProtocol}
+	}
+	return mh, nil
+}
+
 func summarizeFrame(f Frame) string {
 	var buf bytes.Buffer
 	f.Header().writeDebug(&buf)

+ 242 - 2
http2/frame_test.go

@@ -12,6 +12,8 @@ import (
 	"strings"
 	"testing"
 	"unsafe"
+
+	"golang.org/x/net/http2/hpack"
 )
 
 func testFramer() (*Framer, *bytes.Buffer) {
@@ -725,11 +727,249 @@ func TestReadFrameOrder(t *testing.T) {
 			t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes())
 			continue
 		}
-		if f.errReason != tt.wantErr {
-			t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errReason, tt.wantErr, log.Bytes())
+		if !((f.errDetail == nil && tt.wantErr == "") || (fmt.Sprint(f.errDetail) == tt.wantErr)) {
+			t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errDetail, tt.wantErr, log.Bytes())
 		}
 		if n < tt.atLeast {
 			t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes())
 		}
 	}
 }
+
+func TestMetaFrameHeader(t *testing.T) {
+	write := func(f *Framer, frags ...[]byte) {
+		for i, frag := range frags {
+			end := (i == len(frags)-1)
+			if i == 0 {
+				f.WriteHeaders(HeadersFrameParam{
+					StreamID:      1,
+					BlockFragment: frag,
+					EndHeaders:    end,
+				})
+			} else {
+				f.WriteContinuation(1, end, frag)
+			}
+		}
+	}
+
+	want := func(flags Flags, length uint32, pairs ...string) *MetaHeadersFrame {
+		mh := &MetaHeadersFrame{
+			HeadersFrame: &HeadersFrame{
+				FrameHeader: FrameHeader{
+					Type:     FrameHeaders,
+					Flags:    flags,
+					Length:   length,
+					StreamID: 1,
+				},
+			},
+			Fields: []hpack.HeaderField(nil),
+		}
+		for len(pairs) > 0 {
+			mh.Fields = append(mh.Fields, hpack.HeaderField{
+				Name:  pairs[0],
+				Value: pairs[1],
+			})
+			pairs = pairs[2:]
+		}
+		return mh
+	}
+	truncated := func(mh *MetaHeadersFrame) *MetaHeadersFrame {
+		mh.Truncated = true
+		return mh
+	}
+
+	const noFlags Flags = 0
+
+	oneKBString := strings.Repeat("a", 1<<10)
+
+	tests := [...]struct {
+		name              string
+		w                 func(*Framer)
+		want              interface{} // *MetaHeaderFrame or error
+		wantErrReason     string
+		maxHeaderListSize uint32
+	}{
+		0: {
+			name: "single_headers",
+			w: func(f *Framer) {
+				var he hpackEncoder
+				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/")
+				write(f, all)
+			},
+			want: want(FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"),
+		},
+		1: {
+			name: "with_continuation",
+			w: func(f *Framer) {
+				var he hpackEncoder
+				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar")
+				write(f, all[:1], all[1:])
+			},
+			want: want(noFlags, 1, ":method", "GET", ":path", "/", "foo", "bar"),
+		},
+		2: {
+			name: "with_two_continuation",
+			w: func(f *Framer) {
+				var he hpackEncoder
+				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar")
+				write(f, all[:2], all[2:4], all[4:])
+			},
+			want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", "bar"),
+		},
+		3: {
+			name: "big_string_okay",
+			w: func(f *Framer) {
+				var he hpackEncoder
+				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString)
+				write(f, all[:2], all[2:])
+			},
+			want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", oneKBString),
+		},
+		4: {
+			name: "big_string_error",
+			w: func(f *Framer) {
+				var he hpackEncoder
+				all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString)
+				write(f, all[:2], all[2:])
+			},
+			maxHeaderListSize: (1 << 10) / 2,
+			want:              ConnectionError(ErrCodeCompression),
+		},
+		5: {
+			name: "max_header_list_truncated",
+			w: func(f *Framer) {
+				var he hpackEncoder
+				var pairs = []string{":method", "GET", ":path", "/"}
+				for i := 0; i < 100; i++ {
+					pairs = append(pairs, "foo", "bar")
+				}
+				all := he.encodeHeaderRaw(t, pairs...)
+				write(f, all[:2], all[2:])
+			},
+			maxHeaderListSize: (1 << 10) / 2,
+			want: truncated(want(noFlags, 2,
+				":method", "GET",
+				":path", "/",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar",
+				"foo", "bar", // 11
+			)),
+		},
+		6: {
+			name: "pseudo_order",
+			w: func(f *Framer) {
+				write(f, encodeHeaderRaw(t,
+					":method", "GET",
+					"foo", "bar",
+					":path", "/", // bogus
+				))
+			},
+			want:          StreamError{1, ErrCodeProtocol},
+			wantErrReason: "pseudo header field after regular",
+		},
+		7: {
+			name: "pseudo_unknown",
+			w: func(f *Framer) {
+				write(f, encodeHeaderRaw(t,
+					":unknown", "foo", // bogus
+					"foo", "bar",
+				))
+			},
+			want:          StreamError{1, ErrCodeProtocol},
+			wantErrReason: "invalid pseudo-header \":unknown\"",
+		},
+		8: {
+			name: "pseudo_mix_request_response",
+			w: func(f *Framer) {
+				write(f, encodeHeaderRaw(t,
+					":method", "GET",
+					":status", "100",
+				))
+			},
+			want:          StreamError{1, ErrCodeProtocol},
+			wantErrReason: "mix of request and response pseudo headers",
+		},
+		9: {
+			name: "pseudo_dup",
+			w: func(f *Framer) {
+				write(f, encodeHeaderRaw(t,
+					":method", "GET",
+					":method", "POST",
+				))
+			},
+			want:          StreamError{1, ErrCodeProtocol},
+			wantErrReason: "duplicate pseudo-header \":method\"",
+		},
+		10: {
+			name: "trailer_okay_no_pseudo",
+			w:    func(f *Framer) { write(f, encodeHeaderRaw(t, "foo", "bar")) },
+			want: want(FlagHeadersEndHeaders, 8, "foo", "bar"),
+		},
+		11: {
+			name:          "invalid_field_name",
+			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) },
+			want:          StreamError{1, ErrCodeProtocol},
+			wantErrReason: "invalid header field name \"CapitalBad\"",
+		},
+		12: {
+			name:          "invalid_field_value",
+			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) },
+			want:          StreamError{1, ErrCodeProtocol},
+			wantErrReason: "invalid header field value \"bad_null\\x00\"",
+		},
+	}
+	for i, tt := range tests {
+		buf := new(bytes.Buffer)
+		f := NewFramer(buf, buf)
+		f.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
+		f.MaxHeaderListSize = tt.maxHeaderListSize
+		tt.w(f)
+
+		name := tt.name
+		if name == "" {
+			name = fmt.Sprintf("test index %d", i)
+		}
+
+		var got interface{}
+		var err error
+		got, err = f.ReadFrame()
+		if err != nil {
+			got = err
+		}
+		if !reflect.DeepEqual(got, tt.want) {
+			if mhg, ok := got.(*MetaHeadersFrame); ok {
+				if mhw, ok := tt.want.(*MetaHeadersFrame); ok {
+					hg := mhg.HeadersFrame
+					hw := mhw.HeadersFrame
+					if hg != nil && hw != nil && !reflect.DeepEqual(*hg, *hw) {
+						t.Errorf("%s: headers differ:\n got: %+v\nwant: %+v\n", name, *hg, *hw)
+					}
+				}
+			}
+			str := func(v interface{}) string {
+				if _, ok := v.(error); ok {
+					return fmt.Sprintf("error %v", v)
+				} else {
+					return fmt.Sprintf("value %#v", v)
+				}
+			}
+			t.Errorf("%s:\n got: %v\nwant: %s", name, str(got), str(tt.want))
+		}
+		if tt.wantErrReason != "" && tt.wantErrReason != fmt.Sprint(f.errDetail) {
+			t.Errorf("%s: got error reason %q; want %q", name, f.errDetail, tt.wantErrReason)
+		}
+	}
+}
+
+func encodeHeaderRaw(t *testing.T, pairs ...string) []byte {
+	var he hpackEncoder
+	return he.encodeHeaderRaw(t, pairs...)
+}

+ 1 - 1
http2/hpack/encode.go

@@ -144,7 +144,7 @@ func (e *Encoder) SetMaxDynamicTableSizeLimit(v uint32) {
 
 // shouldIndex reports whether f should be indexed.
 func (e *Encoder) shouldIndex(f HeaderField) bool {
-	return !f.Sensitive && f.size() <= e.dynTab.maxSize
+	return !f.Sensitive && f.Size() <= e.dynTab.maxSize
 }
 
 // appendIndexed appends index i, as encoded in "Indexed Header Field"

+ 12 - 3
http2/hpack/hpack.go

@@ -41,6 +41,14 @@ type HeaderField struct {
 	Sensitive bool
 }
 
+// IsPseudo reports whether the header field is an http2 pseudo header.
+// That is, it reports whether it starts with a colon.
+// It is not otherwise guaranteed to be a valid psuedo header field,
+// though.
+func (hf HeaderField) IsPseudo() bool {
+	return len(hf.Name) != 0 && hf.Name[0] == ':'
+}
+
 func (hf HeaderField) String() string {
 	var suffix string
 	if hf.Sensitive {
@@ -49,7 +57,8 @@ func (hf HeaderField) String() string {
 	return fmt.Sprintf("header field %q = %q%s", hf.Name, hf.Value, suffix)
 }
 
-func (hf *HeaderField) size() uint32 {
+// Size returns the size of an entry per RFC 7540 section 5.2.
+func (hf HeaderField) Size() uint32 {
 	// http://http2.github.io/http2-spec/compression.html#rfc.section.4.1
 	// "The size of the dynamic table is the sum of the size of
 	// its entries.  The size of an entry is the sum of its name's
@@ -171,7 +180,7 @@ func (dt *dynamicTable) setMaxSize(v uint32) {
 
 func (dt *dynamicTable) add(f HeaderField) {
 	dt.ents = append(dt.ents, f)
-	dt.size += f.size()
+	dt.size += f.Size()
 	dt.evict()
 }
 
@@ -179,7 +188,7 @@ func (dt *dynamicTable) add(f HeaderField) {
 func (dt *dynamicTable) evict() {
 	base := dt.ents // keep base pointer of slice
 	for dt.size > dt.maxSize {
-		dt.size -= dt.ents[0].size()
+		dt.size -= dt.ents[0].Size()
 		dt.ents = dt.ents[1:]
 	}
 

+ 60 - 219
http2/server.go

@@ -276,10 +276,10 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
 	sc.flow.add(initialWindowSize)
 	sc.inflow.add(initialWindowSize)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
-	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, nil)
-	sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
 
 	fr := NewFramer(sc.bw, c)
+	fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
+	fr.MaxHeaderListSize = sc.maxHeaderListSize()
 	fr.SetMaxReadFrameSize(s.maxReadFrameSize())
 	sc.framer = fr
 
@@ -375,7 +375,6 @@ type serverConn struct {
 	bw               *bufferedWriter // writing to conn
 	handler          http.Handler
 	framer           *Framer
-	hpackDecoder     *hpack.Decoder
 	doneServing      chan struct{}         // closed when serverConn.serve ends
 	readFrameCh      chan readFrameResult  // written by serverConn.readFrames
 	wantWriteFrameCh chan frameWriteMsg    // from handlers -> serve
@@ -402,7 +401,6 @@ type serverConn struct {
 	headerTableSize       uint32
 	peerMaxHeaderListSize uint32            // zero means unknown (default)
 	canonHeader           map[string]string // http2-lower-case -> Go-Canonical-Case
-	req                   requestParam      // non-zero while reading request headers
 	writingFrame          bool              // started write goroutine but haven't heard back on wroteFrameCh
 	needsFrameFlush       bool              // last frame write wasn't a flush
 	writeSched            writeScheduler
@@ -417,16 +415,6 @@ type serverConn struct {
 	hpackEncoder   *hpack.Encoder
 }
 
-func (sc *serverConn) maxHeaderStringLen() int {
-	v := sc.maxHeaderListSize()
-	if uint32(int(v)) == v {
-		return int(v)
-	}
-	// They had a crazy big number for MaxHeaderBytes anyway,
-	// so give them unlimited header lengths:
-	return 0
-}
-
 func (sc *serverConn) maxHeaderListSize() uint32 {
 	n := sc.hs.MaxHeaderBytes
 	if n <= 0 {
@@ -439,21 +427,6 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
 	return uint32(n + typicalHeaders*perFieldOverhead)
 }
 
-// requestParam is the state of the next request, initialized over
-// potentially several frames HEADERS + zero or more CONTINUATION
-// frames.
-type requestParam struct {
-	// stream is non-nil if we're reading (HEADER or CONTINUATION)
-	// frames for a request (but not DATA).
-	stream            *stream
-	header            http.Header
-	method, path      string
-	scheme, authority string
-	sawRegularHeader  bool  // saw a non-pseudo header already
-	invalidHeader     bool  // an invalid header was seen
-	headerListSize    int64 // actually uint32, but easier math this way
-}
-
 // stream represents a stream. This is the minimal metadata needed by
 // the serve goroutine. Most of the actual stream state is owned by
 // the http.Handler's goroutine in the responseWriter. Because the
@@ -589,87 +562,6 @@ func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
 	}
 }
 
-func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
-	sc.serveG.check()
-	if VerboseLogs {
-		sc.vlogf("http2: server decoded %v", f)
-	}
-	switch {
-	case !validHeaderFieldValue(f.Value): // f.Name checked _after_ pseudo check, since ':' is invalid
-		sc.req.invalidHeader = true
-	case strings.HasPrefix(f.Name, ":"):
-		if sc.req.sawRegularHeader {
-			sc.logf("pseudo-header after regular header")
-			sc.req.invalidHeader = true
-			return
-		}
-		var dst *string
-		switch f.Name {
-		case ":method":
-			dst = &sc.req.method
-		case ":path":
-			dst = &sc.req.path
-		case ":scheme":
-			dst = &sc.req.scheme
-		case ":authority":
-			dst = &sc.req.authority
-		default:
-			// 8.1.2.1 Pseudo-Header Fields
-			// "Endpoints MUST treat a request or response
-			// that contains undefined or invalid
-			// pseudo-header fields as malformed (Section
-			// 8.1.2.6)."
-			sc.logf("invalid pseudo-header %q", f.Name)
-			sc.req.invalidHeader = true
-			return
-		}
-		if *dst != "" {
-			sc.logf("duplicate pseudo-header %q sent", f.Name)
-			sc.req.invalidHeader = true
-			return
-		}
-		*dst = f.Value
-	case !validHeaderFieldName(f.Name):
-		sc.req.invalidHeader = true
-	default:
-		sc.req.sawRegularHeader = true
-		sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value)
-		const headerFieldOverhead = 32 // per spec
-		sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
-		if sc.req.headerListSize > int64(sc.maxHeaderListSize()) {
-			sc.hpackDecoder.SetEmitEnabled(false)
-		}
-	}
-}
-
-func (st *stream) onNewTrailerField(f hpack.HeaderField) {
-	sc := st.sc
-	sc.serveG.check()
-	if VerboseLogs {
-		sc.vlogf("http2: server decoded trailer %v", f)
-	}
-	switch {
-	case strings.HasPrefix(f.Name, ":"):
-		sc.req.invalidHeader = true
-		return
-	case !validHeaderFieldName(f.Name) || !validHeaderFieldValue(f.Value):
-		sc.req.invalidHeader = true
-		return
-	default:
-		key := sc.canonicalHeader(f.Name)
-		if st.trailer != nil {
-			vv := append(st.trailer[key], f.Value)
-			st.trailer[key] = vv
-
-			// arbitrary; TODO: read spec about header list size limits wrt trailers
-			const tooBig = 1000
-			if len(vv) >= tooBig {
-				sc.hpackDecoder.SetEmitEnabled(false)
-			}
-		}
-	}
-}
-
 func (sc *serverConn) canonicalHeader(v string) string {
 	sc.serveG.check()
 	cv, ok := commonCanonHeader[v]
@@ -1183,10 +1075,8 @@ func (sc *serverConn) processFrame(f Frame) error {
 	switch f := f.(type) {
 	case *SettingsFrame:
 		return sc.processSettings(f)
-	case *HeadersFrame:
+	case *MetaHeadersFrame:
 		return sc.processHeaders(f)
-	case *ContinuationFrame:
-		return sc.processContinuation(f)
 	case *WindowUpdateFrame:
 		return sc.processWindowUpdate(f)
 	case *PingFrame:
@@ -1442,7 +1332,7 @@ func (st *stream) copyTrailersToHandlerRequest() {
 	}
 }
 
-func (sc *serverConn) processHeaders(f *HeadersFrame) error {
+func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
 	sc.serveG.check()
 	id := f.Header().StreamID
 	if sc.inGoAway {
@@ -1471,13 +1361,11 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	// endpoint has opened or reserved. [...]  An endpoint that
 	// receives an unexpected stream identifier MUST respond with
 	// a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
-	if id <= sc.maxStreamID || sc.req.stream != nil {
+	if id <= sc.maxStreamID {
 		return ConnectionError(ErrCodeProtocol)
 	}
+	sc.maxStreamID = id
 
-	if id > sc.maxStreamID {
-		sc.maxStreamID = id
-	}
 	st = &stream{
 		sc:    sc,
 		id:    id,
@@ -1501,50 +1389,6 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	if sc.curOpenStreams == 1 {
 		sc.setConnState(http.StateActive)
 	}
-	sc.req = requestParam{
-		stream: st,
-		header: make(http.Header),
-	}
-	sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField)
-	sc.hpackDecoder.SetEmitEnabled(true)
-	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
-}
-
-func (st *stream) processTrailerHeaders(f *HeadersFrame) error {
-	sc := st.sc
-	sc.serveG.check()
-	if st.gotTrailerHeader {
-		return ConnectionError(ErrCodeProtocol)
-	}
-	st.gotTrailerHeader = true
-	if !f.StreamEnded() {
-		return StreamError{st.id, ErrCodeProtocol}
-	}
-	sc.resetPendingRequest() // we use invalidHeader from it for trailers
-	return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
-}
-
-func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
-	sc.serveG.check()
-	st := sc.streams[f.Header().StreamID]
-	if st.gotTrailerHeader {
-		return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
-	}
-	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
-}
-
-func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bool) error {
-	sc.serveG.check()
-	if _, err := sc.hpackDecoder.Write(frag); err != nil {
-		return ConnectionError(ErrCodeCompression)
-	}
-	if !end {
-		return nil
-	}
-	if err := sc.hpackDecoder.Close(); err != nil {
-		return ConnectionError(ErrCodeCompression)
-	}
-	defer sc.resetPendingRequest()
 	if sc.curOpenStreams > sc.advMaxStreams {
 		// "Endpoints MUST NOT exceed the limit set by their
 		// peer. An endpoint that receives a HEADERS frame
@@ -1564,7 +1408,7 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 		return StreamError{st.id, ErrCodeRefusedStream}
 	}
 
-	rw, req, err := sc.newWriterAndRequest()
+	rw, req, err := sc.newWriterAndRequest(st, f)
 	if err != nil {
 		return err
 	}
@@ -1576,7 +1420,7 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 	st.declBodyBytes = req.ContentLength
 
 	handler := sc.handler.ServeHTTP
-	if !sc.hpackDecoder.EmitEnabled() {
+	if f.Truncated {
 		// Their header list was too long. Send a 431 error.
 		handler = handleHeaderListTooLong
 	}
@@ -1585,27 +1429,27 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 	return nil
 }
 
-func (st *stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error {
+func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
 	sc := st.sc
 	sc.serveG.check()
-	sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField)
-	if _, err := sc.hpackDecoder.Write(frag); err != nil {
-		return ConnectionError(ErrCodeCompression)
+	if st.gotTrailerHeader {
+		return ConnectionError(ErrCodeProtocol)
 	}
-	if !end {
-		return nil
+	st.gotTrailerHeader = true
+	if !f.StreamEnded() {
+		return StreamError{st.id, ErrCodeProtocol}
 	}
 
-	rp := &sc.req
-	if rp.invalidHeader {
-		return StreamError{rp.stream.id, ErrCodeProtocol}
+	if len(f.PseudoFields()) > 0 {
+		return StreamError{st.id, ErrCodeProtocol}
 	}
-
-	err := sc.hpackDecoder.Close()
-	st.endStream()
-	if err != nil {
-		return ConnectionError(ErrCodeCompression)
+	if st.trailer != nil {
+		for _, hf := range f.RegularFields() {
+			key := sc.canonicalHeader(hf.Name)
+			st.trailer[key] = append(st.trailer[key], hf.Value)
+		}
 	}
+	st.endStream()
 	return nil
 }
 
@@ -1650,29 +1494,21 @@ func adjustStreamPriority(streams map[uint32]*stream, streamID uint32, priority
 	}
 }
 
-// resetPendingRequest zeros out all state related to a HEADERS frame
-// and its zero or more CONTINUATION frames sent to start a new
-// request.
-func (sc *serverConn) resetPendingRequest() {
-	sc.serveG.check()
-	sc.req = requestParam{}
-}
-
-func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) {
+func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
 	sc.serveG.check()
-	rp := &sc.req
 
-	if rp.invalidHeader {
-		return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
-	}
+	method := f.PseudoValue("method")
+	path := f.PseudoValue("path")
+	scheme := f.PseudoValue("scheme")
+	authority := f.PseudoValue("authority")
 
-	isConnect := rp.method == "CONNECT"
+	isConnect := method == "CONNECT"
 	if isConnect {
-		if rp.path != "" || rp.scheme != "" || rp.authority == "" {
-			return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+		if path != "" || scheme != "" || authority == "" {
+			return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
 		}
-	} else if rp.method == "" || rp.path == "" ||
-		(rp.scheme != "https" && rp.scheme != "http") {
+	} else if method == "" || path == "" ||
+		(scheme != "https" && scheme != "http") {
 		// See 8.1.2.6 Malformed Requests and Responses:
 		//
 		// Malformed requests or responses that are detected
@@ -1683,35 +1519,40 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 		// "All HTTP/2 requests MUST include exactly one valid
 		// value for the :method, :scheme, and :path
 		// pseudo-header fields"
-		return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+		return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
 	}
 
-	bodyOpen := rp.stream.state == stateOpen
-	if rp.method == "HEAD" && bodyOpen {
+	bodyOpen := !f.StreamEnded()
+	if method == "HEAD" && bodyOpen {
 		// HEAD requests can't have bodies
-		return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+		return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
 	}
 	var tlsState *tls.ConnectionState // nil if not scheme https
 
-	if rp.scheme == "https" {
+	if scheme == "https" {
 		tlsState = sc.tlsState
 	}
-	authority := rp.authority
+
+	header := make(http.Header)
+	for _, hf := range f.RegularFields() {
+		header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+	}
+
 	if authority == "" {
-		authority = rp.header.Get("Host")
+		authority = header.Get("Host")
 	}
-	needsContinue := rp.header.Get("Expect") == "100-continue"
+	needsContinue := header.Get("Expect") == "100-continue"
 	if needsContinue {
-		rp.header.Del("Expect")
+		header.Del("Expect")
 	}
 	// Merge Cookie headers into one "; "-delimited value.
-	if cookies := rp.header["Cookie"]; len(cookies) > 1 {
-		rp.header.Set("Cookie", strings.Join(cookies, "; "))
+	if cookies := header["Cookie"]; len(cookies) > 1 {
+		header.Set("Cookie", strings.Join(cookies, "; "))
 	}
 
 	// Setup Trailers
 	var trailer http.Header
-	for _, v := range rp.header["Trailer"] {
+	for _, v := range header["Trailer"] {
 		for _, key := range strings.Split(v, ",") {
 			key = http.CanonicalHeaderKey(strings.TrimSpace(key))
 			switch key {
@@ -1726,31 +1567,31 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 			}
 		}
 	}
-	delete(rp.header, "Trailer")
+	delete(header, "Trailer")
 
 	body := &requestBody{
 		conn:          sc,
-		stream:        rp.stream,
+		stream:        st,
 		needsContinue: needsContinue,
 	}
 	var url_ *url.URL
 	var requestURI string
 	if isConnect {
-		url_ = &url.URL{Host: rp.authority}
-		requestURI = rp.authority // mimic HTTP/1 server behavior
+		url_ = &url.URL{Host: authority}
+		requestURI = authority // mimic HTTP/1 server behavior
 	} else {
 		var err error
-		url_, err = url.ParseRequestURI(rp.path)
+		url_, err = url.ParseRequestURI(path)
 		if err != nil {
-			return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+			return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
 		}
-		requestURI = rp.path
+		requestURI = path
 	}
 	req := &http.Request{
-		Method:     rp.method,
+		Method:     method,
 		URL:        url_,
 		RemoteAddr: sc.remoteAddrStr,
-		Header:     rp.header,
+		Header:     header,
 		RequestURI: requestURI,
 		Proto:      "HTTP/2.0",
 		ProtoMajor: 2,
@@ -1765,7 +1606,7 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 			b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: garbage
 		}
 
-		if vv, ok := rp.header["Content-Length"]; ok {
+		if vv, ok := header["Content-Length"]; ok {
 			req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
 		} else {
 			req.ContentLength = -1
@@ -1778,7 +1619,7 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 	rws.conn = sc
 	rws.bw = bwSave
 	rws.bw.Reset(chunkWriter{rws})
-	rws.stream = rp.stream
+	rws.stream = st
 	rws.req = req
 	rws.body = body
 

+ 24 - 4
http2/server_test.go

@@ -2515,7 +2515,7 @@ func TestCompressionErrorOnWrite(t *testing.T) {
 	defer st.Close()
 	st.greet()
 
-	maxAllowed := st.sc.maxHeaderStringLen()
+	maxAllowed := st.sc.framer.maxHeaderStringLen()
 
 	// Crank this up, now that we have a conn connected with the
 	// hpack.Decoder's max string length set has been initialized
@@ -2524,8 +2524,12 @@ func TestCompressionErrorOnWrite(t *testing.T) {
 	// the max string size.
 	serverConfig.MaxHeaderBytes = 1 << 20
 
-	// First a request with a header that's exactly the max allowed size.
+	// First a request with a header that's exactly the max allowed size
+	// for the hpack compression. It's still too long for the header list
+	// size, so we'll get the 431 error, but that keeps the compression
+	// context still valid.
 	hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
+
 	st.writeHeaders(HeadersFrameParam{
 		StreamID:      1,
 		BlockFragment: hbf,
@@ -2533,8 +2537,24 @@ func TestCompressionErrorOnWrite(t *testing.T) {
 		EndHeaders:    true,
 	})
 	h := st.wantHeaders()
-	if !h.HeadersEnded() || !h.StreamEnded() {
-		t.Errorf("Unexpected HEADER frame %v", h)
+	if !h.HeadersEnded() {
+		t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
+	}
+	headers := st.decodeHeader(h.HeaderBlockFragment())
+	want := [][2]string{
+		{":status", "431"},
+		{"content-type", "text/html; charset=utf-8"},
+		{"content-length", "63"},
+	}
+	if !reflect.DeepEqual(headers, want) {
+		t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
+	}
+	df := st.wantData()
+	if !strings.Contains(string(df.Data()), "HTTP Error 431") {
+		t.Errorf("Unexpected data body: %q", df.Data())
+	}
+	if !df.StreamEnded() {
+		t.Fatalf("expect data stream end")
 	}
 
 	// And now send one that's just one byte too big.

+ 127 - 215
http2/transport.go

@@ -187,8 +187,8 @@ type clientStream struct {
 	done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu
 
 	// owned by clientConnReadLoop:
-	pastHeaders  bool // got HEADERS w/ END_HEADERS
-	pastTrailers bool // got second HEADERS frame w/ END_HEADERS
+	pastHeaders  bool // got first MetaHeadersFrame (actual headers)
+	pastTrailers bool // got optional second MetaHeadersFrame (trailers)
 
 	trailer    http.Header  // accumulated trailers
 	resTrailer *http.Header // client's Response.Trailer
@@ -401,6 +401,8 @@ func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
 	cc.bw = bufio.NewWriter(stickyErrWriter{c, &cc.werr})
 	cc.br = bufio.NewReader(c)
 	cc.fr = NewFramer(cc.bw, cc.br)
+	cc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
+	cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
 
 	// TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on
 	// henc in response to SETTINGS frames?
@@ -1064,15 +1066,6 @@ type clientConnReadLoop struct {
 	cc            *ClientConn
 	activeRes     map[uint32]*clientStream // keyed by streamID
 	closeWhenIdle bool
-
-	hdec *hpack.Decoder
-
-	// Fields reset on each HEADERS:
-	nextRes              *http.Response
-	sawRegHeader         bool  // saw non-pseudo header
-	reqMalformed         error // non-nil once known to be malformed
-	lastHeaderEndsStream bool
-	headerListSize       int64 // actually uint32, but easier math this way
 }
 
 // readLoop runs in its own goroutine and reads and dispatches frames.
@@ -1081,7 +1074,6 @@ func (cc *ClientConn) readLoop() {
 		cc:        cc,
 		activeRes: make(map[uint32]*clientStream),
 	}
-	rl.hdec = hpack.NewDecoder(initialHeaderTableSize, rl.onNewHeaderField)
 
 	defer rl.cleanup()
 	cc.readerErr = rl.run()
@@ -1131,8 +1123,10 @@ func (rl *clientConnReadLoop) run() error {
 			cc.vlogf("Transport readFrame error: (%T) %v", err, err)
 		}
 		if se, ok := err.(StreamError); ok {
-			// TODO: deal with stream errors from the framer.
-			return se
+			if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil {
+				rl.endStreamError(cs, cc.fr.errDetail)
+			}
+			continue
 		} else if err != nil {
 			return err
 		}
@@ -1142,13 +1136,10 @@ func (rl *clientConnReadLoop) run() error {
 		maybeIdle := false // whether frame might transition us to idle
 
 		switch f := f.(type) {
-		case *HeadersFrame:
+		case *MetaHeadersFrame:
 			err = rl.processHeaders(f)
 			maybeIdle = true
 			gotReply = true
-		case *ContinuationFrame:
-			err = rl.processContinuation(f)
-			maybeIdle = true
 		case *DataFrame:
 			err = rl.processData(f)
 			maybeIdle = true
@@ -1178,91 +1169,96 @@ func (rl *clientConnReadLoop) run() error {
 	}
 }
 
-func (rl *clientConnReadLoop) processHeaders(f *HeadersFrame) error {
-	rl.sawRegHeader = false
-	rl.reqMalformed = nil
-	rl.lastHeaderEndsStream = f.StreamEnded()
-	rl.headerListSize = 0
-	rl.nextRes = &http.Response{
-		Proto:      "HTTP/2.0",
-		ProtoMajor: 2,
-		Header:     make(http.Header),
-	}
-	rl.hdec.SetEmitEnabled(true)
-	return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
-}
-
-func (rl *clientConnReadLoop) processContinuation(f *ContinuationFrame) error {
-	return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
-}
-
-func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, finalFrag bool) error {
+func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
 	cc := rl.cc
-	streamEnded := rl.lastHeaderEndsStream
-	cs := cc.streamByID(streamID, streamEnded && finalFrag)
+	cs := cc.streamByID(f.StreamID, f.StreamEnded())
 	if cs == nil {
 		// We'd get here if we canceled a request while the
-		// server was mid-way through replying with its
-		// headers. (The case of a CONTINUATION arriving
-		// without HEADERS would be rejected earlier by the
-		// Framer). So if this was just something we canceled,
-		// ignore it.
+		// server had its response still in flight. So if this
+		// was just something we canceled, ignore it.
 		return nil
 	}
-	if cs.pastHeaders {
-		rl.hdec.SetEmitFunc(func(f hpack.HeaderField) { rl.onNewTrailerField(cs, f) })
+	if !cs.pastHeaders {
+		cs.pastHeaders = true
 	} else {
-		rl.hdec.SetEmitFunc(rl.onNewHeaderField)
+		return rl.processTrailers(cs, f)
 	}
-	_, err := rl.hdec.Write(frag)
+
+	res, err := rl.handleResponse(cs, f)
 	if err != nil {
-		return ConnectionError(ErrCodeCompression)
-	}
-	if finalFrag {
-		if err := rl.hdec.Close(); err != nil {
-			return ConnectionError(ErrCodeCompression)
+		if _, ok := err.(ConnectionError); ok {
+			return err
 		}
+		// Any other error type is a stream error.
+		cs.cc.writeStreamReset(f.StreamID, ErrCodeProtocol, err)
+		cs.resc <- resAndError{err: err}
+		return nil // return nil from process* funcs to keep conn alive
 	}
-
-	if !finalFrag {
+	if res == nil {
+		// (nil, nil) special case. See handleResponse docs.
 		return nil
 	}
-
-	if !cs.pastHeaders {
-		cs.pastHeaders = true
-	} else {
-		// We're dealing with trailers. (and specifically the
-		// final frame of headers)
-		if cs.pastTrailers {
-			// Too many HEADERS frames for this stream.
-			return ConnectionError(ErrCodeProtocol)
-		}
-		cs.pastTrailers = true
-		if !streamEnded {
-			// We expect that any header block fragment
-			// frame for trailers with END_HEADERS also
-			// has END_STREAM.
-			return ConnectionError(ErrCodeProtocol)
-		}
-		rl.endStream(cs)
-		return nil
+	if res.Body != noBody {
+		rl.activeRes[cs.ID] = cs
 	}
+	cs.resTrailer = &res.Trailer
+	cs.resc <- resAndError{res: res}
+	return nil
+}
 
-	if rl.reqMalformed != nil {
-		cs.resc <- resAndError{err: rl.reqMalformed}
-		rl.cc.writeStreamReset(cs.ID, ErrCodeProtocol, rl.reqMalformed)
-		return nil
+// may return error types nil, or ConnectionError. Any other error value
+// is a StreamError of type ErrCodeProtocol. The returned error in that case
+// is the detail.
+//
+// As a special case, handleResponse may return (nil, nil) to skip the
+// frame (currently only used for 100 expect continue). This special
+// case is going away after Issue 13851 is fixed.
+func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFrame) (*http.Response, error) {
+	if f.Truncated {
+		return nil, errResponseHeaderListSize
 	}
 
-	res := rl.nextRes
+	status := f.PseudoValue("status")
+	if status == "" {
+		return nil, errors.New("missing status pseudo header")
+	}
+	statusCode, err := strconv.Atoi(status)
+	if err != nil {
+		return nil, errors.New("malformed non-numeric status pseudo header")
+	}
 
-	if res.StatusCode == 100 {
+	if statusCode == 100 {
 		// Just skip 100-continue response headers for now.
 		// TODO: golang.org/issue/13851 for doing it properly.
 		cs.pastHeaders = false // do it all again
-		return nil
+		return nil, nil
 	}
 
+	header := make(http.Header)
+	res := &http.Response{
+		Proto:      "HTTP/2.0",
+		ProtoMajor: 2,
+		Header:     header,
+		StatusCode: statusCode,
+		Status:     status + " " + http.StatusText(statusCode),
+	}
+	for _, hf := range f.RegularFields() {
+		key := http.CanonicalHeaderKey(hf.Name)
+		if key == "Trailer" {
+			t := res.Trailer
+			if t == nil {
+				t = make(http.Header)
+				res.Trailer = t
+			}
+			foreachHeaderElement(hf.Value, func(v string) {
+				t[http.CanonicalHeaderKey(v)] = nil
+			})
+		} else {
+			header[key] = append(header[key], hf.Value)
+		}
+	}
+
+	streamEnded := f.StreamEnded()
 	if !streamEnded || cs.req.Method == "HEAD" {
 		res.ContentLength = -1
 		if clens := res.Header["Content-Length"]; len(clens) == 1 {
@@ -1280,25 +1276,49 @@ func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID u
 
 	if streamEnded {
 		res.Body = noBody
-	} else {
-		buf := new(bytes.Buffer) // TODO(bradfitz): recycle this garbage
-		cs.bufPipe = pipe{b: buf}
-		cs.bytesRemain = res.ContentLength
-		res.Body = transportResponseBody{cs}
-		go cs.awaitRequestCancel(requestCancel(cs.req))
-
-		if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
-			res.Header.Del("Content-Encoding")
-			res.Header.Del("Content-Length")
-			res.ContentLength = -1
-			res.Body = &gzipReader{body: res.Body}
-		}
-		rl.activeRes[cs.ID] = cs
+		return res, nil
 	}
 
-	cs.resTrailer = &res.Trailer
-	cs.resc <- resAndError{res: res}
-	rl.nextRes = nil // unused now; will be reset next HEADERS frame
+	buf := new(bytes.Buffer) // TODO(bradfitz): recycle this garbage
+	cs.bufPipe = pipe{b: buf}
+	cs.bytesRemain = res.ContentLength
+	res.Body = transportResponseBody{cs}
+	go cs.awaitRequestCancel(requestCancel(cs.req))
+
+	if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
+		res.Header.Del("Content-Encoding")
+		res.Header.Del("Content-Length")
+		res.ContentLength = -1
+		res.Body = &gzipReader{body: res.Body}
+	}
+	return res, nil
+}
+
+func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFrame) error {
+	if cs.pastTrailers {
+		// Too many HEADERS frames for this stream.
+		return ConnectionError(ErrCodeProtocol)
+	}
+	cs.pastTrailers = true
+	if !f.StreamEnded() {
+		// We expect that any headers for trailers also
+		// has END_STREAM.
+		return ConnectionError(ErrCodeProtocol)
+	}
+	if len(f.PseudoFields()) > 0 {
+		// No pseudo header fields are defined for trailers.
+		// TODO: ConnectionError might be overly harsh? Check.
+		return ConnectionError(ErrCodeProtocol)
+	}
+
+	trailer := make(http.Header)
+	for _, hf := range f.RegularFields() {
+		key := http.CanonicalHeaderKey(hf.Name)
+		trailer[key] = append(trailer[key], hf.Value)
+	}
+	cs.trailer = trailer
+
+	rl.endStream(cs)
 	return nil
 }
 
@@ -1416,6 +1436,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
 		cc.mu.Unlock()
 
 		if _, err := cs.bufPipe.Write(data); err != nil {
+			rl.endStreamError(cs, err)
 			return err
 		}
 	}
@@ -1431,11 +1452,14 @@ var errInvalidTrailers = errors.New("http2: invalid trailers")
 func (rl *clientConnReadLoop) endStream(cs *clientStream) {
 	// TODO: check that any declared content-length matches, like
 	// server.go's (*stream).endStream method.
-	err := io.EOF
-	code := cs.copyTrailers
-	if rl.reqMalformed != nil {
-		err = rl.reqMalformed
-		code = nil
+	rl.endStreamError(cs, nil)
+}
+
+func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
+	var code func()
+	if err == nil {
+		err = io.EOF
+		code = cs.copyTrailers
 	}
 	cs.bufPipe.closeWithErrorAndCode(err, code)
 	delete(rl.activeRes, cs.ID)
@@ -1574,118 +1598,6 @@ var (
 	errPseudoTrailers         = errors.New("http2: invalid pseudo header in trailers")
 )
 
-func (rl *clientConnReadLoop) checkHeaderField(f hpack.HeaderField) bool {
-	if rl.reqMalformed != nil {
-		return false
-	}
-
-	const headerFieldOverhead = 32 // per spec
-	rl.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
-	if max := rl.cc.t.maxHeaderListSize(); max != 0 && rl.headerListSize > int64(max) {
-		rl.hdec.SetEmitEnabled(false)
-		rl.reqMalformed = errResponseHeaderListSize
-		return false
-	}
-
-	if !validHeaderFieldValue(f.Value) {
-		rl.reqMalformed = errInvalidHeaderFieldValue
-		return false
-	}
-
-	isPseudo := strings.HasPrefix(f.Name, ":")
-	if isPseudo {
-		if rl.sawRegHeader {
-			rl.reqMalformed = errors.New("http2: invalid pseudo header after regular header")
-			return false
-		}
-	} else {
-		if !validHeaderFieldName(f.Name) {
-			rl.reqMalformed = errInvalidHeaderFieldName
-			return false
-		}
-		rl.sawRegHeader = true
-	}
-
-	return true
-}
-
-// onNewHeaderField runs on the readLoop goroutine whenever a new
-// hpack header field is decoded.
-func (rl *clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) {
-	cc := rl.cc
-	if VerboseLogs {
-		cc.logf("http2: Transport decoded %v", f)
-	}
-
-	if !rl.checkHeaderField(f) {
-		return
-	}
-
-	isPseudo := strings.HasPrefix(f.Name, ":")
-	if isPseudo {
-		switch f.Name {
-		case ":status":
-			code, err := strconv.Atoi(f.Value)
-			if err != nil {
-				rl.reqMalformed = errors.New("http2: invalid :status")
-				return
-			}
-			rl.nextRes.Status = f.Value + " " + http.StatusText(code)
-			rl.nextRes.StatusCode = code
-		default:
-			// "Endpoints MUST NOT generate pseudo-header
-			// fields other than those defined in this
-			// document."
-			rl.reqMalformed = fmt.Errorf("http2: unknown response pseudo header %q", f.Name)
-		}
-		return
-	}
-
-	key := http.CanonicalHeaderKey(f.Name)
-	if key == "Trailer" {
-		t := rl.nextRes.Trailer
-		if t == nil {
-			t = make(http.Header)
-			rl.nextRes.Trailer = t
-		}
-		foreachHeaderElement(f.Value, func(v string) {
-			t[http.CanonicalHeaderKey(v)] = nil
-		})
-	} else {
-		rl.nextRes.Header.Add(key, f.Value)
-	}
-}
-
-func (rl *clientConnReadLoop) onNewTrailerField(cs *clientStream, f hpack.HeaderField) {
-	if VerboseLogs {
-		rl.cc.logf("http2: Transport decoded trailer %v", f)
-	}
-	if !rl.checkHeaderField(f) {
-		return
-	}
-	if strings.HasPrefix(f.Name, ":") {
-		// Pseudo-header fields MUST NOT appear in
-		// trailers. Endpoints MUST treat a request or
-		// response that contains undefined or invalid
-		// pseudo-header fields as malformed.
-		rl.reqMalformed = errPseudoTrailers
-		return
-	}
-
-	key := http.CanonicalHeaderKey(f.Name)
-
-	// The spec says one must predeclare their trailers but in practice
-	// popular users (which is to say the only user we found) do not so we
-	// violate the spec and accept all of them.
-	const acceptAllTrailers = true
-	if _, ok := (*cs.resTrailer)[key]; ok || acceptAllTrailers {
-		if cs.trailer == nil {
-			cs.trailer = make(http.Header)
-		}
-		cs.trailer[key] = append(cs.trailer[key], f.Value)
-	}
-}
-
 func (cc *ClientConn) logf(format string, args ...interface{}) {
 	cc.t.logf(format, args...)
 }

+ 6 - 6
http2/transport_test.go

@@ -1104,7 +1104,7 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
 	testTransportInvalidTrailer_Pseudo(t, splitHeader)
 }
 func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
-	testInvalidTrailer(t, trailers, errPseudoTrailers, func(enc *hpack.Encoder) {
+	testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
 		enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
 		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
 	})
@@ -1117,19 +1117,19 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) {
 	testTransportInvalidTrailer_Capital(t, splitHeader)
 }
 func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
-	testInvalidTrailer(t, trailers, errInvalidHeaderFieldName, func(enc *hpack.Encoder) {
+	testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
 		enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
 		enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
 	})
 }
 func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
-	testInvalidTrailer(t, oneHeader, errInvalidHeaderFieldName, func(enc *hpack.Encoder) {
+	testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
 		enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
 	})
 }
 func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
-	testInvalidTrailer(t, oneHeader, errInvalidHeaderFieldValue, func(enc *hpack.Encoder) {
-		enc.WriteField(hpack.HeaderField{Name: "", Value: "has\nnewline"})
+	testInvalidTrailer(t, oneHeader, headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) {
+		enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
 	})
 }
 
@@ -1147,7 +1147,7 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT
 		}
 		slurp, err := ioutil.ReadAll(res.Body)
 		if err != wantErr {
-			return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, wantErr)
+			return fmt.Errorf("res.Body ReadAll error = %q, %#v; want %T of %#v", slurp, err, wantErr, wantErr)
 		}
 		if len(slurp) > 0 {
 			return fmt.Errorf("body = %q; want nothing", slurp)