Browse Source

Merge pull request #440 from Zariel/prevent-duplicate-stream-use

prevent duplicate stream usage
Chris Bannister 10 years ago
parent
commit
211b00de76
3 changed files with 82 additions and 54 deletions
  1. 6 3
      cassandra_test.go
  2. 36 51
      conn.go
  3. 40 0
      stress_test.go

+ 6 - 3
cassandra_test.go

@@ -105,9 +105,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
 	tb.Logf("Created keyspace %s", keyspace)
 }
 
-func createSession(tb testing.TB) *Session {
-	cluster := createCluster()
-
+func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session {
 	// Drop and re-create the keyspace once. Different tests should use their own
 	// individual tables, but can assume that the table does not exist before.
 	initOnce.Do(func() {
@@ -123,6 +121,11 @@ func createSession(tb testing.TB) *Session {
 	return session
 }
 
+func createSession(tb testing.TB) *Session {
+	cluster := createCluster()
+	return createSessionFromCluster(cluster, tb)
+}
+
 // TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections
 func TestAuthentication(t *testing.T) {
 

+ 36 - 51
conn.go

@@ -293,6 +293,40 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 	}
 }
 
+func (c *Conn) closeWithError(err error) {
+	if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
+		return
+	}
+
+	if err != nil {
+		// we should attempt to deliver the error back to the caller if it
+		// exists
+		for id := 0; id < len(c.calls); id++ {
+			req := &c.calls[id]
+			// we need to send the error to all waiting queries, put the state
+			// of this conn into not active so that it can not execute any queries.
+			if err != nil {
+				select {
+				case req.resp <- err:
+				default:
+				}
+			}
+		}
+	}
+
+	// if error was nil then unblock the quit channel
+	close(c.quit)
+	c.conn.Close()
+
+	if c.started && err != nil {
+		c.errorHandler.HandleError(c, err, true)
+	}
+}
+
+func (c *Conn) Close() {
+	c.closeWithError(nil)
+}
+
 // Serve starts the stream multiplexer for this connection, which is required
 // to execute any queries. This method runs as long as the connection is
 // open and is therefore usually called in a separate goroutine.
@@ -336,13 +370,6 @@ func (c *Conn) recv() error {
 		}
 	}
 
-	if !atomic.CompareAndSwapInt32(&call.waiting, 1, 0) {
-		// the waiting thread timed out and is no longer waiting, the stream has
-		// not yet been readded to the chan so it cant be used again,
-		c.releaseStream(head.stream)
-		return nil
-	}
-
 	// we either, return a response to the caller, the caller timedout, or the
 	// connection has closed. Either way we should never block indefinatly here
 	select {
@@ -359,7 +386,6 @@ type callReq struct {
 	// could use a waitgroup but this allows us to do timeouts on the read/send
 	resp    chan error
 	framer  *framer
-	waiting int32
 	timeout chan struct{} // indicates to recv() that a call has timedout
 }
 
@@ -370,7 +396,7 @@ func (c *Conn) releaseStream(stream int) {
 
 	select {
 	case c.uniq <- stream:
-	default:
+	case <-c.quit:
 	}
 }
 
@@ -389,9 +415,9 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		return nil, ErrConnectionClosed
 	}
 
-	call := &c.calls[stream]
 	// resp is basically a waiting semaphore protecting the framer
 	framer := newFramer(c, c, c.compressor, c.version)
+	call := &c.calls[stream]
 	call.framer = framer
 	call.timeout = make(chan struct{})
 
@@ -399,11 +425,6 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		framer.trace()
 	}
 
-	if !atomic.CompareAndSwapInt32(&call.waiting, 0, 1) {
-		return nil, errors.New("gocql: stream is busy or closed")
-	}
-	defer atomic.StoreInt32(&call.waiting, 0)
-
 	err := req.writeFrame(framer, stream)
 	if err != nil {
 		return nil, err
@@ -617,42 +638,6 @@ func (c *Conn) Closed() bool {
 	return atomic.LoadInt32(&c.closed) == 1
 }
 
-func (c *Conn) closeWithError(err error) {
-	if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
-		return
-	}
-
-	if err != nil {
-		// we should attempt to deliver the error back to the caller if it
-		// exists
-		for id := 0; id < len(c.calls); id++ {
-			req := &c.calls[id]
-			// we need to send the error to all waiting queries, put the state
-			// of this conn into not active so that it can not execute any queries.
-			atomic.StoreInt32(&req.waiting, -1)
-
-			if err != nil {
-				select {
-				case req.resp <- err:
-				default:
-				}
-			}
-		}
-	}
-
-	// if error was nil then unblock the quit channel
-	close(c.quit)
-	c.conn.Close()
-
-	if c.started && err != nil {
-		c.errorHandler.HandleError(c, err, true)
-	}
-}
-
-func (c *Conn) Close() {
-	c.closeWithError(nil)
-}
-
 func (c *Conn) Address() string {
 	return c.addr
 }

+ 40 - 0
stress_test.go

@@ -0,0 +1,40 @@
+// +build all integration
+
+package gocql
+
+import (
+	"sync/atomic"
+
+	"testing"
+)
+
+func BenchmarkConnStress(b *testing.B) {
+	const workers = 16
+
+	cluster := createCluster()
+	cluster.NumConns = 1
+	cluster.NumStreams = workers
+	session := createSessionFromCluster(cluster, b)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE IF NOT EXISTS conn_stress (id int primary key)"); err != nil {
+		b.Fatal(err)
+	}
+
+	var seed uint64
+	writer := func(pb *testing.PB) {
+		seed := atomic.AddUint64(&seed, 1)
+		var i uint64 = 0
+		for pb.Next() {
+			if err := session.Query("insert into conn_stress (id) values (?)", i*seed).Exec(); err != nil {
+				b.Error(err)
+				return
+			}
+			i++
+		}
+	}
+
+	b.SetParallelism(workers)
+	b.RunParallel(writer)
+
+}