123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519 |
- // Copyright 2016 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package http2
- import (
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- "net/http"
- "reflect"
- "strconv"
- "sync"
- "testing"
- "time"
- )
- func TestServer_Push_Success(t *testing.T) {
- const (
- mainBody = "<html>index page</html>"
- pushedBody = "<html>pushed page</html>"
- userAgent = "testagent"
- cookie = "testcookie"
- )
- var stURL string
- checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
- if got, want := r.Method, wantMethod; got != want {
- return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
- }
- if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
- return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
- }
- if got, want := "https://"+r.Host, stURL; got != want {
- return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
- }
- if r.Body == nil {
- return fmt.Errorf("nil Body")
- }
- if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
- return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
- }
- return nil
- }
- errc := make(chan error, 3)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- switch r.URL.RequestURI() {
- case "/":
- // Push "/pushed?get" as a GET request, using an absolute URL.
- opt := &http.PushOptions{
- Header: http.Header{
- "User-Agent": {userAgent},
- },
- }
- if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
- errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
- return
- }
- // Push "/pushed?head" as a HEAD request, using a path.
- opt = &http.PushOptions{
- Method: "HEAD",
- Header: http.Header{
- "User-Agent": {userAgent},
- "Cookie": {cookie},
- },
- }
- if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
- errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
- return
- }
- w.Header().Set("Content-Type", "text/html")
- w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
- w.WriteHeader(200)
- io.WriteString(w, mainBody)
- errc <- nil
- case "/pushed?get":
- wantH := http.Header{}
- wantH.Set("User-Agent", userAgent)
- if err := checkPromisedReq(r, "GET", wantH); err != nil {
- errc <- fmt.Errorf("/pushed?get: %v", err)
- return
- }
- w.Header().Set("Content-Type", "text/html")
- w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
- w.WriteHeader(200)
- io.WriteString(w, pushedBody)
- errc <- nil
- case "/pushed?head":
- wantH := http.Header{}
- wantH.Set("User-Agent", userAgent)
- wantH.Set("Cookie", cookie)
- if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
- errc <- fmt.Errorf("/pushed?head: %v", err)
- return
- }
- w.WriteHeader(204)
- errc <- nil
- default:
- errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
- }
- })
- stURL = st.ts.URL
- // Send one request, which should push two responses.
- st.greet()
- getSlash(st)
- for k := 0; k < 3; k++ {
- select {
- case <-time.After(2 * time.Second):
- t.Errorf("timeout waiting for handler %d to finish", k)
- case err := <-errc:
- if err != nil {
- t.Fatal(err)
- }
- }
- }
- checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
- pp, ok := f.(*PushPromiseFrame)
- if !ok {
- return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
- }
- if !pp.HeadersEnded() {
- return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
- }
- if got, want := pp.PromiseID, promiseID; got != want {
- return fmt.Errorf("got PromiseID %v; want %v", got, want)
- }
- gotH := st.decodeHeader(pp.HeaderBlockFragment())
- if !reflect.DeepEqual(gotH, wantH) {
- return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
- }
- return nil
- }
- checkHeaders := func(f Frame, wantH [][2]string) error {
- hf, ok := f.(*HeadersFrame)
- if !ok {
- return fmt.Errorf("got a %T; want *HeadersFrame", f)
- }
- gotH := st.decodeHeader(hf.HeaderBlockFragment())
- if !reflect.DeepEqual(gotH, wantH) {
- return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
- }
- return nil
- }
- checkData := func(f Frame, wantData string) error {
- df, ok := f.(*DataFrame)
- if !ok {
- return fmt.Errorf("got a %T; want *DataFrame", f)
- }
- if gotData := string(df.Data()); gotData != wantData {
- return fmt.Errorf("got response data %q; want %q", gotData, wantData)
- }
- return nil
- }
- // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
- // Stream 2 has HEADERS + DATA
- // Stream 4 has HEADERS
- expected := map[uint32][]func(Frame) error{
- 1: {
- func(f Frame) error {
- return checkPushPromise(f, 2, [][2]string{
- {":method", "GET"},
- {":scheme", "https"},
- {":authority", st.ts.Listener.Addr().String()},
- {":path", "/pushed?get"},
- {"user-agent", userAgent},
- })
- },
- func(f Frame) error {
- return checkPushPromise(f, 4, [][2]string{
- {":method", "HEAD"},
- {":scheme", "https"},
- {":authority", st.ts.Listener.Addr().String()},
- {":path", "/pushed?head"},
- {"cookie", cookie},
- {"user-agent", userAgent},
- })
- },
- func(f Frame) error {
- return checkHeaders(f, [][2]string{
- {":status", "200"},
- {"content-type", "text/html"},
- {"content-length", strconv.Itoa(len(mainBody))},
- })
- },
- func(f Frame) error {
- return checkData(f, mainBody)
- },
- },
- 2: {
- func(f Frame) error {
- return checkHeaders(f, [][2]string{
- {":status", "200"},
- {"content-type", "text/html"},
- {"content-length", strconv.Itoa(len(pushedBody))},
- })
- },
- func(f Frame) error {
- return checkData(f, pushedBody)
- },
- },
- 4: {
- func(f Frame) error {
- return checkHeaders(f, [][2]string{
- {":status", "204"},
- })
- },
- },
- }
- consumed := map[uint32]int{}
- for k := 0; len(expected) > 0; k++ {
- f, err := st.readFrame()
- if err != nil {
- for id, left := range expected {
- t.Errorf("stream %d: missing %d frames", id, len(left))
- }
- t.Fatalf("readFrame %d: %v", k, err)
- }
- id := f.Header().StreamID
- label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
- if len(expected[id]) == 0 {
- t.Fatalf("%s: unexpected frame %#+v", label, f)
- }
- check := expected[id][0]
- expected[id] = expected[id][1:]
- if len(expected[id]) == 0 {
- delete(expected, id)
- }
- if err := check(f); err != nil {
- t.Fatalf("%s: %v", label, err)
- }
- consumed[id]++
- }
- }
- func TestServer_Push_SuccessNoRace(t *testing.T) {
- // Regression test for issue #18326. Ensure the request handler can mutate
- // pushed request headers without racing with the PUSH_PROMISE write.
- errc := make(chan error, 2)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- switch r.URL.RequestURI() {
- case "/":
- opt := &http.PushOptions{
- Header: http.Header{"User-Agent": {"testagent"}},
- }
- if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
- errc <- fmt.Errorf("error pushing: %v", err)
- return
- }
- w.WriteHeader(200)
- errc <- nil
- case "/pushed":
- // Update request header, ensure there is no race.
- r.Header.Set("User-Agent", "newagent")
- r.Header.Set("Cookie", "cookie")
- w.WriteHeader(200)
- errc <- nil
- default:
- errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
- }
- })
- // Send one request, which should push one response.
- st.greet()
- getSlash(st)
- for k := 0; k < 2; k++ {
- select {
- case <-time.After(2 * time.Second):
- t.Errorf("timeout waiting for handler %d to finish", k)
- case err := <-errc:
- if err != nil {
- t.Fatal(err)
- }
- }
- }
- }
- func TestServer_Push_RejectRecursivePush(t *testing.T) {
- // Expect two requests, but might get three if there's a bug and the second push succeeds.
- errc := make(chan error, 3)
- handler := func(w http.ResponseWriter, r *http.Request) error {
- baseURL := "https://" + r.Host
- switch r.URL.Path {
- case "/":
- if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
- return fmt.Errorf("first Push()=%v, want nil", err)
- }
- return nil
- case "/push1":
- if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
- return fmt.Errorf("Push()=%v, want %v", got, want)
- }
- return nil
- default:
- return fmt.Errorf("unexpected path: %q", r.URL.Path)
- }
- }
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- errc <- handler(w, r)
- })
- defer st.Close()
- st.greet()
- getSlash(st)
- if err := <-errc; err != nil {
- t.Errorf("First request failed: %v", err)
- }
- if err := <-errc; err != nil {
- t.Errorf("Second request failed: %v", err)
- }
- }
- func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
- // Expect one request, but might get two if there's a bug and the push succeeds.
- errc := make(chan error, 2)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- errc <- doPush(w.(http.Pusher), r)
- })
- defer st.Close()
- st.greet()
- if err := st.fr.WriteSettings(settings...); err != nil {
- st.t.Fatalf("WriteSettings: %v", err)
- }
- st.wantSettingsAck()
- getSlash(st)
- if err := <-errc; err != nil {
- t.Error(err)
- }
- // Should not get a PUSH_PROMISE frame.
- hf := st.wantHeaders()
- if !hf.StreamEnded() {
- t.Error("stream should end after headers")
- }
- }
- func TestServer_Push_RejectIfDisabled(t *testing.T) {
- testServer_Push_RejectSingleRequest(t,
- func(p http.Pusher, r *http.Request) error {
- if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
- return fmt.Errorf("Push()=%v, want %v", got, want)
- }
- return nil
- },
- Setting{SettingEnablePush, 0})
- }
- func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
- testServer_Push_RejectSingleRequest(t,
- func(p http.Pusher, r *http.Request) error {
- if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
- return fmt.Errorf("Push()=%v, want %v", got, want)
- }
- return nil
- },
- Setting{SettingMaxConcurrentStreams, 0})
- }
- func TestServer_Push_RejectWrongScheme(t *testing.T) {
- testServer_Push_RejectSingleRequest(t,
- func(p http.Pusher, r *http.Request) error {
- if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
- return errors.New("Push() should have failed (push target URL is http)")
- }
- return nil
- })
- }
- func TestServer_Push_RejectMissingHost(t *testing.T) {
- testServer_Push_RejectSingleRequest(t,
- func(p http.Pusher, r *http.Request) error {
- if err := p.Push("https:pushed", nil); err == nil {
- return errors.New("Push() should have failed (push target URL missing host)")
- }
- return nil
- })
- }
- func TestServer_Push_RejectRelativePath(t *testing.T) {
- testServer_Push_RejectSingleRequest(t,
- func(p http.Pusher, r *http.Request) error {
- if err := p.Push("../test", nil); err == nil {
- return errors.New("Push() should have failed (push target is a relative path)")
- }
- return nil
- })
- }
- func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
- testServer_Push_RejectSingleRequest(t,
- func(p http.Pusher, r *http.Request) error {
- if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
- return errors.New("Push() should have failed (cannot promise a POST)")
- }
- return nil
- })
- }
- func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
- testServer_Push_RejectSingleRequest(t,
- func(p http.Pusher, r *http.Request) error {
- header := http.Header{
- "Content-Length": {"10"},
- "Content-Encoding": {"gzip"},
- "Trailer": {"Foo"},
- "Te": {"trailers"},
- "Host": {"test.com"},
- ":authority": {"test.com"},
- }
- if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
- return errors.New("Push() should have failed (forbidden headers)")
- }
- return nil
- })
- }
- func TestServer_Push_StateTransitions(t *testing.T) {
- const body = "foo"
- gotPromise := make(chan bool)
- finishedPush := make(chan bool)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- switch r.URL.RequestURI() {
- case "/":
- if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
- t.Errorf("Push error: %v", err)
- }
- // Don't finish this request until the push finishes so we don't
- // nondeterministically interleave output frames with the push.
- <-finishedPush
- case "/pushed":
- <-gotPromise
- }
- w.Header().Set("Content-Type", "text/html")
- w.Header().Set("Content-Length", strconv.Itoa(len(body)))
- w.WriteHeader(200)
- io.WriteString(w, body)
- })
- defer st.Close()
- st.greet()
- if st.stream(2) != nil {
- t.Fatal("stream 2 should be empty")
- }
- if got, want := st.streamState(2), stateIdle; got != want {
- t.Fatalf("streamState(2)=%v, want %v", got, want)
- }
- getSlash(st)
- // After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
- st.wantPushPromise()
- if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
- t.Fatalf("streamState(2)=%v, want %v", got, want)
- }
- // We stall the HTTP handler for "/pushed" until the above check. If we don't
- // stall the handler, then the handler might write HEADERS and DATA and finish
- // the stream before we check st.streamState(2) -- should that happen, we'll
- // see stateClosed and fail the above check.
- close(gotPromise)
- st.wantHeaders()
- if df := st.wantData(); !df.StreamEnded() {
- t.Fatal("expected END_STREAM flag on DATA")
- }
- if got, want := st.streamState(2), stateClosed; got != want {
- t.Fatalf("streamState(2)=%v, want %v", got, want)
- }
- close(finishedPush)
- }
- func TestServer_Push_RejectAfterGoAway(t *testing.T) {
- var readyOnce sync.Once
- ready := make(chan struct{})
- errc := make(chan error, 2)
- st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- select {
- case <-ready:
- case <-time.After(5 * time.Second):
- errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed")
- }
- if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
- errc <- fmt.Errorf("Push()=%v, want %v", got, want)
- }
- errc <- nil
- })
- defer st.Close()
- st.greet()
- getSlash(st)
- // Send GOAWAY and wait for it to be processed.
- st.fr.WriteGoAway(1, ErrCodeNo, nil)
- go func() {
- for {
- select {
- case <-ready:
- return
- default:
- }
- st.sc.serveMsgCh <- func(loopNum int) {
- if !st.sc.pushEnabled {
- readyOnce.Do(func() { close(ready) })
- }
- }
- }
- }()
- if err := <-errc; err != nil {
- t.Error(err)
- }
- }
|