Browse Source

Merge remote-tracking branch 'gocql/master'

mark 9 years ago
parent
commit
6ec2700078
13 changed files with 183 additions and 94 deletions
  1. 1 1
      .travis.yml
  2. 2 0
      AUTHORS
  3. 8 2
      cassandra_test.go
  4. 7 7
      cluster.go
  5. 121 61
      conn_test.go
  6. 16 0
      frame.go
  7. 7 0
      frame_test.go
  8. 1 1
      helpers.go
  9. 1 1
      host_source.go
  10. 12 14
      policies.go
  11. 5 5
      policies_test.go
  12. 1 1
      query_executor.go
  13. 1 1
      session.go

+ 1 - 1
.travis.yml

@@ -25,8 +25,8 @@ env:
       AUTH=false
       AUTH=false
 
 
 go:
 go:
-  - 1.5.3
   - 1.6
   - 1.6
+  - 1.7
 
 
 install:
 install:
   - pip install --user cql PyYAML six
   - pip install --user cql PyYAML six

+ 2 - 0
AUTHORS

@@ -75,3 +75,5 @@ Caleb Doxsey <caleb@datadoghq.com>
 Frederic Hemery <frederic.hemery@datadoghq.com>
 Frederic Hemery <frederic.hemery@datadoghq.com>
 Pekka Enberg <penberg@scylladb.com>
 Pekka Enberg <penberg@scylladb.com>
 Mark M <m.mim95@gmail.com>
 Mark M <m.mim95@gmail.com>
+Bartosz Burclaf <burclaf@gmail.com>
+Marcus King <marcusking01@gmail.com>

+ 8 - 2
cassandra_test.go

