Przeglądaj źródła

conn: dont propogate context cancelled in prepare (#1368)

* conn: dont propogate context cancelled in prepare

If a callers context is cancelled and they won the race to prepare a
statement, dont stop the prepare and fail it with context failed error
as other callers may be waiting for us to finish.

Achieve this by replacing the waitgroup used to wait for prepared to
finish with a channel which is closed by the preparer. Do the preparing
in a goroutine and use context.Background in the exec call so that the
prepare will always finish. All callers now use the same flow to wait
for the result of the prepare waiting against their own context so they
can bail out individually.

fixes #1341

* ensure closing a connection kills inflight prepares
Chris Bannister 6 lat temu
rodzic
commit
ae2f7fc85f
5 zmienionych plików z 139 dodań i 122 usunięć
  1. 123 111
      conn.go
  2. 1 1
      conn_test.go
  3. 1 1
      connectionpool.go
  4. 4 4
      control.go
  5. 10 5
      session.go

+ 123 - 111
conn.go

@@ -155,25 +155,26 @@ type Conn struct {
 	session *Session
 
 	closed int32
-	quit   chan struct{}
+	ctx    context.Context
+	cancel context.CancelFunc
 
 	timeouts int64
 }
 
 // connect establishes a connection to a Cassandra node using session's connection config.
-func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
-	return s.dial(host, s.connCfg, errorHandler)
+func (s *Session) connect(ctx context.Context, host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
+	return s.dial(ctx, host, s.connCfg, errorHandler)
 }
 
 // dial establishes a connection to a Cassandra node and notifies the session's connectObserver.
-func (s *Session) dial(host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
+func (s *Session) dial(ctx context.Context, host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
 	var obs ObservedConnect
 	if s.connectObserver != nil {
 		obs.Host = host
 		obs.Start = time.Now()
 	}
 
-	conn, err := s.dialWithoutObserver(host, connConfig, errorHandler)
+	conn, err := s.dialWithoutObserver(ctx, host, connConfig, errorHandler)
 
 	if s.connectObserver != nil {
 		obs.End = time.Now()
@@ -187,7 +188,7 @@ func (s *Session) dial(host *HostInfo, connConfig *ConnConfig, errorHandler Conn
 // dialWithoutObserver establishes connection to a Cassandra node.
 //
 // dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead.
-func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
+func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
 	ip := host.ConnectAddress()
 	port := host.port
 
@@ -198,11 +199,6 @@ func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHand
 		panic(fmt.Sprintf("host missing port: %v", port))
 	}
 
-	var (
-		err  error
-		conn net.Conn
-	)
-
 	dialer := &net.Dialer{
 		Timeout: cfg.ConnectTimeout,
 	}
@@ -210,18 +206,22 @@ func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHand
 		dialer.KeepAlive = cfg.Keepalive
 	}
 
+	conn, err := dialer.DialContext(ctx, "tcp", host.HostnameAndPort())
+	if err != nil {
+		return nil, err
+	}
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not
 		// be modified after being used.
-		conn, err = tls.DialWithDialer(dialer, "tcp", host.HostnameAndPort(), cfg.tlsConfig)
-	} else {
-		conn, err = dialer.Dial("tcp", host.HostnameAndPort())
-	}
-
-	if err != nil {
-		return nil, err
+		tconn := tls.Client(conn, cfg.tlsConfig)
+		if err := tconn.Handshake(); err != nil {
+			conn.Close()
+			return nil, err
+		}
+		conn = tconn
 	}
 
+	ctx, cancel := context.WithCancel(ctx)
 	c := &Conn{
 		conn:          conn,
 		r:             bufio.NewReader(conn),
@@ -231,7 +231,6 @@ func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHand
 		addr:          conn.RemoteAddr().String(),
 		errorHandler:  errorHandler,
 		compressor:    cfg.Compressor,
-		quit:          make(chan struct{}),
 		session:       s,
 		streams:       streams.New(cfg.ProtoVersion),
 		host:          host,
@@ -240,50 +239,51 @@ func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHand
 			w:       conn,
 			timeout: cfg.Timeout,
 		},
+		ctx:    ctx,
+		cancel: cancel,
 	}
 
-	if cfg.AuthProvider != nil {
-		c.auth, err = cfg.AuthProvider(host)
-		if err != nil {
-			return nil, err
-		}
-	} else {
-		c.auth = cfg.Authenticator
+	if err := c.init(ctx); err != nil {
+		cancel()
+		c.Close()
+		return nil, err
 	}
 
-	var (
-		ctx    context.Context
-		cancel func()
-	)
-	if cfg.ConnectTimeout > 0 {
-		ctx, cancel = context.WithTimeout(context.TODO(), cfg.ConnectTimeout)
+	return c, nil
+}
+
+func (c *Conn) init(ctx context.Context) error {
+	if c.session.cfg.AuthProvider != nil {
+		var err error
+		c.auth, err = c.cfg.AuthProvider(c.host)
+		if err != nil {
+			return err
+		}
 	} else {
-		ctx, cancel = context.WithCancel(context.TODO())
+		c.auth = c.cfg.Authenticator
 	}
-	defer cancel()
 
 	startup := &startupCoordinator{
 		frameTicker: make(chan struct{}),
 		conn:        c,
 	}
 
-	c.timeout = cfg.ConnectTimeout
+	c.timeout = c.cfg.ConnectTimeout
 	if err := startup.setupConn(ctx); err != nil {
-		c.close()
-		return nil, err
+		return err
 	}
 
-	c.timeout = cfg.Timeout
+	c.timeout = c.cfg.Timeout
 
 	// dont coalesce startup frames
-	if s.cfg.WriteCoalesceWaitTime > 0 && !cfg.disableCoalesce {
-		c.w = newWriteCoalescer(conn, c.timeout, s.cfg.WriteCoalesceWaitTime, c.quit)
+	if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce {
+		c.w = newWriteCoalescer(c.conn, c.timeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
 	}
 
-	go c.serve()
-	go c.heartBeat()
+	go c.serve(ctx)
+	go c.heartBeat(ctx)
 
-	return c, nil
+	return nil
 }
 
 func (c *Conn) Write(p []byte) (n int, err error) {
@@ -319,10 +319,18 @@ type startupCoordinator struct {
 }
 
 func (s *startupCoordinator) setupConn(ctx context.Context) error {
+	var cancel context.CancelFunc
+	if s.conn.timeout > 0 {
+		ctx, cancel = context.WithTimeout(ctx, s.conn.timeout)
+	} else {
+		ctx, cancel = context.WithCancel(ctx)
+	}
+	defer cancel()
+
 	startupErr := make(chan error)
 	go func() {
 		for range s.frameTicker {
-			err := s.conn.recv()
+			err := s.conn.recv(ctx)
 			if err != nil {
 				select {
 				case startupErr <- err:
@@ -482,7 +490,7 @@ func (c *Conn) closeWithError(err error) {
 	}
 
 	// if error was nil then unblock the quit channel
-	close(c.quit)
+	c.cancel()
 	cerr := c.close()
 
 	if err != nil {
@@ -504,10 +512,10 @@ func (c *Conn) Close() {
 // Serve starts the stream multiplexer for this connection, which is required
 // to execute any queries. This method runs as long as the connection is
 // open and is therefore usually called in a separate goroutine.
-func (c *Conn) serve() {
+func (c *Conn) serve(ctx context.Context) {
 	var err error
 	for err == nil {
-		err = c.recv()
+		err = c.recv(ctx)
 	}
 
 	c.closeWithError(err)
@@ -532,7 +540,7 @@ func (p *protocolError) Error() string {
 	return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame)
 }
 
-func (c *Conn) heartBeat() {
+func (c *Conn) heartBeat(ctx context.Context) {
 	sleepTime := 1 * time.Second
 	timer := time.NewTimer(sleepTime)
 	defer timer.Stop()
@@ -548,7 +556,7 @@ func (c *Conn) heartBeat() {
 		timer.Reset(sleepTime)
 
 		select {
-		case <-c.quit:
+		case <-ctx.Done():
 			return
 		case <-timer.C:
 		}
@@ -579,7 +587,7 @@ func (c *Conn) heartBeat() {
 	}
 }
 
-func (c *Conn) recv() error {
+func (c *Conn) recv(ctx context.Context) error {
 	// not safe for concurrent reads
 
 	// read a full header, ignore timeouts, as this is being ran in a loop
@@ -663,7 +671,7 @@ func (c *Conn) recv() error {
 	case call.resp <- err:
 	case <-call.timeout:
 		c.releaseStream(call)
-	case <-c.quit:
+	case <-ctx.Done():
 	}
 
 	return nil
@@ -919,7 +927,7 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
 	case <-ctxDone:
 		close(call.timeout)
 		return nil, ctx.Err()
-	case <-c.quit:
+	case <-c.ctx.Done():
 		return nil, ErrConnectionClosed
 	}
 
@@ -945,8 +953,8 @@ type preparedStatment struct {
 }
 
 type inflightPrepare struct {
-	wg  sync.WaitGroup
-	err error
+	done chan struct{}
+	err  error
 
 	preparedStatment *preparedStatment
 }
@@ -954,69 +962,76 @@ type inflightPrepare struct {
 func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
 	stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
 	flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
-		flight := new(inflightPrepare)
-		flight.wg.Add(1)
+		flight := &inflightPrepare{
+			done: make(chan struct{}),
+		}
 		lru.Add(stmtCacheKey, flight)
 		return flight
 	})
 
-	if ok {
-		flight.wg.Wait()
-		return flight.preparedStatment, flight.err
-	}
+	if !ok {
+		go func() {
+			defer close(flight.done)
 
-	prep := &writePrepareFrame{
-		statement: stmt,
-	}
-	if c.version > protoVersion4 {
-		prep.keyspace = c.currentKeyspace
-	}
+			prep := &writePrepareFrame{
+				statement: stmt,
+			}
+			if c.version > protoVersion4 {
+				prep.keyspace = c.currentKeyspace
+			}
 
-	framer, err := c.exec(ctx, prep, tracer)
-	if err != nil {
-		flight.err = err
-		flight.wg.Done()
-		c.session.stmtsLRU.remove(stmtCacheKey)
-		return nil, err
-	}
+			// we won the race to do the load, if our context is canceled we shouldnt
+			// stop the load as other callers are waiting for it but this caller should get
+			// their context cancelled error.
+			framer, err := c.exec(c.ctx, prep, tracer)
+			if err != nil {
+				flight.err = err
+				c.session.stmtsLRU.remove(stmtCacheKey)
+				return
+			}
 
-	frame, err := framer.parseFrame()
-	if err != nil {
-		flight.err = err
-		flight.wg.Done()
-		c.session.stmtsLRU.remove(stmtCacheKey)
-		return nil, err
-	}
+			frame, err := framer.parseFrame()
+			if err != nil {
+				flight.err = err
+				c.session.stmtsLRU.remove(stmtCacheKey)
+				return
+			}
 
-	// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
-	// everytime we need to parse a frame.
-	if len(framer.traceID) > 0 && tracer != nil {
-		tracer.Trace(framer.traceID)
-	}
+			// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
+			// everytime we need to parse a frame.
+			if len(framer.traceID) > 0 && tracer != nil {
+				tracer.Trace(framer.traceID)
+			}
 
-	switch x := frame.(type) {
-	case *resultPreparedFrame:
-		flight.preparedStatment = &preparedStatment{
-			// defensively copy as we will recycle the underlying buffer after we
-			// return.
-			id: copyBytes(x.preparedID),
-			// the type info's should _not_ have a reference to the framers read buffer,
-			// therefore we can just copy them directly.
-			request:  x.reqMeta,
-			response: x.respMeta,
-		}
-	case error:
-		flight.err = x
-	default:
-		flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
-	}
-	flight.wg.Done()
+			switch x := frame.(type) {
+			case *resultPreparedFrame:
+				flight.preparedStatment = &preparedStatment{
+					// defensively copy as we will recycle the underlying buffer after we
+					// return.
+					id: copyBytes(x.preparedID),
+					// the type info's should _not_ have a reference to the framers read buffer,
+					// therefore we can just copy them directly.
+					request:  x.reqMeta,
+					response: x.respMeta,
+				}
+			case error:
+				flight.err = x
+			default:
+				flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
+			}
 
-	if flight.err != nil {
-		c.session.stmtsLRU.remove(stmtCacheKey)
+			if flight.err != nil {
+				c.session.stmtsLRU.remove(stmtCacheKey)
+			}
+		}()
 	}
 
-	return flight.preparedStatment, flight.err
+	select {
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	case <-flight.done:
+		return flight.preparedStatment, flight.err
+	}
 }
 
 func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
@@ -1072,11 +1087,8 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
 			return &Iter{err: err}
 		}
 
-		var values []interface{}
-
-		if qry.binding == nil {
-			values = qry.values
-		} else {
+		values := qry.values
+		if qry.binding != nil {
 			values, err = qry.binding(&QueryInfo{
 				Id:          info.id,
 				Args:        info.request.columns,
@@ -1218,7 +1230,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
 	q.params.consistency = Any
 
-	framer, err := c.exec(context.Background(), q, nil)
+	framer, err := c.exec(c.ctx, q, nil)
 	if err != nil {
 		return err
 	}

+ 1 - 1
conn_test.go

@@ -773,7 +773,7 @@ func TestStream0(t *testing.T) {
 		streams: streams.New(protoVersion4),
 	}
 
-	err := conn.recv()
+	err := conn.recv(context.Background())
 	if err == nil {
 		t.Fatal("expected to get an error on stream 0")
 	} else if !strings.HasPrefix(err.Error(), expErr) {

+ 1 - 1
connectionpool.go

@@ -506,7 +506,7 @@ func (pool *hostConnPool) connect() (err error) {
 	var conn *Conn
 	reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
 	for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
-		conn, err = pool.session.connect(pool.host, pool)
+		conn, err = pool.session.connect(pool.session.ctx, pool.host, pool)
 		if err == nil {
 			break
 		}

+ 4 - 4
control.go

@@ -172,7 +172,7 @@ func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
 	var err error
 	for _, host := range shuffled {
 		var conn *Conn
-		conn, err = c.session.dial(host, &cfg, c)
+		conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
 		if err == nil {
 			return conn, nil
 		}
@@ -221,7 +221,7 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
 	var err error
 	for _, host := range hosts {
 		var conn *Conn
-		conn, err = c.session.dial(host, &connCfg, handler)
+		conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
 		if conn != nil {
 			conn.Close()
 		}
@@ -343,7 +343,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 	var newConn *Conn
 	if host != nil {
 		// try to connect to the old host
-		conn, err := c.session.connect(host, c)
+		conn, err := c.session.connect(c.session.ctx, host, c)
 		if err != nil {
 			// host is dead
 			// TODO: this is replicated in a few places
@@ -365,7 +365,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 		}
 
 		var err error
-		newConn, err = c.session.connect(host, c)
+		newConn, err = c.session.connect(c.session.ctx, host, c)
 		if err != nil {
 			// TODO: add log handler for things like this
 			return

+ 10 - 5
session.go

@@ -68,7 +68,8 @@ type Session struct {
 
 	cfg ClusterConfig
 
-	quit chan struct{}
+	ctx    context.Context
+	cancel context.CancelFunc
 
 	closeMu  sync.RWMutex
 	isClosed bool
@@ -113,14 +114,18 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		return nil, errors.New("Can't use both Authenticator and AuthProvider in cluster config.")
 	}
 
+	// TODO: we should take a context in here at some point
+	ctx, cancel := context.WithCancel(context.TODO())
+
 	s := &Session{
 		cons:            cfg.Consistency,
 		prefetch:        0.25,
 		cfg:             cfg,
 		pageSize:        cfg.PageSize,
 		stmtsLRU:        &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)},
-		quit:            make(chan struct{}),
 		connectObserver: cfg.ConnectObserver,
+		ctx:             ctx,
+		cancel:          cancel,
 	}
 
 	s.schemaDescriber = newSchemaDescriber(s)
@@ -302,7 +307,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
 				}
 				s.handleNodeUp(h.ConnectAddress(), h.Port(), true)
 			}
-		case <-s.quit:
+		case <-s.ctx.Done():
 			return
 		}
 	}
@@ -405,8 +410,8 @@ func (s *Session) Close() {
 		s.schemaEvents.stop()
 	}
 
-	if s.quit != nil {
-		close(s.quit)
+	if s.cancel != nil {
+		s.cancel()
 	}
 }