Selaa lähdekoodia

Added token aware connection pool implementation

- new implementation of ConnectionPool based on separate host and connection selection policies
- new round-robin and token-aware policies for connection pool
- moved HandleError outside of ConnectionPool interface to a ConnErrorHandler interface to
decouple ConnectionPool from Conn while maintaining backwards compatibility
- modified ring discovery to pass the partitioner value onto the connection pool via
a new optional SetPartitioners interface
- integration test for token aware connection pool
- added a systems test of connection pool behavior when cluster nodes shutdown and restart
Justin Corpron 11 vuotta sitten
vanhempi
commit
0b2fd28388
9 muutettua tiedostoa jossa 1097 lisäystä ja 39 poistoa
  1. 45 0
      cassandra_test.go
  2. 18 15
      conn.go
  3. 408 3
      connectionpool.go
  4. 266 0
      connectionpool_systems_test.go
  5. 54 0
      connectionpool_systems_test.sh
  6. 30 19
      host_source.go
  7. 187 0
      policies.go
  8. 86 0
      policies_test.go
  9. 3 2
      token.go

+ 45 - 0
cassandra_test.go

@@ -1902,3 +1902,48 @@ func TestRoutingKey(t *testing.T) {
 		t.Errorf("Expected cache size to be 2 but was %d", cacheSize)
 	}
 }
+
+func TestTokenAwareConnPool(t *testing.T) {
+	cluster := createCluster()
+	cluster.ConnPoolType = NewTokenAwareConnPool
+	cluster.DiscoverHosts = true
+
+	// Drop and re-create the keyspace once. Different tests should use their own
+	// individual tables, but can assume that the table does not exist before.
+	initOnce.Do(func() {
+		createKeyspace(t, cluster, "gocql_test")
+	})
+
+	cluster.Keyspace = "gocql_test"
+	session, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal("createSession:", err)
+	}
+	defer session.Close()
+
+	if *clusterSize > 1 {
+		// wait for autodiscovery to update the pool with the list of known hosts
+		time.Sleep(*flagAutoWait)
+	}
+
+	if session.Pool.Size() != cluster.NumConns*len(cluster.Hosts) {
+		t.Errorf("Expected pool size %d but was %d", cluster.NumConns*len(cluster.Hosts), session.Pool.Size())
+	}
+
+	if err := createTable(session, "CREATE TABLE test_token_aware (id int, data text, PRIMARY KEY (id))"); err != nil {
+		t.Fatalf("failed to create test_token_aware table with err: %v", err)
+	}
+	query := session.Query("INSERT INTO test_token_aware (id, data) VALUES (?,?)", 42, "8 * 6 =")
+	if err := query.Exec(); err != nil {
+		t.Fatalf("failed to insert with err: %v", err)
+	}
+	query = session.Query("SELECT data FROM test_token_aware where id = ?", 42).Consistency(One)
+	iter := query.Iter()
+	var data string
+	if !iter.Scan(&data) {
+		t.Error("failed to scan data")
+	}
+	if err := iter.Close(); err != nil {
+		t.Errorf("iter failed with err: %v", err)
+	}
+}

+ 18 - 15
conn.go

@@ -75,6 +75,10 @@ type ConnConfig struct {
 	tlsConfig     *tls.Config
 }
 
+type ConnErrorHandler interface {
+	HandleError(conn *Conn, err error, closed bool)
+}
+
 // Conn is a single connection to a Cassandra node. It can be used to execute
 // queries, but users are usually advised to use a more reliable, higher
 // level API.
