Forráskód Böngészése

http2: add Server support for reading trailers from clients

Updates golang/go#13557

Change-Id: I95bbb15d9abbbbc4dc6c3a22cd965d8dcef53fb8
Reviewed-on: https://go-review.googlesource.com/17891
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
Brad Fitzpatrick 10 éve
szülő
commit
c24de9d546
5 módosított fájl, 256 hozzáadás és 45 törlés
  1. 1 0
      http2/headermap.go
  2. 7 0
      http2/hpack/hpack.go
  3. 27 14
      http2/pipe.go
  4. 159 28
      http2/server.go
  5. 62 3
      http2/server_test.go

+ 1 - 0
http2/headermap.go

@@ -57,6 +57,7 @@ func init() {
 		"server",
 		"server",
 		"set-cookie",
 		"set-cookie",
 		"strict-transport-security",
 		"strict-transport-security",
+		"trailer",
 		"transfer-encoding",
 		"transfer-encoding",
 		"user-agent",
 		"user-agent",
 		"vary",
 		"vary",

+ 7 - 0
http2/hpack/hpack.go

@@ -102,6 +102,13 @@ func (d *Decoder) SetMaxStringLength(n int) {
 	d.maxStrLen = n
 	d.maxStrLen = n
 }
 }
 
 
+// SetEmitFunc changes the callback used when new header fields
+// are decoded.
+// It must be non-nil. It does not affect EmitEnabled.
+func (d *Decoder) SetEmitFunc(emitFunc func(f HeaderField)) {
+	d.emit = emitFunc
+}
+
 // SetEmitEnabled controls whether the emitFunc provided to NewDecoder
 // SetEmitEnabled controls whether the emitFunc provided to NewDecoder
 // should be called. The default is true.
 // should be called. The default is true.
 //
 //

+ 27 - 14
http2/pipe.go

@@ -14,11 +14,12 @@ import (
 // io.Pipe except there are no PipeReader/PipeWriter halves, and the
 // io.Pipe except there are no PipeReader/PipeWriter halves, and the
 // underlying buffer is an interface. (io.Pipe is always unbuffered)
 // underlying buffer is an interface. (io.Pipe is always unbuffered)
 type pipe struct {
 type pipe struct {
-	mu    sync.Mutex
-	c     sync.Cond // c.L must point to
-	b     pipeBuffer
-	err   error         // read error once empty. non-nil means closed.
-	donec chan struct{} // closed on error
+	mu     sync.Mutex
+	c      sync.Cond // c.L must point to
+	b      pipeBuffer
+	err    error         // read error once empty. non-nil means closed.
+	donec  chan struct{} // closed on error
+	readFn func()        // optional code to run in Read before error
 }
 }
 
 
 type pipeBuffer interface {
 type pipeBuffer interface {
@@ -40,6 +41,10 @@ func (p *pipe) Read(d []byte) (n int, err error) {
 			return p.b.Read(d)
 			return p.b.Read(d)
 		}
 		}
 		if p.err != nil {
 		if p.err != nil {
+			if p.readFn != nil {
+				p.readFn()     // e.g. copy trailers
+				p.readFn = nil // not sticky like p.err
+			}
 			return 0, p.err
 			return 0, p.err
 		}
 		}
 		p.c.Wait()
 		p.c.Wait()
@@ -63,13 +68,18 @@ func (p *pipe) Write(d []byte) (n int, err error) {
 	return p.b.Write(d)
 	return p.b.Write(d)
 }
 }
 
 
-// CloseWithError causes Reads to wake up and return the
-// provided err after all data has been read.
+// CloseWithError causes the next Read (waking up a current blocked
+// Read if needed) to return the provided err after all data has been
+// read.
 //
 //
 // The error must be non-nil.
 // The error must be non-nil.
-func (p *pipe) CloseWithError(err error) {
+func (p *pipe) CloseWithError(err error) { p.closeWithErrorAndCode(err, nil) }
+
+// closeWithErrorAndCode is like CloseWithError but also sets some code to run
+// in the caller's goroutine before returning the error.
+func (p *pipe) closeWithErrorAndCode(err error, fn func()) {
 	if err == nil {
 	if err == nil {
-		panic("CloseWithError must be non-nil")
+		panic("CloseWithError err must be non-nil")
 	}
 	}
 	p.mu.Lock()
 	p.mu.Lock()
 	defer p.mu.Unlock()
 	defer p.mu.Unlock()
@@ -77,11 +87,14 @@ func (p *pipe) CloseWithError(err error) {
 		p.c.L = &p.mu
 		p.c.L = &p.mu
 	}
 	}
 	defer p.c.Signal()
 	defer p.c.Signal()
-	if p.err == nil {
-		p.err = err
-		if p.donec != nil {
-			close(p.donec)
-		}
+	if p.err != nil {
+		// Already been done.
+		return
+	}
+	p.readFn = fn
+	p.err = err
+	if p.donec != nil {
+		close(p.donec)
 	}
 	}
 }
 }
 
 

