Browse Source

Merge pull request #429 from Zariel/reinstate-query-timeouts

Reinstate query timeouts
Chris Bannister 10 years ago
parent
commit
a35e624b6f
2 changed files with 99 additions and 26 deletions
  1. 61 24
      conn.go
  2. 38 2
      conn_test.go

+ 61 - 24
conn.go

@@ -113,6 +113,7 @@ type Conn struct {
 	started         bool
 	started         bool
 
 
 	closed int32
 	closed int32
+	quit   chan struct{}
 
 
 	timeouts int64
 	timeouts int64
 }
 }
@@ -171,6 +172,7 @@ func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn,
 		compressor:   cfg.Compressor,
 		compressor:   cfg.Compressor,
 		auth:         cfg.Authenticator,
 		auth:         cfg.Authenticator,
 		headerBuf:    make([]byte, headerSize),
 		headerBuf:    make([]byte, headerSize),
+		quit:         make(chan struct{}),
 	}
 	}
 
 
 	if cfg.Keepalive > 0 {
 	if cfg.Keepalive > 0 {
@@ -178,7 +180,7 @@ func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn,
 	}
 	}
 
 
 	for i := 0; i < cfg.NumStreams; i++ {
 	for i := 0; i < cfg.NumStreams; i++ {
-		c.calls[i].resp = make(chan error, 1)
+		c.calls[i].resp = make(chan error)
 		c.uniq <- i
 		c.uniq <- i
 	}
 	}
 
 
