浏览代码

http2: client & server fixes

Fixes found in the process of adding more A/B tests to net/http,
comparing HTTP/1 and HTTP/2 behaviors.

Most of the new tests are in Gerrit change Id9c45fad44cdf70ac9
in the "go" repo.

Fixes golang/go#13315
Fixes golang/go#13316
Fixes golang/go#13317
Fixes other stuff found in the process too
Updates golang/go#6891 (http2 support in general)

Change-Id: I83b5bfb471047312c0dcb0a0b21d709008f34136
Reviewed-on: https://go-review.googlesource.com/17204
Reviewed-by: Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick 10 年之前
父节点
当前提交
c745c36eab
共有 7 个文件被更改,包括 145 次插入17 次删除
  1. 22 6
      http2/http2.go
  2. 6 0
      http2/http2_test.go
  3. 40 5
      http2/server.go
  4. 18 1
      http2/server_test.go
  5. 48 4
      http2/transport.go
  6. 2 0
      http2/transport_test.go
  7. 9 1
      http2/write.go

+ 22 - 6
http2/http2.go

@@ -4,13 +4,15 @@
 
 
 // Package http2 implements the HTTP/2 protocol.
 // Package http2 implements the HTTP/2 protocol.
 //
 //
-// This is a work in progress. This package is low-level and intended
-// to be used directly by very few people. Most users will use it
-// indirectly through integration with the net/http package. See
-// ConfigureServer. That ConfigureServer call will likely be automatic
-// or available via an empty import in the future.
+// This package is low-level and intended to be used directly by very
+// few people. Most users will use it indirectly through the automatic
+// use by the net/http package (from Go 1.6 and later).
+// For use in earlier Go versions see ConfigureServer. (Transport support
+// requires Go 1.6 or later)
 //
 //
-// See http://http2.github.io/
+// See https://http2.github.io/ for more information on HTTP/2.
+//
+// See https://http2.golang.org/ for a test server running this code.
 package http2
 package http2
 
 
 import (
 import (
@@ -251,3 +253,17 @@ func mustUint31(v int32) uint32 {
 	}
 	}
 	return uint32(v)
 	return uint32(v)
 }
 }
+
+// bodyAllowedForStatus reports whether a given response status code
+// permits a body. See RFC2616, section 4.4.
+func bodyAllowedForStatus(status int) bool {
+	switch {
+	case status >= 100 && status <= 199:
+		return false
+	case status == 204:
+		return false
+	case status == 304:
+		return false
+	}
+	return true
+}

+ 6 - 0
http2/http2_test.go

@@ -166,3 +166,9 @@ func kill(container string) {
 	exec.Command("docker", "kill", container).Run()
 	exec.Command("docker", "kill", container).Run()
 	exec.Command("docker", "rm", container).Run()
 	exec.Command("docker", "rm", container).Run()
 }
 }
+
+func cleanDate(res *http.Response) {
+	if d := res.Header["Date"]; len(d) == 1 {
+		d[0] = "XXX"
+	}
+}

+ 40 - 5
http2/server.go

@@ -1457,6 +1457,11 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 		// pseudo-header fields"
 		// pseudo-header fields"
 		return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
 		return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
 	}
 	}