+ 159 - 28
http2/server.go

@@ -224,7 +224,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
 	sc.flow.add(initialWindowSize)
 	sc.flow.add(initialWindowSize)
 	sc.inflow.add(initialWindowSize)
 	sc.inflow.add(initialWindowSize)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
-	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
+	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, nil)
 	sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
 	sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
 
 
 	fr := NewFramer(sc.bw, c)
 	fr := NewFramer(sc.bw, c)
@@ -411,20 +411,26 @@ type requestParam struct {
 // responseWriter's state field.
 // responseWriter's state field.
 type stream struct {
 type stream struct {
 	// immutable:
 	// immutable:
+	sc   *serverConn
 	id   uint32
 	id   uint32
 	body *pipe       // non-nil if expecting DATA frames
 	body *pipe       // non-nil if expecting DATA frames
 	cw   closeWaiter // closed wait stream transitions to closed state
 	cw   closeWaiter // closed wait stream transitions to closed state
 
 
 	// owned by serverConn's serve loop:
 	// owned by serverConn's serve loop:
-	bodyBytes     int64   // body bytes seen so far
-	declBodyBytes int64   // or -1 if undeclared
-	flow          flow    // limits writing from Handler to client
-	inflow        flow    // what the client is allowed to POST/etc to us
-	parent        *stream // or nil
-	weight        uint8
-	state         streamState
-	sentReset     bool // only true once detached from streams map
-	gotReset      bool // only true once detacted from streams map
+	bodyBytes        int64   // body bytes seen so far
+	declBodyBytes    int64   // or -1 if undeclared
+	flow             flow    // limits writing from Handler to client
+	inflow           flow    // what the client is allowed to POST/etc to us
+	parent           *stream // or nil
+	numTrailerValues int64
+	weight           uint8
+	state            streamState
+	sentReset        bool // only true once detached from streams map
+	gotReset         bool // only true once detacted from streams map
+	gotTrailerHeader bool // HEADER frame for trailers was seen
+
+	trailer    http.Header // accumulated trailers
+	reqTrailer http.Header // handler's Request.Trailer
 }
 }
 
 
 func (sc *serverConn) Framer() *Framer  { return sc.framer }
 func (sc *serverConn) Framer() *Framer  { return sc.framer }
@@ -537,6 +543,37 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
 	}
 	}
 }
 }
 
 
