Explorar o código

send frame should not block (#1062)

* make send frame async

* refactor

* fix goroutine leak

* add author

* fix test

* fix double creation of timeoutCh

* exit loop on err

* pull loop in to writeToConn

* timeout is shared between read and write

* move resetTimeout to callReq

* merge frameWriteArg into callReq
Zhixin Wen %!s(int64=7) %!d(string=hai) anos
pai
achega
acd6671d04
Modificáronse 2 ficheiros con 108 adicións e 57 borrados
  1. 2 1
      AUTHORS
  2. 106 56
      conn.go

+ 2 - 1
AUTHORS

@@ -100,4 +100,5 @@ Vivian Mathews <vivian.mathews.3@gmail.com>
 Sascha Steinbiss <satta@debian.org>
 Seth Rosenblum <seth.t.rosenblum@gmail.com>
 Javier Zunzunegui <javier.zunzunegui.b@gmail.com>
-Luke Hines <lukehines@protonmail.com>
+Luke Hines <lukehines@protonmail.com>
+Zhixin Wen <john.wenzhixin@hotmail.com>

+ 106 - 56
conn.go

@@ -147,6 +147,8 @@ type Conn struct {
 	quit   chan struct{}
 
 	timeouts int64
+
+	frameWriteArgChan chan *callReq
 }
 
 // Connect establishes a connection to a Cassandra node.
@@ -183,19 +185,20 @@ func (s *Session) dial(ip net.IP, port int, cfg *ConnConfig, errorHandler ConnEr
 	}
 
 	c := &Conn{
-		conn:         conn,
-		r:            bufio.NewReader(conn),
-		cfg:          cfg,
-		calls:        make(map[int]*callReq),
-		timeout:      cfg.Timeout,
-		version:      uint8(cfg.ProtoVersion),
-		addr:         conn.RemoteAddr().String(),
-		errorHandler: errorHandler,
-		compressor:   cfg.Compressor,
-		auth:         cfg.Authenticator,
-		quit:         make(chan struct{}),
-		session:      s,
-		streams:      streams.New(cfg.ProtoVersion),
+		conn:              conn,
+		r:                 bufio.NewReader(conn),
+		cfg:               cfg,
+		calls:             make(map[int]*callReq),
+		timeout:           cfg.Timeout,
+		version:           uint8(cfg.ProtoVersion),
+		addr:              conn.RemoteAddr().String(),
+		errorHandler:      errorHandler,
+		compressor:        cfg.Compressor,
+		auth:              cfg.Authenticator,
+		quit:              make(chan struct{}),
+		session:           s,
+		streams:           streams.New(cfg.ProtoVersion),
+		frameWriteArgChan: make(chan *callReq),
 	}
 
 	if cfg.Keepalive > 0 {
@@ -215,6 +218,9 @@ func (s *Session) dial(ip net.IP, port int, cfg *ConnConfig, errorHandler ConnEr
 
 	frameTicker := make(chan struct{}, 1)
 	startupErr := make(chan error)
+
+	go c.writeToConn()
+
 	go func() {
 		for range frameTicker {
 			err := c.recv()
@@ -436,6 +442,27 @@ func (c *Conn) discardFrame(head frameHeader) error {
 	return nil
 }
 
+// writeToConn() processes writing to the connection, which is required before any write
+// to Conn and is usually called in a separate goroutine.
+func (c *Conn) writeToConn() {
+	for {
+		select {
+		case call := <-c.frameWriteArgChan:
+			if err := call.req.writeFrame(call.framer, call.streamID); err != nil {
+				// I think this is the correct thing to do, im not entirely sure. It is not
+				// ideal as readers might still get some data, but they probably wont.
+				// Here we need to be careful as the stream is not available and if all
+				// writes just timeout or fail then the pool might use this connection to
+				// send a frame on, with all the streams used up and not returned.
+				c.closeWithError(err)
+				return
+			}
+		case <-c.quit:
+			return
+		}
+	}
+}
+
 type protocolError struct {
 	frame frame
 }
@@ -561,10 +588,32 @@ type callReq struct {
 	framer   *framer
 	timeout  chan struct{} // indicates to recv() that a call has timedout
 	streamID int           // current stream in use
+	req      frameWriter
 
 	timer *time.Timer
 }
 
+func (c *callReq) resetTimeout(timeout time.Duration) <-chan time.Time {
+	var timeoutCh <-chan time.Time
+	if timeout > 0 {
+		if c.timer == nil {
+			c.timer = time.NewTimer(0)
+			<-c.timer.C
+		} else {
+			if !c.timer.Stop() {
+				select {
+				case <-c.timer.C:
+				default:
+				}
+			}
+		}
+
+		c.timer.Reset(timeout)
+		timeoutCh = c.timer.C
+	}
+	return timeoutCh
+}
+
 func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
 	// TODO: move tracer onto conn
 	stream, ok := c.streams.GetStream()
@@ -588,45 +637,38 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
 	call.framer = framer
 	call.timeout = make(chan struct{})
 	call.streamID = stream
+	call.req = req
 	c.mu.Unlock()
 
 	if tracer != nil {
 		framer.trace()
 	}
 
-	err := req.writeFrame(framer, stream)
-	if err != nil {
-		// closeWithError will block waiting for this stream to either receive a response
-		// or for us to timeout, close the timeout chan here. Im not entirely sure
-		// but we should not get a response after an error on the write side.
-		close(call.timeout)
-		// I think this is the correct thing to do, im not entirely sure. It is not
-		// ideal as readers might still get some data, but they probably wont.
-		// Here we need to be careful as the stream is not available and if all
-		// writes just timeout or fail then the pool might use this connection to
-		// send a frame on, with all the streams used up and not returned.
-		c.closeWithError(err)
+	timeoutCh := call.resetTimeout(c.timeout)
+	if err := c.sendFrame(ctx, call, timeoutCh); err != nil {
 		return nil, err
 	}
 
-	var timeoutCh <-chan time.Time
-	if c.timeout > 0 {
-		if call.timer == nil {
-			call.timer = time.NewTimer(0)
-			<-call.timer.C
-		} else {
-			if !call.timer.Stop() {
-				select {
-				case <-call.timer.C:
-				default:
-				}
-			}
-		}
+	if err := c.getResp(ctx, call, timeoutCh); err != nil {
+		return nil, err
+	}
+
+	// dont release the stream if detect a timeout as another request can reuse
+	// that stream and get a response for the old request, which we have no
+	// easy way of detecting.
+	//
+	// Ensure that the stream is not released if there are potentially outstanding
+	// requests on the stream to prevent nil pointer dereferences in recv().
+	defer c.releaseStream(stream)
 
-		call.timer.Reset(c.timeout)
-		timeoutCh = call.timer.C
+	if v := framer.header.version.version(); v != c.version {
+		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
 	}
 
+	return framer, nil
+}
+
+func (c *Conn) getResp(ctx context.Context, call *callReq, timeoutCh <-chan time.Time) error {
 	var ctxDone <-chan struct{}
 	if ctx != nil {
 		ctxDone = ctx.Done()
@@ -641,34 +683,42 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
 				// this is because the request is still outstanding and we have
 				// been handed another error from another stream which caused the
 				// connection to close.
-				c.releaseStream(stream)
+				c.releaseStream(call.streamID)
 			}
-			return nil, err
+			return err
 		}
+		return nil
 	case <-timeoutCh:
 		close(call.timeout)
 		c.handleTimeout()
-		return nil, ErrTimeoutNoResponse
+		return ErrTimeoutNoResponse
 	case <-ctxDone:
 		close(call.timeout)
-		return nil, ctx.Err()
+		return ctx.Err()
 	case <-c.quit:
-		return nil, ErrConnectionClosed
+		return ErrConnectionClosed
 	}
+}
 
-	// dont release the stream if detect a timeout as another request can reuse
-	// that stream and get a response for the old request, which we have no
-	// easy way of detecting.
-	//
-	// Ensure that the stream is not released if there are potentially outstanding
-	// requests on the stream to prevent nil pointer dereferences in recv().
-	defer c.releaseStream(stream)
-
-	if v := framer.header.version.version(); v != c.version {
-		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
+func (c *Conn) sendFrame(ctx context.Context, call *callReq, timeoutCh <-chan time.Time) error {
+	var ctxDone <-chan struct{}
+	if ctx != nil {
+		ctxDone = ctx.Done()
 	}
 
-	return framer, nil
+	select {
+	case c.frameWriteArgChan <- call:
+		return nil
+	case <-timeoutCh:
+		close(call.timeout)
+		c.handleTimeout()
+		return ErrTimeoutNoResponse
+	case <-ctxDone:
+		close(call.timeout)
+		return ctx.Err()
+	case <-c.quit:
+		return ErrConnectionClosed
+	}
 }
 
 type preparedStatment struct {