|
|
@@ -11,6 +11,7 @@ import (
|
|
|
"bytes"
|
|
|
"crypto/tls"
|
|
|
"errors"
|
|
|
+ "fmt"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
"log"
|
|
|
@@ -147,6 +148,30 @@ func (st *serverTester) readFrame() (Frame, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func (st *serverTester) wantHeaders() *HeadersFrame {
|
|
|
+ f, err := st.readFrame()
|
|
|
+ if err != nil {
|
|
|
+ st.t.Fatalf("Error while expecting a HEADERS frame: %v", err)
|
|
|
+ }
|
|
|
+ hf, ok := f.(*HeadersFrame)
|
|
|
+ if !ok {
|
|
|
+ st.t.Fatalf("got a %T; want *HeadersFrame", f)
|
|
|
+ }
|
|
|
+ return hf
|
|
|
+}
|
|
|
+
|
|
|
+func (st *serverTester) wantData() *DataFrame {
|
|
|
+ f, err := st.readFrame()
|
|
|
+ if err != nil {
|
|
|
+ st.t.Fatalf("Error while expecting a DATA frame: %v", err)
|
|
|
+ }
|
|
|
+ df, ok := f.(*DataFrame)
|
|
|
+ if !ok {
|
|
|
+ st.t.Fatalf("got a %T; want *DataFrame", f)
|
|
|
+ }
|
|
|
+ return df
|
|
|
+}
|
|
|
+
|
|
|
func (st *serverTester) wantSettings() *SettingsFrame {
|
|
|
f, err := st.readFrame()
|
|
|
if err != nil {
|
|
|
@@ -709,6 +734,349 @@ func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func getSlash(st *serverTester) { st.bodylessReq1() }
|
|
|
+
|
|
|
+func TestServer_Response_NoData(t *testing.T) {
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ // Nothing.
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st)
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ if !hf.StreamEnded() {
|
|
|
+ t.Fatal("want END_STREAM flag")
|
|
|
+ }
|
|
|
+ if !hf.HeadersEnded() {
|
|
|
+ t.Fatal("want END_HEADERS flag")
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ w.Header().Set("Foo-Bar", "some-value")
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st)
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ if !hf.StreamEnded() {
|
|
|
+ t.Fatal("want END_STREAM flag")
|
|
|
+ }
|
|
|
+ if !hf.HeadersEnded() {
|
|
|
+ t.Fatal("want END_HEADERS flag")
|
|
|
+ }
|
|
|
+ goth := decodeHeader(t, hf.HeaderBlockFragment())
|
|
|
+ wanth := [][2]string{
|
|
|
+ {":status", "200"},
|
|
|
+ {"foo-bar", "some-value"},
|
|
|
+ {"content-type", "text/plain; charset=utf-8"},
|
|
|
+ {"content-length", "0"},
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(goth, wanth) {
|
|
|
+ t.Errorf("Got headers %v; want %v", goth, wanth)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
|
|
|
+ const msg = "<html>this is HTML."
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ w.Header().Set("Content-Type", "foo/bar")
|
|
|
+ io.WriteString(w, msg)
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st)
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ if hf.StreamEnded() {
|
|
|
+ t.Fatal("don't want END_STREAM, expecting data")
|
|
|
+ }
|
|
|
+ if !hf.HeadersEnded() {
|
|
|
+ t.Fatal("want END_HEADERS flag")
|
|
|
+ }
|
|
|
+ goth := decodeHeader(t, hf.HeaderBlockFragment())
|
|
|
+ wanth := [][2]string{
|
|
|
+ {":status", "200"},
|
|
|
+ {"content-type", "foo/bar"},
|
|
|
+ {"content-length", strconv.Itoa(len(msg))},
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(goth, wanth) {
|
|
|
+ t.Errorf("Got headers %v; want %v", goth, wanth)
|
|
|
+ }
|
|
|
+ df := st.wantData()
|
|
|
+ if !df.StreamEnded() {
|
|
|
+ t.Error("expected DATA to have END_STREAM flag")
|
|
|
+ }
|
|
|
+ if got := string(df.Data()); got != msg {
|
|
|
+ t.Errorf("got DATA %q; want %q", got, msg)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
|
|
|
+ const msg = "hi"
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ w.Header().Set("Transfer-Encoding", "chunked") // should be stripped
|
|
|
+ io.WriteString(w, msg)
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st)
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ goth := decodeHeader(t, hf.HeaderBlockFragment())
|
|
|
+ wanth := [][2]string{
|
|
|
+ {":status", "200"},
|
|
|
+ {"content-type", "text/plain; charset=utf-8"},
|
|
|
+ {"content-length", strconv.Itoa(len(msg))},
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(goth, wanth) {
|
|
|
+ t.Errorf("Got headers %v; want %v", goth, wanth)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// Header accessed only after the initial write.
|
|
|
+func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
|
|
|
+ const msg = "<html>this is HTML."
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ io.WriteString(w, msg)
|
|
|
+ w.Header().Set("foo", "should be ignored")
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st)
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ if hf.StreamEnded() {
|
|
|
+ t.Fatal("unexpected END_STREAM")
|
|
|
+ }
|
|
|
+ if !hf.HeadersEnded() {
|
|
|
+ t.Fatal("want END_HEADERS flag")
|
|
|
+ }
|
|
|
+ goth := decodeHeader(t, hf.HeaderBlockFragment())
|
|
|
+ wanth := [][2]string{
|
|
|
+ {":status", "200"},
|
|
|
+ {"content-type", "text/html; charset=utf-8"},
|
|
|
+ {"content-length", strconv.Itoa(len(msg))},
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(goth, wanth) {
|
|
|
+ t.Errorf("Got headers %v; want %v", goth, wanth)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+// Header accessed before the initial write and later mutated.
|
|
|
+func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
|
|
|
+ const msg = "<html>this is HTML."
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ w.Header().Set("foo", "proper value")
|
|
|
+ io.WriteString(w, msg)
|
|
|
+ w.Header().Set("foo", "should be ignored")
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st)
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ if hf.StreamEnded() {
|
|
|
+ t.Fatal("unexpected END_STREAM")
|
|
|
+ }
|
|
|
+ if !hf.HeadersEnded() {
|
|
|
+ t.Fatal("want END_HEADERS flag")
|
|
|
+ }
|
|
|
+ goth := decodeHeader(t, hf.HeaderBlockFragment())
|
|
|
+ wanth := [][2]string{
|
|
|
+ {":status", "200"},
|
|
|
+ {"foo", "proper value"},
|
|
|
+ {"content-type", "text/html; charset=utf-8"},
|
|
|
+ {"content-length", strconv.Itoa(len(msg))},
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(goth, wanth) {
|
|
|
+ t.Errorf("Got headers %v; want %v", goth, wanth)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestServer_Response_Data_SniffLenType(t *testing.T) {
|
|
|
+ const msg = "<html>this is HTML."
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ io.WriteString(w, msg)
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st)
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ if hf.StreamEnded() {
|
|
|
+ t.Fatal("don't want END_STREAM, expecting data")
|
|
|
+ }
|
|
|
+ if !hf.HeadersEnded() {
|
|
|
+ t.Fatal("want END_HEADERS flag")
|
|
|
+ }
|
|
|
+ goth := decodeHeader(t, hf.HeaderBlockFragment())
|
|
|
+ wanth := [][2]string{
|
|
|
+ {":status", "200"},
|
|
|
+ {"content-type", "text/html; charset=utf-8"},
|
|
|
+ {"content-length", strconv.Itoa(len(msg))},
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(goth, wanth) {
|
|
|
+ t.Errorf("Got headers %v; want %v", goth, wanth)
|
|
|
+ }
|
|
|
+ df := st.wantData()
|
|
|
+ if !df.StreamEnded() {
|
|
|
+ t.Error("expected DATA to have END_STREAM flag")
|
|
|
+ }
|
|
|
+ if got := string(df.Data()); got != msg {
|
|
|
+ t.Errorf("got DATA %q; want %q", got, msg)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
|
|
|
+ const msg = "<html>this is HTML"
|
|
|
+ const msg2 = ", and this is the next chunk"
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ io.WriteString(w, msg)
|
|
|
+ w.(http.Flusher).Flush()
|
|
|
+ io.WriteString(w, msg2)
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st)
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ if hf.StreamEnded() {
|
|
|
+ t.Fatal("unexpected END_STREAM flag")
|
|
|
+ }
|
|
|
+ if !hf.HeadersEnded() {
|
|
|
+ t.Fatal("want END_HEADERS flag")
|
|
|
+ }
|
|
|
+ goth := decodeHeader(t, hf.HeaderBlockFragment())
|
|
|
+ wanth := [][2]string{
|
|
|
+ {":status", "200"},
|
|
|
+ {"content-type", "text/html; charset=utf-8"}, // sniffed
|
|
|
+ // and no content-length
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(goth, wanth) {
|
|
|
+ t.Errorf("Got headers %v; want %v", goth, wanth)
|
|
|
+ }
|
|
|
+ {
|
|
|
+ df := st.wantData()
|
|
|
+ if df.StreamEnded() {
|
|
|
+ t.Error("unexpected END_STREAM flag")
|
|
|
+ }
|
|
|
+ if got := string(df.Data()); got != msg {
|
|
|
+ t.Errorf("got DATA %q; want %q", got, msg)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ {
|
|
|
+ df := st.wantData()
|
|
|
+ if !df.StreamEnded() {
|
|
|
+ t.Error("wanted END_STREAM flag on last data chunk")
|
|
|
+ }
|
|
|
+ if got := string(df.Data()); got != msg2 {
|
|
|
+ t.Errorf("got DATA %q; want %q", got, msg2)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestServer_Response_LargeWrite(t *testing.T) {
|
|
|
+ const size = 1 << 20
|
|
|
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
|
|
|
+ n, err := w.Write(bytes.Repeat([]byte("a"), size))
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("Write error: %v", err)
|
|
|
+ }
|
|
|
+ if n != size {
|
|
|
+ return fmt.Errorf("wrong size %d from Write", n)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }, func(st *serverTester) {
|
|
|
+ getSlash(st) // make the single request
|
|
|
+ hf := st.wantHeaders()
|
|
|
+ if hf.StreamEnded() {
|
|
|
+ t.Fatal("unexpected END_STREAM flag")
|
|
|
+ }
|
|
|
+ if !hf.HeadersEnded() {
|
|
|
+ t.Fatal("want END_HEADERS flag")
|
|
|
+ }
|
|
|
+ goth := decodeHeader(t, hf.HeaderBlockFragment())
|
|
|
+ wanth := [][2]string{
|
|
|
+ {":status", "200"},
|
|
|
+ {"content-type", "text/plain; charset=utf-8"}, // sniffed
|
|
|
+ // and no content-length
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(goth, wanth) {
|
|
|
+ t.Errorf("Got headers %v; want %v", goth, wanth)
|
|
|
+ }
|
|
|
+ var bytes, frames int
|
|
|
+ for {
|
|
|
+ df := st.wantData()
|
|
|
+ bytes += len(df.Data())
|
|
|
+ frames++
|
|
|
+ // TODO: send WINDOW_UPDATE frames at the server to keep it from stalling
|
|
|
+ for _, b := range df.Data() {
|
|
|
+ if b != 'a' {
|
|
|
+ t.Fatal("non-'a' byte seen in DATA")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if df.StreamEnded() {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if bytes != size {
|
|
|
+ t.Errorf("Got %d bytes; want %d", bytes, size)
|
|
|
+ }
|
|
|
+ if want := 257; frames != want {
|
|
|
+ t.Errorf("Got %d frames; want %d", frames, size)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) {
|
|
|
+ d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
|
|
|
+ pairs = append(pairs, [2]string{f.Name, f.Value})
|
|
|
+ })
|
|
|
+ if _, err := d.Write(headerBlock); err != nil {
|
|
|
+ t.Fatalf("hpack decoding error: %v", err)
|
|
|
+ }
|
|
|
+ if err := d.Close(); err != nil {
|
|
|
+ t.Fatalf("hpack decoding error: %v", err)
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+// testServerResponse sets up an idle HTTP/2 connection and lets you
|
|
|
+// write a single request with writeReq, and then reply to it in some way with the provided handler,
|
|
|
+// and then verify the output with the serverTester again (assuming the handler returns nil)
|
|
|
+func testServerResponse(t *testing.T,
|
|
|
+ handler func(http.ResponseWriter, *http.Request) error,
|
|
|
+ client func(*serverTester),
|
|
|
+) {
|
|
|
+ errc := make(chan error, 1)
|
|
|
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ if r.Body == nil {
|
|
|
+ t.Fatal("nil Body")
|
|
|
+ }
|
|
|
+ errc <- handler(w, r)
|
|
|
+ })
|
|
|
+ defer st.Close()
|
|
|
+
|
|
|
+ donec := make(chan bool)
|
|
|
+ go func() {
|
|
|
+ defer close(donec)
|
|
|
+ st.greet()
|
|
|
+ client(st)
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-donec:
|
|
|
+ return
|
|
|
+ case <-time.After(5 * time.Second):
|
|
|
+ t.Fatal("timeout")
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case err := <-errc:
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error in handler: %v", err)
|
|
|
+ }
|
|
|
+ case <-time.After(2 * time.Second):
|
|
|
+ t.Error("timeout waiting for handler to finish")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TestServerWithCurl(t *testing.T) {
|
|
|
requireCurl(t)
|
|
|
|