|
|
@@ -155,25 +155,26 @@ type Conn struct {
|
|
|
session *Session
|
|
|
|
|
|
closed int32
|
|
|
- quit chan struct{}
|
|
|
+ ctx context.Context
|
|
|
+ cancel context.CancelFunc
|
|
|
|
|
|
timeouts int64
|
|
|
}
|
|
|
|
|
|
// connect establishes a connection to a Cassandra node using session's connection config.
|
|
|
-func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
|
|
|
- return s.dial(host, s.connCfg, errorHandler)
|
|
|
+func (s *Session) connect(ctx context.Context, host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
|
|
|
+ return s.dial(ctx, host, s.connCfg, errorHandler)
|
|
|
}
|
|
|
|
|
|
// dial establishes a connection to a Cassandra node and notifies the session's connectObserver.
|
|
|
-func (s *Session) dial(host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
|
|
|
+func (s *Session) dial(ctx context.Context, host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
|
|
|
var obs ObservedConnect
|
|
|
if s.connectObserver != nil {
|
|
|
obs.Host = host
|
|
|
obs.Start = time.Now()
|
|
|
}
|
|
|
|
|
|
- conn, err := s.dialWithoutObserver(host, connConfig, errorHandler)
|
|
|
+ conn, err := s.dialWithoutObserver(ctx, host, connConfig, errorHandler)
|
|
|
|
|
|
if s.connectObserver != nil {
|
|
|
obs.End = time.Now()
|
|
|
@@ -187,7 +188,7 @@ func (s *Session) dial(host *HostInfo, connConfig *ConnConfig, errorHandler Conn
|
|
|
// dialWithoutObserver establishes connection to a Cassandra node.
|
|
|
//
|
|
|
// dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead.
|
|
|
-func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
|
|
|
+func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
|
|
|
ip := host.ConnectAddress()
|
|
|
port := host.port
|
|
|
|
|
|
@@ -198,11 +199,6 @@ func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHand
|
|
|
panic(fmt.Sprintf("host missing port: %v", port))
|
|
|
}
|
|
|
|
|
|
- var (
|
|
|
- err error
|
|
|
- conn net.Conn
|
|
|
- )
|
|
|
-
|
|
|
dialer := &net.Dialer{
|
|
|
Timeout: cfg.ConnectTimeout,
|
|
|
}
|
|
|
@@ -210,18 +206,22 @@ func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHand
|
|
|
dialer.KeepAlive = cfg.Keepalive
|
|
|
}
|
|
|
|
|
|
+ conn, err := dialer.DialContext(ctx, "tcp", host.HostnameAndPort())
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
if cfg.tlsConfig != nil {
|
|
|
// the TLS config is safe to be reused by connections but it must not
|
|
|
// be modified after being used.
|
|
|
- conn, err = tls.DialWithDialer(dialer, "tcp", host.HostnameAndPort(), cfg.tlsConfig)
|
|
|
- } else {
|
|
|
- conn, err = dialer.Dial("tcp", host.HostnameAndPort())
|
|
|
- }
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+ tconn := tls.Client(conn, cfg.tlsConfig)
|
|
|
+ if err := tconn.Handshake(); err != nil {
|
|
|
+ conn.Close()
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ conn = tconn
|
|
|
}
|
|
|
|
|
|
+ ctx, cancel := context.WithCancel(ctx)
|
|
|
c := &Conn{
|
|
|
conn: conn,
|
|
|
r: bufio.NewReader(conn),
|
|
|
@@ -231,7 +231,6 @@ func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHand
|
|
|
addr: conn.RemoteAddr().String(),
|
|
|
errorHandler: errorHandler,
|
|
|
compressor: cfg.Compressor,
|
|
|
- quit: make(chan struct{}),
|
|
|
session: s,
|
|
|
streams: streams.New(cfg.ProtoVersion),
|
|
|
host: host,
|
|
|
@@ -240,50 +239,51 @@ func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHand
|
|
|
w: conn,
|
|
|
timeout: cfg.Timeout,
|
|
|
},
|
|
|
+ ctx: ctx,
|
|
|
+ cancel: cancel,
|
|
|
}
|
|
|
|
|
|
- if cfg.AuthProvider != nil {
|
|
|
- c.auth, err = cfg.AuthProvider(host)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- } else {
|
|
|
- c.auth = cfg.Authenticator
|
|
|
+ if err := c.init(ctx); err != nil {
|
|
|
+ cancel()
|
|
|
+ c.Close()
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- var (
|
|
|
- ctx context.Context
|
|
|
- cancel func()
|
|
|
- )
|
|
|
- if cfg.ConnectTimeout > 0 {
|
|
|
- ctx, cancel = context.WithTimeout(context.TODO(), cfg.ConnectTimeout)
|
|
|
+ return c, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Conn) init(ctx context.Context) error {
|
|
|
+ if c.session.cfg.AuthProvider != nil {
|
|
|
+ var err error
|
|
|
+ c.auth, err = c.cfg.AuthProvider(c.host)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
} else {
|
|
|
- ctx, cancel = context.WithCancel(context.TODO())
|
|
|
+ c.auth = c.cfg.Authenticator
|
|
|
}
|
|
|
- defer cancel()
|
|
|
|
|
|
startup := &startupCoordinator{
|
|
|
frameTicker: make(chan struct{}),
|
|
|
conn: c,
|
|
|
}
|
|
|
|
|
|
- c.timeout = cfg.ConnectTimeout
|
|
|
+ c.timeout = c.cfg.ConnectTimeout
|
|
|
if err := startup.setupConn(ctx); err != nil {
|
|
|
- c.close()
|
|
|
- return nil, err
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
- c.timeout = cfg.Timeout
|
|
|
+ c.timeout = c.cfg.Timeout
|
|
|
|
|
|
// dont coalesce startup frames
|
|
|
- if s.cfg.WriteCoalesceWaitTime > 0 && !cfg.disableCoalesce {
|
|
|
- c.w = newWriteCoalescer(conn, c.timeout, s.cfg.WriteCoalesceWaitTime, c.quit)
|
|
|
+ if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce {
|
|
|
+ c.w = newWriteCoalescer(c.conn, c.timeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
|
|
|
}
|
|
|
|
|
|
- go c.serve()
|
|
|
- go c.heartBeat()
|
|
|
+ go c.serve(ctx)
|
|
|
+ go c.heartBeat(ctx)
|
|
|
|
|
|
- return c, nil
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (c *Conn) Write(p []byte) (n int, err error) {
|
|
|
@@ -319,10 +319,18 @@ type startupCoordinator struct {
|
|
|
}
|
|
|
|
|
|
func (s *startupCoordinator) setupConn(ctx context.Context) error {
|
|
|
+ var cancel context.CancelFunc
|
|
|
+ if s.conn.timeout > 0 {
|
|
|
+ ctx, cancel = context.WithTimeout(ctx, s.conn.timeout)
|
|
|
+ } else {
|
|
|
+ ctx, cancel = context.WithCancel(ctx)
|
|
|
+ }
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
startupErr := make(chan error)
|
|
|
go func() {
|
|
|
for range s.frameTicker {
|
|
|
- err := s.conn.recv()
|
|
|
+ err := s.conn.recv(ctx)
|
|
|
if err != nil {
|
|
|
select {
|
|
|
case startupErr <- err:
|
|
|
@@ -482,7 +490,7 @@ func (c *Conn) closeWithError(err error) {
|
|
|
}
|
|
|
|
|
|
// if error was nil then unblock the quit channel
|
|
|
- close(c.quit)
|
|
|
+ c.cancel()
|
|
|
cerr := c.close()
|
|
|
|
|
|
if err != nil {
|
|
|
@@ -504,10 +512,10 @@ func (c *Conn) Close() {
|
|
|
// Serve starts the stream multiplexer for this connection, which is required
|
|
|
// to execute any queries. This method runs as long as the connection is
|
|
|
// open and is therefore usually called in a separate goroutine.
|
|
|
-func (c *Conn) serve() {
|
|
|
+func (c *Conn) serve(ctx context.Context) {
|
|
|
var err error
|
|
|
for err == nil {
|
|
|
- err = c.recv()
|
|
|
+ err = c.recv(ctx)
|
|
|
}
|
|
|
|
|
|
c.closeWithError(err)
|
|
|
@@ -532,7 +540,7 @@ func (p *protocolError) Error() string {
|
|
|
return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame)
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) heartBeat() {
|
|
|
+func (c *Conn) heartBeat(ctx context.Context) {
|
|
|
sleepTime := 1 * time.Second
|
|
|
timer := time.NewTimer(sleepTime)
|
|
|
defer timer.Stop()
|
|
|
@@ -548,7 +556,7 @@ func (c *Conn) heartBeat() {
|
|
|
timer.Reset(sleepTime)
|
|
|
|
|
|
select {
|
|
|
- case <-c.quit:
|
|
|
+ case <-ctx.Done():
|
|
|
return
|
|
|
case <-timer.C:
|
|
|
}
|
|
|
@@ -579,7 +587,7 @@ func (c *Conn) heartBeat() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) recv() error {
|
|
|
+func (c *Conn) recv(ctx context.Context) error {
|
|
|
// not safe for concurrent reads
|
|
|
|
|
|
// read a full header, ignore timeouts, as this is being ran in a loop
|
|
|
@@ -663,7 +671,7 @@ func (c *Conn) recv() error {
|
|
|
case call.resp <- err:
|
|
|
case <-call.timeout:
|
|
|
c.releaseStream(call)
|
|
|
- case <-c.quit:
|
|
|
+ case <-ctx.Done():
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
@@ -919,7 +927,7 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
|
|
|
case <-ctxDone:
|
|
|
close(call.timeout)
|
|
|
return nil, ctx.Err()
|
|
|
- case <-c.quit:
|
|
|
+ case <-c.ctx.Done():
|
|
|
return nil, ErrConnectionClosed
|
|
|
}
|
|
|
|
|
|
@@ -945,8 +953,8 @@ type preparedStatment struct {
|
|
|
}
|
|
|
|
|
|
type inflightPrepare struct {
|
|
|
- wg sync.WaitGroup
|
|
|
- err error
|
|
|
+ done chan struct{}
|
|
|
+ err error
|
|
|
|
|
|
preparedStatment *preparedStatment
|
|
|
}
|
|
|
@@ -954,69 +962,76 @@ type inflightPrepare struct {
|
|
|
func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
|
|
|
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
|
|
|
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
|
|
|
- flight := new(inflightPrepare)
|
|
|
- flight.wg.Add(1)
|
|
|
+ flight := &inflightPrepare{
|
|
|
+ done: make(chan struct{}),
|
|
|
+ }
|
|
|
lru.Add(stmtCacheKey, flight)
|
|
|
return flight
|
|
|
})
|
|
|
|
|
|
- if ok {
|
|
|
- flight.wg.Wait()
|
|
|
- return flight.preparedStatment, flight.err
|
|
|
- }
|
|
|
+ if !ok {
|
|
|
+ go func() {
|
|
|
+ defer close(flight.done)
|
|
|
|
|
|
- prep := &writePrepareFrame{
|
|
|
- statement: stmt,
|
|
|
- }
|
|
|
- if c.version > protoVersion4 {
|
|
|
- prep.keyspace = c.currentKeyspace
|
|
|
- }
|
|
|
+ prep := &writePrepareFrame{
|
|
|
+ statement: stmt,
|
|
|
+ }
|
|
|
+ if c.version > protoVersion4 {
|
|
|
+ prep.keyspace = c.currentKeyspace
|
|
|
+ }
|
|
|
|
|
|
- framer, err := c.exec(ctx, prep, tracer)
|
|
|
- if err != nil {
|
|
|
- flight.err = err
|
|
|
- flight.wg.Done()
|
|
|
- c.session.stmtsLRU.remove(stmtCacheKey)
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ // we won the race to do the load, if our context is canceled we shouldnt
|
|
|
+ // stop the load as other callers are waiting for it but this caller should get
|
|
|
+ // their context cancelled error.
|
|
|
+ framer, err := c.exec(c.ctx, prep, tracer)
|
|
|
+ if err != nil {
|
|
|
+ flight.err = err
|
|
|
+ c.session.stmtsLRU.remove(stmtCacheKey)
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- frame, err := framer.parseFrame()
|
|
|
- if err != nil {
|
|
|
- flight.err = err
|
|
|
- flight.wg.Done()
|
|
|
- c.session.stmtsLRU.remove(stmtCacheKey)
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ frame, err := framer.parseFrame()
|
|
|
+ if err != nil {
|
|
|
+ flight.err = err
|
|
|
+ c.session.stmtsLRU.remove(stmtCacheKey)
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
|
|
|
- // everytime we need to parse a frame.
|
|
|
- if len(framer.traceID) > 0 && tracer != nil {
|
|
|
- tracer.Trace(framer.traceID)
|
|
|
- }
|
|
|
+ // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
|
|
|
+ // everytime we need to parse a frame.
|
|
|
+ if len(framer.traceID) > 0 && tracer != nil {
|
|
|
+ tracer.Trace(framer.traceID)
|
|
|
+ }
|
|
|
|
|
|
- switch x := frame.(type) {
|
|
|
- case *resultPreparedFrame:
|
|
|
- flight.preparedStatment = &preparedStatment{
|
|
|
- // defensively copy as we will recycle the underlying buffer after we
|
|
|
- // return.
|
|
|
- id: copyBytes(x.preparedID),
|
|
|
- // the type info's should _not_ have a reference to the framers read buffer,
|
|
|
- // therefore we can just copy them directly.
|
|
|
- request: x.reqMeta,
|
|
|
- response: x.respMeta,
|
|
|
- }
|
|
|
- case error:
|
|
|
- flight.err = x
|
|
|
- default:
|
|
|
- flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
|
|
|
- }
|
|
|
- flight.wg.Done()
|
|
|
+ switch x := frame.(type) {
|
|
|
+ case *resultPreparedFrame:
|
|
|
+ flight.preparedStatment = &preparedStatment{
|
|
|
+ // defensively copy as we will recycle the underlying buffer after we
|
|
|
+ // return.
|
|
|
+ id: copyBytes(x.preparedID),
|
|
|
+ // the type info's should _not_ have a reference to the framers read buffer,
|
|
|
+ // therefore we can just copy them directly.
|
|
|
+ request: x.reqMeta,
|
|
|
+ response: x.respMeta,
|
|
|
+ }
|
|
|
+ case error:
|
|
|
+ flight.err = x
|
|
|
+ default:
|
|
|
+ flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
|
|
|
+ }
|
|
|
|
|
|
- if flight.err != nil {
|
|
|
- c.session.stmtsLRU.remove(stmtCacheKey)
|
|
|
+ if flight.err != nil {
|
|
|
+ c.session.stmtsLRU.remove(stmtCacheKey)
|
|
|
+ }
|
|
|
+ }()
|
|
|
}
|
|
|
|
|
|
- return flight.preparedStatment, flight.err
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return nil, ctx.Err()
|
|
|
+ case <-flight.done:
|
|
|
+ return flight.preparedStatment, flight.err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
|
|
|
@@ -1072,11 +1087,8 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
|
|
|
return &Iter{err: err}
|
|
|
}
|
|
|
|
|
|
- var values []interface{}
|
|
|
-
|
|
|
- if qry.binding == nil {
|
|
|
- values = qry.values
|
|
|
- } else {
|
|
|
+ values := qry.values
|
|
|
+ if qry.binding != nil {
|
|
|
values, err = qry.binding(&QueryInfo{
|
|
|
Id: info.id,
|
|
|
Args: info.request.columns,
|
|
|
@@ -1218,7 +1230,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
|
|
|
q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
|
|
|
q.params.consistency = Any
|
|
|
|
|
|
- framer, err := c.exec(context.Background(), q, nil)
|
|
|
+ framer, err := c.exec(c.ctx, q, nil)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|