Browse Source

Support ResponseWriter.Write from http.Handlers (sending DATA frames)

Brad Fitzpatrick 11 years ago
parent
commit
390047ea2a
2 changed files with 509 additions and 49 deletions
  1. 141 49
      server.go
  2. 368 0
      server_test.go

+ 141 - 49
server.go

@@ -8,6 +8,7 @@
 package http2
 
 import (
+	"bufio"
 	"bytes"
 	"crypto/tls"
 	"errors"
@@ -418,9 +419,13 @@ func (sc *serverConn) scheduleFrameWrite() {
 		// TODO: flush Framer's underlying buffered writer, once that's added
 		return
 	}
+
 	// TODO: proper scheduler
 	wm := sc.writeQueue[0]
-	copy(sc.writeQueue, sc.writeQueue[1:]) // shift it all down. kinda lame. will be removed later anyway.
+	// shift it all down. kinda lame. will be removed later anyway.
+	copy(sc.writeQueue, sc.writeQueue[1:])
+	sc.writeQueue = sc.writeQueue[:len(sc.writeQueue)-1]
+
 	sc.writingFrame = true
 	sc.writeFrameCh <- wm
 }
@@ -740,24 +745,27 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
 	}
 
 	rws := responseWriterStatePool.Get().(*responseWriterState)
-	wbufSave := rws.wbuf
+	bwSave := rws.bw
 	*rws = responseWriterState{} // zero all the fields
-	rws.wbuf = wbufSave
-	rws.wbuf.Reset()
+	rws.bw = bwSave
+	rws.bw.Reset(chunkWriter{rws})
 	rws.sc = sc
 	rws.streamID = rp.stream.id
 	rws.req = req
 	rws.body = body
+	rws.chunkWrittenCh = make(chan error, 1)
 
 	rw := &responseWriter{rws: rws}
 	return rw, req, nil
 }
 
+const handlerChunkWriteSize = 4 << 10
+
 var responseWriterStatePool = sync.Pool{
 	New: func() interface{} {
-		return &responseWriterState{
-			wbuf: new(bytes.Buffer),
-		}
+		rws := &responseWriterState{}
+		rws.bw = bufio.NewWriterSize(chunkWriter{rws}, handlerChunkWriteSize)
+		return rws
 	},
 }
 
@@ -768,13 +776,6 @@ func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request) {
 	sc.handler.ServeHTTP(rw, req)
 }
 