+func (st *stream) onNewTrailerField(f hpack.HeaderField) {
+	sc := st.sc
+	sc.serveG.check()
+	sc.vlogf("got trailer field %+v", f)
+	switch {
+	case !validHeader(f.Name):
+		// TODO: change hpack signature so this can return
+		// errors?  Or stash an error somewhere on st or sc
+		// for processHeaderBlockFragment etc to pick up and
+		// return after the hpack Write/Close.  For now just
+		// ignore.
+		return
+	case strings.HasPrefix(f.Name, ":"):
+		// TODO: same TODO as above.
+		return
+	default:
+		key := sc.canonicalHeader(f.Name)
+		if st.trailer != nil {
+			vv := append(st.trailer[key], f.Value)
+			st.trailer[key] = vv
+
+			// arbitrary; TODO: read spec about header list size limits wrt trailers
+			const tooBig = 1000
+			if len(vv) >= tooBig {
+				sc.hpackDecoder.SetEmitEnabled(false)
+			}
+
+		}
+	}
+}
+
 func (sc *serverConn) canonicalHeader(v string) string {
 func (sc *serverConn) canonicalHeader(v string) string {
 	sc.serveG.check()
 	sc.serveG.check()
 	cv, ok := commonCanonHeader[v]
 	cv, ok := commonCanonHeader[v]
@@ -1249,7 +1286,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
 	// with a stream error (Section 5.4.2) of type STREAM_CLOSED."
 	// with a stream error (Section 5.4.2) of type STREAM_CLOSED."
 	id := f.Header().StreamID
 	id := f.Header().StreamID
 	st, ok := sc.streams[id]
 	st, ok := sc.streams[id]
-	if !ok || st.state != stateOpen {
+	if !ok || st.state != stateOpen || st.gotTrailerHeader {
 		// This includes sending a RST_STREAM if the stream is
 		// This includes sending a RST_STREAM if the stream is
 		// in stateHalfClosedLocal (which currently means that
 		// in stateHalfClosedLocal (which currently means that
 		// the http.Handler returned, so it's done reading &
 		// the http.Handler returned, so it's done reading &
@@ -1283,17 +1320,38 @@ func (sc *serverConn) processData(f *DataFrame) error {
 		st.bodyBytes += int64(len(data))
 		st.bodyBytes += int64(len(data))
 	}
 	}
 	if f.StreamEnded() {
 	if f.StreamEnded() {
-		if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
-			st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
-				st.declBodyBytes, st.bodyBytes))
-		} else {
-			st.body.CloseWithError(io.EOF)
-		}
-		st.state = stateHalfClosedRemote
+		st.endStream()
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
+// endStream closes a Request.Body's pipe. It is called when a DATA
+// frame says a request body is over (or after trailers).
+func (st *stream) endStream() {
+	sc := st.sc
+	sc.serveG.check()
+
+	if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
+		st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
+			st.declBodyBytes, st.bodyBytes))
+	} else {
+		st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
+		st.body.CloseWithError(io.EOF)
+	}
+	st.state = stateHalfClosedRemote
+}
+
+// copyTrailersToHandlerRequest is run in the Handler's goroutine in
+// its Request.Body.Read just before it gets io.EOF.
+func (st *stream) copyTrailersToHandlerRequest() {
+	for k, vv := range st.trailer {
+		if _, ok := st.reqTrailer[k]; ok {
+			// Only copy it over it was pre-declared.
+			st.reqTrailer[k] = vv
+		}
+	}
+}
+
 func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	sc.serveG.check()
 	sc.serveG.check()
 	id := f.Header().StreamID
 	id := f.Header().StreamID
@@ -1302,20 +1360,36 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		return nil
 		return nil
 	}
 	}
 	// http://http2.github.io/http2-spec/#rfc.section.5.1.1
 	// http://http2.github.io/http2-spec/#rfc.section.5.1.1
-	if id%2 != 1 || id <= sc.maxStreamID || sc.req.stream != nil {
-		// Streams initiated by a client MUST use odd-numbered
-		// stream identifiers. [...] The identifier of a newly
-		// established stream MUST be numerically greater than all
-		// streams that the initiating endpoint has opened or
-		// reserved. [...]  An endpoint that receives an unexpected
-		// stream identifier MUST respond with a connection error
-		// (Section 5.4.1) of type PROTOCOL_ERROR.
+	// Streams initiated by a client MUST use odd-numbered stream
+	// identifiers. [...] An endpoint that receives an unexpected
+	// stream identifier MUST respond with a connection error
+	// (Section 5.4.1) of type PROTOCOL_ERROR.
+	if id%2 != 1 {
 		return ConnectionError(ErrCodeProtocol)
 		return ConnectionError(ErrCodeProtocol)
 	}
 	}
