// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package http2 import ( "bufio" "bytes" "crypto/tls" "errors" "flag" "fmt" "io" "io/ioutil" "log" "math/rand" "net" "net/http" "net/url" "os" "reflect" "strconv" "strings" "sync" "sync/atomic" "testing" "time" "golang.org/x/net/http2/hpack" ) var ( extNet = flag.Bool("extnet", false, "do external network tests") transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport") insecure = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove? ) var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true} func TestTransportExternal(t *testing.T) { if !*extNet { t.Skip("skipping external network test") } req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil) rt := &Transport{TLSClientConfig: tlsConfigInsecure} res, err := rt.RoundTrip(req) if err != nil { t.Fatalf("%v", err) } res.Write(os.Stdout) } func TestTransport(t *testing.T) { const body = "sup" st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, body) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } defer res.Body.Close() t.Logf("Got res: %+v", res) if g, w := res.StatusCode, 200; g != w { t.Errorf("StatusCode = %v; want %v", g, w) } if g, w := res.Status, "200 OK"; g != w { t.Errorf("Status = %q; want %q", g, w) } wantHeader := http.Header{ "Content-Length": []string{"3"}, "Content-Type": []string{"text/plain; charset=utf-8"}, "Date": []string{"XXX"}, // see cleanDate } cleanDate(res) if !reflect.DeepEqual(res.Header, wantHeader) { t.Errorf("res Header = %v; want %v", res.Header, wantHeader) } if res.Request != req { t.Errorf("Response.Request = %p; want %p", res.Request, req) } if res.TLS == nil { t.Error("Response.TLS = nil; want non-nil") } slurp, err := ioutil.ReadAll(res.Body) if err != nil { t.Errorf("Body read: %v", err) } else if string(slurp) != body { t.Errorf("Body = %q; want %q", slurp, body) } } func TestTransportReusesConns(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.RemoteAddr) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() get := func() string { req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } defer res.Body.Close() slurp, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("Body read: %v", err) } addr := strings.TrimSpace(string(slurp)) if addr == "" { t.Fatalf("didn't get an addr in response") } return addr } first := get() second := get() if first != second { t.Errorf("first and second responses were on different connections: %q vs %q", first, second) } } // Tests that the Transport only keeps one pending dial open per destination address. // https://golang.org/issue/13397 func TestTransportGroupsPendingDials(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.RemoteAddr) }, optOnlyServer) defer st.Close() tr := &Transport{ TLSClientConfig: tlsConfigInsecure, } defer tr.CloseIdleConnections() var ( mu sync.Mutex dials = map[string]int{} ) var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Error(err) return } res, err := tr.RoundTrip(req) if err != nil { t.Error(err) return } defer res.Body.Close() slurp, err := ioutil.ReadAll(res.Body) if err != nil { t.Errorf("Body read: %v", err) } addr := strings.TrimSpace(string(slurp)) if addr == "" { t.Errorf("didn't get an addr in response") } mu.Lock() dials[addr]++ mu.Unlock() }() } wg.Wait() if len(dials) != 1 { t.Errorf("saw %d dials; want 1: %v", len(dials), dials) } tr.CloseIdleConnections() if err := retry(50, 10*time.Millisecond, func() error { cp, ok := tr.connPool().(*clientConnPool) if !ok { return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool()) } cp.mu.Lock() defer cp.mu.Unlock() if len(cp.dialing) != 0 { return fmt.Errorf("dialing map = %v; want empty", cp.dialing) } if len(cp.conns) != 0 { return fmt.Errorf("conns = %v; want empty", cp.conns) } if len(cp.keys) != 0 { return fmt.Errorf("keys = %v; want empty", cp.keys) } return nil }); err != nil { t.Error("State of pool after CloseIdleConnections: %v", err) } } func retry(tries int, delay time.Duration, fn func() error) error { var err error for i := 0; i < tries; i++ { err = fn() if err == nil { return nil } time.Sleep(delay) } return err } func TestTransportAbortClosesPipes(t *testing.T) { shutdown := make(chan struct{}) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.(http.Flusher).Flush() <-shutdown }, optOnlyServer, ) defer st.Close() defer close(shutdown) // we must shutdown before st.Close() to avoid hanging done := make(chan struct{}) requestMade := make(chan struct{}) go func() { defer close(done) tr := &Transport{TLSClientConfig: tlsConfigInsecure} req, err := http.NewRequest("GET", st.ts.URL, nil) if err != nil { t.Fatal(err) } res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } defer res.Body.Close() close(requestMade) _, err = ioutil.ReadAll(res.Body) if err == nil { t.Error("expected error from res.Body.Read") } }() <-requestMade // Now force the serve loop to end, via closing the connection. st.closeConn() // deadlock? that's a bug. select { case <-done: case <-time.After(3 * time.Second): t.Fatal("timeout") } } // TODO: merge this with TestTransportBody to make TestTransportRequest? This // could be a table-driven test with extra goodies. func TestTransportPath(t *testing.T) { gotc := make(chan *url.URL, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { gotc <- r.URL }, optOnlyServer, ) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() const ( path = "/testpath" query = "q=1" ) surl := st.ts.URL + path + "?" + query req, err := http.NewRequest("POST", surl, nil) if err != nil { t.Fatal(err) } c := &http.Client{Transport: tr} res, err := c.Do(req) if err != nil { t.Fatal(err) } defer res.Body.Close() got := <-gotc if got.Path != path { t.Errorf("Read Path = %q; want %q", got.Path, path) } if got.RawQuery != query { t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query) } } func randString(n int) string { rnd := rand.New(rand.NewSource(int64(n))) b := make([]byte, n) for i := range b { b[i] = byte(rnd.Intn(256)) } return string(b) } var bodyTests = []struct { body string noContentLen bool }{ {body: "some message"}, {body: "some message", noContentLen: true}, {body: ""}, {body: "", noContentLen: true}, {body: strings.Repeat("a", 1<<20), noContentLen: true}, {body: strings.Repeat("a", 1<<20)}, {body: randString(16<<10 - 1)}, {body: randString(16 << 10)}, {body: randString(16<<10 + 1)}, {body: randString(512<<10 - 1)}, {body: randString(512 << 10)}, {body: randString(512<<10 + 1)}, {body: randString(1<<20 - 1)}, {body: randString(1 << 20)}, {body: randString(1<<20 + 2)}, } func TestTransportBody(t *testing.T) { gotc := make(chan interface{}, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { slurp, err := ioutil.ReadAll(r.Body) if err != nil { gotc <- err } else { gotc <- string(slurp) } }, optOnlyServer, ) defer st.Close() for i, tt := range bodyTests { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() var body io.Reader = strings.NewReader(tt.body) if tt.noContentLen { body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods } req, err := http.NewRequest("POST", st.ts.URL, body) if err != nil { t.Fatalf("#%d: %v", i, err) } c := &http.Client{Transport: tr} res, err := c.Do(req) if err != nil { t.Fatalf("#%d: %v", i, err) } defer res.Body.Close() got := <-gotc if err, ok := got.(error); ok { t.Fatalf("#%d: %v", i, err) } else if got.(string) != tt.body { got := got.(string) t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body)) } } } func shortString(v string) string { const maxLen = 100 if len(v) <= maxLen { return v } return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:]) } func TestTransportDialTLS(t *testing.T) { var mu sync.Mutex // guards following var gotReq, didDial bool ts := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { mu.Lock() gotReq = true mu.Unlock() }, optOnlyServer, ) defer ts.Close() tr := &Transport{ DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { mu.Lock() didDial = true mu.Unlock() cfg.InsecureSkipVerify = true c, err := tls.Dial(netw, addr, cfg) if err != nil { return nil, err } return c, c.Handshake() }, } defer tr.CloseIdleConnections() client := &http.Client{Transport: tr} res, err := client.Get(ts.ts.URL) if err != nil { t.Fatal(err) } res.Body.Close() mu.Lock() if !gotReq { t.Error("didn't get request") } if !didDial { t.Error("didn't use dial hook") } } func TestConfigureTransport(t *testing.T) { t1 := &http.Transport{} err := ConfigureTransport(t1) if err == errTransportVersion { t.Skip(err) } if err != nil { t.Fatal(err) } if got := fmt.Sprintf("%#v", *t1); !strings.Contains(got, `"h2"`) { // Laziness, to avoid buildtags. t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got) } if t1.TLSClientConfig == nil { t.Errorf("nil t1.TLSClientConfig") } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, []string{"h2"}) { t.Errorf("TLSClientConfig.NextProtos = %q; want just 'h2'", t1.TLSClientConfig.NextProtos) } if err := ConfigureTransport(t1); err == nil { t.Error("unexpected success on second call to ConfigureTransport") } // And does it work? st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.Proto) }, optOnlyServer) defer st.Close() t1.TLSClientConfig.InsecureSkipVerify = true c := &http.Client{Transport: t1} res, err := c.Get(st.ts.URL) if err != nil { t.Fatal(err) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatal(err) } if got, want := string(slurp), "HTTP/2.0"; got != want { t.Errorf("body = %q; want %q", got, want) } } type capitalizeReader struct { r io.Reader } func (cr capitalizeReader) Read(p []byte) (n int, err error) { n, err = cr.r.Read(p) for i, b := range p[:n] { if b >= 'a' && b <= 'z' { p[i] = b - ('a' - 'A') } } return } type flushWriter struct { w io.Writer } func (fw flushWriter) Write(p []byte) (n int, err error) { n, err = fw.w.Write(p) if f, ok := fw.w.(http.Flusher); ok { f.Flush() } return } type clientTester struct { t *testing.T tr *Transport sc, cc net.Conn // server and client conn fr *Framer // server's framer client func() error server func() error } func newClientTester(t *testing.T) *clientTester { var dialOnce struct { sync.Mutex dialed bool } ct := &clientTester{ t: t, } ct.tr = &Transport{ TLSClientConfig: tlsConfigInsecure, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { dialOnce.Lock() defer dialOnce.Unlock() if dialOnce.dialed { return nil, errors.New("only one dial allowed in test mode") } dialOnce.dialed = true return ct.cc, nil }, } ln := newLocalListener(t) cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } sc, err := ln.Accept() if err != nil { t.Fatal(err) } ln.Close() ct.cc = cc ct.sc = sc ct.fr = NewFramer(sc, sc) return ct } func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp4", "127.0.0.1:0") if err == nil { return ln } ln, err = net.Listen("tcp6", "[::1]:0") if err != nil { t.Fatal(err) } return ln } func (ct *clientTester) greet() { buf := make([]byte, len(ClientPreface)) _, err := io.ReadFull(ct.sc, buf) if err != nil { ct.t.Fatalf("reading client preface: %v", err) } f, err := ct.fr.ReadFrame() if err != nil { ct.t.Fatalf("Reading client settings frame: %v", err) } if sf, ok := f.(*SettingsFrame); !ok { ct.t.Fatalf("Wanted client settings frame; got %v", f) _ = sf // stash it away? } if err := ct.fr.WriteSettings(); err != nil { ct.t.Fatal(err) } if err := ct.fr.WriteSettingsAck(); err != nil { ct.t.Fatal(err) } } func (ct *clientTester) run() { errc := make(chan error, 2) ct.start("client", errc, ct.client) ct.start("server", errc, ct.server) for i := 0; i < 2; i++ { if err := <-errc; err != nil { ct.t.Error(err) return } } } func (ct *clientTester) start(which string, errc chan<- error, fn func() error) { go func() { finished := false var err error defer func() { if !finished { err = fmt.Errorf("%s goroutine didn't finish.", which) } else if err != nil { err = fmt.Errorf("%s: %v", which, err) } errc <- err }() err = fn() finished = true }() } type countingReader struct { n *int64 } func (r countingReader) Read(p []byte) (n int, err error) { for i := range p { p[i] = byte(i) } atomic.AddInt64(r.n, int64(len(p))) return len(p), err } func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } func testTransportReqBodyAfterResponse(t *testing.T, status int) { const bodySize = 10 << 20 ct := newClientTester(t) ct.client = func() error { var n int64 // atomic req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize)) if err != nil { return err } res, err := ct.tr.RoundTrip(req) if err != nil { return fmt.Errorf("RoundTrip: %v", err) } defer res.Body.Close() if res.StatusCode != status { return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) } slurp, err := ioutil.ReadAll(res.Body) if err != nil { return fmt.Errorf("Slurp: %v", err) } if len(slurp) > 0 { return fmt.Errorf("unexpected body: %q", slurp) } if status == 200 { if got := atomic.LoadInt64(&n); got != bodySize { return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) } } else { if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize { return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) } } return nil } ct.server = func() error { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) var dataRecv int64 var closed bool for { f, err := ct.fr.ReadFrame() if err != nil { return err } //println(fmt.Sprintf("server got frame: %v", f)) switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: if !f.HeadersEnded() { return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) } if f.StreamEnded() { return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) } time.Sleep(50 * time.Millisecond) // let client send body enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) ct.fr.WriteHeaders(HeadersFrameParam{ StreamID: f.StreamID, EndHeaders: true, EndStream: false, BlockFragment: buf.Bytes(), }) case *DataFrame: dataLen := len(f.Data()) dataRecv += int64(dataLen) if dataLen > 0 { if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { return err } if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { return err } } if !closed && ((status != 200 && dataRecv > 0) || (status == 200 && dataRecv == bodySize)) { closed = true if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { return err } return nil } default: return fmt.Errorf("Unexpected client frame %v", f) } } return nil } ct.run() } // See golang.org/issue/13444 func TestTransportFullDuplex(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) // redundant but for clarity w.(http.Flusher).Flush() io.Copy(flushWriter{w}, capitalizeReader{r.Body}) fmt.Fprintf(w, "bye.\n") }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() c := &http.Client{Transport: tr} pr, pw := io.Pipe() req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr)) if err != nil { log.Fatal(err) } res, err := c.Do(req) if err != nil { log.Fatal(err) } defer res.Body.Close() if res.StatusCode != 200 { t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200) } bs := bufio.NewScanner(res.Body) want := func(v string) { if !bs.Scan() { t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err()) } } write := func(v string) { _, err := io.WriteString(pw, v) if err != nil { t.Fatalf("pipe write: %v", err) } } write("foo\n") want("FOO") write("bar\n") want("BAR") pw.Close() want("bye.") if err := bs.Err(); err != nil { t.Fatal(err) } }