@@ -88,7 +92,7 @@ type Conn struct {
 	uniq  chan int
 	calls []callReq
 
-	pool            ConnectionPool
+	errorHandler    ConnErrorHandler
 	compressor      Compressor
 	auth            Authenticator
 	addr            string
@@ -102,7 +106,7 @@ type Conn struct {
 
 // Connect establishes a connection to a Cassandra node.
 // You must also call the Serve method before you can execute any queries.
-func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
+func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
 	var (
 		err  error
 		conn net.Conn
@@ -137,18 +141,17 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 	}
 
 	c := &Conn{
-		conn:       conn,
-		r:          bufio.NewReader(conn),
-		uniq:       make(chan int, cfg.NumStreams),
-		calls:      make([]callReq, cfg.NumStreams),
-		timeout:    cfg.Timeout,
-		version:    uint8(cfg.ProtoVersion),
-		addr:       conn.RemoteAddr().String(),
-		pool:       pool,
-		compressor: cfg.Compressor,
-		auth:       cfg.Authenticator,
-
-		headerBuf: make([]byte, headerSize),
+		conn:         conn,
+		r:            bufio.NewReader(conn),
+		uniq:         make(chan int, cfg.NumStreams),
+		calls:        make([]callReq, cfg.NumStreams),
+		timeout:      cfg.Timeout,
+		version:      uint8(cfg.ProtoVersion),
+		addr:         conn.RemoteAddr().String(),
+		errorHandler: errorHandler,
+		compressor:   cfg.Compressor,
+		auth:         cfg.Authenticator,
+		headerBuf:    make([]byte, headerSize),
 	}
 
 	if cfg.Keepalive > 0 {
@@ -298,7 +301,7 @@ func (c *Conn) serve() {
 	}
 
 	if c.started {
-		c.pool.HandleError(c, err, true)
+		c.errorHandler.HandleError(c, err, true)
 	}
 }
 

+ 408 - 3
connectionpool.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io/ioutil"
 	"log"
+	"net"
 	"sync"
 	"time"
 )
@@ -90,9 +91,13 @@ To see a more complete example of a ConnectionPool implementation please see the
 type ConnectionPool interface {
 	Pick(*Query) *Conn
 	Size() int
-	HandleError(*Conn, error, bool)
 	Close()
-	SetHosts(host []HostInfo)
+	SetHosts(hosts []HostInfo)
+}
+
+// interface to implement to receive the partitioner value
+type SetPartitioner interface {
+	SetPartitioner(partitioner string)
 }
 
 //NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
@@ -395,7 +400,6 @@ func (c *SimplePool) SetHosts(hosts []HostInfo) {
 
 	for _, host := range hosts {
 		host := host
-
 		delete(toRemove, host.Peer)
 		// we already have it
 		if _, ok := c.hosts[host.Peer]; ok {
@@ -435,3 +439,404 @@ func (c *SimplePool) removeHostLocked(addr string) {
 		}
 	}
 }
+
+//NewRoundRobinConnPool creates a connection pool which selects hosts by
+//round-robin, and then selects a connection for that host by round-robin.
+func NewRoundRobinConnPool(cfg *ClusterConfig) (ConnectionPool, error) {
+	return NewPolicyConnPool(
+		cfg,
+		NewRoundRobinHostPolicy(),
+		NewRoundRobinConnPolicy,
+	)
+}
+
+//NewTokenAwareConnPool creates a connection pool which selects hosts by
+//a token aware policy, and then selects a connection for that host by
+//round-robin.
+func NewTokenAwareConnPool(cfg *ClusterConfig) (ConnectionPool, error) {
+	return NewPolicyConnPool(
+		cfg,
+		NewTokenAwareHostPolicy(NewRoundRobinHostPolicy()),
+		NewRoundRobinConnPolicy,
+	)
+}
+
+type policyConnPool struct {
+	port     int
+	numConns int
+	connCfg  ConnConfig
+	keyspace string
+
+	mu            sync.RWMutex
+	hostPolicy    HostSelectionPolicy
+	connPolicy    func() ConnSelectionPolicy
+	hostConnPools map[string]*hostConnPool
+}
+
+func NewPolicyConnPool(
+	cfg *ClusterConfig,
+	hostPolicy HostSelectionPolicy,
+	connPolicy func() ConnSelectionPolicy,
+) (ConnectionPool, error) {
+	var err error
+	var tlsConfig *tls.Config
+
+	if cfg.SslOpts != nil {
+		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	// create the pool
+	pool := &policyConnPool{
+		port:     cfg.Port,
+		numConns: cfg.NumConns,
+		connCfg: ConnConfig{
+			ProtoVersion:  cfg.ProtoVersion,
+			CQLVersion:    cfg.CQLVersion,
+			Timeout:       cfg.Timeout,
+			NumStreams:    cfg.NumStreams,
+			Compressor:    cfg.Compressor,
+			Authenticator: cfg.Authenticator,
+			Keepalive:     cfg.SocketKeepalive,
+			tlsConfig:     tlsConfig,
+		},
+		keyspace:      cfg.Keyspace,
+		hostPolicy:    hostPolicy,
+		connPolicy:    connPolicy,
+		hostConnPools: map[string]*hostConnPool{},
+	}
+
+	hosts := make([]HostInfo, len(cfg.Hosts))
+	for i, hostAddr := range cfg.Hosts {
+		hosts[i].Peer = hostAddr
+	}
+
+	pool.SetHosts(hosts)
+
+	return pool, nil
+}
+
+func (p *policyConnPool) SetHosts(hosts []HostInfo) {
+	p.mu.Lock()
+
+	toRemove := make(map[string]struct{})
+	for addr := range p.hostConnPools {
+		toRemove[addr] = struct{}{}
+	}
+
+	// TODO connect to hosts in parallel, but wait for pools to be
+	// created before returning
+
+	for i := range hosts {
+		pool, exists := p.hostConnPools[hosts[i].Peer]
+		if !exists {
+			// create a connection pool for the host
+			pool = newHostConnPool(
+				hosts[i].Peer,
+				p.port,
+				p.numConns,
+				p.connCfg,
+				p.keyspace,
+				p.connPolicy(),
+			)
+			p.hostConnPools[hosts[i].Peer] = pool
+		} else {
+			// still have this host, so don't remove it
+			delete(toRemove, hosts[i].Peer)
+		}
+	}
+
+	for addr := range toRemove {
+		pool := p.hostConnPools[addr]
+		delete(p.hostConnPools, addr)
+		pool.Close()
+	}
+
+	// update the policy
+	p.hostPolicy.SetHosts(hosts)
+
+	p.mu.Unlock()
+}
+
+func (p *policyConnPool) SetPartitioner(partitioner string) {
+	p.hostPolicy.SetPartitioner(partitioner)
+}
+
+func (p *policyConnPool) Size() int {
+	p.mu.RLock()
+	count := 0
+	for _, pool := range p.hostConnPools {
+		count += pool.Size()
+	}
+	p.mu.RUnlock()
+
+	return count
+}
+
+func (p *policyConnPool) Pick(qry *Query) *Conn {
+	nextHost := p.hostPolicy.Pick(qry)
+
+	p.mu.RLock()
+	var host *HostInfo
+	var conn *Conn
+	for conn == nil {
+		host = nextHost()
+		if host == nil {
+			break
+		}
+		conn = p.hostConnPools[host.Peer].Pick(qry)
+	}
+	p.mu.RUnlock()
+	return conn
+}
+
+func (p *policyConnPool) Close() {
+	p.mu.Lock()
+
+	// remove the hosts from the policy
+	p.hostPolicy.SetHosts([]HostInfo{})
+
+	// close the pools
+	for addr, pool := range p.hostConnPools {
+		delete(p.hostConnPools, addr)
+		pool.Close()
+	}
+	p.mu.Unlock()
+}
+
+// hostConnPool is a connection pool for a single host.
+// Connection selection is based on a provided ConnSelectionPolicy
+type hostConnPool struct {
+	host     string
+	port     int
+	addr     string
+	size     int
+	connCfg  ConnConfig
+	keyspace string
+	policy   ConnSelectionPolicy
+	// protection for conns, closed, filling
+	mu      sync.Mutex
+	conns   []*Conn
+	closed  bool
+	filling bool
+}
+
+func newHostConnPool(
+	host string,
+	port int,
+	size int,
+	connCfg ConnConfig,
+	keyspace string,
+	policy ConnSelectionPolicy,
+) *hostConnPool {
+
+	pool := &hostConnPool{
+		host:     host,
+		port:     port,
+		addr:     JoinHostPort(host, port),
+		size:     size,
+		connCfg:  connCfg,
+		keyspace: keyspace,
+		policy:   policy,
+		conns:    make([]*Conn, 0, size),
+		filling:  false,
+		closed:   false,
+	}
+
+	// fill the pool with the initial connections before returning
+	pool.fill()
+
+	return pool
+}
+
+// Pick a connection from this connection pool for the given query.
+func (pool *hostConnPool) Pick(qry *Query) *Conn {
+	pool.mu.Lock()
+	if pool.closed {
+		pool.mu.Unlock()
+		return nil
+	}
+
+	empty := len(pool.conns) == 0
+	pool.mu.Unlock()
+
+	if empty {
+		// try to fill the empty pool
+		pool.fill()
+	}
+
+	return pool.policy.Pick(qry)
+}
+
+//Size returns the number of connections currently active in the pool
+func (pool *hostConnPool) Size() int {
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	return len(pool.conns)
+}
+
+//Close the connection pool
+func (pool *hostConnPool) Close() {
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	if pool.closed {
+		return
+	}
+	pool.closed = true
+
+	// drain, but don't wait
+	go pool.drain()
+}
+
+// Fill the connection pool
+func (pool *hostConnPool) fill() {
+	pool.mu.Lock()
+	// avoid filling a closed pool, or concurrent filling
+	if pool.closed || pool.filling {
+		pool.mu.Unlock()
+		return
+	}
+
+	// determine the filling work to be done
+	startCount := len(pool.conns)
+	fillCount := pool.size - startCount
+
+	// avoid filling a full (or overfull) pool
+	if fillCount <= 0 {
+		pool.mu.Unlock()
+		return
+	}
+
+	// ok fill the pool
+	pool.filling = true
+
+	// allow others to access the pool while filling
+	pool.mu.Unlock()
+
+	// fill only the first connection synchronously
+	if startCount == 0 {
+		err := pool.connect()
+		if opErr, ok := err.(*net.OpError); ok && opErr.Op == "read" {
+			// connection refused
+			// these are typical during a node outage so avoid log spam.
+		} else if err != nil {
+			log.Printf("error: failed to connect to %s - %v", pool.addr, err)
+		}
+
+		if err != nil {
+			// probably unreachable host
+			go pool.fillingStopped()
+			return
+		}
+
+		// filled one
+		fillCount--
+	}
+
+	// fill the rest of the pool asynchronously
+	go func() {
+		for fillCount > 0 {
+			err := pool.connect()
+			if opErr, ok := err.(*net.OpError); ok && opErr.Op == "read" {
+				// connection refused
+				// these are typical during a node outage so avoid log spam.
+			} else if err != nil {
+				log.Printf("error: failed to connect to %s - %v", pool.addr, err)
+			}
+
+			// decrement, even on error
+			fillCount--
+		}
+
+		// mark the end of filling
+		pool.fillingStopped()
+	}()
+}
+
+// transition back to a not-filling state.
+func (pool *hostConnPool) fillingStopped() {
+	// wait for some time to avoid back-to-back filling
+	time.Sleep(100 * time.Millisecond)
+
+	pool.mu.Lock()
+	pool.filling = false
+	pool.mu.Unlock()
+}
+
+// create a new connection to the host and add it to the pool
+func (pool *hostConnPool) connect() error {
+	// try to connect
+	conn, err := Connect(pool.addr, pool.connCfg, pool)
+	if err != nil {
+		return err
+	}
+
+	if pool.keyspace != "" {
+		// set the keyspace
+		if err := conn.UseKeyspace(pool.keyspace); err != nil {
+			conn.Close()
+			return err
+		}
+	}
+
+	// add the Conn to the pool
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	if pool.closed {
+		conn.Close()
+		return nil
+	}
+
+	pool.conns = append(pool.conns, conn)
+	pool.policy.SetConns(pool.conns)
+	return nil
+}
+
+// handle any error from a Conn
+func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
+	if !closed {
+		// still an open connection, so continue using it
+		return
+	}
+
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	// find the connection index
+	for i, candidate := range pool.conns {
+		if candidate == conn {
+			// remove the connection, not preserving order
+			pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
+
+			// update the policy
+			pool.policy.SetConns(pool.conns)
+
+			// lost a connection, so fill the pool
+			go pool.fill()
+			break
+		}
+	}
+}
+
+// removes and closes all connections from the pool
+func (pool *hostConnPool) drain() {
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	// empty the pool
+	conns := pool.conns
+	pool.conns = pool.conns[:0]
+
+	// update the policy
+	pool.policy.SetConns(pool.conns)
+
+	// close the connections
+	for _, conn := range conns {
+		conn.Close()
+	}
+}

+ 266 - 0
connectionpool_systems_test.go

@@ -0,0 +1,266 @@
+// Copyright (c) 2015 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+// +build conn_pool
+
+package gocql
+
+import (
+	"flag"
+	"fmt"
+	"log"
+	"os/exec"
+	"strconv"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+)
+
+// connection pool behavior test when nodes are removed from the cluster
+// to run this test, see connectionpool_systems_test.sh
+
+var (
+	flagCluster  = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
+	flagProto    = flag.Int("proto", 2, "protcol version")
+	flagCQL      = flag.String("cql", "3.0.0", "CQL version")
+	flagRF       = flag.Int("rf", 1, "replication factor for test keyspace")
+	clusterSize  = flag.Int("clusterSize", 1, "the expected size of the cluster")
+	nodesShut    = flag.Int("nodesShut", 1, "the number of nodes to shutdown during the test")
+	flagRetry    = flag.Int("retries", 5, "number of times to retry queries")
+	flagRunSsl   = flag.Bool("runssl", false, "Set to true to run ssl test")
+	clusterHosts []string
+)
+var initOnce sync.Once
+
+func init() {
+	flag.Parse()
+	clusterHosts = strings.Split(*flagCluster, ",")
+	log.SetFlags(log.Lshortfile | log.LstdFlags)
+}
+
+func createTable(s *Session, table string) error {
+	err := s.Query(table).Consistency(All).Exec()
+	if *clusterSize > 1 {
+		// wait for table definition to propogate
+		time.Sleep(250 * time.Millisecond)
+	}
+	return err
+}
+
+func createCluster() *ClusterConfig {
+	cluster := NewCluster(clusterHosts...)
+	cluster.ProtoVersion = *flagProto
+	cluster.CQLVersion = *flagCQL
+	cluster.Timeout = 5 * time.Second
+	cluster.Consistency = Quorum
+	if *flagRetry > 0 {
+		cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
+	}
+	if *flagRunSsl {
+		cluster.SslOpts = &SslOptions{
+			CertPath:               "testdata/pki/gocql.crt",
+			KeyPath:                "testdata/pki/gocql.key",
+			CaPath:                 "testdata/pki/ca.crt",
+			EnableHostVerification: false,
+		}
+	}
+	return cluster
+}
+
+func createKeyspace(t testing.T, cluster *ClusterConfig, keyspace string) {
+	session, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal("createSession:", err)
+	}
+	if err = session.Query(`DROP KEYSPACE ` + keyspace).Exec(); err != nil {
+		t.Log("drop keyspace:", err)
+	}
+	err = session.Query(
+		fmt.Sprintf(
+			`
+			CREATE KEYSPACE %s
+			WITH replication = {
+				'class' : 'SimpleStrategy',
+				'replication_factor' : %d
+			}
+			`,
+			keyspace,
+			*flagRF,
+		),
+	).Consistency(All).Exec()
+	if err != nil {
+		t.Fatalf("error creating keyspace %s: %v", keyspace, err)
+	}
+	t.Logf("Created keyspace %s", keyspace)
+	session.Close()
+}
+
+func createSession(t testing.T) *Session {
+	cluster := createCluster()
+
+	// Drop and re-create the keyspace once. Different tests should use their own
+	// individual tables, but can assume that the table does not exist before.
+	initOnce.Do(func() {
+		createKeyspace(t, cluster, "gocql_test")
+	})
+
+	cluster.Keyspace = "gocql_test"
+	session, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal("createSession:", err)
+	}
+
+	return session
+}
+
+func TestSimplePool(t *testing.T) {
+	testConnPool(t, NewSimplePool)
+}
+
+func TestRRPolicyConnPool(t *testing.T) {
+	testConnPool(t, NewRoundRobinConnPool)
+}
+
+func TestTAPolicyConnPool(t *testing.T) {
+	testConnPool(t, NewTokenAwareConnPool)
+}
+
+func testConnPool(t *testing.T, connPoolType NewPoolFunc) {
+	var out []byte
+	var err error
+	log.SetFlags(log.Ltime)
+
+	// make sure the cluster is running
+	out, err = exec.Command("ccm", "start").CombinedOutput()
+	if err != nil {
+		t.Fatalf("Error running ccm command: %v", err)
+		fmt.Printf("ccm output:\n%s", string(out))
+	}
+
+	time.Sleep(time.Duration(*clusterSize) * 1000 * time.Millisecond)
+
+	// fire up a session (no discovery)
+	cluster := createCluster()
+	cluster.ConnPoolType = connPoolType
+	cluster.DiscoverHosts = false
+	session, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatalf("Error connecting to cluster: %v", err)
+	}
+	defer session.Close()
+
+	time.Sleep(time.Duration(*clusterSize) * 1000 * time.Millisecond)
+
+	if session.Pool.Size() != (*clusterSize)*cluster.NumConns {
+		t.Logf(
+			"WARN: Expected %d pool size, but was %d",
+			(*clusterSize)*cluster.NumConns,
+			session.Pool.Size(),
+		)
+	}
+
+	// start some connection monitoring
+	nilCheckStop := false
+	nilCount := 0
+	nilCheck := func() {
+		// assert that all connections returned by the pool are non-nil
+		for !nilCheckStop {
+			actual := session.Pool.Pick(nil)
+			if actual == nil {
+				nilCount++
+			}
+		}
+	}
+	go nilCheck()
+
+	// shutdown some hosts
+	log.Println("shutdown some hosts")
+	for i := 0; i < *nodesShut; i++ {
+		out, err = exec.Command("ccm", "node"+strconv.Itoa(i+1), "stop").CombinedOutput()
+		if err != nil {
+			t.Fatalf("Error running ccm command: %v", err)
+			fmt.Printf("ccm output:\n%s", string(out))
+		}
+		time.Sleep(1500 * time.Millisecond)
+	}
+	time.Sleep(500 * time.Millisecond)
+
+	if session.Pool.Size() != ((*clusterSize)-(*nodesShut))*cluster.NumConns {
+		t.Logf(
+			"WARN: Expected %d pool size, but was %d",
+			((*clusterSize)-(*nodesShut))*cluster.NumConns,
+			session.Pool.Size(),
+		)
+	}
+
+	// bringup the shutdown hosts
+	log.Println("bringup the shutdown hosts")
+	for i := 0; i < *nodesShut; i++ {
+		out, err = exec.Command("ccm", "node"+strconv.Itoa(i+1), "start").CombinedOutput()
+		if err != nil {
+			t.Fatalf("Error running ccm command: %v", err)
+			fmt.Printf("ccm output:\n%s", string(out))
+		}
+		time.Sleep(1500 * time.Millisecond)
+	}
+	time.Sleep(500 * time.Millisecond)
+
+	if session.Pool.Size() != (*clusterSize)*cluster.NumConns {
+		t.Logf(
+			"WARN: Expected %d pool size, but was %d",
+			(*clusterSize)*cluster.NumConns,
+			session.Pool.Size(),
+		)
+	}
+
+	// assert that all connections returned by the pool are non-nil
+	if nilCount > 0 {
+		t.Errorf("%d nil connections returned from %T", nilCount, session.Pool)
+	}
+
+	// shutdown cluster
+	log.Println("shutdown cluster")
+	out, err = exec.Command("ccm", "stop").CombinedOutput()
+	if err != nil {
+		t.Fatalf("Error running ccm command: %v", err)
+		fmt.Printf("ccm output:\n%s", string(out))
+	}
+	time.Sleep(2500 * time.Millisecond)
+
+	if session.Pool.Size() != 0 {
+		t.Logf(
+			"WARN: Expected %d pool size, but was %d",
+			0,
+			session.Pool.Size(),
+		)
+	}
+
+	// start cluster
+	log.Println("start cluster")
+	out, err = exec.Command("ccm", "start").CombinedOutput()
+	if err != nil {
+		t.Fatalf("Error running ccm command: %v", err)
+		fmt.Printf("ccm output:\n%s", string(out))
+	}
+	time.Sleep(500 * time.Millisecond)
+
+	// reset the count
+	nilCount = 0
+
+	time.Sleep(3000 * time.Millisecond)
+
+	if session.Pool.Size() != (*clusterSize)*cluster.NumConns {
+		t.Logf(
+			"WARN: Expected %d pool size, but was %d",
+			(*clusterSize)*cluster.NumConns,
+			session.Pool.Size(),
+		)
+	}
+
+	// assert that all connections returned by the pool are non-nil
+	if nilCount > 0 {
+		t.Errorf("%d nil connections returned from %T", nilCount, session.Pool)
+	}
+	nilCheckStop = true
+}