+	// A HEADERS frame can be used to create a new stream or
+	// send a trailer for an open one. If we already have a stream
+	// open, let it process its own HEADERS frame (trailers at this
+	// point, if it's valid).
+	st := sc.streams[f.Header().StreamID]
+	if st != nil {
+		return st.processTrailerHeaders(f)
+	}
+
+	// [...] The identifier of a newly established stream MUST be
+	// numerically greater than all streams that the initiating
+	// endpoint has opened or reserved. [...]  An endpoint that
+	// receives an unexpected stream identifier MUST respond with
+	// a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
+	if id <= sc.maxStreamID || sc.req.stream != nil {
+		return ConnectionError(ErrCodeProtocol)
+	}
+
 	if id > sc.maxStreamID {
 	if id > sc.maxStreamID {
 		sc.maxStreamID = id
 		sc.maxStreamID = id
 	}
 	}
-	st := &stream{
+	st = &stream{
+		sc:    sc,
 		id:    id,
 		id:    id,
 		state: stateOpen,
 		state: stateOpen,
 	}
 	}
@@ -1341,16 +1415,30 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		stream: st,
 		stream: st,
 		header: make(http.Header),
 		header: make(http.Header),
 	}
 	}
+	sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField)
 	sc.hpackDecoder.SetEmitEnabled(true)
 	sc.hpackDecoder.SetEmitEnabled(true)
 	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 }
 
 
+func (st *stream) processTrailerHeaders(f *HeadersFrame) error {
+	sc := st.sc
+	sc.serveG.check()
+	if st.gotTrailerHeader {
+		return ConnectionError(ErrCodeProtocol)
+	}
+	st.gotTrailerHeader = true
+	return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+}
+
 func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
 func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
 	sc.serveG.check()
 	sc.serveG.check()
 	st := sc.streams[f.Header().StreamID]
 	st := sc.streams[f.Header().StreamID]
 	if st == nil || sc.curHeaderStreamID() != st.id {
 	if st == nil || sc.curHeaderStreamID() != st.id {
 		return ConnectionError(ErrCodeProtocol)
 		return ConnectionError(ErrCodeProtocol)
 	}
 	}
+	if st.gotTrailerHeader {
+		return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+	}
 	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 }
 
 
@@ -1389,6 +1477,10 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	st.reqTrailer = req.Trailer
+	if st.reqTrailer != nil {
+		st.trailer = make(http.Header)
+	}
 	st.body = req.Body.(*requestBody).pipe // may be nil
 	st.body = req.Body.(*requestBody).pipe // may be nil
 	st.declBodyBytes = req.ContentLength
 	st.declBodyBytes = req.ContentLength
 
 
@@ -1402,6 +1494,24 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
 	return nil
 	return nil
 }
 }
 
 
+func (st *stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error {
+	sc := st.sc
+	sc.serveG.check()
+	sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField)
+	if _, err := sc.hpackDecoder.Write(frag); err != nil {
+		return ConnectionError(ErrCodeCompression)
+	}
+	if !end {
+		return nil
+	}
+	err := sc.hpackDecoder.Close()
+	st.endStream()
+	if err != nil {
+		return ConnectionError(ErrCodeCompression)
+	}
+	return nil
+}
+
 func (sc *serverConn) processPriority(f *PriorityFrame) error {
 func (sc *serverConn) processPriority(f *PriorityFrame) error {
 	adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam)
 	adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam)
 	return nil
 	return nil
@@ -1489,6 +1599,26 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 	if cookies := rp.header["Cookie"]; len(cookies) > 1 {
 	if cookies := rp.header["Cookie"]; len(cookies) > 1 {
 		rp.header.Set("Cookie", strings.Join(cookies, "; "))
 		rp.header.Set("Cookie", strings.Join(cookies, "; "))
 	}
 	}
