Selaa lähdekoodia

apply a timeout for every query at the connection

Start a timeout when waiting for the response from cassandra
in the call to exec() so that if no header is read the query
will not hang for ever.

If there are more than a limit of timeouts then the connection will
be closed. This may need tuning in the future if there are queries
which take a long time.

fixes #416
Chris Bannister 10 vuotta sitten
vanhempi
commit
45865a5076
2 muutettua tiedostoa jossa 83 lisäystä ja 3 poistoa
  1. 38 3
      conn.go
  2. 45 0
      conn_test.go

+ 38 - 3
conn.go

@@ -15,6 +15,7 @@ import (
 	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"time"
 )
 
@@ -82,6 +83,11 @@ type ConnErrorHandler interface {
 	HandleError(conn *Conn, err error, closed bool)
 }
 
+// How many timeouts we will allow to occur before the connection is closed
+// and restarted. This is to prevent a single query timeout from killing a connection
+// which may be serving more queries just fine.
+const timeoutLimit = 10
+
 // Conn is a single connection to a Cassandra node. It can be used to execute
 // queries, but users are usually advised to use a more reliable, higher
 // level API.
@@ -105,6 +111,8 @@ type Conn struct {
 
 	closedMu sync.RWMutex
 	isClosed bool
+
+	timeouts int64
 }
 
 // Connect establishes a connection to a Cassandra node.
@@ -296,7 +304,16 @@ func (c *Conn) serve() {
 		}
 	}
 
+	c.closeWithError(err)
+}
+
+func (c *Conn) closeWithError(err error) {
+	if c.Closed() {
+		return
+	}
+
 	c.Close()
+
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
 		// we need to send the error to all waiting queries, put the state
@@ -337,7 +354,11 @@ func (c *Conn) recv() error {
 
 	// once we get to here we know that the caller must be waiting and that there
 	// is no error.
-	call.resp <- nil
+	select {
+	case call.resp <- nil:
+	default:
+		// in case the caller timedout
+	}
 
 	return nil
 }
@@ -355,6 +376,12 @@ func (c *Conn) releaseStream(stream int) {
 	}
 }
 
+func (c *Conn) handleTimeout() {
+	if atomic.AddInt64(&c.timeouts, 1) > timeoutLimit {
+		c.closeWithError(ErrTooManyTimeouts)
+	}
+}
+
 func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 	// TODO: move tracer onto conn
 	stream := <-c.uniq
@@ -374,7 +401,13 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		return nil, err
 	}
 
-	err = <-call.resp
+	select {
+	case err = <-call.resp:
+	case <-time.After(c.timeout):
+		c.handleTimeout()
+		return nil, ErrTimeoutNoResponse
+	}
+
 	if err != nil {
 		return nil, err
 	}
@@ -728,5 +761,7 @@ type inflightPrepare struct {
 }
 
 var (
-	ErrQueryArgLength = errors.New("query argument length mismatch")
+	ErrQueryArgLength    = errors.New("query argument length mismatch")
+	ErrTimeoutNoResponse = errors.New("gocql: no response recieved from cassandra within timeout period")
+	ErrTooManyTimeouts   = errors.New("gocql: too many query timeouts on the connection")
 )

+ 45 - 0
conn_test.go

@@ -486,6 +486,43 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 	}
 }
 
+func TestQueryTimeout(t *testing.T) {
+	srv := NewTestServer(t, protoVersion2)
+	defer srv.Stop()
+
+	cluster := NewCluster(srv.Address)
+	// Set the timeout arbitrarily low so that the query hits the timeout in a
+	// timely manner.
+	cluster.Timeout = 1 * time.Millisecond
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Errorf("NewCluster: %v", err)
+	}
+	defer db.Close()
+
+	ch := make(chan error, 1)
+
+	go func() {
+		err := db.Query("timeout").Exec()
+		if err != nil {
+			ch <- err
+			return
+		}
+		t.Errorf("err was nil, expected to get a timeout after %v", db.cfg.Timeout)
+	}()
+
+	select {
+	case err := <-ch:
+		if err != ErrTimeoutNoResponse {
+			t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err)
+		}
+	case <-time.After(10*time.Millisecond + db.cfg.Timeout):
+		// ensure that the query goroutines have been scheduled
+		t.Fatalf("query did not timeout after %v", db.cfg.Timeout)
+	}
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
@@ -508,6 +545,7 @@ func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 		t:          t,
 		protocol:   protocol,
 		headerSize: headerSize,
+		quit:       make(chan struct{}),
 	}
 
 	go srv.serve()
@@ -545,6 +583,7 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 		t:          t,
 		protocol:   protocol,
 		headerSize: headerSize,
+		quit:       make(chan struct{}),
 	}
 	go srv.serve()
 	return srv
@@ -560,6 +599,8 @@ type TestServer struct {
 
 	protocol   byte
 	headerSize int
+
+	quit chan struct{}
 }
 
 func (srv *TestServer) serve() {
@@ -592,6 +633,7 @@ func (srv *TestServer) serve() {
 
 func (srv *TestServer) Stop() {
 	srv.listen.Close()
+	close(srv.quit)
 }
 
 func (srv *TestServer) process(f *framer) {
@@ -637,6 +679,9 @@ func (srv *TestServer) process(f *framer) {
 		case "void":
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
+		case "timeout":
+			<-srv.quit
+			return
 		default:
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)