Browse Source

http2: send client trailers

Change-Id: I9cb50eeb3f183f4237d7ba123b8123582fd37882
Reviewed-on: https://go-review.googlesource.com/17912
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Blake Mizerany 10 years ago
parent
commit
d2ecd08ab7
1 changed files with 103 additions and 33 deletions
  1. 103 33
      http2/transport.go

+ 103 - 33
http2/transport.go

@@ -18,6 +18,7 @@ import (
 	"log"
 	"net"
 	"net/http"
+	"sort"
 	"strconv"
 	"strings"
 	"sync"
@@ -488,9 +489,33 @@ func (cc *ClientConn) putFrameScratchBuffer(buf []byte) {
 // exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests.
 var errRequestCanceled = errors.New("net/http: request canceled")
 
+func commaSeparatedTrailers(req *http.Request) (string, error) {
+	keys := make([]string, 0, len(req.Trailer))
+	for k := range req.Trailer {
+		k = http.CanonicalHeaderKey(k)
+		switch k {
+		case "Transfer-Encoding", "Trailer", "Content-Length":
+			return "", &badStringError{"invalid Trailer key", k}
+		}
+		keys = append(keys, k)
+	}
+	if len(keys) > 0 {
+		sort.Strings(keys)
+		// TODO: could do better allocation-wise here, but trailers are rare,
+		// so being lazy for now.
+		return strings.Join(keys, ","), nil
+	}
+	return "", nil
+}
+
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
-	cc.mu.Lock()
+	trailers, err := commaSeparatedTrailers(req)
+	if err != nil {
+		return nil, err
+	}
+	hasTrailers := trailers != ""
 
+	cc.mu.Lock()
 	if cc.closed || !cc.canTakeNewRequestLocked() {
 		cc.mu.Unlock()
 		return nil, errClientConnUnusable
@@ -521,36 +546,10 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 
 	// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,}
-	hdrs := cc.encodeHeaders(req, cs.requestedGzip)
-	first := true // first frame written (HEADERS is first, then CONTINUATION)
-
+	hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers)
 	cc.wmu.Lock()
-	frameSize := int(cc.maxFrameSize)
-	for len(hdrs) > 0 && cc.werr == nil {
-		chunk := hdrs
-		if len(chunk) > frameSize {
-			chunk = chunk[:frameSize]
-		}
-		hdrs = hdrs[len(chunk):]
-		endHeaders := len(hdrs) == 0
-		if first {
-			cc.fr.WriteHeaders(HeadersFrameParam{
-				StreamID:      cs.ID,
-				BlockFragment: chunk,
-				EndStream:     !hasBody,
-				EndHeaders:    endHeaders,
-			})
-			first = false
-		} else {
-			cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
-		}
-	}
-	// TODO(bradfitz): this Flush could potentially block (as
-	// could the WriteHeaders call(s) above), which means they
-	// wouldn't respond to Request.Cancel being readable. That's
-	// rare, but this should probably be in a goroutine.
-	cc.bw.Flush()
-	werr := cc.werr
+	endStream := !hasBody && !hasTrailers
+	werr := cc.writeHeaders(cs.ID, endStream, hdrs)
 	cc.wmu.Unlock()
 	cc.mu.Unlock()
 
@@ -601,6 +600,37 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 }
 
+// requires cc.wmu be held
+func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error {
+	first := true // first frame written (HEADERS is first, then CONTINUATION)
+	frameSize := int(cc.maxFrameSize)
+	for len(hdrs) > 0 && cc.werr == nil {
+		chunk := hdrs
+		if len(chunk) > frameSize {
+			chunk = chunk[:frameSize]
+		}
+		hdrs = hdrs[len(chunk):]
+		endHeaders := len(hdrs) == 0
+		if first {
+			cc.fr.WriteHeaders(HeadersFrameParam{
+				StreamID:      streamID,
+				BlockFragment: chunk,
+				EndStream:     endStream,
+				EndHeaders:    endHeaders,
+			})
+			first = false
+		} else {
+			cc.fr.WriteContinuation(streamID, endHeaders, chunk)
+		}
+	}
+	// TODO(bradfitz): this Flush could potentially block (as
+	// could the WriteHeaders call(s) above), which means they
+	// wouldn't respond to Request.Cancel being readable. That's
+	// rare, but this should probably be in a goroutine.
+	cc.bw.Flush()
+	return cc.werr
+}
+
 // errAbortReqBodyWrite is an internal error value.
 // It doesn't escape to callers.
 var errAbortReqBodyWrite = errors.New("http2: aborting request body write")
@@ -622,6 +652,9 @@ func (cs *clientStream) writeRequestBody(body io.ReadCloser) (err error) {
 		}
 	}()
 
+	req := cs.req
+	hasTrailers := req.Trailer != nil
+
 	var sawEOF bool
 	for !sawEOF {
 		n, err := body.Read(buf)
@@ -642,7 +675,7 @@ func (cs *clientStream) writeRequestBody(body io.ReadCloser) (err error) {
 			cc.wmu.Lock()
 			data := remain[:allowed]
 			remain = remain[allowed:]
-			sentEnd = sawEOF && len(remain) == 0
+			sentEnd = sawEOF && len(remain) == 0 && !hasTrailers
 			err = cc.fr.WriteData(cs.ID, sentEnd, data)
 			if err == nil {
 				// TODO(bradfitz): this flush is for latency, not bandwidth.
@@ -661,7 +694,20 @@ func (cs *clientStream) writeRequestBody(body io.ReadCloser) (err error) {
 
 	cc.wmu.Lock()
 	if !sentEnd {
-		err = cc.fr.WriteData(cs.ID, true, nil)
+		var trls []byte
+		if hasTrailers {
+			cc.mu.Lock()
+			trls = cc.encodeTrailers(req)
+			cc.mu.Unlock()
+		}
+
+		// Avoid forgetting to send an END_STREAM if the encoded
+		// trailers are 0 bytes. Both results produce and END_STREAM.
+		if len(trls) > 0 {
+			err = cc.writeHeaders(cs.ID, true, trls)
+		} else {
+			err = cc.fr.WriteData(cs.ID, true, nil)
+		}
 	}
 	if ferr := cc.bw.Flush(); ferr != nil && err == nil {
 		err = ferr
@@ -705,8 +751,15 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
 	}
 }
 
+type badStringError struct {
+	what string
+	str  string
+}
+
+func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
+
 // requires cc.mu be held.
-func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool) []byte {
+func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string) []byte {
 	cc.hbuf.Reset()
 
 	// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
@@ -724,6 +777,9 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool) []byt
 	cc.writeHeader(":method", req.Method)
 	cc.writeHeader(":path", req.URL.RequestURI())
 	cc.writeHeader(":scheme", "https")
+	if trailers != "" {
+		cc.writeHeader("trailer", trailers)
+	}
 
 	for k, vv := range req.Header {
 		lowKey := strings.ToLower(k)
@@ -740,6 +796,20 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool) []byt
 	return cc.hbuf.Bytes()
 }
 
+// requires cc.mu be held.
+func (cc *ClientConn) encodeTrailers(req *http.Request) []byte {
+	cc.hbuf.Reset()
+	for k, vv := range req.Trailer {
+		// Transfer-Encoding, etc.. have already been filter at the
+		// start of RoundTrip
+		lowKey := strings.ToLower(k)
+		for _, v := range vv {
+			cc.writeHeader(lowKey, v)
+		}
+	}
+	return cc.hbuf.Bytes()
+}
+
 func (cc *ClientConn) writeHeader(name, value string) {
 	cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
 }