+
+	// Setup Trailers
+	var trailer http.Header
+	for _, v := range rp.header["Trailer"] {
+		for _, key := range strings.Split(v, ",") {
+			key = http.CanonicalHeaderKey(strings.TrimSpace(key))
+			switch key {
+			case "Transfer-Encoding", "Trailer", "Content-Length":
+				// Bogus. (copy of http1 rules)
+				// Ignore.
+			default:
+				if trailer == nil {
+					trailer = make(http.Header)
+				}
+				trailer[key] = nil
+			}
+		}
+	}
+	delete(rp.header, "Trailer")
+
 	body := &requestBody{
 	body := &requestBody{
 		conn:          sc,
 		conn:          sc,
 		stream:        rp.stream,
 		stream:        rp.stream,
@@ -1512,10 +1642,11 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 		TLS:        tlsState,
 		TLS:        tlsState,
 		Host:       authority,
 		Host:       authority,
 		Body:       body,
 		Body:       body,
+		Trailer:    trailer,
 	}
 	}
 	if bodyOpen {
 	if bodyOpen {
 		body.pipe = &pipe{
 		body.pipe = &pipe{
-			b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: share/remove XXX
+			b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: garbage
 		}
 		}
 
 
 		if vv, ok := rp.header["Content-Length"]; ok {
 		if vv, ok := rp.header["Content-Length"]; ok {

+ 62 - 3
http2/server_test.go

@@ -246,6 +246,21 @@ func (st *serverTester) encodeHeaderField(k, v string) {
 	}
 	}
 }
 }
 
 
+// encodeHeaderRaw is the magic-free version of encodeHeader.
+// It takes 0 or more (k, v) pairs and encodes them.
+func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
+	if len(headers)%2 == 1 {
+		panic("odd number of kv args")
+	}
+	st.headerBuf.Reset()
+	for len(headers) > 0 {
+		k, v := headers[0], headers[1]
+		st.encodeHeaderField(k, v)
+		headers = headers[2:]
+	}
+	return st.headerBuf.Bytes()
+}
+
 // encodeHeader encodes headers and returns their HPACK bytes. headers
 // encodeHeader encodes headers and returns their HPACK bytes. headers
 // must contain an even number of key/value pairs.  There may be
 // must contain an even number of key/value pairs.  There may be
 // multiple pairs for keys (e.g. "cookie").  The :method, :path, and
 // multiple pairs for keys (e.g. "cookie").  The :method, :path, and
@@ -299,7 +314,6 @@ func (st *serverTester) encodeHeader(headers ...string) []byte {
 			vals[k] = append(vals[k], v)
 			vals[k] = append(vals[k], v)
 		}
 		}
 	}
 	}
-	st.headerBuf.Reset()
 	for _, k := range keys {
 	for _, k := range keys {
 		for _, v := range vals[k] {
 		for _, v := range vals[k] {
 			st.encodeHeaderField(k, v)
 			st.encodeHeaderField(k, v)
@@ -2451,8 +2465,53 @@ func TestCompressionErrorOnClose(t *testing.T) {
 
 
 // test that a server handler can read trailers from a client
 // test that a server handler can read trailers from a client
 func TestServerReadsTrailers(t *testing.T) {
 func TestServerReadsTrailers(t *testing.T) {
-	// TODO: use testBodyContents or testServerRequest
-	t.Skip("unimplemented")
+	const testBody = "some test body"
+	writeReq := func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
+			EndStream:     false,
+			EndHeaders:    true,
+		})
+		st.writeData(1, false, []byte(testBody))
+		st.writeHeaders(HeadersFrameParam{
+			StreamID: 1, // clients send odd numbers
+			BlockFragment: st.encodeHeaderRaw(
+				"foo", "foov",
+				"bar", "barv",
+				"baz", "bazv",
+				"surprise", "wasn't declared; shouldn't show up",
+			),
+			EndStream:  true,
+			EndHeaders: true,
+		})
+	}
+	checkReq := func(r *http.Request) {
+		wantTrailer := http.Header{
+			"Foo": nil,
+			"Bar": nil,
+			"Baz": nil,
+		}
+		if !reflect.DeepEqual(r.Trailer, wantTrailer) {
+			t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
+		}
+		slurp, err := ioutil.ReadAll(r.Body)
+		if string(slurp) != testBody {
+			t.Errorf("read body %q; want %q", slurp, testBody)
+		}
+		if err != nil {
+			t.Fatalf("Body slurp: %v", err)
+		}
+		wantTrailerAfter := http.Header{
+			"Foo": {"foov"},
+			"Bar": {"barv"},
+			"Baz": {"bazv"},
+		}
+		if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
+			t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
+		}
+	}
+	testServerRequest(t, writeReq, checkReq)
 }
 }
 
 
 // test that a server handler can send trailers
 // test that a server handler can send trailers