@@ -327,15 +329,27 @@ func (c *Conn) recv() error {
 	call := &c.calls[head.stream]
 	call := &c.calls[head.stream]
 	err = call.framer.readFrame(&head)
 	err = call.framer.readFrame(&head)
 	if err != nil {
 	if err != nil {
-		return err
+		// only net errors should cause the connection to be closed. Though
+		// cassandra returning corrupt frames will be returned here as well.
+		if _, ok := err.(net.Error); ok {
+			return err
+		}
 	}
 	}
 
 
-	// once we get to here we know that the caller must be waiting and that there
-	// is no error.
+	if !atomic.CompareAndSwapInt32(&call.waiting, 1, 0) {
+		// the waiting thread timed out and is no longer waiting, the stream has
+		// not yet been readded to the chan so it cant be used again,
+		c.releaseStream(head.stream)
+		return nil
+	}
+
+	// we either, return a response to the caller, the caller timedout, or the
+	// connection has closed. Either way we should never block indefinatly here
 	select {
 	select {
-	case call.resp <- nil:
-	default:
-		// in case the caller timedout
+	case call.resp <- err:
+	case <-call.timeout:
+		c.releaseStream(head.stream)
+	case <-c.quit:
 	}
 	}
 
 
 	return nil
 	return nil
@@ -343,11 +357,17 @@ func (c *Conn) recv() error {
 
 
 type callReq struct {
 type callReq struct {
 	// could use a waitgroup but this allows us to do timeouts on the read/send
 	// could use a waitgroup but this allows us to do timeouts on the read/send
-	resp   chan error
-	framer *framer
+	resp    chan error
+	framer  *framer
+	waiting int32
+	timeout chan struct{} // indicates to recv() that a call has timedout
 }
 }
 
 
 func (c *Conn) releaseStream(stream int) {
 func (c *Conn) releaseStream(stream int) {
+	call := &c.calls[stream]
+	framerPool.Put(call.framer)
+	call.framer = nil
+
 	select {
 	select {
 	case c.uniq <- stream:
 	case c.uniq <- stream:
 	default:
 	default:
@@ -362,27 +382,49 @@ func (c *Conn) handleTimeout() {
 
 
 func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 	// TODO: move tracer onto conn
 	// TODO: move tracer onto conn
-	stream := <-c.uniq
-	defer c.releaseStream(stream)
+	var stream int
+	select {
+	case stream = <-c.uniq:
+	case <-c.quit:
+		return nil, ErrConnectionClosed
+	}
 
 
 	call := &c.calls[stream]
 	call := &c.calls[stream]
 	// resp is basically a waiting semaphore protecting the framer
 	// resp is basically a waiting semaphore protecting the framer
 	framer := newFramer(c, c, c.compressor, c.version)
 	framer := newFramer(c, c, c.compressor, c.version)
 	call.framer = framer
 	call.framer = framer
+	call.timeout = make(chan struct{})
 
 
 	if tracer != nil {
 	if tracer != nil {
 		framer.trace()
 		framer.trace()
 	}
 	}
 
 
+	if !atomic.CompareAndSwapInt32(&call.waiting, 0, 1) {
+		return nil, errors.New("gocql: stream is busy or closed")
+	}
+	defer atomic.StoreInt32(&call.waiting, 0)
+
 	err := req.writeFrame(framer, stream)
 	err := req.writeFrame(framer, stream)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	err = <-call.resp
+	select {
+	case err := <-call.resp:
+		// dont release the stream if detect a timeout as another request can reuse
+		// that stream and get a response for the old request, which we have no
+		// easy way of detecting.
+		defer c.releaseStream(stream)
 
 
-	if err != nil {
-		return nil, err
+		if err != nil {
+			return nil, err
+		}
+	case <-time.After(c.timeout):
+		close(call.timeout)
+		c.handleTimeout()
+		return nil, ErrTimeoutNoResponse
+	case <-c.quit:
+		return nil, ErrConnectionClosed
 	}
 	}
 
 
 	if v := framer.header.version.version(); v != c.version {
 	if v := framer.header.version.version(); v != c.version {
@@ -398,9 +440,6 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		tracer.Trace(framer.traceID)
 		tracer.Trace(framer.traceID)
 	}
 	}
 
 
-	framerPool.Put(framer)
-	call.framer = nil
-
 	return frame, nil
 	return frame, nil
 }
 }
 
 
@@ -583,16 +622,13 @@ func (c *Conn) closeWithError(err error) {
 		return
 		return
 	}
 	}
 
 
+	close(c.quit)
+
 	for id := 0; id < len(c.calls); id++ {
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
 		req := &c.calls[id]
 		// we need to send the error to all waiting queries, put the state
 		// we need to send the error to all waiting queries, put the state
 		// of this conn into not active so that it can not execute any queries.
 		// of this conn into not active so that it can not execute any queries.
-		select {
-		case req.resp <- err:
-		default:
-		}
-
-		close(req.resp)
+		atomic.StoreInt32(&req.waiting, -1)
 	}
 	}
 
 
 	c.conn.Close()
 	c.conn.Close()
@@ -747,7 +783,8 @@ type inflightPrepare struct {
 }
 }
 
 
 var (
 var (
-	ErrQueryArgLength    = errors.New("query argument length mismatch")
+	ErrQueryArgLength    = errors.New("gocql: query argument length mismatch")
 	ErrTimeoutNoResponse = errors.New("gocql: no response recieved from cassandra within timeout period")
 	ErrTimeoutNoResponse = errors.New("gocql: no response recieved from cassandra within timeout period")
 	ErrTooManyTimeouts   = errors.New("gocql: too many query timeouts on the connection")
 	ErrTooManyTimeouts   = errors.New("gocql: too many query timeouts on the connection")
+	ErrConnectionClosed  = errors.New("gocql: connection closed waiting for response")
 )
 )

+ 38 - 2
conn_test.go

@@ -472,8 +472,7 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 }
 }
 
 
 func TestQueryTimeout(t *testing.T) {
 func TestQueryTimeout(t *testing.T) {
-	t.Skip("skipping until query timeouts are enabled")
-	srv := NewTestServer(t, protoVersion2)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := NewCluster(srv.Address)
 	cluster := NewCluster(srv.Address)
@@ -509,6 +508,31 @@ func TestQueryTimeout(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestQueryTimeoutReuseStream(t *testing.T) {
+	srv := NewTestServer(t, defaultProto)
+	defer srv.Stop()
+
+	cluster := NewCluster(srv.Address)
+	// Set the timeout arbitrarily low so that the query hits the timeout in a
+	// timely manner.
+	cluster.Timeout = 1 * time.Millisecond
+	cluster.NumConns = 1
+	cluster.NumStreams = 1
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatalf("NewCluster: %v", err)
+	}
+	defer db.Close()
+
+	db.Query("slow").Exec()
+
+	err = db.Query("void").Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
@@ -656,6 +680,18 @@ func (srv *TestServer) process(f *framer) {
 		case "timeout":
 		case "timeout":
 			<-srv.quit
 			<-srv.quit
 			return
 			return
+		case "slow":
+			go func() {
+				f.writeHeader(0, opResult, head.stream)
+				f.writeInt(resultKindVoid)
+				f.wbuf[0] = srv.protocol | 0x80
+				select {
+				case <-srv.quit:
+				case <-time.After(50 * time.Millisecond):
+					f.finishWrite()
+				}
+			}()
+			return
 		default:
 		default:
 			f.writeHeader(0, opResult, head.stream)
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
 			f.writeInt(resultKindVoid)