Bläddra i källkod

Merge branch 'master' into 2.1_v2

Conflicts:
	cassandra_test.go
Ben Hood 11 år sedan
förälder
incheckning
d94bf88a1b
13 ändrade filer med 529 tillägg och 107 borttagningar
  1. 2 0
      AUTHORS
  2. 9 1
      README.md
  3. 193 30
      cassandra_test.go
  4. 20 22
      cluster.go
  5. 19 11
      conn.go
  6. 9 5
      conn_test.go
  7. 95 17
      connectionpool.go
  8. 104 0
      host_source.go
  9. 6 3
      integration.sh
  10. 4 0
      marshal.go
  11. 10 0
      marshal_test.go
  12. 57 17
      session.go
  13. 1 1
      wiki_test.go

+ 2 - 0
AUTHORS

@@ -28,3 +28,5 @@ Ben Frye <benfrye@gmail.com>
 Fred McCann <fred@sharpnoodles.com>
 Dan Simmons <dan@simmons.io>
 Muir Manders <muir@retailnext.net>
+Sankar P <sankar.curiosity@gmail.com>
+Julien Da Silva <julien.dasilva@gmail.com>

+ 9 - 1
README.md

@@ -43,6 +43,7 @@ Features
   * Round robin distribution of queries to different connections on a host
   * Each connection can execute up to 128 concurrent queries
   * Optional automatic discovery of nodes
+  * Optional support for periodic node discovery via system.peers
 * Iteration over paged results with configurable page size
 * Optional frame compression (using snappy)
 * Automatic query preparation
