Browse Source

Implement request bodies.

Few TODOs yet.
Brad Fitzpatrick 11 years ago
parent
commit
ff6db8eca6
2 changed files with 202 additions and 25 deletions
  1. 84 15
      http2.go
  2. 118 10
      http2_test.go

+ 84 - 15
http2.go

@@ -23,6 +23,7 @@ import (
 	"bytes"
 	"crypto/tls"
 	"errors"
+	"fmt"
 	"io"
 	"log"
 	"net"
@@ -149,7 +150,11 @@ const (
 type stream struct {
 	id    uint32
 	state streamState // owned by serverConn's processing loop
-	flow  *flow
+	flow  *flow       // limits writing from Handler to client
+	body  *pipe       // non-nil if expecting DATA frames
+
+	bodyBytes     int64 // body bytes seen so far
+	declBodyBytes int64 // or -1 if undeclared
 }
 
 func (sc *serverConn) state(streamID uint32) streamState {
@@ -419,6 +424,8 @@ func (sc *serverConn) processFrame(f Frame) error {
 		return sc.processWindowUpdate(f)
 	case *PingFrame:
 		return sc.processPing(f)
+	case *DataFrame:
+		return sc.processData(f)
 	default:
 		log.Printf("Ignoring unknown frame %#v", f)
 		return nil
@@ -516,6 +523,48 @@ func (sc *serverConn) processSettingInitialWindowSize(val uint32) error {
 	return nil
 }
 
+func (sc *serverConn) processData(f *DataFrame) error {
+	sc.serveG.check()
+	// "If a DATA frame is received whose stream is not in "open"
+	// or "half closed (local)" state, the recipient MUST respond
+	// with a stream error (Section 5.4.2) of type STREAM_CLOSED."
+	id := f.Header().StreamID
+	st, ok := sc.streams[id]
+	if !ok || (st.state != stateOpen && st.state != stateHalfClosedLocal) {
+		return StreamError{id, ErrCodeStreamClosed}
+	}
+	if st.body == nil {
+		// Not expecting data.
+		// TODO: which error code?
+		return StreamError{id, ErrCodeStreamClosed}
+	}
+	data := f.Data()
+
+	// Sender sending more than they'd declared?
+	if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
+		st.body.Close(fmt.Errorf("Sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
+		return StreamError{id, ErrCodeStreamClosed}
+	}
+	if len(data) > 0 {
+		// TODO: verify they're allowed to write with the flow control
+		// window we'd advertised to them.
+		// TODO: verify n from Write
+		if _, err := st.body.Write(data); err != nil {
+			return StreamError{id, ErrCodeStreamClosed}
+		}
+		st.bodyBytes += int64(len(data))
+	}
+	if f.Header().Flags.Has(FlagDataEndStream) {
+		if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
+			st.body.Close(fmt.Errorf("Request declared a Content-Length of %d but only wrote %d bytes",
+				st.declBodyBytes, st.bodyBytes))
+		} else {
+			st.body.Close(io.EOF)
+		}
+	}
+	return nil
+}
+
 func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	sc.serveG.check()
 	id := f.Header().StreamID
@@ -550,19 +599,19 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 		stream: st,
 		header: make(http.Header),
 	}
-	return sc.processHeaderBlockFragment(id, f.HeaderBlockFragment(), f.HeadersEnded())
+	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 
 func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
 	sc.serveG.check()
-	id := f.Header().StreamID
-	if sc.curHeaderStreamID() != id {
+	st := sc.streams[f.Header().StreamID]
+	if st == nil || sc.curHeaderStreamID() != st.id {
 		return ConnectionError(ErrCodeProtocol)
 	}
-	return sc.processHeaderBlockFragment(id, f.HeaderBlockFragment(), f.HeadersEnded())
+	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 
-func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, end bool) error {
+func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bool) error {
 	sc.serveG.check()
 	if _, err := sc.hpackDecoder.Write(frag); err != nil {
 		// TODO: convert to stream error I assume?
@@ -580,6 +629,8 @@ func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, e
 	if err != nil {
 		return err
 	}
+	st.body = req.Body.(*requestBody).pipe // may be nil
+	st.declBodyBytes = req.ContentLength
 	go sc.runHandler(rw, req)
 	return nil
 }
@@ -611,6 +662,10 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 		authority = rp.header.Get("Host")
 	}
 	bodyOpen := rp.stream.state == stateOpen
+	body := &requestBody{
+		sc:       sc,
+		streamID: rp.stream.id,
+	}
 	req := &http.Request{
 		Method:     rp.method,
 		URL:        &url.URL{},
@@ -622,13 +677,14 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 		ProtoMinor: 0,
 		TLS:        tlsState,
 		Host:       authority,
-		Body: &requestBody{
-			sc:       sc,
-			streamID: rp.stream.id,
-			hasBody:  bodyOpen,
-		},
+		Body:       body,
 	}
 	if bodyOpen {
+		body.pipe = &pipe{
+			b: buffer{buf: make([]byte, 65536)}, // TODO: share/remove
+		}
+		body.pipe.c.L = &body.pipe.m
+
 		if vv, ok := rp.header["Content-Length"]; ok {
 			req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
 		} else {
@@ -638,6 +694,8 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 	rw := &responseWriter{
 		sc:       sc,
 		streamID: rp.stream.id,
+		req:      req,
+		body:     body,
 	}
 	return rw, req, nil
 }
@@ -732,21 +790,29 @@ func ConfigureServer(s *http.Server, conf *Server) {
 type requestBody struct {
 	sc       *serverConn
 	streamID uint32
-	hasBody  bool
 	closed   bool
+	pipe     *pipe // non-nil if we have a HTTP entity message body
 }
 
+var errClosedBody = errors.New("body closed by handler")
+
 func (b *requestBody) Close() error {
+	if b.pipe != nil {
+		b.pipe.Close(errClosedBody)
+	}
 	b.closed = true
 	return nil
 }
 
 func (b *requestBody) Read(p []byte) (n int, err error) {
-	if !b.hasBody {
+	if b.pipe == nil {
 		return 0, io.EOF
 	}
-	// TODO: implement
-	return 0, errors.New("TODO: we don't handle request bodies yet")
+	n, err = b.pipe.Read(p)
+	if n > 0 {
+		// TODO: tell b.sc to send back 'n' flow control quota credits to the sender
+	}
+	return
 }
 
 type responseWriter struct {
@@ -754,6 +820,9 @@ type responseWriter struct {
 	streamID     uint32
 	wroteHeaders bool
 	h            http.Header
+
+	req  *http.Request
+	body *requestBody // to close at end of request, if DATA frames didn't
 }
 
 // TODO: bufio writing of responseWriter. add Flush, add pools of

+ 118 - 10
http2_test.go

@@ -14,6 +14,7 @@ import (
 	"flag"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"log"
 	"net"
 	"net/http"
@@ -308,25 +309,132 @@ func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
 	})
 }
 
-func TestServer_Request_Post_Body(t *testing.T) {
-	t.Skip("TODO: post bodies not yet implemented")
-	testServerRequest(t, func(st *serverTester) {
+func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) {
+	testBodyContents(t, -1, "", func(st *serverTester) {
 		st.writeHeaders(HeadersFrameParam{
 			StreamID:      1, // clients send odd numbers
 			BlockFragment: encodeHeader(t, ":method", "POST"),
-			EndStream:     false, // migth be DATA frames
+			EndStream:     false, // to say DATA frames are coming
 			EndHeaders:    true,
 		})
-		st.writeData(1, true, nil)
-	}, func(r *http.Request) {
+		st.writeData(1, true, nil) // just kidding. empty body.
+	})
+}
+
+func TestServer_Request_Post_Body_OneData(t *testing.T) {
+	const content = "Some content"
+	testBodyContents(t, -1, content, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":method", "POST"),
+			EndStream:     false, // to say DATA frames are coming
+			EndHeaders:    true,
+		})
+		st.writeData(1, true, []byte(content))
+	})
+}
+
+func TestServer_Request_Post_Body_TwoData(t *testing.T) {
+	const content = "Some content"
+	testBodyContents(t, -1, content, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(t, ":method", "POST"),
+			EndStream:     false, // to say DATA frames are coming
+			EndHeaders:    true,
+		})
+		st.writeData(1, false, []byte(content[:5]))
+		st.writeData(1, true, []byte(content[5:]))
+	})
+}
+
+func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) {
+	const content = "Some content"
+	testBodyContents(t, int64(len(content)), content, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID: 1, // clients send odd numbers
+			BlockFragment: encodeHeader(t,
+				":method", "POST",
+				"content-length", strconv.Itoa(len(content)),
+			),
+			EndStream:  false, // to say DATA frames are coming
+			EndHeaders: true,
+		})
+		st.writeData(1, true, []byte(content))
+	})
+}
+
+func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) {
+	testBodyContentsFail(t, 3, "Request declared a Content-Length of 3 but only wrote 2 bytes",
+		func(st *serverTester) {
+			st.writeHeaders(HeadersFrameParam{
+				StreamID: 1, // clients send odd numbers
+				BlockFragment: encodeHeader(t,
+					":method", "POST",
+					"content-length", "3",
+				),
+				EndStream:  false, // to say DATA frames are coming
+				EndHeaders: true,
+			})
+			st.writeData(1, true, []byte("12"))
+		})
+}
+
+func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) {
+	testBodyContentsFail(t, 4, "Sender tried to send more than declared Content-Length of 4 bytes",
+		func(st *serverTester) {
+			st.writeHeaders(HeadersFrameParam{
+				StreamID: 1, // clients send odd numbers
+				BlockFragment: encodeHeader(t,
+					":method", "POST",
+					"content-length", "4",
+				),
+				EndStream:  false, // to say DATA frames are coming
+				EndHeaders: true,
+			})
+			st.writeData(1, true, []byte("12345"))
+		})
+}
+
+func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) {
+	testServerRequest(t, write, func(r *http.Request) {
 		if r.Method != "POST" {
 			t.Errorf("Method = %q; want POST", r.Method)
 		}
-		if r.ContentLength != -1 {
-			t.Errorf("ContentLength = %v; want -1", r.ContentLength)
+		if r.ContentLength != wantContentLength {
+			t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
 		}
-		if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
-			t.Errorf("Read = %d, %v; want 0, EOF", n, err)
+		all, err := ioutil.ReadAll(r.Body)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if string(all) != wantBody {
+			t.Errorf("Read = %q; want %q", all, wantBody)
+		}
+		if err := r.Body.Close(); err != nil {
+			t.Fatalf("Close: %v", err)
+		}
+	})
+}
+
+func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) {
+	testServerRequest(t, write, func(r *http.Request) {
+		if r.Method != "POST" {
+			t.Errorf("Method = %q; want POST", r.Method)
+		}
+		if r.ContentLength != wantContentLength {
+			t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
+		}
+		all, err := ioutil.ReadAll(r.Body)
+		if err == nil {
+			t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
+				wantReadError, all)
+		}
+		if !strings.Contains(err.Error(), wantReadError) {
+			t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError)
+		}
+		if err := r.Body.Close(); err != nil {
+			t.Fatalf("Close: %v", err)
 		}
 	})
 }