|
|
@@ -35,6 +35,7 @@ package transport
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "fmt"
|
|
|
"io"
|
|
|
"math"
|
|
|
"net"
|
|
|
@@ -71,6 +72,9 @@ type http2Client struct {
|
|
|
shutdownChan chan struct{}
|
|
|
// errorChan is closed to notify the I/O error to the caller.
|
|
|
errorChan chan struct{}
|
|
|
+ // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
|
|
|
+ // that the server sent GoAway on this transport.
|
|
|
+ goAway chan struct{}
|
|
|
|
|
|
framer *framer
|
|
|
hBuf *bytes.Buffer // the buffer for HPACK encoding
|
|
|
@@ -97,41 +101,44 @@ type http2Client struct {
|
|
|
maxStreams int
|
|
|
// the per-stream outbound flow control window size set by the peer.
|
|
|
streamSendQuota uint32
|
|
|
+ // goAwayID records the Last-Stream-ID in the GoAway frame from the server.
|
|
|
+ goAwayID uint32
|
|
|
+ // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
|
|
|
+ prevGoAwayID uint32
|
|
|
+}
|
|
|
+
|
|
|
+func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
|
|
|
+ if fn != nil {
|
|
|
+ return fn(ctx, addr)
|
|
|
+ }
|
|
|
+ return dialContext(ctx, "tcp", addr)
|
|
|
}
|
|
|
|
|
|
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
|
|
// and starts to receive messages on it. Non-nil error returns if construction
|
|
|
// fails.
|
|
|
-func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
|
|
|
- if opts.Dialer == nil {
|
|
|
- // Set the default Dialer.
|
|
|
- opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
|
|
|
- return net.DialTimeout("tcp", addr, timeout)
|
|
|
- }
|
|
|
- }
|
|
|
+func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
|
|
|
scheme := "http"
|
|
|
- startT := time.Now()
|
|
|
- timeout := opts.Timeout
|
|
|
- conn, connErr := opts.Dialer(addr, timeout)
|
|
|
+ conn, connErr := dial(opts.Dialer, ctx, addr)
|
|
|
if connErr != nil {
|
|
|
- return nil, ConnectionErrorf("transport: %v", connErr)
|
|
|
+ return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr)
|
|
|
}
|
|
|
+ // Any further errors will close the underlying connection
|
|
|
+ defer func(conn net.Conn) {
|
|
|
+ if err != nil {
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+ }(conn)
|
|
|
var authInfo credentials.AuthInfo
|
|
|
- if opts.TransportCredentials != nil {
|
|
|
+ if creds := opts.TransportCredentials; creds != nil {
|
|
|
scheme = "https"
|
|
|
- if timeout > 0 {
|
|
|
- timeout -= time.Since(startT)
|
|
|
- }
|
|
|
- conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
|
|
|
+ conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
|
|
|
}
|
|
|
if connErr != nil {
|
|
|
- return nil, ConnectionErrorf("transport: %v", connErr)
|
|
|
+ // Credentials handshake error is not a temporary error (unless the error
|
|
|
+ // was the connection closing).
|
|
|
+ return nil, ConnectionErrorf(connErr == io.EOF, connErr, "transport: %v", connErr)
|
|
|
}
|
|
|
- defer func() {
|
|
|
- if err != nil {
|
|
|
- conn.Close()
|
|
|
- }
|
|
|
- }()
|
|
|
ua := primaryUA
|
|
|
if opts.UserAgent != "" {
|
|
|
ua = opts.UserAgent + " " + ua
|
|
|
@@ -147,6 +154,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|
|
writableChan: make(chan int, 1),
|
|
|
shutdownChan: make(chan struct{}),
|
|
|
errorChan: make(chan struct{}),
|
|
|
+ goAway: make(chan struct{}),
|
|
|
framer: newFramer(conn),
|
|
|
hBuf: &buf,
|
|
|
hEnc: hpack.NewEncoder(&buf),
|
|
|
@@ -168,11 +176,11 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|
|
n, err := t.conn.Write(clientPreface)
|
|
|
if err != nil {
|
|
|
t.Close()
|
|
|
- return nil, ConnectionErrorf("transport: %v", err)
|
|
|
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
|
|
}
|
|
|
if n != len(clientPreface) {
|
|
|
t.Close()
|
|
|
- return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
|
|
|
+ return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
|
|
|
}
|
|
|
if initialWindowSize != defaultWindowSize {
|
|
|
err = t.framer.writeSettings(true, http2.Setting{
|
|
|
@@ -184,13 +192,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|
|
}
|
|
|
if err != nil {
|
|
|
t.Close()
|
|
|
- return nil, ConnectionErrorf("transport: %v", err)
|
|
|
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
|
|
}
|
|
|
// Adjust the connection flow control window if needed.
|
|
|
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
|
|
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
|
|
|
t.Close()
|
|
|
- return nil, ConnectionErrorf("transport: %v", err)
|
|
|
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
|
|
}
|
|
|
}
|
|
|
go t.controller()
|
|
|
@@ -202,6 +210,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
|
|
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
|
|
|
s := &Stream{
|
|
|
id: t.nextID,
|
|
|
+ done: make(chan struct{}),
|
|
|
+ goAway: make(chan struct{}),
|
|
|
method: callHdr.Method,
|
|
|
sendCompress: callHdr.SendCompress,
|
|
|
buf: newRecvBuffer(),
|
|
|
@@ -216,8 +226,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
|
|
// Make a stream be able to cancel the pending operations by itself.
|
|
|
s.ctx, s.cancel = context.WithCancel(ctx)
|
|
|
s.dec = &recvBufferReader{
|
|
|
- ctx: s.ctx,
|
|
|
- recv: s.buf,
|
|
|
+ ctx: s.ctx,
|
|
|
+ goAway: s.goAway,
|
|
|
+ recv: s.buf,
|
|
|
}
|
|
|
return s
|
|
|
}
|
|
|
@@ -271,6 +282,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|
|
t.mu.Unlock()
|
|
|
return nil, ErrConnClosing
|
|
|
}
|
|
|
+ if t.state == draining {
|
|
|
+ t.mu.Unlock()
|
|
|
+ return nil, ErrStreamDrain
|
|
|
+ }
|
|
|
if t.state != reachable {
|
|
|
t.mu.Unlock()
|
|
|
return nil, ErrConnClosing
|
|
|
@@ -278,7 +293,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|
|
checkStreamsQuota := t.streamsQuota != nil
|
|
|
t.mu.Unlock()
|
|
|
if checkStreamsQuota {
|
|
|
- sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire())
|
|
|
+ sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -287,7 +302,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|
|
t.streamsQuota.add(sq - 1)
|
|
|
}
|
|
|
}
|
|
|
- if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
|
|
|
+ if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
|
|
|
// Return the quota back now because there is no stream returned to the caller.
|
|
|
if _, ok := err.(StreamError); ok && checkStreamsQuota {
|
|
|
t.streamsQuota.add(1)
|
|
|
@@ -295,6 +310,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|
|
return nil, err
|
|
|
}
|
|
|
t.mu.Lock()
|
|
|
+ if t.state == draining {
|
|
|
+ t.mu.Unlock()
|
|
|
+ if checkStreamsQuota {
|
|
|
+ t.streamsQuota.add(1)
|
|
|
+ }
|
|
|
+ // Need to make t writable again so that the rpc in flight can still proceed.
|
|
|
+ t.writableChan <- 0
|
|
|
+ return nil, ErrStreamDrain
|
|
|
+ }
|
|
|
if t.state != reachable {
|
|
|
t.mu.Unlock()
|
|
|
return nil, ErrConnClosing
|
|
|
@@ -329,7 +353,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|
|
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
|
|
|
}
|
|
|
if timeout > 0 {
|
|
|
- t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
|
|
|
+ t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
|
|
|
}
|
|
|
for k, v := range authData {
|
|
|
// Capital header names are illegal in HTTP/2.
|
|
|
@@ -384,7 +408,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|
|
}
|
|
|
if err != nil {
|
|
|
t.notifyError(err)
|
|
|
- return nil, ConnectionErrorf("transport: %v", err)
|
|
|
+ return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
|
|
}
|
|
|
}
|
|
|
t.writableChan <- 0
|
|
|
@@ -403,22 +427,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
|
|
|
if t.streamsQuota != nil {
|
|
|
updateStreams = true
|
|
|
}
|
|
|
- if t.state == draining && len(t.activeStreams) == 1 {
|
|
|
+ delete(t.activeStreams, s.id)
|
|
|
+ if t.state == draining && len(t.activeStreams) == 0 {
|
|
|
// The transport is draining and s is the last live stream on t.
|
|
|
t.mu.Unlock()
|
|
|
t.Close()
|
|
|
return
|
|
|
}
|
|
|
- delete(t.activeStreams, s.id)
|
|
|
t.mu.Unlock()
|
|
|
if updateStreams {
|
|
|
t.streamsQuota.add(1)
|
|
|
}
|
|
|
- // In case stream sending and receiving are invoked in separate
|
|
|
- // goroutines (e.g., bi-directional streaming), the caller needs
|
|
|
- // to call cancel on the stream to interrupt the blocking on
|
|
|
- // other goroutines.
|
|
|
- s.cancel()
|
|
|
s.mu.Lock()
|
|
|
if q := s.fc.resetPendingData(); q > 0 {
|
|
|
if n := t.fc.onRead(q); n > 0 {
|
|
|
@@ -445,13 +464,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
|
|
|
// accessed any more.
|
|
|
func (t *http2Client) Close() (err error) {
|
|
|
t.mu.Lock()
|
|
|
- if t.state == reachable {
|
|
|
- close(t.errorChan)
|
|
|
- }
|
|
|
if t.state == closing {
|
|
|
t.mu.Unlock()
|
|
|
return
|
|
|
}
|
|
|
+ if t.state == reachable || t.state == draining {
|
|
|
+ close(t.errorChan)
|
|
|
+ }
|
|
|
t.state = closing
|
|
|
t.mu.Unlock()
|
|
|
close(t.shutdownChan)
|
|
|
@@ -475,10 +494,35 @@ func (t *http2Client) Close() (err error) {
|
|
|
|
|
|
func (t *http2Client) GracefulClose() error {
|
|
|
t.mu.Lock()
|
|
|
- if t.state == closing {
|
|
|
+ switch t.state {
|
|
|
+ case unreachable:
|
|
|
+ // The server may close the connection concurrently. t is not available for
|
|
|
+ // any streams. Close it now.
|
|
|
+ t.mu.Unlock()
|
|
|
+ t.Close()
|
|
|
+ return nil
|
|
|
+ case closing:
|
|
|
t.mu.Unlock()
|
|
|
return nil
|
|
|
}
|
|
|
+ // Notify the streams which were initiated after the server sent GOAWAY.
|
|
|
+ select {
|
|
|
+ case <-t.goAway:
|
|
|
+ n := t.prevGoAwayID
|
|
|
+ if n == 0 && t.nextID > 1 {
|
|
|
+ n = t.nextID - 2
|
|
|
+ }
|
|
|
+ m := t.goAwayID + 2
|
|
|
+ if m == 2 {
|
|
|
+ m = 1
|
|
|
+ }
|
|
|
+ for i := m; i <= n; i += 2 {
|
|
|
+ if s, ok := t.activeStreams[i]; ok {
|
|
|
+ close(s.goAway)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ }
|
|
|
if t.state == draining {
|
|
|
t.mu.Unlock()
|
|
|
return nil
|
|
|
@@ -504,15 +548,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|
|
size := http2MaxFrameLen
|
|
|
s.sendQuotaPool.add(0)
|
|
|
// Wait until the stream has some quota to send the data.
|
|
|
- sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
|
|
|
+ sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
t.sendQuotaPool.add(0)
|
|
|
// Wait until the transport has some quota to send the data.
|
|
|
- tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
|
|
|
+ tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
|
|
|
if err != nil {
|
|
|
- if _, ok := err.(StreamError); ok {
|
|
|
+ if _, ok := err.(StreamError); ok || err == io.EOF {
|
|
|
t.sendQuotaPool.cancel()
|
|
|
}
|
|
|
return err
|
|
|
@@ -544,8 +588,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|
|
// Indicate there is a writer who is about to write a data frame.
|
|
|
t.framer.adjustNumWriters(1)
|
|
|
// Got some quota. Try to acquire writing privilege on the transport.
|
|
|
- if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
|
|
|
- if _, ok := err.(StreamError); ok {
|
|
|
+ if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
|
|
|
+ if _, ok := err.(StreamError); ok || err == io.EOF {
|
|
|
// Return the connection quota back.
|
|
|
t.sendQuotaPool.add(len(p))
|
|
|
}
|
|
|
@@ -578,7 +622,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|
|
// invoked.
|
|
|
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
|
|
|
t.notifyError(err)
|
|
|
- return ConnectionErrorf("transport: %v", err)
|
|
|
+ return ConnectionErrorf(true, err, "transport: %v", err)
|
|
|
}
|
|
|
if t.framer.adjustNumWriters(-1) == 0 {
|
|
|
t.framer.flushWrite()
|
|
|
@@ -593,11 +637,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|
|
}
|
|
|
s.mu.Lock()
|
|
|
if s.state != streamDone {
|
|
|
- if s.state == streamReadDone {
|
|
|
- s.state = streamDone
|
|
|
- } else {
|
|
|
- s.state = streamWriteDone
|
|
|
- }
|
|
|
+ s.state = streamWriteDone
|
|
|
}
|
|
|
s.mu.Unlock()
|
|
|
return nil
|
|
|
@@ -630,7 +670,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
|
|
|
func (t *http2Client) handleData(f *http2.DataFrame) {
|
|
|
size := len(f.Data())
|
|
|
if err := t.fc.onData(uint32(size)); err != nil {
|
|
|
- t.notifyError(ConnectionErrorf("%v", err))
|
|
|
+ t.notifyError(ConnectionErrorf(true, err, "%v", err))
|
|
|
return
|
|
|
}
|
|
|
// Select the right stream to dispatch.
|
|
|
@@ -655,6 +695,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
|
|
|
s.state = streamDone
|
|
|
s.statusCode = codes.Internal
|
|
|
s.statusDesc = err.Error()
|
|
|
+ close(s.done)
|
|
|
s.mu.Unlock()
|
|
|
s.write(recvMsg{err: io.EOF})
|
|
|
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
|
|
|
@@ -672,13 +713,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
|
|
|
// the read direction is closed, and set the status appropriately.
|
|
|
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
|
|
|
s.mu.Lock()
|
|
|
- if s.state == streamWriteDone {
|
|
|
- s.state = streamDone
|
|
|
- } else {
|
|
|
- s.state = streamReadDone
|
|
|
+ if s.state == streamDone {
|
|
|
+ s.mu.Unlock()
|
|
|
+ return
|
|
|
}
|
|
|
+ s.state = streamDone
|
|
|
s.statusCode = codes.Internal
|
|
|
s.statusDesc = "server closed the stream without sending trailers"
|
|
|
+ close(s.done)
|
|
|
s.mu.Unlock()
|
|
|
s.write(recvMsg{err: io.EOF})
|
|
|
}
|
|
|
@@ -704,6 +746,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
|
|
|
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
|
|
|
s.statusCode = codes.Unknown
|
|
|
}
|
|
|
+ s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
|
|
|
+ close(s.done)
|
|
|
s.mu.Unlock()
|
|
|
s.write(recvMsg{err: io.EOF})
|
|
|
}
|
|
|
@@ -728,7 +772,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
|
|
|
}
|
|
|
|
|
|
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
|
|
- // TODO(zhaoq): GoAwayFrame handler to be implemented
|
|
|
+ t.mu.Lock()
|
|
|
+ if t.state == reachable || t.state == draining {
|
|
|
+ if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
|
|
|
+ t.mu.Unlock()
|
|
|
+ t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
|
|
|
+ return
|
|
|
+ }
|
|
|
+ select {
|
|
|
+ case <-t.goAway:
|
|
|
+ id := t.goAwayID
|
|
|
+ // t.goAway has been closed (i.e.,multiple GoAways).
|
|
|
+ if id < f.LastStreamID {
|
|
|
+ t.mu.Unlock()
|
|
|
+ t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
|
|
|
+ return
|
|
|
+ }
|
|
|
+ t.prevGoAwayID = id
|
|
|
+ t.goAwayID = f.LastStreamID
|
|
|
+ t.mu.Unlock()
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ t.goAwayID = f.LastStreamID
|
|
|
+ close(t.goAway)
|
|
|
+ }
|
|
|
+ t.mu.Unlock()
|
|
|
}
|
|
|
|
|
|
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
|
|
|
@@ -780,11 +849,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
|
|
if len(state.mdata) > 0 {
|
|
|
s.trailer = state.mdata
|
|
|
}
|
|
|
- s.state = streamDone
|
|
|
s.statusCode = state.statusCode
|
|
|
s.statusDesc = state.statusDesc
|
|
|
+ close(s.done)
|
|
|
+ s.state = streamDone
|
|
|
s.mu.Unlock()
|
|
|
-
|
|
|
s.write(recvMsg{err: io.EOF})
|
|
|
}
|
|
|
|
|
|
@@ -937,13 +1006,22 @@ func (t *http2Client) Error() <-chan struct{} {
|
|
|
return t.errorChan
|
|
|
}
|
|
|
|
|
|
+func (t *http2Client) GoAway() <-chan struct{} {
|
|
|
+ return t.goAway
|
|
|
+}
|
|
|
+
|
|
|
func (t *http2Client) notifyError(err error) {
|
|
|
t.mu.Lock()
|
|
|
- defer t.mu.Unlock()
|
|
|
// make sure t.errorChan is closed only once.
|
|
|
+ if t.state == draining {
|
|
|
+ t.mu.Unlock()
|
|
|
+ t.Close()
|
|
|
+ return
|
|
|
+ }
|
|
|
if t.state == reachable {
|
|
|
t.state = unreachable
|
|
|
close(t.errorChan)
|
|
|
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
|
|
|
}
|
|
|
+ t.mu.Unlock()
|
|
|
}
|