+	bodyOpen := rp.stream.state == stateOpen
+	if rp.method == "HEAD" && bodyOpen {
+		// HEAD requests can't have bodies
+		return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+	}
 	var tlsState *tls.ConnectionState // nil if not scheme https
 	var tlsState *tls.ConnectionState // nil if not scheme https
 	if rp.scheme == "https" {
 	if rp.scheme == "https" {
 		tlsState = sc.tlsState
 		tlsState = sc.tlsState
@@ -1473,7 +1478,6 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 	if cookies := rp.header["Cookie"]; len(cookies) > 1 {
 	if cookies := rp.header["Cookie"]; len(cookies) > 1 {
 		rp.header.Set("Cookie", strings.Join(cookies, "; "))
 		rp.header.Set("Cookie", strings.Join(cookies, "; "))
 	}
 	}
-	bodyOpen := rp.stream.state == stateOpen
 	body := &requestBody{
 	body := &requestBody{
 		conn:          sc,
 		conn:          sc,
 		stream:        rp.stream,
 		stream:        rp.stream,
@@ -1720,6 +1724,9 @@ type responseWriterState struct {
 	sentHeader    bool        // have we sent the header frame?
 	sentHeader    bool        // have we sent the header frame?
 	handlerDone   bool        // handler has finished
 	handlerDone   bool        // handler has finished
 
 
+	sentContentLen int64 // non-zero if handler set a Content-Length header
+	wroteBytes     int64
+
 	closeNotifierMu sync.Mutex // guards closeNotifierCh
 	closeNotifierMu sync.Mutex // guards closeNotifierCh
 	closeNotifierCh chan bool  // nil until first used
 	closeNotifierCh chan bool  // nil until first used
 }
 }
@@ -1738,16 +1745,31 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 	if !rws.wroteHeader {
 	if !rws.wroteHeader {
 		rws.writeHeader(200)
 		rws.writeHeader(200)
 	}
 	}
+	isHeadResp := rws.req.Method == "HEAD"
 	if !rws.sentHeader {
 	if !rws.sentHeader {
 		rws.sentHeader = true
 		rws.sentHeader = true
-		var ctype, clen string // implicit ones, if we can calculate it
-		if rws.handlerDone && rws.snapHeader.Get("Content-Length") == "" {
+		var ctype, clen string
+		if clen = rws.snapHeader.Get("Content-Length"); clen != "" {
+			rws.snapHeader.Del("Content-Length")
+			clen64, err := strconv.ParseInt(clen, 10, 64)
+			if err == nil && clen64 >= 0 {
+				rws.sentContentLen = clen64
+			} else {
+				clen = ""
+			}
+		}
+		if clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) {
 			clen = strconv.Itoa(len(p))
 			clen = strconv.Itoa(len(p))
 		}
 		}
-		if rws.snapHeader.Get("Content-Type") == "" {
+		if rws.snapHeader.Get("Content-Type") == "" && bodyAllowedForStatus(rws.status) {
 			ctype = http.DetectContentType(p)
 			ctype = http.DetectContentType(p)
 		}
 		}
-		endStream := rws.handlerDone && len(p) == 0
+		var date string
+		if _, ok := rws.snapHeader["Date"]; !ok {
+			// TODO(bradfitz): be faster here, like net/http? measure.
+			date = time.Now().UTC().Format(http.TimeFormat)
+		}
+		endStream := (rws.handlerDone && len(p) == 0) || isHeadResp
 		err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
 		err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
 			streamID:      rws.stream.id,
 			streamID:      rws.stream.id,
 			httpResCode:   rws.status,
 			httpResCode:   rws.status,
@@ -1755,6 +1777,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			endStream:     endStream,
 			endStream:     endStream,
 			contentType:   ctype,
 			contentType:   ctype,
 			contentLength: clen,
 			contentLength: clen,
+			date:          date,
 		})
 		})
 		if err != nil {
 		if err != nil {
 			return 0, err
 			return 0, err
@@ -1763,6 +1786,9 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 			return 0, nil
 			return 0, nil
 		}
 		}
 	}
 	}
+	if isHeadResp {
+		return len(p), nil
+	}
 	if len(p) == 0 && !rws.handlerDone {
 	if len(p) == 0 && !rws.handlerDone {
 		return 0, nil
 		return 0, nil
 	}
 	}
@@ -1875,6 +1901,15 @@ func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int,
 	if !rws.wroteHeader {
 	if !rws.wroteHeader {
 		w.WriteHeader(200)
 		w.WriteHeader(200)
 	}
 	}
+	if !bodyAllowedForStatus(rws.status) {
+		return 0, http.ErrBodyNotAllowed
+	}
+	rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set
+	if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen {
+		// TODO: send a RST_STREAM
+		return 0, errors.New("http2: handler wrote more than declared Content-Length")
+	}
+
 	if dataB != nil {
 	if dataB != nil {
 		return rws.bw.Write(dataB)
 		return rws.bw.Write(dataB)
 	} else {
 	} else {

+ 18 - 1
http2/server_test.go

@@ -134,7 +134,6 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
 		st.cc = cc
 		st.cc = cc
 		st.fr = NewFramer(cc, cc)
 		st.fr = NewFramer(cc, cc)
 	}
 	}
