浏览代码

transport: start of re-using client connections

Brad Fitzpatrick 11 年之前
父节点
当前提交
ddf20046ea
共有 1 个文件被更改,包括 59 次插入19 次删除
  1. 59 19
      transport.go

+ 59 - 19
transport.go

@@ -30,16 +30,12 @@ type Transport struct {
 }
 
 type clientConn struct {
-	tconn *tls.Conn
-	bw    *bufio.Writer
-	br    *bufio.Reader
-	fr    *Framer
+	tconn    *tls.Conn
+	tlsState *tls.ConnectionState
 
 	readerDone chan struct{} // closed on error
 	readerErr  error         // set before readerDone is closed
 
-	werr error // first write error that has occurred
-
 	hbuf bytes.Buffer // HPACK encoder writes into this
 	henc *hpack.Encoder
 
@@ -53,11 +49,15 @@ type clientConn struct {
 	mu           sync.Mutex
 	streams      map[uint32]*clientStream
 	nextStreamID uint32
+	bw           *bufio.Writer
+	werr         error // first write error that has occurred
+	br           *bufio.Reader
+	fr           *Framer
 }
 
 type clientStream struct {
 	ID   uint32
-	resc chan *http.Response
+	resc chan resAndError
 	pw   *io.PipeWriter
 	pr   *io.PipeReader
 }
@@ -89,6 +89,30 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 		host = req.URL.Host
 		port = "443"
 	}
+
+	for {
+		cc, err := t.getClientConn(host, port)
+		if err != nil {
+			return nil, err
+		}
+		res, err := cc.roundTrip(req)
+		if isShutdownError(err) {
+			continue
+		}
+		if err != nil {
+			return nil, err
+		}
+		return res, nil
+	}
+}
+
+func isShutdownError(err error) bool {
+	// TODO: implement
+	return false
+}
+
+func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
+	// TODO: cache these
 	cfg := &tls.Config{
 		ServerName:         host,
 		NextProtos:         []string{NextProtoTLS},
@@ -120,6 +144,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 
 	cc := &clientConn{
 		tconn:        tconn,
+		tlsState:     &state,
 		readerDone:   make(chan struct{}),
 		nextStreamID: 1,
 		streams:      make(map[uint32]*clientStream),
@@ -163,6 +188,11 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 	cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
 
 	go cc.readLoop()
+	return cc, nil
+}
+
+func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
+	cc.mu.Lock()
 
 	cs := cc.newStream()
 	hasBody := false // TODO
@@ -190,16 +220,24 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 		}
 	}
 	cc.bw.Flush()
-	if cc.werr != nil {
-		return nil, cc.werr
+	werr := cc.werr
+	cc.mu.Unlock()
+
+	if werr != nil {
+		return nil, werr
 	}
 
-	resp := <-cs.resc
-	resp.Request = req
-	resp.TLS = &state
-	return resp, nil
+	re := <-cs.resc
+	if re.err != nil {
+		return nil, re.err
+	}
+	res := re.res
+	res.Request = req
+	res.TLS = cc.tlsState
+	return res, nil
 }
 
+// requires cc.mu be held.
 func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
 	cc.hbuf.Reset()
 
@@ -236,17 +274,19 @@ func (cc *clientConn) writeHeader(name, value string) {
 	cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
 }
 
-func (cc *clientConn) newStream() *clientStream {
-	cc.mu.Lock()
-	defer cc.mu.Unlock()
+type resAndError struct {
+	res *http.Response
+	err error
+}
 
+// requires cc.mu be held.
+func (cc *clientConn) newStream() *clientStream {
 	cs := &clientStream{
 		ID:   cc.nextStreamID,
-		resc: make(chan *http.Response, 1),
+		resc: make(chan resAndError, 1),
 	}
 	cc.nextStreamID += 2
 	cc.streams[cs.ID] = cs
-
 	return cs
 }
 
@@ -307,7 +347,7 @@ func (cc *clientConn) readLoop() {
 				panic("couldn't find stream") // TODO be graceful
 			}
 			cc.nextRes.Body = cs.pr
-			cs.resc <- cc.nextRes
+			cs.resc <- resAndError{res: cc.nextRes}
 		}
 	}
 }