Browse Source

http2: add server-side trailer support

Change-Id: I39dbf0cdeee0123b6c6efff1fc6854bcedb94753
Reviewed-on: https://go-review.googlesource.com/17878
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Blake Mizerany 10 năm trước cách đây
mục cha
commit
b4be494138
3 tập tin đã thay đổi với 121 bổ sung29 xóa
  1. 58 3
      http2/server.go
  2. 28 5
      http2/server_test.go
  3. 35 21
      http2/write.go

+ 58 - 3
http2/server.go

@@ -46,6 +46,7 @@ import (
 	"log"
 	"net"
 	"net/http"
+	"net/textproto"
 	"net/url"
 	"runtime"
 	"strconv"
@@ -1877,6 +1878,7 @@ type responseWriterState struct {
 	// mutated by http.Handler goroutine:
 	handlerHeader http.Header // nil until called
 	snapHeader    http.Header // snapshot of handlerHeader at WriteHeader time
+	trailers      []string    // set in writeChunk
 	status        int         // status code passed to WriteHeader
 	wroteHeader   bool        // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
 	sentHeader    bool        // have we sent the header frame?
@@ -1893,6 +1895,21 @@ type chunkWriter struct{ rws *responseWriterState }
 
 func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) }
 
+func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) != 0 }
+
+// declareTrailer is called for each Trailer header when the
+// response header is written. It notes that a header will need to be
+// written in the trailers at the end of the response.
+func (rws *responseWriterState) declareTrailer(k string) {
+	k = http.CanonicalHeaderKey(k)
+	switch k {
+	case "Transfer-Encoding", "Content-Length", "Trailer":
+		// Forbidden by RFC 2616 14.40.
+		return
+	}
+	rws.trailers = append(rws.trailers, k)
+}
+
 // writeChunk writes chunks from the bufio.Writer. But because
 // bufio.Writer may bypass its chunking, sometimes p may be
 // arbitrarily large.
@@ -1903,6 +1920,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 	if !rws.wroteHeader {
 		rws.writeHeader(200)
 	}
+
 	isHeadResp := rws.req.Method == "HEAD"
 	if !rws.sentHeader {
 		rws.sentHeader = true
@@ -1928,7 +1946,12 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			// TODO(bradfitz): be faster here, like net/http? measure.
 			date = time.Now().UTC().Format(http.TimeFormat)
 		}
-		endStream := (rws.handlerDone && len(p) == 0) || isHeadResp
+
+		for _, v := range rws.snapHeader["Trailer"] {
+			foreachHeaderElement(v, rws.declareTrailer)
+		}
+
+		endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
 		err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
 			streamID:      rws.stream.id,
 			httpResCode:   rws.status,
@@ -1952,8 +1975,22 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 		return 0, nil
 	}
 
-	if err := rws.conn.writeDataFromHandler(rws.stream, p, rws.handlerDone); err != nil {
-		return 0, err
+	endStream := rws.handlerDone && !rws.hasTrailers()
+	if len(p) > 0 || endStream {
+		// only send a 0 byte DATA frame if we're ending the stream.
+		if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
+			return 0, err
+		}
+	}
+
+	if rws.handlerDone && rws.hasTrailers() {
+		err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
+			streamID:  rws.stream.id,
+			h:         rws.handlerHeader,
+			trailers:  rws.trailers,
+			endStream: true,
+		})
+		return len(p), err
 	}
 	return len(p), nil
 }
@@ -2083,3 +2120,21 @@ func (w *responseWriter) handlerDone() {
 	w.rws = nil
 	responseWriterStatePool.Put(rws)
 }
+
+// foreachHeaderElement splits v according to the "#rule" construction
+// in RFC 2616 section 2.1 and calls fn for each non-empty element.
+func foreachHeaderElement(v string, fn func(string)) {
+	v = textproto.TrimString(v)
+	if v == "" {
+		return
+	}
+	if !strings.Contains(v, ",") {
+		fn(v)
+		return
+	}
+	for _, f := range strings.Split(v, ",") {
+		if f = textproto.TrimString(f); f != "" {
+			fn(f)
+		}
+	}
+}

+ 28 - 5
http2/server_test.go

@@ -2515,17 +2515,32 @@ func TestServerReadsTrailers(t *testing.T) {
 }
 
 // test that a server handler can send trailers