-// called from handler goroutines
-func (sc *serverConn) writeData(streamID uint32, p []byte) (n int, err error) {
-	// TODO: implement
-	log.Printf("WRITE on %d: %q", streamID, p)
-	return len(p), nil
-}
-
 type frameWriteMsg struct {
 	// write runs on the writeFrames goroutine.
 	write func(sc *serverConn, v interface{}) error
@@ -786,7 +787,7 @@ type frameWriteMsg struct {
 	// done, if non-nil, must be a buffered channel with space for
 	// 1 message and is sent the return value from write (or an
 	// earlier error) when the frame has been written.
-	done chan<- error
+	done chan error
 }
 
 // headerWriteReq is a request to write an HTTP response header from a server Handler.
@@ -803,10 +804,22 @@ type headerWriteReq struct {
 // called from handler goroutines.
 // h may be nil.
 func (sc *serverConn) writeHeader(req headerWriteReq) {
+	var errc chan error
+	if req.h != nil {
+		// If there's a header map (which we don't own), so we have to block on
+		// waiting for this frame to be written, so an http.Flush mid-handler
+		// writes out the correct value of keys, before a handler later potentially
+		// mutates it.
+		errc = make(chan error, 1)
+	}
 	sc.wantWriteFrameCh <- frameWriteMsg{
 		write:    (*serverConn).writeHeaderInLoop,
 		v:        req,
 		streamID: req.streamID,
+		done:     errc,
+	}
+	if errc != nil {
+		<-errc
 	}
 }
 
@@ -818,6 +831,10 @@ func (sc *serverConn) writeHeaderInLoop(v interface{}) error {
 	sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(req.httpResCode)})
 	for k, vv := range req.h {
 		for _, v := range vv {
+			// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
+			if k == "Transfer-Encoding" && v != "trailers" {
+				continue
+			}
 			// TODO: for gargage, cache lowercase copies of headers at
 			// least for common ones and/or popular recent ones for
 			// this serverConn. LRU?
@@ -844,6 +861,12 @@ func (sc *serverConn) writeHeaderInLoop(v interface{}) error {
 	})
 }
 
+func (sc *serverConn) writeDataInLoop(v interface{}) error {
+	sc.writeG.check()
+	rws := v.(*responseWriterState)
+	return sc.framer.WriteData(rws.streamID, rws.curChunkIsFinal, rws.curChunk)
+}
+
 type windowUpdateReq struct {
 	streamID uint32
 	n        uint32
@@ -920,6 +943,13 @@ type responseWriter struct {
 	rws *responseWriterState
 }
 
+// Optional http.ResponseWriter interfaces implemented.
+var (
+	_ http.Flusher = (*responseWriter)(nil)
+	_ stringWriter = (*responseWriter)(nil)
+	// TODO: hijacker for websockets?
+)
+
 type responseWriterState struct {
 	// immutable within a request:
 	sc       *serverConn
@@ -928,52 +958,101 @@ type responseWriterState struct {
 	body     *requestBody // to close at end of request, if DATA frames didn't
 
 	// TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc
-	wbuf *bytes.Buffer
+	bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState}
 
 	// mutated by http.Handler goroutine:
-	h             http.Header // h goes from maybe-nil to non-nil; contents changed by http.Handler goroutine
+	handlerHeader http.Header // nil until called
+	snapHeader    http.Header // snapshot of handlerHeader at WriteHeader time
 	wroteHeader   bool        // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
 	status        int         // status code passed to WriteHeader
 	wroteContinue bool        // 100 Continue response was written
-	calledHeader  bool
-	sentHeader    bool // have we sent the header frame?
-	handlerDone   bool // handler has finished.
+	sentHeader    bool        // have we sent the header frame?
+	handlerDone   bool        // handler has finished
+
+	curChunk        []byte // current chunk we're writing
+	curChunkIsFinal bool
+	chunkWrittenCh  chan error
 }
 
-// Optional http.ResponseWriter interfaces implemented.
-var (
-	_ http.Flusher = (*responseWriter)(nil)
-	_ stringWriter = (*responseWriter)(nil)
-	// TODO: hijacker for websockets
-)
+type chunkWriter struct{ rws *responseWriterState }
 
-func (w *responseWriter) Flush() {
-	rws := w.rws
-	if rws == nil {
-		panic("Header called after Handler finished")
+// chunkWriter.Write is called from bufio.Writer. Because bufio.Writer passes through large
+// writes, we break them up here if they're too big.
+func (cw chunkWriter) Write(p []byte) (n int, err error) {
+	for len(p) > 0 {
+		chunk := p
+		if len(chunk) > handlerChunkWriteSize {
+			chunk = chunk[:handlerChunkWriteSize]
+		}
+		_, err = cw.rws.writeChunk(chunk)
+		if err != nil {
+			return
+		}
+		n += len(chunk)
+		p = p[len(chunk):]
 	}
+	return n, nil
+}
+
+// writeChunk writes small (max 4k, or handlerChunkWriteSize) chunks.
+// It's also responsible for sending the HEADER response.
+func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
 	if !rws.wroteHeader {
-		w.WriteHeader(200)
+		rws.writeHeader(200)
 	}
 	if !rws.sentHeader {
 		rws.sentHeader = true
 		var ctype, clen string // implicit ones, if we can calculate it
-		if rws.handlerDone && rws.h.Get("Content-Length") == "" {
-			clen = strconv.Itoa(rws.wbuf.Len())
+		if rws.handlerDone && rws.snapHeader.Get("Content-Length") == "" {
+			clen = strconv.Itoa(len(p))
 		}
-		if rws.h.Get("Content-Type") == "" {
-			ctype = http.DetectContentType(rws.wbuf.Bytes())
+		if rws.snapHeader.Get("Content-Type") == "" {
+			ctype = http.DetectContentType(p)
 		}
 		rws.sc.writeHeader(headerWriteReq{
 			streamID:      rws.streamID,
 			httpResCode:   rws.status,
-			h:             rws.h,
-			endStream:     rws.wbuf.Len() == 0,
+			h:             rws.snapHeader,
+			endStream:     rws.handlerDone && len(p) == 0,
 			contentType:   ctype,
 			contentLength: clen,
 		})
 	}
+	if len(p) == 0 && !rws.handlerDone {
+		return
+	}
+	rws.curChunk = p
+	rws.curChunkIsFinal = rws.handlerDone
 
+	// TODO: await flow control tokens for both stream and conn
+	rws.sc.wantWriteFrameCh <- frameWriteMsg{
+		cost:     uint32(len(p)),
+		streamID: rws.streamID,
+		write:    (*serverConn).writeDataInLoop,
+		done:     rws.chunkWrittenCh,
+		v:        rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
+	}
+	err = <-rws.chunkWrittenCh // block until it's written
+	return len(p), err
+}
+
+func (w *responseWriter) Flush() {
+	rws := w.rws
+	if rws == nil {
+		panic("Header called after Handler finished")
+	}
+	if rws.bw.Buffered() > 0 {
+		if err := rws.bw.Flush(); err != nil {
+			// Ignore the error. The frame writer already knows.
+			return
+		}
+	} else {
+		// The bufio.Writer won't call chunkWriter.Write
+		// (writeChunk with zero bytes, so we have to do it
+		// ourselves to force the HTTP response header and/or
+		// final DATA frame (with END_STREAM) to be sent.
+		rws.writeChunk(nil)
+	}
 }
 
 func (w *responseWriter) Header() http.Header {
@@ -981,11 +1060,10 @@ func (w *responseWriter) Header() http.Header {
 	if rws == nil {
 		panic("Header called after Handler finished")
 	}
-	if rws.h == nil {
-		rws.h = make(http.Header)
+	if rws.handlerHeader == nil {
+		rws.handlerHeader = make(http.Header)
 	}
-	rws.calledHeader = true
-	return rws.h
+	return rws.handlerHeader
 }
 
 func (w *responseWriter) WriteHeader(code int) {
@@ -993,11 +1071,27 @@ func (w *responseWriter) WriteHeader(code int) {
 	if rws == nil {
 		panic("WriteHeader called after Handler finished")
 	}
-	if rws.wroteHeader {
-		return
+	rws.writeHeader(code)
+}
+
+func (rws *responseWriterState) writeHeader(code int) {
+	if !rws.wroteHeader {
+		rws.wroteHeader = true
+		rws.status = code
+		if len(rws.handlerHeader) > 0 {
+			rws.snapHeader = cloneHeader(rws.handlerHeader)
+		}
+	}
+}
+
+func cloneHeader(h http.Header) http.Header {
+	h2 := make(http.Header, len(h))
+	for k, vv := range h {
+		vv2 := make([]string, len(vv))
+		copy(vv2, vv)
+		h2[k] = vv2
 	}
-	rws.wroteHeader = true
-	rws.status = code
+	return h2
 }
 
 // The Life Of A Write is like this:
@@ -1020,13 +1114,11 @@ func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int,
 	if !rws.wroteHeader {
 		w.WriteHeader(200)
 	}
-	// TODO: write to a bufio.Writer instead like the
 	if dataB != nil {
-		rws.wbuf.Write(dataB)
+		return rws.bw.Write(dataB)
 	} else {
-		rws.wbuf.WriteString(dataS)
+		return rws.bw.WriteString(dataS)
 	}
-	return lenData, nil
 }
 
 func (w *responseWriter) handlerDone() {

+ 368 - 0
server_test.go

@@ -11,6 +11,7 @@ import (
 	"bytes"
 	"crypto/tls"
 	"errors"
+	"fmt"
 	"io"
 	"io/ioutil"
 	"log"
@@ -147,6 +148,30 @@ func (st *serverTester) readFrame() (Frame, error) {
 	}
 }
 
+func (st *serverTester) wantHeaders() *HeadersFrame {
+	f, err := st.readFrame()
+	if err != nil {
+		st.t.Fatalf("Error while expecting a HEADERS frame: %v", err)
+	}
+	hf, ok := f.(*HeadersFrame)
+	if !ok {
+		st.t.Fatalf("got a %T; want *HeadersFrame", f)
+	}
+	return hf
+}
+
+func (st *serverTester) wantData() *DataFrame {
+	f, err := st.readFrame()
+	if err != nil {
+		st.t.Fatalf("Error while expecting a DATA frame: %v", err)
+	}
+	df, ok := f.(*DataFrame)
+	if !ok {
+		st.t.Fatalf("got a %T; want *DataFrame", f)
+	}
+	return df
+}
+
 func (st *serverTester) wantSettings() *SettingsFrame {
 	f, err := st.readFrame()
 	if err != nil {
@@ -709,6 +734,349 @@ func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func
 	}
 }
 
+func getSlash(st *serverTester) { st.bodylessReq1() }
+
+func TestServer_Response_NoData(t *testing.T) {
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		// Nothing.
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st)
+		hf := st.wantHeaders()
+		if !hf.StreamEnded() {
+			t.Fatal("want END_STREAM flag")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+	})
+}
+
+func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		w.Header().Set("Foo-Bar", "some-value")
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st)
+		hf := st.wantHeaders()
+		if !hf.StreamEnded() {
+			t.Fatal("want END_STREAM flag")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "200"},
+			{"foo-bar", "some-value"},
+			{"content-type", "text/plain; charset=utf-8"},
+			{"content-length", "0"},
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+	})
+}
+
+func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
+	const msg = "<html>this is HTML."
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		w.Header().Set("Content-Type", "foo/bar")
+		io.WriteString(w, msg)
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st)
+		hf := st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("don't want END_STREAM, expecting data")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "200"},
+			{"content-type", "foo/bar"},
+			{"content-length", strconv.Itoa(len(msg))},
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+		df := st.wantData()
+		if !df.StreamEnded() {
+			t.Error("expected DATA to have END_STREAM flag")
+		}
+		if got := string(df.Data()); got != msg {
+			t.Errorf("got DATA %q; want %q", got, msg)
+		}
+	})
+}
+
+func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
+	const msg = "hi"
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		w.Header().Set("Transfer-Encoding", "chunked") // should be stripped
+		io.WriteString(w, msg)
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st)
+		hf := st.wantHeaders()
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "200"},
+			{"content-type", "text/plain; charset=utf-8"},
+			{"content-length", strconv.Itoa(len(msg))},
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+	})
+}
+
+// Header accessed only after the initial write.
+func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
+	const msg = "<html>this is HTML."
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		io.WriteString(w, msg)
+		w.Header().Set("foo", "should be ignored")
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st)
+		hf := st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("unexpected END_STREAM")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "200"},
+			{"content-type", "text/html; charset=utf-8"},
+			{"content-length", strconv.Itoa(len(msg))},
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+	})
+}
+
+// Header accessed before the initial write and later mutated.
+func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
+	const msg = "<html>this is HTML."
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		w.Header().Set("foo", "proper value")
+		io.WriteString(w, msg)
+		w.Header().Set("foo", "should be ignored")
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st)
+		hf := st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("unexpected END_STREAM")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "200"},
+			{"foo", "proper value"},
+			{"content-type", "text/html; charset=utf-8"},
+			{"content-length", strconv.Itoa(len(msg))},
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+	})
+}
+
+func TestServer_Response_Data_SniffLenType(t *testing.T) {
+	const msg = "<html>this is HTML."
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		io.WriteString(w, msg)
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st)
+		hf := st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("don't want END_STREAM, expecting data")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "200"},
+			{"content-type", "text/html; charset=utf-8"},
+			{"content-length", strconv.Itoa(len(msg))},
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+		df := st.wantData()
+		if !df.StreamEnded() {
+			t.Error("expected DATA to have END_STREAM flag")
+		}
+		if got := string(df.Data()); got != msg {
+			t.Errorf("got DATA %q; want %q", got, msg)
+		}
+	})
+}
+
+func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
+	const msg = "<html>this is HTML"
+	const msg2 = ", and this is the next chunk"
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		io.WriteString(w, msg)
+		w.(http.Flusher).Flush()
+		io.WriteString(w, msg2)
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st)
+		hf := st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("unexpected END_STREAM flag")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "200"},
+			{"content-type", "text/html; charset=utf-8"}, // sniffed
+			// and no content-length
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+		{
+			df := st.wantData()
+			if df.StreamEnded() {
+				t.Error("unexpected END_STREAM flag")
+			}
+			if got := string(df.Data()); got != msg {
+				t.Errorf("got DATA %q; want %q", got, msg)
+			}
+		}
+		{
+			df := st.wantData()
+			if !df.StreamEnded() {
+				t.Error("wanted END_STREAM flag on last data chunk")
+			}
+			if got := string(df.Data()); got != msg2 {
+				t.Errorf("got DATA %q; want %q", got, msg2)
+			}
+		}
+	})
+}
+
+func TestServer_Response_LargeWrite(t *testing.T) {
+	const size = 1 << 20
+	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+		n, err := w.Write(bytes.Repeat([]byte("a"), size))
+		if err != nil {
+			return fmt.Errorf("Write error: %v", err)
+		}
+		if n != size {
+			return fmt.Errorf("wrong size %d from Write", n)
+		}
+		return nil
+	}, func(st *serverTester) {
+		getSlash(st) // make the single request
+		hf := st.wantHeaders()
+		if hf.StreamEnded() {
+			t.Fatal("unexpected END_STREAM flag")
+		}
+		if !hf.HeadersEnded() {
+			t.Fatal("want END_HEADERS flag")
+		}
+		goth := decodeHeader(t, hf.HeaderBlockFragment())
+		wanth := [][2]string{
+			{":status", "200"},
+			{"content-type", "text/plain; charset=utf-8"}, // sniffed
+			// and no content-length
+		}
+		if !reflect.DeepEqual(goth, wanth) {
+			t.Errorf("Got headers %v; want %v", goth, wanth)
+		}
+		var bytes, frames int
+		for {
+			df := st.wantData()
+			bytes += len(df.Data())
+			frames++
+			// TODO: send WINDOW_UPDATE frames at the server to keep it from stalling
+			for _, b := range df.Data() {
+				if b != 'a' {
+					t.Fatal("non-'a' byte seen in DATA")
+				}
+			}
+			if df.StreamEnded() {
+				break
+			}
+		}
+		if bytes != size {
+			t.Errorf("Got %d bytes; want %d", bytes, size)
+		}
+		if want := 257; frames != want {
+			t.Errorf("Got %d frames; want %d", frames, size)
+		}
+	})
+}
+
+func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) {
+	d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
+		pairs = append(pairs, [2]string{f.Name, f.Value})
+	})
+	if _, err := d.Write(headerBlock); err != nil {
+		t.Fatalf("hpack decoding error: %v", err)
+	}
+	if err := d.Close(); err != nil {
+		t.Fatalf("hpack decoding error: %v", err)
+	}
+	return
+}
+
+// testServerResponse sets up an idle HTTP/2 connection and lets you
+// write a single request with writeReq, and then reply to it in some way with the provided handler,
+// and then verify the output with the serverTester again (assuming the handler returns nil)
+func testServerResponse(t *testing.T,
+	handler func(http.ResponseWriter, *http.Request) error,
+	client func(*serverTester),
+) {
+	errc := make(chan error, 1)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		if r.Body == nil {
+			t.Fatal("nil Body")
+		}
+		errc <- handler(w, r)
+	})
+	defer st.Close()
+
+	donec := make(chan bool)
+	go func() {
+		defer close(donec)
+		st.greet()
+		client(st)
+	}()
+
+	select {
+	case <-donec:
+		return
+	case <-time.After(5 * time.Second):
+		t.Fatal("timeout")
+	}
+
+	select {
+	case err := <-errc:
+		if err != nil {
+			t.Fatalf("Error in handler: %v", err)
+		}
+	case <-time.After(2 * time.Second):
+		t.Error("timeout waiting for handler to finish")
+	}
+}
+
 func TestServerWithCurl(t *testing.T) {
 	requireCurl(t)