@@ -85,6 +86,11 @@ Example
 -------
 
 ```go
+/* Before you execute the program, Launch `cqlsh` and execute:
+create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };
+create table example.tweet(timeline text, id UUID, text text, PRIMARY KEY(id));
+create index on example.tweet(timeline);
+*/
 package main
 
 import (
@@ -111,7 +117,9 @@ func main() {
 	var id gocql.UUID
 	var text string
 
-	// select a single tweet
+	/* Search for a specific set of records whose 'timeline' column matches
+	 * the value 'me'. The secondary index that we created earlier will be
+	 * used for optimizing the search */
 	if err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
 		"me").Consistency(gocql.One).Scan(&id, &text); err != nil {
 		log.Fatal(err)

+ 193 - 30
cassandra_test.go

@@ -27,8 +27,9 @@ var (
 	flagProto    = flag.Int("proto", 2, "protcol version")
 	flagCQL      = flag.String("cql", "3.0.0", "CQL version")
 	flagRF       = flag.Int("rf", 1, "replication factor for test keyspace")
+	clusterSize  = flag.Int("clusterSize", 1, "the expected size of the cluster")
 	flagRetry    = flag.Int("retries", 5, "number of times to retry queries")
-	clusterSize  = *flag.Int("clusterSize", 1, "the expected size of the cluster")
+	flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
 	clusterHosts []string
 )
 
@@ -36,7 +37,6 @@ func init() {
 
 	flag.Parse()
 	clusterHosts = strings.Split(*flagCluster, ",")
-	clusterSize = len(clusterHosts)
 	log.SetFlags(log.Lshortfile | log.LstdFlags)
 }
 
@@ -44,14 +44,14 @@ var initOnce sync.Once
 
 func createTable(s *Session, table string) error {
 	err := s.Query(table).Consistency(All).Exec()
-	if clusterSize > 1 {
+	if *clusterSize > 1 {
 		// wait for table definition to propogate
 		time.Sleep(250 * time.Millisecond)
 	}
 	return err
 }
 
-func createSession(tb testing.TB) *Session {
+func createCluster() *ClusterConfig {
 	cluster := NewCluster(clusterHosts...)
 	cluster.ProtoVersion = *flagProto
 	cluster.CQLVersion = *flagCQL
@@ -59,26 +59,37 @@ func createSession(tb testing.TB) *Session {
 	cluster.Consistency = Quorum
 	cluster.RetryPolicy.NumRetries = *flagRetry
 
+	return cluster
+}
+
+func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
+	session, err := cluster.CreateSession()
+	if err != nil {
+		tb.Fatal("createSession:", err)
+	}
+	if err = session.Query(`DROP KEYSPACE ` + keyspace).Exec(); err != nil {
+		tb.Log("drop keyspace:", err)
+	}
+	if err := session.Query(fmt.Sprintf(`CREATE KEYSPACE %s
+	WITH replication = {
+		'class' : 'SimpleStrategy',
+		'replication_factor' : %d
+	}`, keyspace, *flagRF)).Consistency(All).Exec(); err != nil {
+		tb.Fatalf("error creating keyspace %s: %v", keyspace, err)
+	}
+	tb.Logf("Created keyspace %s", keyspace)
+	session.Close()
+}
+
+func createSession(tb testing.TB) *Session {
+	cluster := createCluster()
+
+	// Drop and re-create the keyspace once. Different tests should use their own
+	// individual tables, but can assume that the table does not exist before.
 	initOnce.Do(func() {
-		session, err := cluster.CreateSession()
-		if err != nil {
-			tb.Fatal("createSession:", err)
-		}
-		// Drop and re-create the keyspace once. Different tests should use their own
-		// individual tables, but can assume that the table does not exist before.
-		if err := session.Query(`DROP KEYSPACE gocql_test`).Exec(); err != nil {
-			tb.Log("drop keyspace:", err)
-		}
-		if err := session.Query(fmt.Sprintf(`CREATE KEYSPACE gocql_test
-			WITH replication = {
-				'class' : 'SimpleStrategy',
-				'replication_factor' : %d
-			}`, *flagRF)).Consistency(All).Exec(); err != nil {
-			tb.Fatal("create keyspace:", err)
-		}
-		tb.Log("Created keyspace")
-		session.Close()
+		createKeyspace(tb, cluster, "gocql_test")
 	})
+
 	cluster.Keyspace = "gocql_test"
 	session, err := cluster.CreateSession()
 	if err != nil {
@@ -88,6 +99,36 @@ func createSession(tb testing.TB) *Session {
 	return session
 }
 
+//TestRingDiscovery makes sure that you can autodiscover other cluster members when you seed a cluster config with just one node
+func TestRingDiscovery(t *testing.T) {
+
+	cluster := NewCluster(clusterHosts[0])
+	cluster.ProtoVersion = *flagProto
+	cluster.CQLVersion = *flagCQL
+	cluster.Timeout = 5 * time.Second
+	cluster.Consistency = Quorum
+	cluster.RetryPolicy.NumRetries = *flagRetry
+	cluster.DiscoverHosts = true
+
+	session, err := cluster.CreateSession()
+	if err != nil {
+		t.Errorf("got error connecting to the cluster %v", err)
+	}
+
+	if *clusterSize > 1 {
+		// wait for autodiscovery to update the pool with the list of known hosts
+		time.Sleep(*flagAutoWait)
+	}
+
+	size := len(session.Pool.(*SimplePool).connPool)
+
+	if *clusterSize != size {
+		t.Fatalf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
+	}
+
+	session.Close()
+}
+
 func TestEmptyHosts(t *testing.T) {
 	cluster := NewCluster()
 	if session, err := cluster.CreateSession(); err == nil {
@@ -381,6 +422,7 @@ func TestSliceMap(t *testing.T) {
 			testdouble     double,
 			testint        int,
 			testdecimal    decimal,
+			testlist       list<text>,
 			testset        set<int>,
 			testmap        map<varchar, varchar>,
 			testvarint     varint
@@ -404,12 +446,13 @@ func TestSliceMap(t *testing.T) {
 	m["testdouble"] = float64(4.815162342)
 	m["testint"] = 2343
 	m["testdecimal"] = inf.NewDec(100, 0)
+	m["testlist"] = []string{"quux", "foo", "bar", "baz", "quux"}
 	m["testset"] = []int{1, 2, 3, 4, 5, 6, 7, 8, 9}
 	m["testmap"] = map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"}
 	m["testvarint"] = bigInt
 	sliceMap := []map[string]interface{}{m}
-	if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testset, testmap, testvarint) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
-		m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testset"], m["testmap"], m["testvarint"]).Exec(); err != nil {
+	if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testlist, testset, testmap, testvarint) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+		m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testlist"], m["testset"], m["testmap"], m["testvarint"]).Exec(); err != nil {
 		t.Fatal("insert:", err)
 	}
 	if returned, retErr := session.Query(`SELECT * FROM slice_map_table`).Iter().SliceMap(); retErr != nil {
@@ -449,6 +492,9 @@ func TestSliceMap(t *testing.T) {
 		if expectedDecimal.Cmp(returnedDecimal) != 0 {
 			t.Fatal("returned testdecimal did not match")
 		}
+		if !reflect.DeepEqual(sliceMap[0]["testlist"], returned[0]["testlist"]) {
+			t.Fatal("returned testlist did not match")
+		}
 		if !reflect.DeepEqual(sliceMap[0]["testset"], returned[0]["testset"]) {
 			t.Fatal("returned testset did not match")
 		}
@@ -503,6 +549,9 @@ func TestSliceMap(t *testing.T) {
 		t.Fatal("returned testdecimal did not match")
 	}
 
+	if !reflect.DeepEqual(sliceMap[0]["testlist"], testMap["testlist"]) {
+		t.Fatal("returned testlist did not match")
+	}
 	if !reflect.DeepEqual(sliceMap[0]["testset"], testMap["testset"]) {
 		t.Fatal("returned testset did not match")
 	}
@@ -941,19 +990,19 @@ func TestPreparedCacheEviction(t *testing.T) {
 	//Walk through all the configured hosts and test cache retention and eviction
 	var selFound, insFound, updFound, delFound, selEvict bool
 	for i := range session.cfg.Hosts {
-		_, ok := stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042SELECT id,mod FROM prepcachetest WHERE id = 1")
+		_, ok := stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 1")
 		selFound = selFound || ok
 
-		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042INSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
+		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testINSERT INTO prepcachetest (id,mod) VALUES (?, ?)")
 		insFound = insFound || ok
 
-		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042UPDATE prepcachetest SET mod = ? WHERE id = ?")
+		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testUPDATE prepcachetest SET mod = ? WHERE id = ?")
 		updFound = updFound || ok
 
-		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042DELETE FROM prepcachetest WHERE id = ?")
+		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testDELETE FROM prepcachetest WHERE id = ?")
 		delFound = delFound || ok
 
-		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042SELECT id,mod FROM prepcachetest WHERE id = 0")
+		_, ok = stmtsLRU.lru.Get(session.cfg.Hosts[i] + ":9042gocql_testSELECT id,mod FROM prepcachetest WHERE id = 0")
 		selEvict = selEvict || !ok
 
 	}
@@ -974,6 +1023,53 @@ func TestPreparedCacheEviction(t *testing.T) {
 	}
 }
 
+func TestPreparedCacheKey(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	// create a second keyspace
+	cluster2 := createCluster()
+	createKeyspace(t, cluster2, "gocql_test2")
+	cluster2.Keyspace = "gocql_test2"
+	session2, err := cluster2.CreateSession()
+	if err != nil {
+		t.Fatal("create session:", err)
+	}
+	defer session2.Close()
+
+	// both keyspaces have a table named "test_stmt_cache_key"
+	if err := createTable(session, "CREATE TABLE test_stmt_cache_key (id varchar primary key, field varchar)"); err != nil {
+		t.Fatal("create table:", err)
+	}
+	if err := createTable(session2, "CREATE TABLE test_stmt_cache_key (id varchar primary key, field varchar)"); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	// both tables have a single row with the same partition key but different column value
+	if err = session.Query(`INSERT INTO test_stmt_cache_key (id, field) VALUES (?, ?)`, "key", "one").Exec(); err != nil {
+		t.Fatal("insert:", err)
+	}
+	if err = session2.Query(`INSERT INTO test_stmt_cache_key (id, field) VALUES (?, ?)`, "key", "two").Exec(); err != nil {
+		t.Fatal("insert:", err)
+	}
+
+	// should be able to see different values in each keyspace
+	var value string
+	if err = session.Query("SELECT field FROM test_stmt_cache_key WHERE id = ?", "key").Scan(&value); err != nil {
+		t.Fatal("select:", err)
+	}
+	if value != "one" {
+		t.Errorf("Expected one, got %s", value)
+	}
+
+	if err = session2.Query("SELECT field FROM test_stmt_cache_key WHERE id = ?", "key").Scan(&value); err != nil {
+		t.Fatal("select:", err)
+	}
+	if value != "two" {
+		t.Errorf("Expected two, got %s", value)
+	}
+}
+
 //TestMarshalFloat64Ptr tests to see that a pointer to a float64 is marshalled correctly.
 func TestMarshalFloat64Ptr(t *testing.T) {
 	session := createSession(t)
@@ -1051,7 +1147,7 @@ func TestVarint(t *testing.T) {
 
 	err := session.Query("SELECT test FROM varint_test").Scan(&result64)
 	if err == nil || strings.Index(err.Error(), "out of range") == -1 {
-		t.Errorf("expected our of range error since value is too big for int64")
+		t.Errorf("expected out of range error since value is too big for int64")
 	}
 
 	// value not set in cassandra, leave bind variable empty
@@ -1073,3 +1169,70 @@ func TestVarint(t *testing.T) {
 		t.Errorf("Expected %v, was %v", nil, *resultBig)
 	}
 }
+
+//TestQueryStats confirms that the stats are returning valid data. Accuracy may be questionable.
+func TestQueryStats(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	qry := session.Query("SELECT * FROM system.peers")
+	if err := qry.Exec(); err != nil {
+		t.Fatalf("query failed. %v", err)
+	} else {
+		if qry.Attempts() < 1 {
+			t.Fatal("expected at least 1 attempt, but got 0")
+		}
+		if qry.Latency() <= 0 {
+			t.Fatalf("expected latency to be greater than 0, but got %v instead.", qry.Latency())
+		}
+	}
+}
+
+//TestBatchStats confirms that the stats are returning valid data. Accuracy may be questionable.
+func TestBatchStats(t *testing.T) {
+	if *flagProto == 1 {
+		t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
+	}
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE batchStats (id int, PRIMARY KEY (id))"); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	b := session.NewBatch(LoggedBatch)
+	b.Query("INSERT INTO batchStats (id) VALUES (?)", 1)
+	b.Query("INSERT INTO batchStats (id) VALUES (?)", 2)
+
+	if err := session.ExecuteBatch(b); err != nil {
+		t.Fatalf("query failed. %v", err)
+	} else {
+		if b.Attempts() < 1 {
+			t.Fatal("expected at least 1 attempt, but got 0")
+		}
+		if b.Latency() <= 0 {
+			t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency())
+		}
+	}
+}
+
+//TestNilInQuery tests to see that a nil value passed to a query is handled by Cassandra
+//TODO validate the nil value by reading back the nil. Need to fix Unmarshalling.
+func TestNilInQuery(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE testNilInsert (id int, count int, PRIMARY KEY (id))"); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+	if err := session.Query("INSERT INTO testNilInsert (id,count) VALUES (?,?)", 1, nil).Exec(); err != nil {
+		t.Fatalf("failed to insert with err: %v", err)
+	}
+
+	var id int
+
+	if err := session.Query("SELECT id FROM testNilInsert").Scan(&id); err != nil {
+		t.Fatalf("failed to select with err: %v", err)
+	} else if id != 1 {
+		t.Fatalf("expected id to be 1, got %v", id)
+	}
+}

+ 20 - 22
cluster.go

@@ -6,9 +6,10 @@ package gocql
 
 import (
 	"errors"
-	"github.com/golang/groupcache/lru"
 	"sync"
 	"time"
+
+	"github.com/golang/groupcache/lru"
 )
 
 //Package global reference to Prepared Statements LRU
@@ -29,6 +30,16 @@ func (p *preparedLRU) Max(max int) {
 	p.lru.MaxEntries = max
 }
 
+// To enable periodic node discovery enable DiscoverHosts in ClusterConfig
+type DiscoveryConfig struct {
+	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
+	DcFilter string
+	// If not empty will filter all discoverred hosts to a single Rack (default: "")
+	RackFilter string
+	// The interval to check for new hosts (default: 30s)
+	Sleep time.Duration
+}
+
 // ClusterConfig is a struct to configure the default cluster implementation
 // of gocoql. It has a varity of attributes that can be used to modify the
 // behavior to fit the most common use cases. Applications that requre a
@@ -38,7 +49,7 @@ type ClusterConfig struct {
 	CQLVersion       string        // CQL version (default: 3.0.0)
 	ProtoVersion     int           // version of the native protocol (default: 2)
 	Timeout          time.Duration // connection timeout (default: 600ms)
-	DefaultPort      int           // default port (default: 9042)
+	Port             int           // port (default: 9042)
 	Keyspace         string        // initial keyspace (optional)
 	NumConns         int           // number of connections per host (default: 2)
 	NumStreams       int           // number of streams per connection (default: 128)
@@ -50,6 +61,7 @@ type ClusterConfig struct {
 	ConnPoolType     NewPoolFunc   // The function used to create the connection pool for the session (default: NewSimplePool)
 	DiscoverHosts    bool          // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
 	MaxPreparedStmts int           // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
+	Discovery        DiscoveryConfig
 }
 
 // NewCluster generates a new config for the default cluster implementation.
@@ -59,7 +71,7 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		CQLVersion:       "3.0.0",
 		ProtoVersion:     2,
 		Timeout:          600 * time.Millisecond,
-		DefaultPort:      9042,
+		Port:             9042,
 		NumConns:         2,
 		NumStreams:       128,
 		Consistency:      Quorum,
@@ -95,26 +107,13 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 		s.SetConsistency(cfg.Consistency)
 
 		if cfg.DiscoverHosts {
-			//Fill out cfg.Hosts
-			query := "SELECT peer FROM system.peers"
-			peers := s.Query(query).Iter()
-
-			var ip string
-			for peers.Scan(&ip) {
-				exists := false
-				for ii := 0; ii < len(cfg.Hosts); ii++ {
-					if cfg.Hosts[ii] == ip {
-						exists = true
-					}
-				}
-				if !exists {
-					cfg.Hosts = append(cfg.Hosts, ip)
-				}
+			hostSource := &ringDescriber{
+				session:    s,
+				dcFilter:   cfg.Discovery.DcFilter,
+				rackFilter: cfg.Discovery.RackFilter,
 			}
 
-			if err := peers.Close(); err != nil {
-				return s, ErrHostQueryFailed
-			}
+			go hostSource.run(cfg.Discovery.Sleep)
 		}
 
 		return s, nil
@@ -122,7 +121,6 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 
 	pool.Close()
 	return nil, ErrNoConnectionsStarted
-
 }
 
 var (

+ 19 - 11
conn.go

@@ -66,11 +66,12 @@ type Conn struct {
 	calls []callReq
 	nwait int32
 
-	pool       ConnectionPool
-	compressor Compressor
-	auth       Authenticator
-	addr       string
-	version    uint8
+	pool            ConnectionPool
+	compressor      Compressor
+	auth            Authenticator
+	addr            string
+	version         uint8
+	currentKeyspace string
 
 	closedMu sync.RWMutex
 	isClosed bool
@@ -310,7 +311,10 @@ func (c *Conn) ping() error {
 
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	stmtsLRU.mu.Lock()
-	if val, ok := stmtsLRU.lru.Get(c.addr + stmt); ok {
+
+	stmtCacheKey := c.addr + c.currentKeyspace + stmt
+
+	if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
 		flight := val.(*inflightPrepare)
 		stmtsLRU.mu.Unlock()
 		flight.wg.Wait()
@@ -319,7 +323,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 
 	flight := new(inflightPrepare)
 	flight.wg.Add(1)
-	stmtsLRU.lru.Add(c.addr+stmt, flight)
+	stmtsLRU.lru.Add(stmtCacheKey, flight)
 	stmtsLRU.mu.Unlock()
 
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
@@ -345,7 +349,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 
 	if err != nil {
 		stmtsLRU.mu.Lock()
-		stmtsLRU.lru.Remove(c.addr + stmt)
+		stmtsLRU.lru.Remove(stmtCacheKey)
 		stmtsLRU.mu.Unlock()
 	}
 
@@ -414,8 +418,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		return &Iter{}
 	case RequestErrUnprepared:
 		stmtsLRU.mu.Lock()
-		if _, ok := stmtsLRU.lru.Get(c.addr + qry.stmt); ok {
-			stmtsLRU.lru.Remove(c.addr + qry.stmt)
+		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
+		if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
+			stmtsLRU.lru.Remove(stmtCacheKey)
 			stmtsLRU.mu.Unlock()
 			return c.executeQuery(qry)
 		}
@@ -470,6 +475,9 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	default:
 		return NewErrProtocol("Unknown type in response to USE: %s", x)
 	}
+
+	c.currentKeyspace = keyspace
+
 	return nil
 }
 
@@ -537,7 +545,7 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
 			stmtsLRU.mu.Lock()
-			stmtsLRU.lru.Remove(c.addr + stmt)
+			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
 			stmtsLRU.mu.Unlock()
 		}
 		if found {

+ 9 - 5
conn_test.go

@@ -86,13 +86,17 @@ func TestQueryRetry(t *testing.T) {
 		t.Fatal("no timeout")
 	}()
 	rt := RetryPolicy{NumRetries: 1}
-
-	if err := db.Query("kill").RetryPolicy(rt).Exec(); err == nil {
+	qry := db.Query("kill").RetryPolicy(rt)
+	if err := qry.Exec(); err == nil {
 		t.Fatal("expected error")
 	}
-	//Minus 1 from the nKillReq variable since there is the initial query attempt
-	if srv.nKillReq-1 != uint64(rt.NumRetries) {
-		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, srv.nKillReq-1)
+	requests := srv.nKillReq
+	if requests != uint64(qry.Attempts()) {
+		t.Fatalf("expected requests %v to match query attemps %v", requests, qry.Attempts())
+	}
+	//Minus 1 from the requests variable since there is the initial query attempt
+	if requests-1 != uint64(rt.NumRetries) {
+		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
 	}
 }
 

+ 95 - 17
connectionpool.go

@@ -30,7 +30,7 @@ Example of Single Connection Pool:
 	func NewSingleConnection(cfg *ClusterConfig) ConnectionPool {
 		addr := strings.TrimSpace(cfg.Hosts[0])
 		if strings.Index(addr, ":") < 0 {
-			addr = fmt.Sprintf("%s:%d", addr, cfg.DefaultPort)
+			addr = fmt.Sprintf("%s:%d", addr, cfg.Port)
 		}
 		connCfg := ConnConfig{
 			ProtoVersion:  cfg.ProtoVersion,
@@ -91,6 +91,7 @@ type ConnectionPool interface {
 	Size() int
 	HandleError(*Conn, error, bool)
 	Close()
+	SetHosts(host []HostInfo)
 }
 
 //NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
@@ -105,7 +106,13 @@ type SimplePool struct {
 	connPool map[string]*RoundRobin
 	conns    map[*Conn]struct{}
 	keyspace string
-	mu       sync.Mutex
+
+	hostMu sync.RWMutex
+	// this is the set of current hosts which the pool will attempt to connect to
+	hosts map[string]*HostInfo
+
+	// protects hostpool, connPoll, conns, quit
+	mu sync.Mutex
 
 	cFillingPool chan int
 
@@ -117,7 +124,7 @@ type SimplePool struct {
 //NewSimplePool is the function used by gocql to create the simple connection pool.
 //This is the default if no other pool type is specified.
 func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
-	pool := SimplePool{
+	pool := &SimplePool{
 		cfg:          cfg,
 		hostPool:     NewRoundRobin(),
 		connPool:     make(map[string]*RoundRobin),
@@ -125,25 +132,35 @@ func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
 		quitWait:     make(chan bool),
 		cFillingPool: make(chan int, 1),
 		keyspace:     cfg.Keyspace,
+		hosts:        make(map[string]*HostInfo),
+	}
+
+	for _, host := range cfg.Hosts {
+		// seed hosts have unknown topology
+		// TODO: Handle populating this during SetHosts
+		pool.hosts[host] = &HostInfo{Peer: host}
 	}
+
 	//Walk through connecting to hosts. As soon as one host connects
 	//defer the remaining connections to cluster.fillPool()
 	for i := 0; i < len(cfg.Hosts); i++ {
 		addr := strings.TrimSpace(cfg.Hosts[i])
 		if strings.Index(addr, ":") < 0 {
-			addr = fmt.Sprintf("%s:%d", addr, cfg.DefaultPort)
+			addr = fmt.Sprintf("%s:%d", addr, cfg.Port)
 		}
+
 		if pool.connect(addr) == nil {
 			pool.cFillingPool <- 1
 			go pool.fillPool()
 			break
 		}
-
 	}
-	return &pool
+
+	return pool
 }
 
 func (c *SimplePool) connect(addr string) error {
+
 	cfg := ConnConfig{
 		ProtoVersion:  c.cfg.ProtoVersion,
 		CQLVersion:    c.cfg.CQLVersion,
@@ -154,14 +171,13 @@ func (c *SimplePool) connect(addr string) error {
 		Keepalive:     c.cfg.SocketKeepalive,
 	}
 
-	for {
-		conn, err := Connect(addr, cfg, c)
-		if err != nil {
-			log.Printf("failed to connect to %q: %v", addr, err)
-			return err
-		}
-		return c.addConn(conn)
+	conn, err := Connect(addr, cfg, c)
+	if err != nil {
+		log.Printf("connect: failed to connect to %q: %v", addr, err)
+		return err
 	}
+
+	return c.addConn(conn)
 }
 
 func (c *SimplePool) addConn(conn *Conn) error {
@@ -171,6 +187,7 @@ func (c *SimplePool) addConn(conn *Conn) error {
 		conn.Close()
 		return nil
 	}
+
 	//Set the connection's keyspace if any before adding it to the pool
 	if c.keyspace != "" {
 		if err := conn.UseKeyspace(c.keyspace); err != nil {
@@ -179,14 +196,17 @@ func (c *SimplePool) addConn(conn *Conn) error {
 			return err
 		}
 	}
+
 	connPool := c.connPool[conn.Address()]
 	if connPool == nil {
 		connPool = NewRoundRobin()
 		c.connPool[conn.Address()] = connPool
 		c.hostPool.AddNode(connPool)
 	}
+
 	connPool.AddNode(conn)
 	c.conns[conn] = struct{}{}
+
 	return nil
 }
 
@@ -209,13 +229,17 @@ func (c *SimplePool) fillPool() {
 	if isClosed {
 		return
 	}
+
+	c.hostMu.RLock()
+
 	//Walk through list of defined hosts
-	for i := 0; i < len(c.cfg.Hosts); i++ {
-		addr := strings.TrimSpace(c.cfg.Hosts[i])
+	for host := range c.hosts {
+		addr := strings.TrimSpace(host)
 		if strings.Index(addr, ":") < 0 {
-			addr = fmt.Sprintf("%s:%d", addr, c.cfg.DefaultPort)
+			addr = fmt.Sprintf("%s:%d", addr, c.cfg.Port)
 		}
-		var numConns int = 1
+
+		numConns := 1
 		//See if the host already has connections in the pool
 		c.mu.Lock()
 		conns, ok := c.connPool[addr]
@@ -233,6 +257,7 @@ func (c *SimplePool) fillPool() {
 				continue
 			}
 		}
+
 		//This is reached if the host is responsive and needs more connections
 		//Create connections for host synchronously to mitigate flooding the host.
 		go func(a string, conns int) {
@@ -241,6 +266,8 @@ func (c *SimplePool) fillPool() {
 			}
 		}(addr, numConns)
 	}
+
+	c.hostMu.RUnlock()
 }
 
 // Should only be called if c.mu is locked
@@ -313,3 +340,54 @@ func (c *SimplePool) Close() {
 		}
 	})
 }
+
+func (c *SimplePool) SetHosts(hosts []HostInfo) {
+
+	c.hostMu.Lock()
+	toRemove := make(map[string]struct{})
+	for k := range c.hosts {
+		toRemove[k] = struct{}{}
+	}
+
+	for _, host := range hosts {
+		host := host
+
+		delete(toRemove, host.Peer)
+		// we already have it
+		if _, ok := c.hosts[host.Peer]; ok {
+			// TODO: Check rack, dc, token range is consistent, trigger topology change
+			// update stored host
+			continue
+		}
+
+		c.hosts[host.Peer] = &host
+	}
+
+	// can we hold c.mu whilst iterating this loop?
+	for addr := range toRemove {
+		c.removeHostLocked(addr)
+	}
+	c.hostMu.Unlock()
+
+	c.fillPool()
+}
+
+func (c *SimplePool) removeHostLocked(addr string) {
+	if _, ok := c.hosts[addr]; !ok {
+		return
+	}
+	delete(c.hosts, addr)
+
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	if _, ok := c.connPool[addr]; !ok {
+		return
+	}
+
+	for conn := range c.conns {
+		if conn.Address() == addr {
+			c.removeConnLocked(conn)
+		}
+	}
+}

+ 104 - 0
host_source.go

@@ -0,0 +1,104 @@
+package gocql
+
+import (
+	"log"
+	"net"
+	"time"
+)
+
+type HostInfo struct {
+	Peer       string
+	DataCenter string
+	Rack       string
+	HostId     string
+	Tokens     []string
+}
+
+// Polls system.peers at a specific interval to find new hosts
+type ringDescriber struct {
+	dcFilter   string
+	rackFilter string
+	previous   []HostInfo
+	session    *Session
+}
+
+func (r *ringDescriber) GetHosts() ([]HostInfo, error) {
+	// we need conn to be the same because we need to query system.peers and system.local
+	// on the same node to get the whole cluster
+	conn := r.session.Pool.Pick(nil)
+	if conn == nil {
+		return r.previous, nil
+	}
+
+	query := r.session.Query("SELECT data_center, rack, host_id, tokens FROM system.local")
+	iter := conn.executeQuery(query)
+
+	host := &HostInfo{}
+	iter.Scan(&host.DataCenter, &host.Rack, &host.HostId, &host.Tokens)
+
+	if err := iter.Close(); err != nil {
+		return nil, err
+	}
+
+	addr, _, err := net.SplitHostPort(conn.Address())
+	if err != nil {
+		// this should not happen, ever, as this is the address that was dialed by conn, here
+		// a panic makes sense, please report a bug if it occurs.
+		panic(err)
+	}
+
+	host.Peer = addr
+
+	hosts := []HostInfo{*host}
+
+	query = r.session.Query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
+	iter = conn.executeQuery(query)
+
+	for iter.Scan(&host.Peer, &host.DataCenter, &host.Rack, &host.HostId, &host.Tokens) {
+		if r.matchFilter(host) {
+			hosts = append(hosts, *host)
+		}
+	}
+
+	if err := iter.Close(); err != nil {
+		return nil, err
+	}
+
+	r.previous = hosts
+
+	return hosts, nil
+}
+
+func (r *ringDescriber) matchFilter(host *HostInfo) bool {
+
+	if r.dcFilter != "" && r.dcFilter != host.DataCenter {
+		return false
+	}
+
+	if r.rackFilter != "" && r.rackFilter != host.Rack {
+		return false
+	}
+
+	return true
+}
+
+func (h *ringDescriber) run(sleep time.Duration) {
+	if sleep == 0 {
+		sleep = 30 * time.Second
+	}
+
+	for {
+		// if we have 0 hosts this will return the previous list of hosts to
+		// attempt to reconnect to the cluster otherwise we would never find
+		// downed hosts again, could possibly have an optimisation to only
+		// try to add new hosts if GetHosts didnt error and the hosts didnt change.
+		hosts, err := h.GetHosts()
+		if err != nil {
+			log.Println("RingDescriber: unable to get ring topology:", err)
+		} else {
+			h.session.Pool.SetHosts(hosts)
+		}
+
+		time.Sleep(sleep)
+	}
+}

+ 6 - 3
integration.sh

@@ -3,17 +3,20 @@
 set -e
 
 function run_tests() {
+	local clusterSize=3
 	local version=$1
-	ccm create test -v binary:$version -n 3 -s -d --vnodes
-	ccm status
+
+	ccm create test -v binary:$version -n $clusterSize -d --vnodes
 	ccm updateconf 'concurrent_reads: 8' 'concurrent_writes: 32' 'rpc_server_type: sync' 'rpc_min_threads: 2' 'rpc_max_threads: 8' 'write_request_timeout_in_ms: 5000' 'read_request_timeout_in_ms: 5000'
+	ccm start
+	ccm status
 
 	local proto=2
 	if [[ $version == 1.2.* ]]; then
 		proto=1
 	fi
 
-	go test -v -proto=$proto -rf=3 -cluster=$(ccm liveset) ./...
+	go test -v -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms ./...
 
 	ccm clear
 }

+ 4 - 0
marshal.go

@@ -35,6 +35,10 @@ type Unmarshaler interface {
 // Marshal returns the CQL encoding of the value for the Cassandra
 // internal type described by the info parameter.
 func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
+	if value == nil {
+		return nil, nil
+	}
+
 	if v, ok := value.(Marshaler); ok {
 		return v.MarshalCQL(info)
 	}

+ 10 - 0
marshal_test.go

@@ -297,6 +297,16 @@ func TestMarshal(t *testing.T) {
 	}
 }
 
+func TestMarshalNil(t *testing.T) {
+	data, err := Marshal(&TypeInfo{Type: TypeInt}, nil)
+	if err != nil {
+		t.Errorf("failed to marshal nil with err: %v", err)
+	}
+	if data != nil {
+		t.Errorf("expected nil, got %v", data)
+	}
+}
+
 func TestUnmarshal(t *testing.T) {
 	for i, test := range marshalTests {
 		v := reflect.New(reflect.TypeOf(test.Value))

+ 57 - 17
session.go

@@ -134,7 +134,9 @@ func (s *Session) executeQuery(qry *Query) *Iter {
 	}
 
 	var iter *Iter
-	for count := 0; count <= qry.rt.NumRetries; count++ {
+	qry.attempts = 0
+	qry.totalLatency = 0
+	for qry.attempts <= qry.rt.NumRetries {
 		conn := s.Pool.Pick(qry)
 
 		//Assign the error unavailable to the iterator
@@ -143,7 +145,11 @@ func (s *Session) executeQuery(qry *Query) *Iter {
 			break
 		}
 
+		t := time.Now()
 		iter = conn.executeQuery(qry)
+		qry.totalLatency += time.Now().Sub(t).Nanoseconds()
+		qry.attempts++
+
 		//Exit for loop if the query was successful
 		if iter.err == nil {
 			break
@@ -169,7 +175,9 @@ func (s *Session) ExecuteBatch(batch *Batch) error {
 	}
 
 	var err error
-	for count := 0; count <= batch.rt.NumRetries; count++ {
+	batch.attempts = 0
+	batch.totalLatency = 0
+	for batch.attempts <= batch.rt.NumRetries {
 		conn := s.Pool.Pick(nil)
 
 		//Assign the error unavailable and break loop
@@ -177,8 +185,10 @@ func (s *Session) ExecuteBatch(batch *Batch) error {
 			err = ErrNoConnections
 			break
 		}
-
+		t := time.Now()
 		err = conn.executeBatch(batch)
+		batch.totalLatency += time.Now().Sub(t).Nanoseconds()
+		batch.attempts++
 		//Exit loop if operation executed correctly
 		if err == nil {
 			return nil
@@ -190,16 +200,31 @@ func (s *Session) ExecuteBatch(batch *Batch) error {
 
 // Query represents a CQL statement that can be executed.
 type Query struct {
-	stmt      string
-	values    []interface{}
-	cons      Consistency
-	pageSize  int
-	pageState []byte
-	prefetch  float64
-	trace     Tracer
-	session   *Session
-	rt        RetryPolicy
-	binding   func(q *QueryInfo) ([]interface{}, error)
+	stmt         string
+	values       []interface{}
+	cons         Consistency
+	pageSize     int
+	pageState    []byte
+	prefetch     float64
+	trace        Tracer
+	session      *Session
+	rt           RetryPolicy
+	binding      func(q *QueryInfo) ([]interface{}, error)
+	attempts     int
+	totalLatency int64
+}
+
+//Attempts returns the number of times the query was executed.
+func (q *Query) Attempts() int {
+	return q.attempts
+}
+
+//Latency returns the average amount of nanoseconds per attempt of the query.
+func (q *Query) Latency() int64 {
+	if q.attempts > 0 {
+		return q.totalLatency / int64(q.attempts)
+	}
+	return 0
 }
 
 // Consistency sets the consistency level for this query. If no consistency
@@ -397,10 +422,12 @@ func (n *nextIter) fetch() *Iter {
 }
 
 type Batch struct {
-	Type    BatchType
-	Entries []BatchEntry
-	Cons    Consistency
-	rt      RetryPolicy
+	Type         BatchType
+	Entries      []BatchEntry
+	Cons         Consistency
+	rt           RetryPolicy
+	attempts     int
+	totalLatency int64
 }
 
 // NewBatch creates a new batch operation without defaults from the cluster
@@ -413,6 +440,19 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
 	return &Batch{Type: typ, rt: s.cfg.RetryPolicy}
 }
 
+// Attempts returns the number of attempts made to execute the batch.
+func (b *Batch) Attempts() int {
+	return b.attempts
+}
+
+//Latency returns the average number of nanoseconds to execute a single attempt of the batch.
+func (b *Batch) Latency() int64 {
+	if b.attempts > 0 {
+		return b.totalLatency / int64(b.attempts)
+	}
+	return 0
+}
+
 // Query adds the query to the batch operation
 func (b *Batch) Query(stmt string, args ...interface{}) {
 	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})

+ 1 - 1
wiki_test.go

@@ -70,7 +70,7 @@ func (w *WikiTest) CreateSchema() {
 			attachments map<varchar, blob>,
 			PRIMARY KEY (title, revid)
 		)`)
-	if clusterSize > 1 {
+	if *clusterSize > 1 {
 		// wait for table definition to propogate
 		time.Sleep(250 * time.Millisecond)
 	}