|
@@ -9,6 +9,7 @@ import (
|
|
|
"crypto/tls"
|
|
"crypto/tls"
|
|
|
"errors"
|
|
"errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
|
|
+ "golang.org/x/net/context"
|
|
|
"io"
|
|
"io"
|
|
|
"io/ioutil"
|
|
"io/ioutil"
|
|
|
"log"
|
|
"log"
|
|
@@ -284,7 +285,7 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
frameTicker <- struct{}{}
|
|
frameTicker <- struct{}{}
|
|
|
- framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
|
|
|
|
|
|
|
+ framer, err := c.exec(context.Background(), &writeStartupFrame{opts: m}, nil)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
@@ -320,7 +321,7 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker c
|
|
|
|
|
|
|
|
for {
|
|
for {
|
|
|
frameTicker <- struct{}{}
|
|
frameTicker <- struct{}{}
|
|
|
- framer, err := c.exec(req, nil)
|
|
|
|
|
|
|
+ framer, err := c.exec(context.Background(), req, nil)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
@@ -531,7 +532,7 @@ type callReq struct {
|
|
|
timer *time.Timer
|
|
timer *time.Timer
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
|
|
|
|
|
|
|
+func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
|
|
|
// TODO: move tracer onto conn
|
|
// TODO: move tracer onto conn
|
|
|
stream, ok := c.streams.GetStream()
|
|
stream, ok := c.streams.GetStream()
|
|
|
if !ok {
|
|
if !ok {
|
|
@@ -593,6 +594,11 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
|
|
|
timeoutCh = call.timer.C
|
|
timeoutCh = call.timer.C
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ var ctxDone <-chan struct{}
|
|
|
|
|
+ if ctx != nil {
|
|
|
|
|
+ ctxDone = ctx.Done()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
select {
|
|
select {
|
|
|
case err := <-call.resp:
|
|
case err := <-call.resp:
|
|
|
close(call.timeout)
|
|
close(call.timeout)
|
|
@@ -610,6 +616,9 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
|
|
|
close(call.timeout)
|
|
close(call.timeout)
|
|
|
c.handleTimeout()
|
|
c.handleTimeout()
|
|
|
return nil, ErrTimeoutNoResponse
|
|
return nil, ErrTimeoutNoResponse
|
|
|
|
|
+ case <-ctxDone:
|
|
|
|
|
+ close(call.timeout)
|
|
|
|
|
+ return nil, ctx.Err()
|
|
|
case <-c.quit:
|
|
case <-c.quit:
|
|
|
return nil, ErrConnectionClosed
|
|
return nil, ErrConnectionClosed
|
|
|
}
|
|
}
|
|
@@ -642,7 +651,7 @@ type inflightPrepare struct {
|
|
|
preparedStatment *preparedStatment
|
|
preparedStatment *preparedStatment
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment, error) {
|
|
|
|
|
|
|
+func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
|
|
|
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
|
|
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
|
|
|
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
|
|
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
|
|
|
flight := new(inflightPrepare)
|
|
flight := new(inflightPrepare)
|
|
@@ -660,7 +669,7 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment,
|
|
|
statement: stmt,
|
|
statement: stmt,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- framer, err := c.exec(prep, tracer)
|
|
|
|
|
|
|
+ framer, err := c.exec(ctx, prep, tracer)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
flight.err = err
|
|
flight.err = err
|
|
|
flight.wg.Done()
|
|
flight.wg.Done()
|
|
@@ -732,7 +741,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
|
|
|
if qry.shouldPrepare() {
|
|
if qry.shouldPrepare() {
|
|
|
// Prepare all DML queries. Other queries can not be prepared.
|
|
// Prepare all DML queries. Other queries can not be prepared.
|
|
|
var err error
|
|
var err error
|
|
|
- info, err = c.prepareStatement(qry.stmt, qry.trace)
|
|
|
|
|
|
|
+ info, err = c.prepareStatement(qry.context, qry.stmt, qry.trace)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return &Iter{err: err}
|
|
return &Iter{err: err}
|
|
|
}
|
|
}
|
|
@@ -783,7 +792,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- framer, err := c.exec(frame, qry.trace)
|
|
|
|
|
|
|
+ framer, err := c.exec(qry.context, frame, qry.trace)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return &Iter{err: err}
|
|
return &Iter{err: err}
|
|
|
}
|
|
}
|
|
@@ -883,7 +892,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
|
|
|
q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
|
|
q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
|
|
|
q.params.consistency = Any
|
|
q.params.consistency = Any
|
|
|
|
|
|
|
|
- framer, err := c.exec(q, nil)
|
|
|
|
|
|
|
+ framer, err := c.exec(context.Background(), q, nil)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
@@ -926,7 +935,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
|
|
|
entry := &batch.Entries[i]
|
|
entry := &batch.Entries[i]
|
|
|
b := &req.statements[i]
|
|
b := &req.statements[i]
|
|
|
if len(entry.Args) > 0 || entry.binding != nil {
|
|
if len(entry.Args) > 0 || entry.binding != nil {
|
|
|
- info, err := c.prepareStatement(entry.Stmt, nil)
|
|
|
|
|
|
|
+ info, err := c.prepareStatement(batch.context, entry.Stmt, nil)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return &Iter{err: err}
|
|
return &Iter{err: err}
|
|
|
}
|
|
}
|
|
@@ -970,7 +979,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// TODO: should batch support tracing?
|
|
// TODO: should batch support tracing?
|
|
|
- framer, err := c.exec(req, nil)
|
|
|
|
|
|
|
+ framer, err := c.exec(batch.context, req, nil)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return &Iter{err: err}
|
|
return &Iter{err: err}
|
|
|
}
|
|
}
|