@@ -545,9 +545,15 @@ func TestNotEnoughQueryArgs(t *testing.T) {
 // TestCreateSessionTimeout tests to make sure the CreateSession function timeouts out correctly
 // TestCreateSessionTimeout tests to make sure the CreateSession function timeouts out correctly
 // and prevents an infinite loop of connection retries.
 // and prevents an infinite loop of connection retries.
 func TestCreateSessionTimeout(t *testing.T) {
 func TestCreateSessionTimeout(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	go func() {
 	go func() {
-		<-time.After(2 * time.Second)
-		t.Error("no startup timeout")
+		select {
+		case <-time.After(2 * time.Second):
+			t.Error("no startup timeout")
+		case <-ctx.Done():
+		}
 	}()
 	}()
 
 
 	cluster := createCluster()
 	cluster := createCluster()

+ 7 - 7
cluster.go

@@ -10,7 +10,7 @@ import (
 )
 )
 
 
 // PoolConfig configures the connection pool used by the driver, it defaults to
 // PoolConfig configures the connection pool used by the driver, it defaults to
-// using a round robbin host selection policy and a round robbin connection selection
+// using a round-robin host selection policy and a round-robin connection selection
 // policy for each host.
 // policy for each host.
 type PoolConfig struct {
 type PoolConfig struct {
 	// HostSelectionPolicy sets the policy for selecting which host to use for a
 	// HostSelectionPolicy sets the policy for selecting which host to use for a
@@ -23,9 +23,9 @@ func (p PoolConfig) buildPool(session *Session) *policyConnPool {
 }
 }
 
 
 type DiscoveryConfig struct {
 type DiscoveryConfig struct {
-	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
+	// If not empty will filter all discovered hosts to a single Data Centre (default: "")
 	DcFilter string
 	DcFilter string
-	// If not empty will filter all discoverred hosts to a single Rack (default: "")
+	// If not empty will filter all discovered hosts to a single Rack (default: "")
 	RackFilter string
 	RackFilter string
 	// ignored
 	// ignored
 	Sleep time.Duration
 	Sleep time.Duration
@@ -44,8 +44,8 @@ func (d DiscoveryConfig) matchFilter(host *HostInfo) bool {
 }
 }
 
 
 // ClusterConfig is a struct to configure the default cluster implementation
 // ClusterConfig is a struct to configure the default cluster implementation
-// of gocoql. It has a varity of attributes that can be used to modify the
-// behavior to fit the most common use cases. Applications that requre a
+// of gocoql. It has a variety of attributes that can be used to modify the
+// behavior to fit the most common use cases. Applications that require a
 // different setup must implement their own cluster.
 // different setup must implement their own cluster.
 type ClusterConfig struct {
 type ClusterConfig struct {
 	Hosts             []string          // addresses for the initial connections
 	Hosts             []string          // addresses for the initial connections
@@ -79,7 +79,7 @@ type ClusterConfig struct {
 	// receiving a schema change frame. (deault: 60s)
 	// receiving a schema change frame. (deault: 60s)
 	MaxWaitSchemaAgreement time.Duration
 	MaxWaitSchemaAgreement time.Duration
 
 
-	// HostFilter will filter all incoming events for host, any which dont pass
+	// HostFilter will filter all incoming events for host, any which don't pass
 	// the filter will be ignored. If set will take precedence over any options set
 	// the filter will be ignored. If set will take precedence over any options set
 	// via Discovery
 	// via Discovery
 	HostFilter HostFilter
 	HostFilter HostFilter
@@ -113,7 +113,7 @@ type ClusterConfig struct {
 	// DisableSkipMetadata will override the internal result metadata cache so that the driver does not
 	// DisableSkipMetadata will override the internal result metadata cache so that the driver does not
 	// send skip_metadata for queries, this means that the result will always contain
 	// send skip_metadata for queries, this means that the result will always contain
 	// the metadata to parse the rows and will not reuse the metadata from the prepared
 	// the metadata to parse the rows and will not reuse the metadata from the prepared
-	// staement.
+	// statement.
 	//
 	//
 	// See https://issues.apache.org/jira/browse/CASSANDRA-10786
 	// See https://issues.apache.org/jira/browse/CASSANDRA-10786
 	DisableSkipMetadata bool
 	DisableSkipMetadata bool

+ 121 - 61
conn_test.go

@@ -9,7 +9,6 @@ import (
 	"crypto/tls"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509"
 	"fmt"
 	"fmt"
-	"golang.org/x/net/context"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
@@ -18,6 +17,8 @@ import (
 	"sync/atomic"
 	"sync/atomic"
 	"testing"
 	"testing"
 	"time"
 	"time"
+
+	"golang.org/x/net/context"
 )
 )
 
 
 const (
 const (
@@ -59,7 +60,7 @@ func testCluster(addr string, proto protoVersion) *ClusterConfig {
 }
 }
 
 
 func TestSimple(t *testing.T) {
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -74,7 +75,7 @@ func TestSimple(t *testing.T) {
 }
 }
 
 
 func TestSSLSimple(t *testing.T) {
 func TestSSLSimple(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := createTestSslCluster(srv.Address, defaultProto, true).CreateSession()
 	db, err := createTestSslCluster(srv.Address, defaultProto, true).CreateSession()
@@ -88,7 +89,7 @@ func TestSSLSimple(t *testing.T) {
 }
 }
 
 
 func TestSSLSimpleNoClientCert(t *testing.T) {
 func TestSSLSimpleNoClientCert(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := createTestSslCluster(srv.Address, defaultProto, false).CreateSession()
 	db, err := createTestSslCluster(srv.Address, defaultProto, false).CreateSession()
@@ -120,7 +121,7 @@ func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *
 func TestClosed(t *testing.T) {
 func TestClosed(t *testing.T) {
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
 
 
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	session, err := newTestSession(srv.Address, defaultProto)
 	session, err := newTestSession(srv.Address, defaultProto)
@@ -140,7 +141,9 @@ func newTestSession(addr string, proto protoVersion) (*Session, error) {
 }
 }
 
 
 func TestTimeout(t *testing.T) {
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := newTestSession(srv.Address, defaultProto)
 	db, err := newTestSession(srv.Address, defaultProto)
@@ -149,20 +152,34 @@ func TestTimeout(t *testing.T) {
 	}
 	}
 	defer db.Close()
 	defer db.Close()
 
 
+	var wg sync.WaitGroup
+	wg.Add(1)
+
 	go func() {
 	go func() {
-		<-time.After(2 * time.Second)
-		t.Errorf("no timeout")
+		defer wg.Done()
+
+		select {
+		case <-time.After(5 * time.Second):
+			t.Errorf("no timeout")
+		case <-ctx.Done():
+		}
 	}()
 	}()
 
 
-	if err := db.Query("kill").Exec(); err == nil {
-		t.Errorf("expected error")
+	if err := db.Query("kill").WithContext(ctx).Exec(); err == nil {
+		t.Fatal("expected error got nil")
 	}
 	}
+	cancel()
+
+	wg.Wait()
 }
 }
 
 
 // TestQueryRetry will test to make sure that gocql will execute
 // TestQueryRetry will test to make sure that gocql will execute
 // the exact amount of retry queries designated by the user.
 // the exact amount of retry queries designated by the user.
 func TestQueryRetry(t *testing.T) {
 func TestQueryRetry(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := newTestSession(srv.Address, defaultProto)
 	db, err := newTestSession(srv.Address, defaultProto)
@@ -172,9 +189,14 @@ func TestQueryRetry(t *testing.T) {
 	defer db.Close()
 	defer db.Close()
 
 
 	go func() {
 	go func() {
-		<-time.After(5 * time.Second)
-		t.Fatalf("no timeout")
+		select {
+		case <-ctx.Done():
+			return
+		case <-time.After(5 * time.Second):
+			t.Errorf("no timeout")
+		}
 	}()
 	}()
+
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 
 
 	qry := db.Query("kill").RetryPolicy(rt)
 	qry := db.Query("kill").RetryPolicy(rt)
@@ -195,7 +217,7 @@ func TestQueryRetry(t *testing.T) {
 }
 }
 
 
 func TestStreams_Protocol1(t *testing.T) {
 func TestStreams_Protocol1(t *testing.T) {
-	srv := NewTestServer(t, protoVersion1)
+	srv := NewTestServer(t, protoVersion1, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	// TODO: these are more like session tests and should instead operate
 	// TODO: these are more like session tests and should instead operate
@@ -227,7 +249,7 @@ func TestStreams_Protocol1(t *testing.T) {
 }
 }
 
 
 func TestStreams_Protocol3(t *testing.T) {
 func TestStreams_Protocol3(t *testing.T) {
-	srv := NewTestServer(t, protoVersion3)
+	srv := NewTestServer(t, protoVersion3, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	// TODO: these are more like session tests and should instead operate
 	// TODO: these are more like session tests and should instead operate
@@ -254,7 +276,7 @@ func TestStreams_Protocol3(t *testing.T) {
 }
 }
 
 
 func BenchmarkProtocolV3(b *testing.B) {
 func BenchmarkProtocolV3(b *testing.B) {
-	srv := NewTestServer(b, protoVersion3)
+	srv := NewTestServer(b, protoVersion3, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	// TODO: these are more like session tests and should instead operate
 	// TODO: these are more like session tests and should instead operate
@@ -280,7 +302,7 @@ func BenchmarkProtocolV3(b *testing.B) {
 
 
 // This tests that the policy connection pool handles SSL correctly
 // This tests that the policy connection pool handles SSL correctly
 func TestPolicyConnPoolSSL(t *testing.T) {
 func TestPolicyConnPoolSSL(t *testing.T) {
-	srv := NewSSLTestServer(t, defaultProto)
+	srv := NewSSLTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
@@ -305,7 +327,7 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 }
 }
 
 
 func TestQueryTimeout(t *testing.T) {
 func TestQueryTimeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -341,33 +363,8 @@ func TestQueryTimeout(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestQueryTimeoutMany(t *testing.T) {
-	srv := NewTestServer(t, 3)
-	defer srv.Stop()
-
-	cluster := testCluster(srv.Address, 3)
-	// Set the timeout arbitrarily low so that the query hits the timeout in a
-	// timely manner.
-	cluster.Timeout = 5 * time.Millisecond
-	cluster.NumConns = 1
-
-	db, err := cluster.CreateSession()
-	if err != nil {
-		t.Fatalf("NewCluster: %v", err)
-	}
-	defer db.Close()
-
-	for i := 0; i < 128; i++ {
-		err := db.Query("void").Exec()
-		if err != nil {
-			t.Error(err)
-			return
-		}
-	}
-}
-
 func BenchmarkSingleConn(b *testing.B) {
 func BenchmarkSingleConn(b *testing.B) {
-	srv := NewTestServer(b, 3)
+	srv := NewTestServer(b, 3, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, 3)
 	cluster := testCluster(srv.Address, 3)
@@ -398,7 +395,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	// TODO(zariel): move this to conn test, we really just want to check what
 	// TODO(zariel): move this to conn test, we really just want to check what
 	// happens when a conn is
 	// happens when a conn is
 
 
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -422,7 +419,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 }
 }
 
 
 func TestQueryTimeoutClose(t *testing.T) {
 func TestQueryTimeoutClose(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -459,12 +456,20 @@ func TestQueryTimeoutClose(t *testing.T) {
 func TestStream0(t *testing.T) {
 func TestStream0(t *testing.T) {
 	const expErr = "gocql: received frame on stream 0"
 	const expErr = "gocql: received frame on stream 0"
 
 
-	srv := NewTestServer(t, defaultProto)
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 		if !srv.isClosed() && !strings.HasPrefix(err.Error(), expErr) {
 		if !srv.isClosed() && !strings.HasPrefix(err.Error(), expErr) {
-			t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+			select {
+			case <-ctx.Done():
+				return
+			default:
+				t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+			}
 		}
 		}
 	})
 	})
 
 
@@ -480,7 +485,7 @@ func TestStream0(t *testing.T) {
 	})
 	})
 
 
 	// need to write out an invalid frame, which we need a connection to do
 	// need to write out an invalid frame, which we need a connection to do
-	framer, err := conn.exec(context.Background(), writer, nil)
+	framer, err := conn.exec(ctx, writer, nil)
 	if err == nil {
 	if err == nil {
 		t.Fatal("expected to get an error on stream 0")
 		t.Fatal("expected to get an error on stream 0")
 	} else if !strings.HasPrefix(err.Error(), expErr) {
 	} else if !strings.HasPrefix(err.Error(), expErr) {
@@ -498,7 +503,7 @@ func TestConnClosedBlocked(t *testing.T) {
 	// issue 664
 	// issue 664
 	const proto = 3
 	const proto = 3
 
 
-	srv := NewTestServer(t, proto)
+	srv := NewTestServer(t, proto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
 		t.Log(err)
 		t.Log(err)
@@ -522,7 +527,7 @@ func TestConnClosedBlocked(t *testing.T) {
 }
 }
 
 
 func TestContext_Timeout(t *testing.T) {
 func TestContext_Timeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto)
+	srv := NewTestServer(t, defaultProto, context.Background())
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(srv.Address, defaultProto)
 	cluster := testCluster(srv.Address, defaultProto)
@@ -541,7 +546,7 @@ func TestContext_Timeout(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func NewTestServer(t testing.TB, protocol uint8) *TestServer {
+func NewTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
@@ -557,21 +562,24 @@ func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 		headerSize = 9
 		headerSize = 9
 	}
 	}
 
 
+	ctx, cancel := context.WithCancel(ctx)
 	srv := &TestServer{
 	srv := &TestServer{
 		Address:    listen.Addr().String(),
 		Address:    listen.Addr().String(),
 		listen:     listen,
 		listen:     listen,
 		t:          t,
 		t:          t,
 		protocol:   protocol,
 		protocol:   protocol,
 		headerSize: headerSize,
 		headerSize: headerSize,
-		quit:       make(chan struct{}),
+		ctx:        ctx,
+		cancel:     cancel,
 	}
 	}
 
 
+	go srv.closeWatch()
 	go srv.serve()
 	go srv.serve()
 
 
 	return srv
 	return srv
 }
 }
 
 
-func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
+func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer {
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	certPool := x509.NewCertPool()
 	certPool := x509.NewCertPool()
 	if !certPool.AppendCertsFromPEM(pem) {
 	if !certPool.AppendCertsFromPEM(pem) {
@@ -595,14 +603,18 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 		headerSize = 9
 		headerSize = 9
 	}
 	}
 
 
+	ctx, cancel := context.WithCancel(ctx)
 	srv := &TestServer{
 	srv := &TestServer{
 		Address:    listen.Addr().String(),
 		Address:    listen.Addr().String(),
 		listen:     listen,
 		listen:     listen,
 		t:          t,
 		t:          t,
 		protocol:   protocol,
 		protocol:   protocol,
 		headerSize: headerSize,
 		headerSize: headerSize,
-		quit:       make(chan struct{}),
+		ctx:        ctx,
+		cancel:     cancel,
 	}
 	}
+
+	go srv.closeWatch()
 	go srv.serve()
 	go srv.serve()
 	return srv
 	return srv
 }
 }
@@ -617,28 +629,58 @@ type TestServer struct {
 
 
 	protocol   byte
 	protocol   byte
 	headerSize int
 	headerSize int
+	ctx        context.Context
+	cancel     context.CancelFunc
 
 
 	quit   chan struct{}
 	quit   chan struct{}
 	mu     sync.Mutex
 	mu     sync.Mutex
 	closed bool
 	closed bool
 }
 }
 
 
+func (srv *TestServer) closeWatch() {
+	<-srv.ctx.Done()
+
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+
+	srv.closeLocked()
+}
+
 func (srv *TestServer) serve() {
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
 	defer srv.listen.Close()
 	for {
 	for {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		conn, err := srv.listen.Accept()
 		conn, err := srv.listen.Accept()
 		if err != nil {
 		if err != nil {
 			break
 			break
 		}
 		}
+
 		go func(conn net.Conn) {
 		go func(conn net.Conn) {
 			defer conn.Close()
 			defer conn.Close()
 			for {
 			for {
+				select {
+				case <-srv.ctx.Done():
+					return
+				default:
+				}
+
 				framer, err := srv.readFrame(conn)
 				framer, err := srv.readFrame(conn)
 				if err != nil {
 				if err != nil {
 					if err == io.EOF {
 					if err == io.EOF {
 						return
 						return
 					}
 					}
 
 
+					select {
+					case <-srv.ctx.Done():
+						return
+					default:
+					}
+
 					srv.t.Error(err)
 					srv.t.Error(err)
 					return
 					return
 				}
 				}
@@ -657,21 +699,32 @@ func (srv *TestServer) isClosed() bool {
 	return srv.closed
 	return srv.closed
 }
 }
 
 
-func (srv *TestServer) Stop() {
-	srv.mu.Lock()
-	defer srv.mu.Unlock()
+func (srv *TestServer) closeLocked() {
 	if srv.closed {
 	if srv.closed {
 		return
 		return
 	}
 	}
+
 	srv.closed = true
 	srv.closed = true
 
 
 	srv.listen.Close()
 	srv.listen.Close()
-	close(srv.quit)
+	srv.cancel()
+}
+
+func (srv *TestServer) Stop() {
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+	srv.closeLocked()
 }
 }
 
 
 func (srv *TestServer) process(f *framer) {
 func (srv *TestServer) process(f *framer) {
 	head := f.header
 	head := f.header
 	if head == nil {
 	if head == nil {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		srv.t.Error("process frame with a nil header")
 		srv.t.Error("process frame with a nil header")
 		return
 		return
 	}
 	}
@@ -701,7 +754,7 @@ func (srv *TestServer) process(f *framer) {
 			f.writeHeader(0, opResult, head.stream)
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
 			f.writeInt(resultKindVoid)
 		case "timeout":
 		case "timeout":
-			<-srv.quit
+			<-srv.ctx.Done()
 			return
 			return
 		case "slow":
 		case "slow":
 			go func() {
 			go func() {
@@ -709,7 +762,8 @@ func (srv *TestServer) process(f *framer) {
 				f.writeInt(resultKindVoid)
 				f.writeInt(resultKindVoid)
 				f.wbuf[0] = srv.protocol | 0x80
 				f.wbuf[0] = srv.protocol | 0x80
 				select {
 				select {
-				case <-srv.quit:
+				case <-srv.ctx.Done():
+					return
 				case <-time.After(50 * time.Millisecond):
 				case <-time.After(50 * time.Millisecond):
 					f.finishWrite()
 					f.finishWrite()
 				}
 				}
@@ -731,6 +785,12 @@ func (srv *TestServer) process(f *framer) {
 	f.wbuf[0] = srv.protocol | 0x80
 	f.wbuf[0] = srv.protocol | 0x80
 
 
 	if err := f.finishWrite(); err != nil {
 	if err := f.finishWrite(); err != nil {
+		select {
+		case <-srv.ctx.Done():
+			return
+		default:
+		}
+
 		srv.t.Error(err)
 		srv.t.Error(err)
 	}
 	}
 }
 }

+ 16 - 0
frame.go

@@ -205,6 +205,22 @@ func ParseConsistency(s string) Consistency {
 	}
 	}
 }
 }
 
 
+// ParseConsistencyWrapper wraps gocql.ParseConsistency to provide an err
+// return instead of a panic
+func ParseConsistencyWrapper(s string) (consistency Consistency, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			var ok bool
+			err, ok = r.(error)
+			if !ok {
+				err = fmt.Errorf("ParseConsistencyWrapper: %v", r)
+			}
+		}
+	}()
+	consistency = ParseConsistency(s)
+	return consistency, nil
+}
+
 type SerialConsistency uint16
 type SerialConsistency uint16
 
 
 const (
 const (

+ 7 - 0
frame_test.go

@@ -97,3 +97,10 @@ func TestFrameReadTooLong(t *testing.T) {
 		t.Fatalf("expected to get header %v got %v", opReady, head.op)
 		t.Fatalf("expected to get header %v got %v", opReady, head.op)
 	}
 	}
 }
 }
+
+func TestParseConsistencyErrorInsteadOfPanic(t *testing.T) {
+	_, err := ParseConsistencyWrapper("TEST")
+	if err == nil {
+		t.Fatal("expected ParseConsistencyWrapper error got nil")
+	}
+}

+ 1 - 1
helpers.go

@@ -247,7 +247,7 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
 }
 }
 
 
 // MapScan takes a map[string]interface{} and populates it with a row
 // MapScan takes a map[string]interface{} and populates it with a row
-// That is returned from cassandra.
+// that is returned from cassandra.
 func (iter *Iter) MapScan(m map[string]interface{}) bool {
 func (iter *Iter) MapScan(m map[string]interface{}) bool {
 	if iter.err != nil {
 	if iter.err != nil {
 		return false
 		return false

+ 1 - 1
host_source.go

@@ -234,7 +234,7 @@ func (h *HostInfo) update(from *HostInfo) {
 }
 }
 
 
 func (h *HostInfo) IsUp() bool {
 func (h *HostInfo) IsUp() bool {
-	return h.State() == NodeUp
+	return h != nil && h.State() == NodeUp
 }
 }
 
 
 func (h *HostInfo) String() string {
 func (h *HostInfo) String() string {

+ 12 - 14
policies.go

@@ -13,7 +13,7 @@ import (
 	"github.com/hailocab/go-hostpool"
 	"github.com/hailocab/go-hostpool"
 )
 )
 
 
-// cowHostList implements a copy on write host list, its equivilent type is []*HostInfo
+// cowHostList implements a copy on write host list, its equivalent type is []*HostInfo
 type cowHostList struct {
 type cowHostList struct {
 	list atomic.Value
 	list atomic.Value
 	mu   sync.Mutex
 	mu   sync.Mutex
@@ -263,9 +263,6 @@ type tokenAwareHostPolicy struct {
 }
 }
 
 
 func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
 func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
-	t.mu.Lock()
-	defer t.mu.Unlock()
-
 	if t.partitioner != partitioner {
 	if t.partitioner != partitioner {
 		t.fallback.SetPartitioner(partitioner)
 		t.fallback.SetPartitioner(partitioner)
 		t.partitioner = partitioner
 		t.partitioner = partitioner
@@ -278,18 +275,14 @@ func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
 	t.hosts.add(host)
 	t.hosts.add(host)
 	t.fallback.AddHost(host)
 	t.fallback.AddHost(host)
 
 
-	t.mu.Lock()
 	t.resetTokenRing()
 	t.resetTokenRing()
-	t.mu.Unlock()
 }
 }
 
 
 func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
 func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
 	t.hosts.remove(addr)
 	t.hosts.remove(addr)
 	t.fallback.RemoveHost(addr)
 	t.fallback.RemoveHost(addr)
 
 
-	t.mu.Lock()
 	t.resetTokenRing()
 	t.resetTokenRing()
-	t.mu.Unlock()
 }
 }
 
 
 func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
 func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
@@ -301,6 +294,9 @@ func (t *tokenAwareHostPolicy) HostDown(addr string) {
 }
 }
 
 
 func (t *tokenAwareHostPolicy) resetTokenRing() {
 func (t *tokenAwareHostPolicy) resetTokenRing() {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
 	if t.partitioner == "" {
 	if t.partitioner == "" {
 		// partitioner not yet set
 		// partitioner not yet set
 		return
 		return
@@ -377,7 +373,7 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 //     // Create host selection policy using a simple host pool
 //     // Create host selection policy using a simple host pool
 //     cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(hostpool.New(nil))
 //     cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(hostpool.New(nil))
 //
 //
-//     // Create host selection policy using an epsilon greddy pool
+//     // Create host selection policy using an epsilon greedy pool
 //     cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(
 //     cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(
 //         hostpool.NewEpsilonGreedy(nil, 0, &hostpool.LinearEpsilonValueCalculator{}),
 //         hostpool.NewEpsilonGreedy(nil, 0, &hostpool.LinearEpsilonValueCalculator{}),
 //     )
 //     )
@@ -411,18 +407,20 @@ func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
 	r.mu.Lock()
 	r.mu.Lock()
 	defer r.mu.Unlock()
 	defer r.mu.Unlock()
 
 
-	if _, ok := r.hostMap[host.Peer()]; ok {
+	// If the host addr is present and isn't nil return
+	if h, ok := r.hostMap[host.Peer()]; ok && h != nil{
 		return
 		return
 	}
 	}
-
-	hosts := make([]string, 0, len(r.hostMap)+1)
+	// otherwise, add the host to the map
+	r.hostMap[host.Peer()] = host
+	// and construct a new peer list to give to the HostPool
+	hosts := make([]string, 0, len(r.hostMap))
 	for addr := range r.hostMap {
 	for addr := range r.hostMap {
 		hosts = append(hosts, addr)
 		hosts = append(hosts, addr)
 	}
 	}
-	hosts = append(hosts, host.Peer())
 
 
 	r.hp.SetHosts(hosts)
 	r.hp.SetHosts(hosts)
-	r.hostMap[host.Peer()] = host
+
 }
 }
 
 
 func (r *hostPoolHostPolicy) RemoveHost(addr string) {
 func (r *hostPoolHostPolicy) RemoveHost(addr string) {

+ 5 - 5
policies_test.go

@@ -111,14 +111,14 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 func TestHostPoolHostPolicy(t *testing.T) {
 func TestHostPoolHostPolicy(t *testing.T) {
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 
 
-	hosts := [...]*HostInfo{
+	hosts := []*HostInfo{
 		{hostId: "0", peer: "0"},
 		{hostId: "0", peer: "0"},
 		{hostId: "1", peer: "1"},
 		{hostId: "1", peer: "1"},
 	}
 	}
 
 
-	for _, host := range hosts {
-		policy.AddHost(host)
-	}
+	// Using set host to control the ordering of the hosts as calling "AddHost" iterates the map
+	// which will result in an unpredictable ordering
+	policy.(*hostPoolHostPolicy).SetHosts(hosts)
 
 
 	// the first host selected is actually at [1], but this is ok for RR
 	// the first host selected is actually at [1], but this is ok for RR
 	// interleaved iteration should always increment the host
 	// interleaved iteration should always increment the host
@@ -161,7 +161,7 @@ func TestRoundRobinNilHostInfo(t *testing.T) {
 	} else if v := next.Info(); v == nil {
 	} else if v := next.Info(); v == nil {
 		t.Fatal("got nil HostInfo")
 		t.Fatal("got nil HostInfo")
 	} else if v.HostID() != host.HostID() {
 	} else if v.HostID() != host.HostID() {
-		t.Fatalf("expected host %v got %v", host, *v)
+		t.Fatalf("expected host %v got %v", host, v)
 	}
 	}
 
 
 	next = iter()
 	next = iter()

+ 1 - 1
query_executor.go

@@ -24,7 +24,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 	var iter *Iter
 	var iter *Iter
 	for hostResponse := hostIter(); hostResponse != nil; hostResponse = hostIter() {
 	for hostResponse := hostIter(); hostResponse != nil; hostResponse = hostIter() {
 		host := hostResponse.Info()
 		host := hostResponse.Info()
-		if !host.IsUp() {
+		if host == nil || !host.IsUp() {
 			continue
 			continue
 		}
 		}
 
 

+ 1 - 1
session.go

@@ -1148,7 +1148,7 @@ func (iter *Iter) PageState() []byte {
 }
 }
 
 
 // NumRows returns the number of rows in this pagination, it will update when new
 // NumRows returns the number of rows in this pagination, it will update when new
-// pages are fetcehd, it is not the value of the total number of rows this iter
+// pages are fetched, it is not the value of the total number of rows this iter
 // will return unless there is only a single page returned.
 // will return unless there is only a single page returned.
 func (iter *Iter) NumRows() int {
 func (iter *Iter) NumRows() int {
 	return iter.numRows
 	return iter.numRows