Browse Source

Merge pull request #417 from Zariel/timeout-queries

apply a timeout for every query at the connection
Ben Hood 10 years ago
parent
commit
5792e6ba59
2 changed files with 84 additions and 30 deletions
  1. 39 3
      conn.go
  2. 45 27
      conn_test.go

+ 39 - 3
conn.go

@@ -15,6 +15,7 @@ import (
 	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"time"
 )
 
@@ -84,6 +85,12 @@ type ConnErrorHandler interface {
 	HandleError(conn *Conn, err error, closed bool)
 }
 
+// How many timeouts we will allow to occur before the connection is closed
+// and restarted. This is to prevent a single query timeout from killing a connection
+// which may be serving more queries just fine.
+// Default is 10, should not be changed concurrently with queries.
+var TimeoutLimit int64 = 10
+
 // Conn is a single connection to a Cassandra node. It can be used to execute
 // queries, but users are usually advised to use a more reliable, higher
 // level API.
@@ -107,6 +114,8 @@ type Conn struct {
 
 	closedMu sync.RWMutex
 	isClosed bool
+
+	timeouts int64
 }
 
 // Connect establishes a connection to a Cassandra node.
@@ -298,7 +307,16 @@ func (c *Conn) serve() {
 		}
 	}
 
+	c.closeWithError(err)
+}
+
+func (c *Conn) closeWithError(err error) {
+	if c.Closed() {
+		return
+	}
+
 	c.Close()
+
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
 		// we need to send the error to all waiting queries, put the state
@@ -339,7 +357,11 @@ func (c *Conn) recv() error {
 
 	// once we get to here we know that the caller must be waiting and that there
 	// is no error.
-	call.resp <- nil
+	select {
+	case call.resp <- nil:
+	default:
+		// in case the caller timedout
+	}
 
 	return nil
 }
@@ -357,6 +379,12 @@ func (c *Conn) releaseStream(stream int) {
 	}
 }
 
+func (c *Conn) handleTimeout() {
+	if atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit {
+		c.closeWithError(ErrTooManyTimeouts)
+	}
+}
+
 func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 	// TODO: move tracer onto conn
 	stream := <-c.uniq
@@ -376,7 +404,13 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		return nil, err
 	}
 
-	err = <-call.resp
+	select {
+	case err = <-call.resp:
+	case <-time.After(c.timeout):
+		c.handleTimeout()
+		return nil, ErrTimeoutNoResponse
+	}
+
 	if err != nil {
 		return nil, err
 	}
@@ -730,5 +764,7 @@ type inflightPrepare struct {
 }
 
 var (
-	ErrQueryArgLength = errors.New("query argument length mismatch")
+	ErrQueryArgLength    = errors.New("query argument length mismatch")
+	ErrTimeoutNoResponse = errors.New("gocql: no response recieved from cassandra within timeout period")
+	ErrTooManyTimeouts   = errors.New("gocql: too many query timeouts on the connection")
 )

+ 45 - 27
conn_test.go

@@ -182,21 +182,6 @@ func TestQueryRetry(t *testing.T) {
 	}
 }
 
-func TestSlowQuery(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
-	defer srv.Stop()
-
-	db, err := newTestSession(srv.Address, defaultProto)
-	if err != nil {
-		t.Errorf("NewCluster: %v", err)
-		return
-	}
-
-	if err := db.Query("slow").Exec(); err != nil {
-		t.Fatal(err)
-	}
-}
-
 func TestSimplePoolRoundRobin(t *testing.T) {
 	servers := make([]*TestServer, 5)
 	addrs := make([]string, len(servers))
@@ -486,6 +471,43 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 	}
 }
 
+func TestQueryTimeout(t *testing.T) {
+	srv := NewTestServer(t, protoVersion2)
+	defer srv.Stop()
+
+	cluster := NewCluster(srv.Address)
+	// Set the timeout arbitrarily low so that the query hits the timeout in a
+	// timely manner.
+	cluster.Timeout = 1 * time.Millisecond
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Errorf("NewCluster: %v", err)
+	}
+	defer db.Close()
+
+	ch := make(chan error, 1)
+
+	go func() {
+		err := db.Query("timeout").Exec()
+		if err != nil {
+			ch <- err
+			return
+		}
+		t.Errorf("err was nil, expected to get a timeout after %v", db.cfg.Timeout)
+	}()
+
+	select {
+	case err := <-ch:
+		if err != ErrTimeoutNoResponse {
+			t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err)
+		}
+	case <-time.After(10*time.Millisecond + db.cfg.Timeout):
+		// ensure that the query goroutines have been scheduled
+		t.Fatalf("query did not timeout after %v", db.cfg.Timeout)
+	}
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
@@ -508,6 +530,7 @@ func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 		t:          t,
 		protocol:   protocol,
 		headerSize: headerSize,
+		quit:       make(chan struct{}),
 	}
 
 	go srv.serve()
@@ -545,6 +568,7 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 		t:          t,
 		protocol:   protocol,
 		headerSize: headerSize,
+		quit:       make(chan struct{}),
 	}
 	go srv.serve()
 	return srv
@@ -560,6 +584,8 @@ type TestServer struct {
 
 	protocol   byte
 	headerSize int
+
+	quit chan struct{}
 }
 
 func (srv *TestServer) serve() {
@@ -592,6 +618,7 @@ func (srv *TestServer) serve() {
 
 func (srv *TestServer) Stop() {
 	srv.listen.Close()
+	close(srv.quit)
 }
 
 func (srv *TestServer) process(f *framer) {
@@ -619,24 +646,15 @@ func (srv *TestServer) process(f *framer) {
 			f.writeHeader(0, opError, head.stream)
 			f.writeInt(0x1001)
 			f.writeString("query killed")
-		case "slow":
-			go func() {
-				<-time.After(1 * time.Second)
-				f.writeHeader(0, opResult, head.stream)
-				f.wbuf[0] = srv.protocol | 0x80
-				f.writeInt(resultKindVoid)
-				if err := f.finishWrite(); err != nil {
-					srv.t.Error(err)
-				}
-			}()
-
-			return
 		case "use":
 			f.writeInt(resultKindKeyspace)
 			f.writeString(strings.TrimSpace(query[3:]))
 		case "void":
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
+		case "timeout":
+			<-srv.quit
+			return
 		default:
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)