+ 54 - 0
connectionpool_systems_test.sh

@@ -0,0 +1,54 @@
+#!/bin/bash
+
+set -e
+
+function run_tests() {
+	local clusterSize=5
+	local nodesShut=2
+	local version=$1
+
+	ccm remove test || true
+
+	local keypath="$(pwd)/testdata/pki"
+
+	local conf=(
+	    "client_encryption_options.enabled: true"
+	    "client_encryption_options.keystore: $keypath/.keystore"
+	    "client_encryption_options.keystore_password: cassandra"
+	    "client_encryption_options.require_client_auth: true"
+	    "client_encryption_options.truststore: $keypath/.truststore"
+	    "client_encryption_options.truststore_password: cassandra"
+	    "concurrent_reads: 2"
+	    "concurrent_writes: 2"
+	    "rpc_server_type: sync"
+	    "rpc_min_threads: 2"
+	    "rpc_max_threads: 2"
+	    "write_request_timeout_in_ms: 5000"
+	    "read_request_timeout_in_ms: 5000"
+	)
+
+	ccm create test -v binary:$version -n $clusterSize -d --vnodes --jvm_arg="-Xmx256m"
+    ccm updateconf "${conf[@]}"
+	ccm start -v
+	ccm status
+	ccm node1 nodetool status
+
+	local proto=2
+	if [[ $version == 1.2.* ]]; then
+		proto=1
+	fi
+
+	go test -timeout 15m -tags conn_pool -v -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -nodesShut=$nodesShut ./... | tee results.txt
+
+	if [ ${PIPESTATUS[0]} -ne 0 ]; then
+		echo "--- FAIL: ccm status follows:"
+		ccm status
+		ccm node1 nodetool status
+		ccm node1 showlog > status.log
+		cat status.log
+		echo "--- FAIL: Received a non-zero exit code from the go test execution, please investigate this"
+		exit 1
+	fi
+	ccm remove
+}
+run_tests $1

