|
|
@@ -5,21 +5,29 @@
|
|
|
package http2
|
|
|
|
|
|
import (
|
|
|
+ "bufio"
|
|
|
+ "bytes"
|
|
|
"crypto/tls"
|
|
|
+ "errors"
|
|
|
"flag"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
+ "log"
|
|
|
"math/rand"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
"os"
|
|
|
"reflect"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
"testing"
|
|
|
"time"
|
|
|
+
|
|
|
+ "golang.org/x/net/http2/hpack"
|
|
|
)
|
|
|
|
|
|
var (
|
|
|
@@ -182,6 +190,8 @@ func TestTransportGroupsPendingDials(t *testing.T) {
|
|
|
if !ok {
|
|
|
return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
|
|
|
}
|
|
|
+ cp.mu.Lock()
|
|
|
+ defer cp.mu.Unlock()
|
|
|
if len(cp.dialing) != 0 {
|
|
|
return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
|
|
|
}
|
|
|
@@ -456,3 +466,296 @@ func TestConfigureTransport(t *testing.T) {
|
|
|
t.Errorf("body = %q; want %q", got, want)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+type capitalizeReader struct {
|
|
|
+ r io.Reader
|
|
|
+}
|
|
|
+
|
|
|
+func (cr capitalizeReader) Read(p []byte) (n int, err error) {
|
|
|
+ n, err = cr.r.Read(p)
|
|
|
+ for i, b := range p[:n] {
|
|
|
+ if b >= 'a' && b <= 'z' {
|
|
|
+ p[i] = b - ('a' - 'A')
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+type flushWriter struct {
|
|
|
+ w io.Writer
|
|
|
+}
|
|
|
+
|
|
|
+func (fw flushWriter) Write(p []byte) (n int, err error) {
|
|
|
+ n, err = fw.w.Write(p)
|
|
|
+ if f, ok := fw.w.(http.Flusher); ok {
|
|
|
+ f.Flush()
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+type clientTester struct {
|
|
|
+ t *testing.T
|
|
|
+ tr *Transport
|
|
|
+ sc, cc net.Conn // server and client conn
|
|
|
+ fr *Framer // server's framer
|
|
|
+ client func() error
|
|
|
+ server func() error
|
|
|
+}
|
|
|
+
|
|
|
+func newClientTester(t *testing.T) *clientTester {
|
|
|
+ var dialOnce struct {
|
|
|
+ sync.Mutex
|
|
|
+ dialed bool
|
|
|
+ }
|
|
|
+ ct := &clientTester{
|
|
|
+ t: t,
|
|
|
+ }
|
|
|
+ ct.tr = &Transport{
|
|
|
+ TLSClientConfig: tlsConfigInsecure,
|
|
|
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
|
|
+ dialOnce.Lock()
|
|
|
+ defer dialOnce.Unlock()
|
|
|
+ if dialOnce.dialed {
|
|
|
+ return nil, errors.New("only one dial allowed in test mode")
|
|
|
+ }
|
|
|
+ dialOnce.dialed = true
|
|
|
+ return ct.cc, nil
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ ln := newLocalListener(t)
|
|
|
+ cc, err := net.Dial("tcp", ln.Addr().String())
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+
|
|
|
+ }
|
|
|
+ sc, err := ln.Accept()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ ln.Close()
|
|
|
+ ct.cc = cc
|
|
|
+ ct.sc = sc
|
|
|
+ ct.fr = NewFramer(sc, sc)
|
|
|
+ return ct
|
|
|
+}
|
|
|
+
|
|
|
+func newLocalListener(t *testing.T) net.Listener {
|
|
|
+ ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
|
|
+ if err == nil {
|
|
|
+ return ln
|
|
|
+ }
|
|
|
+ ln, err = net.Listen("tcp6", "[::1]:0")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ return ln
|
|
|
+}
|
|
|
+
|
|
|
+func (ct *clientTester) greet() {
|
|
|
+ buf := make([]byte, len(ClientPreface))
|
|
|
+ _, err := io.ReadFull(ct.sc, buf)
|
|
|
+ if err != nil {
|
|
|
+ ct.t.Fatalf("reading client preface: %v", err)
|
|
|
+ }
|
|
|
+ f, err := ct.fr.ReadFrame()
|
|
|
+ if err != nil {
|
|
|
+ ct.t.Fatalf("Reading client settings frame: %v", err)
|
|
|
+ }
|
|
|
+ if sf, ok := f.(*SettingsFrame); !ok {
|
|
|
+ ct.t.Fatalf("Wanted client settings frame; got %v", f)
|
|
|
+ _ = sf // stash it away?
|
|
|
+ }
|
|
|
+ if err := ct.fr.WriteSettings(); err != nil {
|
|
|
+ ct.t.Fatal(err)
|
|
|
+ }
|
|
|
+ if err := ct.fr.WriteSettingsAck(); err != nil {
|
|
|
+ ct.t.Fatal(err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (ct *clientTester) run() {
|
|
|
+ errc := make(chan error, 2)
|
|
|
+ ct.start("client", errc, ct.client)
|
|
|
+ ct.start("server", errc, ct.server)
|
|
|
+ for i := 0; i < 2; i++ {
|
|
|
+ if err := <-errc; err != nil {
|
|
|
+ ct.t.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (ct *clientTester) start(which string, errc chan<- error, fn func() error) {
|
|
|
+ go func() {
|
|
|
+ finished := false
|
|
|
+ var err error
|
|
|
+ defer func() {
|
|
|
+ if !finished {
|
|
|
+ err = fmt.Errorf("%s goroutine didn't finish.", which)
|
|
|
+ } else if err != nil {
|
|
|
+ err = fmt.Errorf("%s: %v", which, err)
|
|
|
+ }
|
|
|
+ errc <- err
|
|
|
+ }()
|
|
|
+ err = fn()
|
|
|
+ finished = true
|
|
|
+ }()
|
|
|
+}
|
|
|
+
|
|
|
+type countingReader struct {
|
|
|
+ n *int64
|
|
|
+}
|
|
|
+
|
|
|
+func (r countingReader) Read(p []byte) (n int, err error) {
|
|
|
+ for i := range p {
|
|
|
+ p[i] = byte(i)
|
|
|
+ }
|
|
|
+ atomic.AddInt64(r.n, int64(len(p)))
|
|
|
+ return len(p), err
|
|
|
+}
|
|
|
+
|
|
|
+func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
|
|
|
+func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
|
|
|
+
|
|
|
+func testTransportReqBodyAfterResponse(t *testing.T, status int) {
|
|
|
+ const bodySize = 10 << 20
|
|
|
+ ct := newClientTester(t)
|
|
|
+ ct.client = func() error {
|
|
|
+ var n int64 // atomic
|
|
|
+ req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize))
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ res, err := ct.tr.RoundTrip(req)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("RoundTrip: %v", err)
|
|
|
+ }
|
|
|
+ defer res.Body.Close()
|
|
|
+ if res.StatusCode != status {
|
|
|
+ return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
|
|
|
+ }
|
|
|
+ slurp, err := ioutil.ReadAll(res.Body)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("Slurp: %v", err)
|
|
|
+ }
|
|
|
+ if len(slurp) > 0 {
|
|
|
+ return fmt.Errorf("unexpected body: %q", slurp)
|
|
|
+ }
|
|
|
+ if status == 200 {
|
|
|
+ if got := atomic.LoadInt64(&n); got != bodySize {
|
|
|
+ return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize {
|
|
|
+ return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ ct.server = func() error {
|
|
|
+ ct.greet()
|
|
|
+ var buf bytes.Buffer
|
|
|
+ enc := hpack.NewEncoder(&buf)
|
|
|
+ var dataRecv int64
|
|
|
+ var closed bool
|
|
|
+ for {
|
|
|
+ f, err := ct.fr.ReadFrame()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ //println(fmt.Sprintf("server got frame: %v", f))
|
|
|
+ switch f := f.(type) {
|
|
|
+ case *WindowUpdateFrame, *SettingsFrame:
|
|
|
+ case *HeadersFrame:
|
|
|
+ if !f.HeadersEnded() {
|
|
|
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
|
|
|
+ }
|
|
|
+ if f.StreamEnded() {
|
|
|
+ return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
|
|
|
+ }
|
|
|
+ time.Sleep(50 * time.Millisecond) // let client send body
|
|
|
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
|
|
+ ct.fr.WriteHeaders(HeadersFrameParam{
|
|
|
+ StreamID: f.StreamID,
|
|
|
+ EndHeaders: true,
|
|
|
+ EndStream: false,
|
|
|
+ BlockFragment: buf.Bytes(),
|
|
|
+ })
|
|
|
+ case *DataFrame:
|
|
|
+ dataLen := len(f.Data())
|
|
|
+ dataRecv += int64(dataLen)
|
|
|
+ if dataLen > 0 {
|
|
|
+ if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if !closed && ((status != 200 && dataRecv > 0) ||
|
|
|
+ (status == 200 && dataRecv == bodySize)) {
|
|
|
+ closed = true
|
|
|
+ if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ return fmt.Errorf("Unexpected client frame %v", f)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ ct.run()
|
|
|
+}
|
|
|
+
|
|
|
+// See golang.org/issue/13444
|
|
|
+func TestTransportFullDuplex(t *testing.T) {
|
|
|
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ w.WriteHeader(200) // redundant but for clarity
|
|
|
+ w.(http.Flusher).Flush()
|
|
|
+ io.Copy(flushWriter{w}, capitalizeReader{r.Body})
|
|
|
+ fmt.Fprintf(w, "bye.\n")
|
|
|
+ }, optOnlyServer)
|
|
|
+ defer st.Close()
|
|
|
+
|
|
|
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
|
|
|
+ defer tr.CloseIdleConnections()
|
|
|
+ c := &http.Client{Transport: tr}
|
|
|
+
|
|
|
+ pr, pw := io.Pipe()
|
|
|
+ req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
|
|
|
+ if err != nil {
|
|
|
+ log.Fatal(err)
|
|
|
+ }
|
|
|
+ res, err := c.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ log.Fatal(err)
|
|
|
+ }
|
|
|
+ defer res.Body.Close()
|
|
|
+ if res.StatusCode != 200 {
|
|
|
+ t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
|
|
|
+ }
|
|
|
+ bs := bufio.NewScanner(res.Body)
|
|
|
+ want := func(v string) {
|
|
|
+ if !bs.Scan() {
|
|
|
+ t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
|
|
|
+ }
|
|
|
+ }
|
|
|
+ write := func(v string) {
|
|
|
+ _, err := io.WriteString(pw, v)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("pipe write: %v", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ write("foo\n")
|
|
|
+ want("FOO")
|
|
|
+ write("bar\n")
|
|
|
+ want("BAR")
|
|
|
+ pw.Close()
|
|
|
+ want("bye.")
|
|
|
+ if err := bs.Err(); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+}
|