Browse Source

Automatic 100-continue support

Brad Fitzpatrick 11 years ago
parent
commit
9d63ade81d
2 changed files with 110 additions and 9 deletions
  1. 40 9
      server.go
  2. 70 0
      server_test.go

+ 40 - 9
server.go

@@ -31,8 +31,6 @@ const (
 	firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
 )
 
-// TODO: automatic 100-continue
-
 // TODO: finish GOAWAY support. Consider each incoming frame type and
 // whether it should be ignored during a shutdown race.
 
@@ -817,10 +815,15 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 	if authority == "" {
 		authority = rp.header.Get("Host")
 	}
+	needsContinue := rp.header.Get("Expect") == "100-continue"
+	if needsContinue {
+		rp.header.Del("Expect")
+	}
 	bodyOpen := rp.stream.state == stateOpen
 	body := &requestBody{
-		sc:       sc,
-		streamID: rp.stream.id,
+		sc:            sc,
+		streamID:      rp.stream.id,
+		needsContinue: needsContinue,
 	}
 	url, err := url.ParseRequestURI(rp.path)
 	if err != nil {
@@ -969,6 +972,30 @@ func (sc *serverConn) writeHeadersFrame(v interface{}) error {
 	})
 }
 
+// called from handler goroutines.
+// h may be nil.
+func (sc *serverConn) write100ContinueHeaders(streamID uint32) {
+	sc.serveG.checkNotOn()
+	sc.writeFrame(frameWriteMsg{
+		write:    (*serverConn).write100ContinueHeadersFrame,
+		v:        &streamID,
+		streamID: streamID,
+	})
+}
+
+func (sc *serverConn) write100ContinueHeadersFrame(v interface{}) error {
+	sc.writeG.check()
+	streamID := *(v.(*uint32))
+	sc.headerWriteBuf.Reset()
+	sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
+	return sc.framer.WriteHeaders(HeadersFrameParam{
+		StreamID:      streamID,
+		BlockFragment: sc.headerWriteBuf.Bytes(),
+		EndStream:     false,
+		EndHeaders:    true,
+	})
+}
+
 func (sc *serverConn) writeDataFrame(v interface{}) error {
 	sc.writeG.check()
 	rws := v.(*responseWriterState)
@@ -1013,10 +1040,11 @@ func (sc *serverConn) sendWindowUpdateInLoop(v interface{}) error {
 }
 
 type requestBody struct {
-	sc       *serverConn
-	streamID uint32
-	closed   bool
-	pipe     *pipe // non-nil if we have a HTTP entity message body
+	sc            *serverConn
+	streamID      uint32
+	closed        bool
+	pipe          *pipe // non-nil if we have a HTTP entity message body
+	needsContinue bool  // need to send a 100-continue
 }
 
 var errClosedBody = errors.New("body closed by handler")
@@ -1030,6 +1058,10 @@ func (b *requestBody) Close() error {
 }
 
 func (b *requestBody) Read(p []byte) (n int, err error) {
+	if b.needsContinue {
+		b.needsContinue = false
+		b.sc.write100ContinueHeaders(b.streamID)
+	}
 	if b.pipe == nil {
 		return 0, io.EOF
 	}
@@ -1073,7 +1105,6 @@ type responseWriterState struct {
 	snapHeader    http.Header // snapshot of handlerHeader at WriteHeader time
 	wroteHeader   bool        // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
 	status        int         // status code passed to WriteHeader
-	wroteContinue bool        // 100 Continue response was written
 	sentHeader    bool        // have we sent the header frame?
 	handlerDone   bool        // handler has finished
 

+ 70 - 0
server_test.go

@@ -1045,6 +1045,76 @@ func TestServer_Response_LargeWrite(t *testing.T) {
 	})
 }
 
+func TestServer_Response_Automatic100Continue(t *testing.T) {
+	const msg = "foo"
+	const reply = "bar"
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		if v := r.Header.Get("Expect"); v != "" {
+			t.Errorf("Expect header = %q; want empty", v)
+		}
+		buf := make([]byte, len(msg))
+		// This read should trigger the 100-continue being sent.
+		if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg {
+			return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg)
+		}
+		_, err := io.WriteString(w, reply)
+		return err
+	}, func(st *serverTester) {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      1, // clients send odd numbers
+			BlockFragment: encodeHeader(st.t, ":method", "POST", "expect", "100-continue"),
+			EndStream:     false,
+			EndHeaders:    true,
+		})
+		hf := st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("unexpected END_STREAM flag")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "100"},
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Fatalf("Got headers %v; want %v", goth, wanth)
+		}
+
+		// Okay, they sent status 100, so we can send our
+		// gigantic and/or sensitive "foo" payload now.
+		st.writeData(1, true, []byte(msg))
+
+		st.wantWindowUpdate(0, uint32(len(msg)))
+		st.wantWindowUpdate(1, uint32(len(msg)))
+
+		hf = st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("expected data to follow")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth = decodeHeader(t, hf.HeaderBlockFragment())
+		wanth = [][2]string{
+			{":status", "200"},
+			{"content-type", "text/plain; charset=utf-8"},
+			{"content-length", strconv.Itoa(len(reply))},
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+
+		df := st.wantData()
+		if string(df.Data()) != reply {
+			t.Errorf("Client read %q; want %q", df.Data(), reply)
+		}
+		if !df.StreamEnded() {
+			t.Errorf("expect data stream end")
+		}
+	})
+}
+
 func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) {
 	d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
 		pairs = append(pairs, [2]string{f.Name, f.Value})