-func TestServerWritesTrailers(t *testing.T) {
-	t.Skip("known failing test; see golang.org/issue/13557")
+func TestServerWritesTrailers_WithFlush(t *testing.T)    { testServerWritesTrailers(t, true) }
+func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
+
+func testServerWritesTrailers(t *testing.T, withFlush bool) {
 	// See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3
 	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
 		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
 		w.Header().Add("Trailer", "Server-Trailer-C")
+
+		// TODO: decide if the server should filter these while
+		// writing the Trailer header in the response. Currently it
+		// appears net/http doesn't do this for http/1.1
+		w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer") // filtered
 		w.Header().Set("Foo", "Bar")
+		w.Header().Set("Content-Length", "5")
+
 		io.WriteString(w, "Hello")
-		w.(http.Flusher).Flush()
+		if withFlush {
+			w.(http.Flusher).Flush()
+		}
 		w.Header().Set("Server-Trailer-A", "valuea")
 		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
+		w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
+		w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 2616 14.40")
+		w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 2616 14.40")
+		w.Header().Set("Trailer", "should not be included; Forbidden by RFC 2616 14.40")
 		return nil
 	}, func(st *serverTester) {
 		getSlash(st)
@@ -2542,7 +2557,9 @@ func TestServerWritesTrailers(t *testing.T) {
 			{"foo", "Bar"},
 			{"trailer", "Server-Trailer-A, Server-Trailer-B"},
 			{"trailer", "Server-Trailer-C"},
+			{"trailer", "Transfer-Encoding, Content-Length, Trailer"},
 			{"content-type", "text/plain; charset=utf-8"},
+			{"content-length", "5"},
 		}
 		if !reflect.DeepEqual(goth, wanth) {
 			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
@@ -2561,8 +2578,14 @@ func TestServerWritesTrailers(t *testing.T) {
 		if !tf.HeadersEnded() {
 			t.Fatalf("trailers HEADERS lacked END_HEADERS")
 		}
-		pairs := st.decodeHeader(tf.HeaderBlockFragment())
-		t.Logf("Got: %v", pairs)
+		wanth = [][2]string{
+			{"server-trailer-a", "valuea"},
+			{"server-trailer-c", "valuec"},
+		}
+		goth = st.decodeHeader(tf.HeaderBlockFragment())
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
+		}
 	})
 }
 

+ 35 - 21
http2/write.go

@@ -123,11 +123,12 @@ func (writeSettingsAck) writeFrame(ctx writeContext) error {
 }
 
 // writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
-// for HTTP response headers from a server handler.
+// for HTTP response headers or trailers from a server handler.
 type writeResHeaders struct {
 	streamID    uint32
-	httpResCode int
+	httpResCode int         // 0 means no ":status" line
 	h           http.Header // may be nil
+	trailers    []string    // if non-nil, which keys of h to write. nil means all.
 	endStream   bool
 
 	date          string
@@ -138,26 +139,16 @@ type writeResHeaders struct {
 func (w *writeResHeaders) writeFrame(ctx writeContext) error {
 	enc, buf := ctx.HeaderEncoder()
 	buf.Reset()
-	enc.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(w.httpResCode)})
 
-	// TODO: garbage. pool sorters like http1? hot path for 1 key?
-	keys := make([]string, 0, len(w.h))
-	for k := range w.h {
-		keys = append(keys, k)
-	}
-	sort.Strings(keys)
-	for _, k := range keys {
-		vv := w.h[k]
-		k = lowerHeader(k)
-		isTE := k == "transfer-encoding"
-		for _, v := range vv {
-			// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
-			if isTE && v != "trailers" {
-				continue
-			}
-			enc.WriteField(hpack.HeaderField{Name: k, Value: v})
-		}
+	if w.httpResCode != 0 {
+		enc.WriteField(hpack.HeaderField{
+			Name:  ":status",
+			Value: httpCodeString(w.httpResCode),
+		})
 	}
+
+	encodeHeaders(enc, w.h, w.trailers)
+
 	if w.contentType != "" {
 		enc.WriteField(hpack.HeaderField{Name: "content-type", Value: w.contentType})
 	}
@@ -169,7 +160,7 @@ func (w *writeResHeaders) writeFrame(ctx writeContext) error {
 	}
 
 	headerBlock := buf.Bytes()
-	if len(headerBlock) == 0 {
+	if len(headerBlock) == 0 && w.trailers == nil {
 		panic("unexpected empty hpack")
 	}
 
@@ -232,3 +223,26 @@ type writeWindowUpdate struct {
 func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
 	return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
 }
+
+func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
+	// TODO: garbage. pool sorters like http1? hot path for 1 key?
+	if keys == nil {
+		keys = make([]string, 0, len(h))
+		for k := range h {
+			keys = append(keys, k)
+		}
+		sort.Strings(keys)
+	}
+	for _, k := range keys {
+		vv := h[k]
+		k = lowerHeader(k)
+		isTE := k == "transfer-encoding"
+		for _, v := range vv {
+			// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
+			if isTE && v != "trailers" {
+				continue
+			}
+			enc.WriteField(hpack.HeaderField{Name: k, Value: v})
+		}
+	}
+}