-
 	return st
 	return st
 }
 }
 
 
@@ -2129,6 +2128,9 @@ func TestServer_Advertises_Common_Cipher(t *testing.T) {
 // creating a new decoder each time.
 // creating a new decoder each time.
 func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) {
 func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) {
 	d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
 	d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
+		if f.Name == "date" {
+			return
+		}
 		pairs = append(pairs, [2]string{f.Name, f.Value})
 		pairs = append(pairs, [2]string{f.Name, f.Value})
 	})
 	})
 	if _, err := d.Write(headerBlock); err != nil {
 	if _, err := d.Write(headerBlock); err != nil {
@@ -2620,3 +2622,18 @@ func TestConfigureServer(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestServerRejectHeadWithBody(t *testing.T) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		// No response body.
+	})
+	defer st.Close()
+	st.greet()
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1, // clients send odd numbers
+		BlockFragment: st.encodeHeader(":method", "HEAD"),
+		EndStream:     false, // what we're testing, a bogus HEAD request with body
+		EndHeaders:    true,
+	})
+	st.wantRSTStream(1, ErrCodeProtocol)
+}

+ 48 - 4
http2/transport.go

@@ -125,12 +125,15 @@ type ClientConn struct {
 // is created for each Transport.RoundTrip call.
 // is created for each Transport.RoundTrip call.
 type clientStream struct {
 type clientStream struct {
 	cc      *ClientConn
 	cc      *ClientConn
+	req     *http.Request
 	ID      uint32
 	ID      uint32
 	resc    chan resAndError
 	resc    chan resAndError
 	bufPipe pipe // buffered pipe with the flow-controlled response payload
 	bufPipe pipe // buffered pipe with the flow-controlled response payload
 
 
-	flow   flow // guarded by cc.mu
-	inflow flow // guarded by cc.mu
+	flow        flow  // guarded by cc.mu
+	inflow      flow  // guarded by cc.mu
+	bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
+	readErr     error // sticky read error; owned by transportResponseBody.Read
 
 
 	peerReset chan struct{} // closed on peer reset
 	peerReset chan struct{} // closed on peer reset
 	resetErr  error         // populated before peerReset is closed
 	resetErr  error         // populated before peerReset is closed
@@ -435,6 +438,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 	}
 
 
 	cs := cc.newStream()
 	cs := cc.newStream()
+	cs.req = req
 	hasBody := req.Body != nil
 	hasBody := req.Body != nil
 
 
 	// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,}
 	// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,}
@@ -826,13 +830,31 @@ func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID u
 	}
 	}
 
 
 	res := rl.nextRes
 	res := rl.nextRes
+
+	if !streamEnded || cs.req.Method == "HEAD" {
+		res.ContentLength = -1
+		if clens := res.Header["Content-Length"]; len(clens) == 1 {
+			if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
+				res.ContentLength = clen64
+			} else {
+				// TODO: care? unlike http/1, it won't mess up our framing, so it's
+				// more safe smuggling-wise to ignore.
+			}
+		} else if len(clens) > 1 {
+			// TODO: care? unlike http/1, it won't mess up our framing, so it's
+			// more safe smuggling-wise to ignore.
+		}
+	}
+
 	if streamEnded {
 	if streamEnded {
 		res.Body = noBody
 		res.Body = noBody
 	} else {
 	} else {
 		buf := new(bytes.Buffer) // TODO(bradfitz): recycle this garbage
 		buf := new(bytes.Buffer) // TODO(bradfitz): recycle this garbage
 		cs.bufPipe = pipe{b: buf}
 		cs.bufPipe = pipe{b: buf}
+		cs.bytesRemain = res.ContentLength
 		res.Body = transportResponseBody{cs}
 		res.Body = transportResponseBody{cs}
 	}
 	}