+ 30 - 19
host_source.go

@@ -16,28 +16,33 @@ type HostInfo struct {
 
 // Polls system.peers at a specific interval to find new hosts
 type ringDescriber struct {
-	dcFilter   string
-	rackFilter string
-	previous   []HostInfo
-	session    *Session
+	dcFilter        string
+	rackFilter      string
+	prevHosts       []HostInfo
+	prevPartitioner string
+	session         *Session
 }
 
-func (r *ringDescriber) GetHosts() ([]HostInfo, error) {
+func (r *ringDescriber) GetHosts() (
+	hosts []HostInfo,
+	partitioner string,
+	err error,
+) {
 	// we need conn to be the same because we need to query system.peers and system.local
 	// on the same node to get the whole cluster
 	conn := r.session.Pool.Pick(nil)
 	if conn == nil {
-		return r.previous, nil
+		return r.prevHosts, r.prevPartitioner, nil
 	}
 
-	query := r.session.Query("SELECT data_center, rack, host_id, tokens FROM system.local")
+	query := r.session.Query("SELECT data_center, rack, host_id, tokens, partitioner FROM system.local")
 	iter := conn.executeQuery(query)
 
-	host := &HostInfo{}
-	iter.Scan(&host.DataCenter, &host.Rack, &host.HostId, &host.Tokens)
+	host := HostInfo{}
+	iter.Scan(&host.DataCenter, &host.Rack, &host.HostId, &host.Tokens, &partitioner)
 
-	if err := iter.Close(); err != nil {
-		return nil, err
+	if err = iter.Close(); err != nil {
+		return nil, "", err
 	}
 
 	addr, _, err := net.SplitHostPort(conn.Address())
@@ -49,24 +54,27 @@ func (r *ringDescriber) GetHosts() ([]HostInfo, error) {
 
 	host.Peer = addr
 
-	hosts := []HostInfo{*host}
+	hosts = []HostInfo{host}
 
 	query = r.session.Query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
 	iter = conn.executeQuery(query)
 
+	host = HostInfo{}
 	for iter.Scan(&host.Peer, &host.DataCenter, &host.Rack, &host.HostId, &host.Tokens) {
-		if r.matchFilter(host) {
-			hosts = append(hosts, *host)
+		if r.matchFilter(&host) {
+			hosts = append(hosts, host)
 		}
+		host = HostInfo{}
 	}
 
-	if err := iter.Close(); err != nil {
-		return nil, err
+	if err = iter.Close(); err != nil {
+		return nil, "", err
 	}
 
-	r.previous = hosts
+	r.prevHosts = hosts
+	r.prevPartitioner = partitioner
 
-	return hosts, nil
+	return hosts, partitioner, nil
 }
 
 func (r *ringDescriber) matchFilter(host *HostInfo) bool {
@@ -92,11 +100,14 @@ func (h *ringDescriber) run(sleep time.Duration) {
 		// attempt to reconnect to the cluster otherwise we would never find
 		// downed hosts again, could possibly have an optimisation to only
 		// try to add new hosts if GetHosts didnt error and the hosts didnt change.
-		hosts, err := h.GetHosts()
+		hosts, partitioner, err := h.GetHosts()
 		if err != nil {
 			log.Println("RingDescriber: unable to get ring topology:", err)
 		} else {
 			h.session.Pool.SetHosts(hosts)
+			if v, ok := h.session.Pool.(SetPartitioner); ok {
+				v.SetPartitioner(partitioner)
+			}
 		}
 
 		time.Sleep(sleep)

+ 187 - 0
policies.go

@@ -4,6 +4,12 @@
 //This file will be the future home for more policies
 package gocql
 
+import (
+	"log"
+	"sync"
+	"sync/atomic"
+)
+
 //RetryableQuery is an interface that represents a query or batch statement that
 //exposes the correct functions for the retry policy logic to evaluate correctly.
 type RetryableQuery interface {
@@ -42,3 +48,184 @@ type SimpleRetryPolicy struct {
 func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 	return q.Attempts() <= s.NumRetries
 }
+
+//HostSelectionPolicy is an interface for selecting
+//the most appropriate host to execute a given query.
+type HostSelectionPolicy interface {
+	//SetHosts notifies this policy of the current hosts in the cluster
+	SetHosts(hosts []HostInfo)
+	//SetPartitioner notifies this policy of the current token partitioner
+	SetPartitioner(partitioner string)
+	//Pick returns an iteration function over selected hosts
+	Pick(*Query) NextHost
+}
+
+//NextHost is an iteration function over picked hosts
+type NextHost func() *HostInfo
+
+//NewRoundRobinHostPolicy is a round-robin load balancing policy
+func NewRoundRobinHostPolicy() HostSelectionPolicy {
+	return &roundRobinHostPolicy{hosts: []HostInfo{}}
+}
+
+type roundRobinHostPolicy struct {
+	hosts []HostInfo
+	pos   uint32
+	mu    sync.RWMutex
+}
+
+func (r *roundRobinHostPolicy) SetHosts(hosts []HostInfo) {
+	r.mu.Lock()
+	r.hosts = hosts
+	r.mu.Unlock()
+}
+
+func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {
+	// noop
+}
+
+func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
+	pos := atomic.AddUint32(&r.pos, 1)
+	var i uint32 = 0
+	return func() *HostInfo {
+		var host *HostInfo
+		r.mu.RLock()
+		if len(r.hosts) > 0 && int(i) < len(r.hosts) {
+			host = &r.hosts[(pos+i)%uint32(len(r.hosts))]
+			i++
+		}
+		r.mu.RUnlock()
+		return host
+	}
+}
+
+//NewTokenAwareHostPolicy is a token aware host selection policy
+func NewTokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
+	return &tokenAwareHostPolicy{fallback: fallback, hosts: []HostInfo{}}
+}
+
+type tokenAwareHostPolicy struct {
+	mu          sync.RWMutex
+	hosts       []HostInfo
+	partitioner string
+	tokenRing   *tokenRing
+	fallback    HostSelectionPolicy
+}
+
+func (t *tokenAwareHostPolicy) SetHosts(hosts []HostInfo) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	// always update the fallback
+	t.fallback.SetHosts(hosts)
+	t.hosts = hosts
+
+	t.resetTokenRing()
+}
+
+func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	t.fallback.SetPartitioner(partitioner)
+	t.partitioner = partitioner
+
+	t.resetTokenRing()
+}
+
+func (t *tokenAwareHostPolicy) resetTokenRing() {
+	if t.partitioner == "" {
+		// partitioner not yet set
+		return
+	}
+
+	// create a new token ring
+	tokenRing, err := newTokenRing(t.partitioner, t.hosts)
+	if err != nil {
+		log.Printf("Unable to update the token ring due to error: %s", err)
+		return
+	}
+
+	// replace the token ring
+	t.tokenRing = tokenRing
+}
+
+func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
+	if qry == nil {
+		return t.fallback.Pick(qry)
+	}
+
+	routingKey, err := qry.GetRoutingKey()
+	if err != nil {
+		return t.fallback.Pick(qry)
+	}
+	if routingKey == nil {
+		return t.fallback.Pick(qry)
+	}
+
+	var host *HostInfo
+
+	t.mu.RLock()
+	if t.tokenRing != nil {
+		host = t.tokenRing.GetHostForPartitionKey(routingKey)
+	}
+	t.mu.RUnlock()
+
+	if host == nil {
+		return t.fallback.Pick(qry)
+	}
+
+	var hostReturned bool = false
+	var once sync.Once
+	var fallbackIter NextHost
+	return func() *HostInfo {
+		if !hostReturned {
+			hostReturned = true
+			return host
+		}
+
+		// fallback
+		once.Do(func() { fallbackIter = t.fallback.Pick(qry) })
+
+		fallbackHost := fallbackIter()
+		if fallbackHost == host {
+			fallbackHost = fallbackIter()
+		}
+
+		return fallbackHost
+	}
+}
+
+//ConnSelectionPolicy is an interface for selecting an
+//appropriate connection for executing a query
+type ConnSelectionPolicy interface {
+	SetConns(conns []*Conn)
+	Pick(*Query) *Conn
+}
+
+type roundRobinConnPolicy struct {
+	conns []*Conn
+	pos   uint32
+	mu    sync.RWMutex
+}
+
+func NewRoundRobinConnPolicy() ConnSelectionPolicy {
+	return &roundRobinConnPolicy{conns: []*Conn{}}
+}
+
+func (r *roundRobinConnPolicy) SetConns(conns []*Conn) {
+	r.mu.Lock()
+	r.conns = conns
+	r.mu.Unlock()
+}
+
+func (r *roundRobinConnPolicy) Pick(qry *Query) *Conn {
+	pos := atomic.AddUint32(&r.pos, 1)
+	var conn *Conn
+	r.mu.RLock()
+	if len(r.conns) > 0 {
+		conn = r.conns[pos%uint32(len(r.conns))]
+	}
+	r.mu.RUnlock()
+	return conn
+}

+ 86 - 0
policies_test.go

@@ -0,0 +1,86 @@
+// Copyright (c) 2015 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gocql
+
+import "testing"
+
+func TestRoundRobinHostPolicy(t *testing.T) {
+	policy := NewRoundRobinHostPolicy()
+
+	hosts := []HostInfo{
+		HostInfo{HostId: "0"},
+		HostInfo{HostId: "1"},
+	}
+
+	policy.SetHosts(hosts)
+
+	// the first host selected is actually at [1], but this is ok for RR
+	iter := policy.Pick(nil)
+	if actual := iter(); actual != &hosts[1] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.HostId)
+	}
+	if actual := iter(); actual != &hosts[0] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.HostId)
+	}
+	iter = policy.Pick(nil)
+	if actual := iter(); actual != &hosts[0] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.HostId)
+	}
+	if actual := iter(); actual != &hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.HostId)
+	}
+	iter = policy.Pick(nil)
+	if actual := iter(); actual != &hosts[1] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.HostId)
+	}
+	if actual := iter(); actual != &hosts[0] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.HostId)
+	}
+}
+
+func TestTokenAwareHostPolicy(t *testing.T) {
+	policy := NewTokenAwareHostPolicy(NewRoundRobinHostPolicy())
+
+	hosts := []HostInfo{
+		HostInfo{HostId: "0", Peer: "0", Tokens: []string{"00"}},
+		HostInfo{HostId: "1", Peer: "1", Tokens: []string{"25"}},
+		HostInfo{HostId: "2", Peer: "2", Tokens: []string{"50"}},
+		HostInfo{HostId: "3", Peer: "3", Tokens: []string{"75"}},
+	}
+
+	policy.SetHosts(hosts)
+	policy.SetPartitioner("OrderedPartitioner")
+
+	query := &Query{}
+	query.RoutingKey([]byte("30"))
+
+	if actual := policy.Pick(query)(); actual != &hosts[2] {
+		t.Errorf("Expected hosts[2] but was hosts[%s]", actual.HostId)
+	}
+}
+
+func TestRoundRobinConnPolicy(t *testing.T) {
+	policy := NewRoundRobinConnPolicy()
+
+	conn0 := &Conn{}
+	conn1 := &Conn{}
+	conn := []*Conn{
+		conn0,
+		conn1,
+	}
+
+	policy.SetConns(conn)
+
+	// the first conn selected is actually at [1], but this is ok for RR
+	if actual := policy.Pick(nil); actual != conn1 {
+		t.Error("Expected conn1")
+	}
+	if actual := policy.Pick(nil); actual != conn0 {
+		t.Error("Expected conn0")
+	}
+	if actual := policy.Pick(nil); actual != conn1 {
+		t.Error("Expected conn1")
+	}
+}

+ 3 - 2
token.go

@@ -239,7 +239,7 @@ type tokenRing struct {
 	hosts       []*HostInfo
 }
 
-func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
+func newTokenRing(partitioner string, hosts []HostInfo) (*tokenRing, error) {
 	tokenRing := &tokenRing{
 		tokens: []token{},
 		hosts:  []*HostInfo{},
@@ -255,7 +255,8 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
 		return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
 	}
 
-	for _, host := range hosts {
+	for i := range hosts {
+		host := &hosts[i]
 		for _, strToken := range host.Tokens {
 			token := tokenRing.partitioner.ParseString(strToken)
 			tokenRing.tokens = append(tokenRing.tokens, token)