+
 	rl.activeRes[cs.ID] = cs
 	rl.activeRes[cs.ID] = cs
 	cs.resc <- resAndError{res: res}
 	cs.resc <- resAndError{res: res}
 	rl.nextRes = nil // unused now; will be reset next HEADERS frame
 	rl.nextRes = nil // unused now; will be reset next HEADERS frame
@@ -847,13 +869,35 @@ type transportResponseBody struct {
 }
 }
 
 
 func (b transportResponseBody) Read(p []byte) (n int, err error) {
 func (b transportResponseBody) Read(p []byte) (n int, err error) {
+	cs := b.cs
+	cc := cs.cc
+
+	if cs.readErr != nil {
+		return 0, cs.readErr
+	}
 	n, err = b.cs.bufPipe.Read(p)
 	n, err = b.cs.bufPipe.Read(p)
+	if cs.bytesRemain != -1 {
+		if int64(n) > cs.bytesRemain {
+			n = int(cs.bytesRemain)
+			if err == nil {
+				err = errors.New("net/http: server replied with more than declared Content-Length; truncated")
+				cc.writeStreamReset(cs.ID, ErrCodeProtocol, err)
+			}
+			cs.readErr = err
+			return int(cs.bytesRemain), err
+		}
+		cs.bytesRemain -= int64(n)
+		if err == io.EOF && cs.bytesRemain > 0 {
+			err = io.ErrUnexpectedEOF
+			cs.readErr = err
+			return n, err
+		}
+	}
 	if n == 0 {
 	if n == 0 {
+		// No flow control tokens to send back.
 		return
 		return
 	}
 	}
 
 
-	cs := b.cs
-	cc := cs.cc
 	cc.mu.Lock()
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	defer cc.mu.Unlock()
 
 

+ 2 - 0
http2/transport_test.go

@@ -73,7 +73,9 @@ func TestTransport(t *testing.T) {
 	wantHeader := http.Header{
 	wantHeader := http.Header{
 		"Content-Length": []string{"3"},
 		"Content-Length": []string{"3"},
 		"Content-Type":   []string{"text/plain; charset=utf-8"},
 		"Content-Type":   []string{"text/plain; charset=utf-8"},
+		"Date":           []string{"XXX"}, // see cleanDate
 	}
 	}
+	cleanDate(res)
 	if !reflect.DeepEqual(res.Header, wantHeader) {
 	if !reflect.DeepEqual(res.Header, wantHeader) {
 		t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
 		t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
 	}
 	}

+ 9 - 1
http2/write.go

@@ -23,7 +23,11 @@ type writeFramer interface {
 // frame writing scheduler (see writeScheduler in writesched.go).
 // frame writing scheduler (see writeScheduler in writesched.go).
 //
 //
 // This interface is implemented by *serverConn.
 // This interface is implemented by *serverConn.
-// TODO: use it from the client code too, once it exists.
+//
+// TODO: decide whether to a) use this in the client code (which didn't
+// end up using this yet, because it has a simpler design, not
+// currently implementing priorities), or b) delete this and
+// make the server code a bit more concrete.
 type writeContext interface {
 type writeContext interface {
 	Framer() *Framer
 	Framer() *Framer
 	Flush() error
 	Flush() error
@@ -115,6 +119,7 @@ type writeResHeaders struct {
 	h           http.Header // may be nil
 	h           http.Header // may be nil
 	endStream   bool
 	endStream   bool
 
 
+	date          string
 	contentType   string
 	contentType   string
 	contentLength string
 	contentLength string
 }
 }
@@ -139,6 +144,9 @@ func (w *writeResHeaders) writeFrame(ctx writeContext) error {
 	if w.contentLength != "" {
 	if w.contentLength != "" {
 		enc.WriteField(hpack.HeaderField{Name: "content-length", Value: w.contentLength})
 		enc.WriteField(hpack.HeaderField{Name: "content-length", Value: w.contentLength})
 	}
 	}
+	if w.date != "" {
+		enc.WriteField(hpack.HeaderField{Name: "date", Value: w.date})
+	}
 
 
 	headerBlock := buf.Bytes()
 	headerBlock := buf.Bytes()
 	if len(headerBlock) == 0 {
 	if len(headerBlock) == 0 {