Browse Source

Merge pull request #348 from retailnext/policy_conn_pool

Policy based ConnectionPool with token-aware, and round-robin policy implementation
Ben Hood 10 years ago
parent
commit
d4c587dfad
11 changed files with 1507 additions and 126 deletions
  1. 53 0
      cassandra_test.go
  2. 18 15
      conn.go
  3. 99 3
      conn_test.go
  4. 447 3
      connectionpool.go
  5. 30 19
      host_source.go
  6. 20 8
      metadata.go
  7. 159 34
      metadata_test.go
  8. 200 0
      policies.go
  9. 125 0
      policies_test.go
  10. 41 14
      token.go
  11. 315 30
      token_test.go

+ 53 - 0
cassandra_test.go

@@ -1508,6 +1508,7 @@ func TestEmptyTimestamp(t *testing.T) {
 	}
 }
 
+// Integration test of just querying for data from the system.schema_keyspace table
 func TestGetKeyspaceMetadata(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
@@ -1541,6 +1542,7 @@ func TestGetKeyspaceMetadata(t *testing.T) {
 	}
 }
 
+// Integration test of just querying for data from the system.schema_columnfamilies table
 func TestGetTableMetadata(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
@@ -1622,6 +1624,7 @@ func TestGetTableMetadata(t *testing.T) {
 	}
 }
 
+// Integration test of just querying for data from the system.schema_columns table
 func TestGetColumnMetadata(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
@@ -1723,6 +1726,7 @@ func TestGetColumnMetadata(t *testing.T) {
 	}
 }
 
+// Integration test of querying and composition the keyspace metadata
 func TestKeyspaceMetadata(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
@@ -1785,6 +1789,7 @@ func TestKeyspaceMetadata(t *testing.T) {
 	}
 }
 
+// Integration test of the routing key calculation
 func TestRoutingKey(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
@@ -1902,3 +1907,51 @@ func TestRoutingKey(t *testing.T) {
 		t.Errorf("Expected cache size to be 2 but was %d", cacheSize)
 	}
 }
+
+// Integration test of the token-aware policy-based connection pool
+func TestTokenAwareConnPool(t *testing.T) {
+	cluster := createCluster()
+	cluster.ConnPoolType = NewTokenAwareConnPool
+	cluster.DiscoverHosts = true
+
+	// Drop and re-create the keyspace once. Different tests should use their own
+	// individual tables, but can assume that the table does not exist before.
+	initOnce.Do(func() {
+		createKeyspace(t, cluster, "gocql_test")
+	})
+
+	cluster.Keyspace = "gocql_test"
+	session, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal("createSession:", err)
+	}
+	defer session.Close()
+
+	if *clusterSize > 1 {
+		// wait for autodiscovery to update the pool with the list of known hosts
+		time.Sleep(*flagAutoWait)
+	}
+
+	if session.Pool.Size() != cluster.NumConns*len(cluster.Hosts) {
+		t.Errorf("Expected pool size %d but was %d", cluster.NumConns*len(cluster.Hosts), session.Pool.Size())
+	}
+
+	if err := createTable(session, "CREATE TABLE test_token_aware (id int, data text, PRIMARY KEY (id))"); err != nil {
+		t.Fatalf("failed to create test_token_aware table with err: %v", err)
+	}
+	query := session.Query("INSERT INTO test_token_aware (id, data) VALUES (?,?)", 42, "8 * 6 =")
+	if err := query.Exec(); err != nil {
+		t.Fatalf("failed to insert with err: %v", err)
+	}
+	query = session.Query("SELECT data FROM test_token_aware where id = ?", 42).Consistency(One)
+	iter := query.Iter()
+	var data string
+	if !iter.Scan(&data) {
+		t.Error("failed to scan data")
+	}
+	if err := iter.Close(); err != nil {
+		t.Errorf("iter failed with err: %v", err)
+	}
+
+	// TODO add verification that the query went to the correct host
+}

+ 18 - 15
conn.go

@@ -75,6 +75,10 @@ type ConnConfig struct {
 	tlsConfig     *tls.Config
 }
 
+type ConnErrorHandler interface {
+	HandleError(conn *Conn, err error, closed bool)
+}
+
 // Conn is a single connection to a Cassandra node. It can be used to execute
 // queries, but users are usually advised to use a more reliable, higher
 // level API.
@@ -88,7 +92,7 @@ type Conn struct {
 	uniq  chan int
 	calls []callReq
 
-	pool            ConnectionPool
+	errorHandler    ConnErrorHandler
 	compressor      Compressor
 	auth            Authenticator
 	addr            string
@@ -102,7 +106,7 @@ type Conn struct {
 
 // Connect establishes a connection to a Cassandra node.
 // You must also call the Serve method before you can execute any queries.
-func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
+func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
 	var (
 		err  error
 		conn net.Conn
@@ -137,18 +141,17 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 	}
 
 	c := &Conn{
-		conn:       conn,
-		r:          bufio.NewReader(conn),
-		uniq:       make(chan int, cfg.NumStreams),
-		calls:      make([]callReq, cfg.NumStreams),
-		timeout:    cfg.Timeout,
-		version:    uint8(cfg.ProtoVersion),
-		addr:       conn.RemoteAddr().String(),
-		pool:       pool,
-		compressor: cfg.Compressor,
-		auth:       cfg.Authenticator,
-
-		headerBuf: make([]byte, headerSize),
+		conn:         conn,
+		r:            bufio.NewReader(conn),
+		uniq:         make(chan int, cfg.NumStreams),
+		calls:        make([]callReq, cfg.NumStreams),
+		timeout:      cfg.Timeout,
+		version:      uint8(cfg.ProtoVersion),
+		addr:         conn.RemoteAddr().String(),
+		errorHandler: errorHandler,
+		compressor:   cfg.Compressor,
+		auth:         cfg.Authenticator,
+		headerBuf:    make([]byte, headerSize),
 	}
 
 	if cfg.Keepalive > 0 {
@@ -298,7 +301,7 @@ func (c *Conn) serve() {
 	}
 
 	if c.started {
-		c.pool.HandleError(c, err, true)
+		c.errorHandler.HandleError(c, err, true)
 	}
 }
 

+ 99 - 3
conn_test.go

@@ -1,3 +1,6 @@
+// Copyright (c) 2012 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
 // +build all unit
 
 package gocql
@@ -177,7 +180,7 @@ func TestSlowQuery(t *testing.T) {
 	}
 }
 
-func TestRoundRobin(t *testing.T) {
+func TestSimplePoolRoundRobin(t *testing.T) {
 	servers := make([]*TestServer, 5)
 	addrs := make([]string, len(servers))
 	for n := 0; n < len(servers); n++ {
@@ -223,7 +226,7 @@ func TestRoundRobin(t *testing.T) {
 	}
 
 	if diff > 0 {
-		t.Fatal("diff:", diff)
+		t.Errorf("Expected 0 difference in usage but was %d", diff)
 	}
 }
 
@@ -258,7 +261,7 @@ func TestConnClosing(t *testing.T) {
 	conns := pool.Size()
 
 	if conns != numConns {
-		t.Fatalf("Expected to have %d connections but have %d", numConns, conns)
+		t.Errorf("Expected to have %d connections but have %d", numConns, conns)
 	}
 }
 
@@ -373,6 +376,99 @@ func BenchmarkProtocolV3(b *testing.B) {
 	}
 }
 
+func TestRoundRobinConnPoolRoundRobin(t *testing.T) {
+	// create 5 test servers
+	servers := make([]*TestServer, 5)
+	addrs := make([]string, len(servers))
+	for n := 0; n < len(servers); n++ {
+		servers[n] = NewTestServer(t, defaultProto)
+		addrs[n] = servers[n].Address
+		defer servers[n].Stop()
+	}
+
+	// create a new cluster using the policy-based round robin conn pool
+	cluster := NewCluster(addrs...)
+	cluster.ConnPoolType = NewRoundRobinConnPool
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatalf("failed to create a new session: %v", err)
+	}
+
+	// Sleep to allow the pool to fill
+	time.Sleep(100 * time.Millisecond)
+
+	// run concurrent queries against the pool, server usage should
+	// be even
+	var wg sync.WaitGroup
+	wg.Add(5)
+	for n := 0; n < 5; n++ {
+		go func() {
+			for j := 0; j < 5; j++ {
+				if err := db.Query("void").Exec(); err != nil {
+					t.Errorf("Query failed with error: %v", err)
+				}
+			}
+			wg.Done()
+		}()
+	}
+	wg.Wait()
+
+	db.Close()
+
+	// wait for the pool to drain
+	time.Sleep(100 * time.Millisecond)
+	size := db.Pool.Size()
+	if size != 0 {
+		t.Errorf("connection pool did not drain, still contains %d connections", size)
+	}
+
+	// verify that server usage is even
+	diff := 0
+	for n := 1; n < len(servers); n++ {
+		d := 0
+		if servers[n].nreq > servers[n-1].nreq {
+			d = int(servers[n].nreq - servers[n-1].nreq)
+		} else {
+			d = int(servers[n-1].nreq - servers[n].nreq)
+		}
+		if d > diff {
+			diff = d
+		}
+	}
+
+	if diff > 0 {
+		t.Errorf("expected 0 difference in usage but was %d", diff)
+	}
+}
+
+// This tests that the policy connection pool handles SSL correctly
+func TestPolicyConnPoolSSL(t *testing.T) {
+	srv := NewSSLTestServer(t, defaultProto)
+	defer srv.Stop()
+
+	cluster := createTestSslCluster(srv.Address, defaultProto)
+	cluster.ConnPoolType = NewRoundRobinConnPool
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatalf("failed to create new session: %v", err)
+	}
+
+	if err := db.Query("void").Exec(); err != nil {
+		t.Errorf("query failed due to error: %v", err)
+	}
+
+	db.Close()
+
+	// wait for the pool to drain
+	time.Sleep(100 * time.Millisecond)
+	size := db.Pool.Size()
+	if size != 0 {
+		t.Errorf("connection pool did not drain, still contains %d connections", size)
+	}
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {

+ 447 - 3
connectionpool.go

@@ -1,3 +1,7 @@
+// Copyright (c) 2012 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
 package gocql
 
 import (
@@ -7,6 +11,8 @@ import (
 	"fmt"
 	"io/ioutil"
 	"log"
+	"math/rand"
+	"net"
 	"sync"
 	"time"
 )
@@ -88,11 +94,20 @@ this type as the connection pool to use you would assign it to the ClusterConfig
 To see a more complete example of a ConnectionPool implementation please see the SimplePool type.
 */
 type ConnectionPool interface {
+	SetHosts
 	Pick(*Query) *Conn
 	Size() int
-	HandleError(*Conn, error, bool)
 	Close()
-	SetHosts(host []HostInfo)
+}
+
+// interface to implement to receive the host information
+type SetHosts interface {
+	SetHosts(hosts []HostInfo)
+}
+
+// interface to implement to receive the partitioner value
+type SetPartitioner interface {
+	SetPartitioner(partitioner string)
 }
 
 //NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
@@ -395,7 +410,6 @@ func (c *SimplePool) SetHosts(hosts []HostInfo) {
 
 	for _, host := range hosts {
 		host := host
-
 		delete(toRemove, host.Peer)
 		// we already have it
 		if _, ok := c.hosts[host.Peer]; ok {
@@ -435,3 +449,433 @@ func (c *SimplePool) removeHostLocked(addr string) {
 		}
 	}
 }
+
+//NewRoundRobinConnPool creates a connection pool which selects hosts by
+//round-robin, and then selects a connection for that host by round-robin.
+func NewRoundRobinConnPool(cfg *ClusterConfig) (ConnectionPool, error) {
+	return NewPolicyConnPool(
+		cfg,
+		NewRoundRobinHostPolicy(),
+		NewRoundRobinConnPolicy,
+	)
+}
+
+//NewTokenAwareConnPool creates a connection pool which selects hosts by
+//a token aware policy, and then selects a connection for that host by
+//round-robin.
+func NewTokenAwareConnPool(cfg *ClusterConfig) (ConnectionPool, error) {
+	return NewPolicyConnPool(
+		cfg,
+		NewTokenAwareHostPolicy(NewRoundRobinHostPolicy()),
+		NewRoundRobinConnPolicy,
+	)
+}
+
+type policyConnPool struct {
+	port     int
+	numConns int
+	connCfg  ConnConfig
+	keyspace string
+
+	mu            sync.RWMutex
+	hostPolicy    HostSelectionPolicy
+	connPolicy    func() ConnSelectionPolicy
+	hostConnPools map[string]*hostConnPool
+}
+
+//Creates a policy based connection pool. This func isn't meant to be directly
+//used as a NewPoolFunc in ClusterConfig, instead a func should be created
+//which satisfies the NewPoolFunc type, which calls this func with the desired
+//hostPolicy and connPolicy; see NewRoundRobinConnPool or NewTokenAwareConnPool
+//for examples.
+func NewPolicyConnPool(
+	cfg *ClusterConfig,
+	hostPolicy HostSelectionPolicy,
+	connPolicy func() ConnSelectionPolicy,
+) (ConnectionPool, error) {
+	var err error
+	var tlsConfig *tls.Config
+
+	if cfg.SslOpts != nil {
+		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	// create the pool
+	pool := &policyConnPool{
+		port:     cfg.Port,
+		numConns: cfg.NumConns,
+		connCfg: ConnConfig{
+			ProtoVersion:  cfg.ProtoVersion,
+			CQLVersion:    cfg.CQLVersion,
+			Timeout:       cfg.Timeout,
+			NumStreams:    cfg.NumStreams,
+			Compressor:    cfg.Compressor,
+			Authenticator: cfg.Authenticator,
+			Keepalive:     cfg.SocketKeepalive,
+			tlsConfig:     tlsConfig,
+		},
+		keyspace:      cfg.Keyspace,
+		hostPolicy:    hostPolicy,
+		connPolicy:    connPolicy,
+		hostConnPools: map[string]*hostConnPool{},
+	}
+
+	hosts := make([]HostInfo, len(cfg.Hosts))
+	for i, hostAddr := range cfg.Hosts {
+		hosts[i].Peer = hostAddr
+	}
+
+	pool.SetHosts(hosts)
+
+	return pool, nil
+}
+
+func (p *policyConnPool) SetHosts(hosts []HostInfo) {
+	p.mu.Lock()
+
+	toRemove := make(map[string]struct{})
+	for addr := range p.hostConnPools {
+		toRemove[addr] = struct{}{}
+	}
+
+	// TODO connect to hosts in parallel, but wait for pools to be
+	// created before returning
+
+	for i := range hosts {
+		pool, exists := p.hostConnPools[hosts[i].Peer]
+		if !exists {
+			// create a connection pool for the host
+			pool = newHostConnPool(
+				hosts[i].Peer,
+				p.port,
+				p.numConns,
+				p.connCfg,
+				p.keyspace,
+				p.connPolicy(),
+			)
+			p.hostConnPools[hosts[i].Peer] = pool
+		} else {
+			// still have this host, so don't remove it
+			delete(toRemove, hosts[i].Peer)
+		}
+	}
+
+	for addr := range toRemove {
+		pool := p.hostConnPools[addr]
+		delete(p.hostConnPools, addr)
+		pool.Close()
+	}
+
+	// update the policy
+	p.hostPolicy.SetHosts(hosts)
+
+	p.mu.Unlock()
+}
+
+func (p *policyConnPool) SetPartitioner(partitioner string) {
+	p.hostPolicy.SetPartitioner(partitioner)
+}
+
+func (p *policyConnPool) Size() int {
+	p.mu.RLock()
+	count := 0
+	for _, pool := range p.hostConnPools {
+		count += pool.Size()
+	}
+	p.mu.RUnlock()
+
+	return count
+}
+
+func (p *policyConnPool) Pick(qry *Query) *Conn {
+	nextHost := p.hostPolicy.Pick(qry)
+
+	p.mu.RLock()
+	var host *HostInfo
+	var conn *Conn
+	for conn == nil {
+		host = nextHost()
+		if host == nil {
+			break
+		}
+		conn = p.hostConnPools[host.Peer].Pick(qry)
+	}
+	p.mu.RUnlock()
+	return conn
+}
+
+func (p *policyConnPool) Close() {
+	p.mu.Lock()
+
+	// remove the hosts from the policy
+	p.hostPolicy.SetHosts([]HostInfo{})
+
+	// close the pools
+	for addr, pool := range p.hostConnPools {
+		delete(p.hostConnPools, addr)
+		pool.Close()
+	}
+	p.mu.Unlock()
+}
+
+// hostConnPool is a connection pool for a single host.
+// Connection selection is based on a provided ConnSelectionPolicy
+type hostConnPool struct {
+	host     string
+	port     int
+	addr     string
+	size     int
+	connCfg  ConnConfig
+	keyspace string
+	policy   ConnSelectionPolicy
+	// protection for conns, closed, filling
+	mu      sync.RWMutex
+	conns   []*Conn
+	closed  bool
+	filling bool
+}
+
+func newHostConnPool(
+	host string,
+	port int,
+	size int,
+	connCfg ConnConfig,
+	keyspace string,
+	policy ConnSelectionPolicy,
+) *hostConnPool {
+
+	pool := &hostConnPool{
+		host:     host,
+		port:     port,
+		addr:     JoinHostPort(host, port),
+		size:     size,
+		connCfg:  connCfg,
+		keyspace: keyspace,
+		policy:   policy,
+		conns:    make([]*Conn, 0, size),
+		filling:  false,
+		closed:   false,
+	}
+
+	// fill the pool with the initial connections before returning
+	pool.fill()
+
+	return pool
+}
+
+// Pick a connection from this connection pool for the given query.
+func (pool *hostConnPool) Pick(qry *Query) *Conn {
+	pool.mu.RLock()
+	if pool.closed {
+		pool.mu.RUnlock()
+		return nil
+	}
+
+	empty := len(pool.conns) == 0
+	pool.mu.RUnlock()
+
+	if empty {
+		// try to fill the empty pool
+		pool.fill()
+	}
+
+	return pool.policy.Pick(qry)
+}
+
+//Size returns the number of connections currently active in the pool
+func (pool *hostConnPool) Size() int {
+	pool.mu.RLock()
+	defer pool.mu.RUnlock()
+
+	return len(pool.conns)
+}
+
+//Close the connection pool
+func (pool *hostConnPool) Close() {
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	if pool.closed {
+		return
+	}
+	pool.closed = true
+
+	// drain, but don't wait
+	go pool.drain()
+}
+
+// Fill the connection pool
+func (pool *hostConnPool) fill() {
+	pool.mu.RLock()
+	// avoid filling a closed pool, or concurrent filling
+	if pool.closed || pool.filling {
+		pool.mu.RUnlock()
+		return
+	}
+
+	// determine the filling work to be done
+	startCount := len(pool.conns)
+	fillCount := pool.size - startCount
+
+	// avoid filling a full (or overfull) pool
+	if fillCount <= 0 {
+		pool.mu.RUnlock()
+		return
+	}
+
+	// switch from read to write lock
+	pool.mu.RUnlock()
+	pool.mu.Lock()
+
+	// double check everything since the lock was released
+	startCount = len(pool.conns)
+	fillCount = pool.size - startCount
+	if pool.closed || pool.filling || fillCount <= 0 {
+		// looks like another goroutine already beat this
+		// goroutine to the filling
+		pool.mu.Unlock()
+		return
+	}
+
+	// ok fill the pool
+	pool.filling = true
+
+	// allow others to access the pool while filling
+	pool.mu.Unlock()
+	// only this goroutine should make calls to fill/empty the pool at this
+	// point until after this routine or its subordinates calls
+	// fillingStopped
+
+	// fill only the first connection synchronously
+	if startCount == 0 {
+		err := pool.connect()
+		pool.logConnectErr(err)
+
+		if err != nil {
+			// probably unreachable host
+			go pool.fillingStopped()
+			return
+		}
+
+		// filled one
+		fillCount--
+	}
+
+	// fill the rest of the pool asynchronously
+	go func() {
+		for fillCount > 0 {
+			err := pool.connect()
+			pool.logConnectErr(err)
+
+			// decrement, even on error
+			fillCount--
+		}
+
+		// mark the end of filling
+		pool.fillingStopped()
+	}()
+}
+
+func (pool *hostConnPool) logConnectErr(err error) {
+	if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") {
+		// connection refused
+		// these are typical during a node outage so avoid log spam.
+	} else if err != nil {
+		// unexpected error
+		log.Printf("error: failed to connect to %s due to error: %v", pool.addr, err)
+	}
+}
+
+// transition back to a not-filling state.
+func (pool *hostConnPool) fillingStopped() {
+	// wait for some time to avoid back-to-back filling
+	// this provides some time between failed attempts
+	// to fill the pool for the host to recover
+	time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond)
+
+	pool.mu.Lock()
+	pool.filling = false
+	pool.mu.Unlock()
+}
+
+// create a new connection to the host and add it to the pool
+func (pool *hostConnPool) connect() error {
+	// try to connect
+	conn, err := Connect(pool.addr, pool.connCfg, pool)
+	if err != nil {
+		return err
+	}
+
+	if pool.keyspace != "" {
+		// set the keyspace
+		if err := conn.UseKeyspace(pool.keyspace); err != nil {
+			conn.Close()
+			return err
+		}
+	}
+
+	// add the Conn to the pool
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	if pool.closed {
+		conn.Close()
+		return nil
+	}
+
+	pool.conns = append(pool.conns, conn)
+	pool.policy.SetConns(pool.conns)
+	return nil
+}
+
+// handle any error from a Conn
+func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
+	if !closed {
+		// still an open connection, so continue using it
+		return
+	}
+
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	if pool.closed {
+		// pool closed
+		return
+	}
+
+	// find the connection index
+	for i, candidate := range pool.conns {
+		if candidate == conn {
+			// remove the connection, not preserving order
+			pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
+
+			// update the policy
+			pool.policy.SetConns(pool.conns)
+
+			// lost a connection, so fill the pool
+			go pool.fill()
+			break
+		}
+	}
+}
+
+// removes and closes all connections from the pool
+func (pool *hostConnPool) drain() {
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+
+	// empty the pool
+	conns := pool.conns
+	pool.conns = pool.conns[:0]
+
+	// update the policy
+	pool.policy.SetConns(pool.conns)
+
+	// close the connections
+	for _, conn := range conns {
+		conn.Close()
+	}
+}

+ 30 - 19
host_source.go

@@ -16,28 +16,33 @@ type HostInfo struct {
 
 // Polls system.peers at a specific interval to find new hosts
 type ringDescriber struct {
-	dcFilter   string
-	rackFilter string
-	previous   []HostInfo
-	session    *Session
+	dcFilter        string
+	rackFilter      string
+	prevHosts       []HostInfo
+	prevPartitioner string
+	session         *Session
 }
 
-func (r *ringDescriber) GetHosts() ([]HostInfo, error) {
+func (r *ringDescriber) GetHosts() (
+	hosts []HostInfo,
+	partitioner string,
+	err error,
+) {
 	// we need conn to be the same because we need to query system.peers and system.local
 	// on the same node to get the whole cluster
 	conn := r.session.Pool.Pick(nil)
 	if conn == nil {
-		return r.previous, nil
+		return r.prevHosts, r.prevPartitioner, nil
 	}
 
-	query := r.session.Query("SELECT data_center, rack, host_id, tokens FROM system.local")
+	query := r.session.Query("SELECT data_center, rack, host_id, tokens, partitioner FROM system.local")
 	iter := conn.executeQuery(query)
 
-	host := &HostInfo{}
-	iter.Scan(&host.DataCenter, &host.Rack, &host.HostId, &host.Tokens)
+	host := HostInfo{}
+	iter.Scan(&host.DataCenter, &host.Rack, &host.HostId, &host.Tokens, &partitioner)
 
-	if err := iter.Close(); err != nil {
-		return nil, err
+	if err = iter.Close(); err != nil {
+		return nil, "", err
 	}
 
 	addr, _, err := net.SplitHostPort(conn.Address())
@@ -49,24 +54,27 @@ func (r *ringDescriber) GetHosts() ([]HostInfo, error) {
 
 	host.Peer = addr
 
-	hosts := []HostInfo{*host}
+	hosts = []HostInfo{host}
 
 	query = r.session.Query("SELECT peer, data_center, rack, host_id, tokens FROM system.peers")
 	iter = conn.executeQuery(query)
 
+	host = HostInfo{}
 	for iter.Scan(&host.Peer, &host.DataCenter, &host.Rack, &host.HostId, &host.Tokens) {
-		if r.matchFilter(host) {
-			hosts = append(hosts, *host)
+		if r.matchFilter(&host) {
+			hosts = append(hosts, host)
 		}
+		host = HostInfo{}
 	}
 
-	if err := iter.Close(); err != nil {
-		return nil, err
+	if err = iter.Close(); err != nil {
+		return nil, "", err
 	}
 
-	r.previous = hosts
+	r.prevHosts = hosts
+	r.prevPartitioner = partitioner
 
-	return hosts, nil
+	return hosts, partitioner, nil
 }
 
 func (r *ringDescriber) matchFilter(host *HostInfo) bool {
@@ -92,11 +100,14 @@ func (h *ringDescriber) run(sleep time.Duration) {
 		// attempt to reconnect to the cluster otherwise we would never find
 		// downed hosts again, could possibly have an optimisation to only
 		// try to add new hosts if GetHosts didnt error and the hosts didnt change.
-		hosts, err := h.GetHosts()
+		hosts, partitioner, err := h.GetHosts()
 		if err != nil {
 			log.Println("RingDescriber: unable to get ring topology:", err)
 		} else {
 			h.session.Pool.SetHosts(hosts)
+			if v, ok := h.session.Pool.(SetPartitioner); ok {
+				v.SetPartitioner(partitioner)
+			}
 		}
 
 		time.Sleep(sleep)

+ 20 - 8
metadata.go

@@ -87,6 +87,8 @@ type schemaDescriber struct {
 	cache map[string]*KeyspaceMetadata
 }
 
+// creates a session bound schema describer which will query and cache
+// keyspace metadata
 func newSchemaDescriber(session *Session) *schemaDescriber {
 	return &schemaDescriber{
 		session: session,
@@ -94,6 +96,8 @@ func newSchemaDescriber(session *Session) *schemaDescriber {
 	}
 }
 
+// returns the cached KeyspaceMetadata held by the describer for the named
+// keyspace.
 func (s *schemaDescriber) getSchema(keyspaceName string) (*KeyspaceMetadata, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
@@ -114,6 +118,8 @@ func (s *schemaDescriber) getSchema(keyspaceName string) (*KeyspaceMetadata, err
 	return metadata, nil
 }
 
+// forcibly updates the current KeyspaceMetadata held by the schema describer
+// for a given named keyspace.
 func (s *schemaDescriber) refreshSchema(keyspaceName string) error {
 	var err error
 
@@ -141,9 +147,11 @@ func (s *schemaDescriber) refreshSchema(keyspaceName string) error {
 	return nil
 }
 
-// "compiles" keyspace, table, and column metadata for a keyspace together
-// linking the metadata objects together and calculating the partition key
-// and clustering key.
+// "compiles" derived information about keyspace, table, and column metadata
+// for a keyspace from the basic queried metadata objects returned by
+// getKeyspaceMetadata, getTableMetadata, and getColumnMetadata respectively;
+// Links the metadata objects together and derives the column composition of
+// the partition key and clustering key for a table.
 func compileMetadata(
 	protoVersion int,
 	keyspace *KeyspaceMetadata,
@@ -178,8 +186,11 @@ func compileMetadata(
 	}
 }
 
-// V1 protocol does not return as much column metadata as V2+ so determining
-// PartitionKey and ClusterColumns is more complex
+// Compiles derived information from TableMetadata which have had
+// ColumnMetadata added already. V1 protocol does not return as much
+// column metadata as V2+ (because V1 doesn't support the "type" column in the
+// system.schema_columns table) so determining PartitionKey and ClusterColumns
+// is more complex.
 func compileV1Metadata(tables []TableMetadata) {
 	for i := range tables {
 		table := &tables[i]
@@ -308,6 +319,7 @@ func compileV2Metadata(tables []TableMetadata) {
 	}
 }
 
+// returns the count of coluns with the given "kind" value.
 func countColumnsOfKind(columns map[string]*ColumnMetadata, kind string) int {
 	count := 0
 	for _, column := range columns {
@@ -318,7 +330,7 @@ func countColumnsOfKind(columns map[string]*ColumnMetadata, kind string) int {
 	return count
 }
 
-// query only for the keyspace metadata for the specified keyspace
+// query only for the keyspace metadata for the specified keyspace from system.schema_keyspace
 func getKeyspaceMetadata(
 	session *Session,
 	keyspaceName string,
@@ -358,7 +370,7 @@ func getKeyspaceMetadata(
 	return keyspace, nil
 }
 
-// query for only the table metadata in the specified keyspace
+// query for only the table metadata in the specified keyspace from system.schema_columnfamilies
 func getTableMetadata(
 	session *Session,
 	keyspaceName string,
@@ -437,7 +449,7 @@ func getTableMetadata(
 	return tables, nil
 }
 
-// query for only the table metadata in the specified keyspace
+// query for only the column metadata in the specified keyspace from system.schema_columns
 func getColumnMetadata(
 	session *Session,
 	keyspaceName string,

+ 159 - 34
metadata_test.go

@@ -9,6 +9,8 @@ import (
 	"testing"
 )
 
+// Tests V1 and V2 metadata "compilation" from example data which might be returned
+// from metadata schema queries (see getKeyspaceMetadata, getTableMetadata, and getColumnMetadata)
 func TestCompileMetadata(t *testing.T) {
 	// V1 tests - these are all based on real examples from the integration test ccm cluster
 	keyspace := &KeyspaceMetadata{
@@ -53,7 +55,7 @@ func TestCompileMetadata(t *testing.T) {
 			Keyspace:         "V1Keyspace",
 			Name:             "IndexInfo",
 			KeyValidator:     "org.apache.cassandra.db.marshal.UTF8Type",
-			Comparator:       "org.apache.cassandra.db.marshal.UTF8Type",
+			Comparator:       "org.apache.cassandra.db.marshal.ReversedType(org.apache.cassandra.db.marshal.UTF8Type)",
 			DefaultValidator: "org.apache.cassandra.db.marshal.BytesType",
 			KeyAliases:       []string{"table_name"},
 			ColumnAliases:    []string{"index_name"},
@@ -70,6 +72,17 @@ func TestCompileMetadata(t *testing.T) {
 			ColumnAliases:    []string{"revid"},
 			ValueAlias:       "",
 		},
+		TableMetadata{
+			// This is a made up example with multiple unnamed aliases
+			Keyspace:         "V1Keyspace",
+			Name:             "no_names",
+			KeyValidator:     "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType,org.apache.cassandra.db.marshal.UUIDType)",
+			Comparator:       "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.Int32Type)",
+			DefaultValidator: "org.apache.cassandra.db.marshal.BytesType",
+			KeyAliases:       []string{},
+			ColumnAliases:    []string{},
+			ValueAlias:       "",
+		},
 	}
 	columns := []ColumnMetadata{
 		// Here are the regular columns from the peers table for testing regular columns
@@ -182,7 +195,7 @@ func TestCompileMetadata(t *testing.T) {
 						&ColumnMetadata{
 							Name:  "index_name",
 							Type:  NativeType{typ: TypeVarchar},
-							Order: ASC,
+							Order: DESC,
 						},
 					},
 					Columns: map[string]*ColumnMetadata{
@@ -192,9 +205,10 @@ func TestCompileMetadata(t *testing.T) {
 							Kind: PARTITION_KEY,
 						},
 						"index_name": &ColumnMetadata{
-							Name: "index_name",
-							Type: NativeType{typ: TypeVarchar},
-							Kind: CLUSTERING_KEY,
+							Name:  "index_name",
+							Type:  NativeType{typ: TypeVarchar},
+							Order: DESC,
+							Kind:  CLUSTERING_KEY,
 						},
 						"value": &ColumnMetadata{
 							Name: "value",
@@ -230,6 +244,70 @@ func TestCompileMetadata(t *testing.T) {
 						},
 					},
 				},
+				"no_names": &TableMetadata{
+					PartitionKey: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "key",
+							Type: NativeType{typ: TypeUUID},
+						},
+						&ColumnMetadata{
+							Name: "key2",
+							Type: NativeType{typ: TypeUUID},
+						},
+					},
+					ClusteringColumns: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name:  "column",
+							Type:  NativeType{typ: TypeInt},
+							Order: ASC,
+						},
+						&ColumnMetadata{
+							Name:  "column2",
+							Type:  NativeType{typ: TypeInt},
+							Order: ASC,
+						},
+						&ColumnMetadata{
+							Name:  "column3",
+							Type:  NativeType{typ: TypeInt},
+							Order: ASC,
+						},
+					},
+					Columns: map[string]*ColumnMetadata{
+						"key": &ColumnMetadata{
+							Name: "key",
+							Type: NativeType{typ: TypeUUID},
+							Kind: PARTITION_KEY,
+						},
+						"key2": &ColumnMetadata{
+							Name: "key2",
+							Type: NativeType{typ: TypeUUID},
+							Kind: PARTITION_KEY,
+						},
+						"column": &ColumnMetadata{
+							Name:  "column",
+							Type:  NativeType{typ: TypeInt},
+							Order: ASC,
+							Kind:  CLUSTERING_KEY,
+						},
+						"column2": &ColumnMetadata{
+							Name:  "column2",
+							Type:  NativeType{typ: TypeInt},
+							Order: ASC,
+							Kind:  CLUSTERING_KEY,
+						},
+						"column3": &ColumnMetadata{
+							Name:  "column3",
+							Type:  NativeType{typ: TypeInt},
+							Order: ASC,
+							Kind:  CLUSTERING_KEY,
+						},
+						"value": &ColumnMetadata{
+							Name: "value",
+							Type: NativeType{typ: TypeBlob},
+							Kind: REGULAR,
+						},
+					},
+				},
 			},
 		},
 	)
@@ -250,30 +328,41 @@ func TestCompileMetadata(t *testing.T) {
 	}
 	columns = []ColumnMetadata{
 		ColumnMetadata{
-			Keyspace:  "V2Keyspace",
-			Table:     "Table1",
-			Name:      "Key1",
-			Kind:      PARTITION_KEY,
-			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
+			Keyspace:       "V2Keyspace",
+			Table:          "Table1",
+			Name:           "Key1",
+			Kind:           PARTITION_KEY,
+			ComponentIndex: 0,
+			Validator:      "org.apache.cassandra.db.marshal.UTF8Type",
 		},
 		ColumnMetadata{
-			Keyspace:  "V2Keyspace",
-			Table:     "Table2",
-			Name:      "Column1",
-			Kind:      PARTITION_KEY,
-			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
+			Keyspace:       "V2Keyspace",
+			Table:          "Table2",
+			Name:           "Column1",
+			Kind:           PARTITION_KEY,
+			ComponentIndex: 0,
+			Validator:      "org.apache.cassandra.db.marshal.UTF8Type",
 		},
 		ColumnMetadata{
-			Keyspace:  "V2Keyspace",
-			Table:     "Table2",
-			Name:      "Column2",
-			Kind:      CLUSTERING_KEY,
-			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
+			Keyspace:       "V2Keyspace",
+			Table:          "Table2",
+			Name:           "Column2",
+			Kind:           CLUSTERING_KEY,
+			ComponentIndex: 0,
+			Validator:      "org.apache.cassandra.db.marshal.UTF8Type",
+		},
+		ColumnMetadata{
+			Keyspace:       "V2Keyspace",
+			Table:          "Table2",
+			Name:           "Column3",
+			Kind:           CLUSTERING_KEY,
+			ComponentIndex: 1,
+			Validator:      "org.apache.cassandra.db.marshal.ReversedType(org.apache.cassandra.db.marshal.UTF8Type)",
 		},
 		ColumnMetadata{
 			Keyspace:  "V2Keyspace",
 			Table:     "Table2",
-			Name:      "Column3",
+			Name:      "Column4",
 			Kind:      REGULAR,
 			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
 		},
@@ -310,8 +399,14 @@ func TestCompileMetadata(t *testing.T) {
 					},
 					ClusteringColumns: []*ColumnMetadata{
 						&ColumnMetadata{
-							Name: "Column2",
-							Type: NativeType{typ: TypeVarchar},
+							Name:  "Column2",
+							Type:  NativeType{typ: TypeVarchar},
+							Order: ASC,
+						},
+						&ColumnMetadata{
+							Name:  "Column3",
+							Type:  NativeType{typ: TypeVarchar},
+							Order: DESC,
 						},
 					},
 					Columns: map[string]*ColumnMetadata{
@@ -321,12 +416,19 @@ func TestCompileMetadata(t *testing.T) {
 							Kind: PARTITION_KEY,
 						},
 						"Column2": &ColumnMetadata{
-							Name: "Column2",
-							Type: NativeType{typ: TypeVarchar},
-							Kind: CLUSTERING_KEY,
+							Name:  "Column2",
+							Type:  NativeType{typ: TypeVarchar},
+							Order: ASC,
+							Kind:  CLUSTERING_KEY,
 						},
 						"Column3": &ColumnMetadata{
-							Name: "Column3",
+							Name:  "Column3",
+							Type:  NativeType{typ: TypeVarchar},
+							Order: DESC,
+							Kind:  CLUSTERING_KEY,
+						},
+						"Column4": &ColumnMetadata{
+							Name: "Column4",
 							Type: NativeType{typ: TypeVarchar},
 							Kind: REGULAR,
 						},
@@ -337,6 +439,7 @@ func TestCompileMetadata(t *testing.T) {
 	)
 }
 
+// Helper function for asserting that actual metadata returned was as expected
 func assertKeyspaceMetadata(t *testing.T, actual, expected *KeyspaceMetadata) {
 	if len(expected.Tables) != len(actual.Tables) {
 		t.Errorf("Expected len(%s.Tables) to be %v but was %v", expected.Name, len(expected.Tables), len(actual.Tables))
@@ -379,6 +482,9 @@ func assertKeyspaceMetadata(t *testing.T, actual, expected *KeyspaceMetadata) {
 				t.Errorf("Expected len(%s.Tables[%s].ClusteringColumns) to be %v but was %v", expected.Name, keyT, len(et.ClusteringColumns), len(at.ClusteringColumns))
 			} else {
 				for i := range et.ClusteringColumns {
+					if at.ClusteringColumns[i] == nil {
+						t.Fatalf("Unexpected nil value: %s.Tables[%s].ClusteringColumns[%d]", expected.Name, keyT, i)
+					}
 					if et.ClusteringColumns[i].Name != at.ClusteringColumns[i].Name {
 						t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Name to be '%v' but was '%v'", expected.Name, keyT, i, et.ClusteringColumns[i].Name, at.ClusteringColumns[i].Name)
 					}
@@ -445,6 +551,7 @@ func assertKeyspaceMetadata(t *testing.T, actual, expected *KeyspaceMetadata) {
 	}
 }
 
+// Tests the cassandra type definition parser
 func TestTypeParser(t *testing.T) {
 	// native type
 	assertParseNonCompositeType(
@@ -470,10 +577,20 @@ func TestTypeParser(t *testing.T) {
 		},
 	)
 
+	// list
+	assertParseNonCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.ListType(org.apache.cassandra.db.marshal.TimeUUIDType)",
+		assertTypeInfo{
+			Type: TypeList,
+			Elem: &assertTypeInfo{Type: TypeTimeUUID},
+		},
+	)
+
 	// map
 	assertParseNonCompositeType(
 		t,
-		"org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UUIDType,org.apache.cassandra.db.marshal.BytesType)",
+		" org.apache.cassandra.db.marshal.MapType( org.apache.cassandra.db.marshal.UUIDType , org.apache.cassandra.db.marshal.BytesType ) ",
 		assertTypeInfo{
 			Type: TypeMap,
 			Key:  &assertTypeInfo{Type: TypeUUID},
@@ -482,6 +599,11 @@ func TestTypeParser(t *testing.T) {
 	)
 
 	// custom
+	assertParseNonCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)",
+		assertTypeInfo{Type: TypeCustom, Custom: "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)"},
+	)
 	assertParseNonCompositeType(
 		t,
 		"org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.marshal.LexicalUUIDType)",
@@ -499,9 +621,9 @@ func TestTypeParser(t *testing.T) {
 	)
 	assertParseCompositeType(
 		t,
-		"org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.DateType,org.apache.cassandra.db.marshal.UTF8Type)",
+		"org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.ReversedType(org.apache.cassandra.db.marshal.DateType),org.apache.cassandra.db.marshal.UTF8Type)",
 		[]assertTypeInfo{
-			assertTypeInfo{Type: TypeTimestamp},
+			assertTypeInfo{Type: TypeTimestamp, Reversed: true},
 			assertTypeInfo{Type: TypeVarchar},
 		},
 		nil,
@@ -522,10 +644,7 @@ func TestTypeParser(t *testing.T) {
 	)
 }
 
-//---------------------------------------
-// some code to assert the parser result
-//---------------------------------------
-
+// expected data holder
 type assertTypeInfo struct {
 	Type     Type
 	Reversed bool
@@ -534,6 +653,8 @@ type assertTypeInfo struct {
 	Custom   string
 }
 
+// Helper function for asserting that the type parser returns the expected
+// results for the given definition
 func assertParseNonCompositeType(
 	t *testing.T,
 	def string,
@@ -561,6 +682,8 @@ func assertParseNonCompositeType(
 	}
 }
 
+// Helper function for asserting that the type parser returns the expected
+// results for the given definition
 func assertParseCompositeType(
 	t *testing.T,
 	def string,
@@ -612,6 +735,8 @@ func assertParseCompositeType(
 	}
 }
 
+// Helper function for asserting that the type parser returns the expected
+// results for the given definition
 func assertParseNonCompositeTypes(
 	t *testing.T,
 	context string,

+ 200 - 0
policies.go

@@ -4,6 +4,12 @@
 //This file will be the future home for more policies
 package gocql
 
+import (
+	"log"
+	"sync"
+	"sync/atomic"
+)
+
 //RetryableQuery is an interface that represents a query or batch statement that
 //exposes the correct functions for the retry policy logic to evaluate correctly.
 type RetryableQuery interface {
@@ -42,3 +48,197 @@ type SimpleRetryPolicy struct {
 func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 	return q.Attempts() <= s.NumRetries
 }
+
+//HostSelectionPolicy is an interface for selecting
+//the most appropriate host to execute a given query.
+type HostSelectionPolicy interface {
+	SetHosts
+	SetPartitioner
+	//Pick returns an iteration function over selected hosts
+	Pick(*Query) NextHost
+}
+
+//NextHost is an iteration function over picked hosts
+type NextHost func() *HostInfo
+
+//NewRoundRobinHostPolicy is a round-robin load balancing policy
+func NewRoundRobinHostPolicy() HostSelectionPolicy {
+	return &roundRobinHostPolicy{hosts: []HostInfo{}}
+}
+
+type roundRobinHostPolicy struct {
+	hosts []HostInfo
+	pos   uint32
+	mu    sync.RWMutex
+}
+
+func (r *roundRobinHostPolicy) SetHosts(hosts []HostInfo) {
+	r.mu.Lock()
+	r.hosts = hosts
+	r.mu.Unlock()
+}
+
+func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {
+	// noop
+}
+
+func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
+	// i is used to limit the number of attempts to find a host
+	// to the number of hosts known to this policy
+	var i uint32 = 0
+	return func() *HostInfo {
+		if len(r.hosts) == 0 {
+			return nil
+		}
+
+		var host *HostInfo
+		r.mu.RLock()
+		// always increment pos to evenly distribute traffic in case of
+		// failures
+		pos := atomic.AddUint32(&r.pos, 1)
+		if int(i) < len(r.hosts) {
+			host = &r.hosts[(pos)%uint32(len(r.hosts))]
+			i++
+		}
+		r.mu.RUnlock()
+		return host
+	}
+}
+
+//NewTokenAwareHostPolicy is a token aware host selection policy
+func NewTokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
+	return &tokenAwareHostPolicy{fallback: fallback, hosts: []HostInfo{}}
+}
+
+type tokenAwareHostPolicy struct {
+	mu          sync.RWMutex
+	hosts       []HostInfo
+	partitioner string
+	tokenRing   *tokenRing
+	fallback    HostSelectionPolicy
+}
+
+func (t *tokenAwareHostPolicy) SetHosts(hosts []HostInfo) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	// always update the fallback
+	t.fallback.SetHosts(hosts)
+	t.hosts = hosts
+
+	t.resetTokenRing()
+}
+
+func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	if t.partitioner != partitioner {
+		t.fallback.SetPartitioner(partitioner)
+		t.partitioner = partitioner
+
+		t.resetTokenRing()
+	}
+}
+
+func (t *tokenAwareHostPolicy) resetTokenRing() {
+	if t.partitioner == "" {
+		// partitioner not yet set
+		return
+	}
+
+	// create a new token ring
+	tokenRing, err := newTokenRing(t.partitioner, t.hosts)
+	if err != nil {
+		log.Printf("Unable to update the token ring due to error: %s", err)
+		return
+	}
+
+	// replace the token ring
+	t.tokenRing = tokenRing
+}
+
+func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
+	if qry == nil {
+		return t.fallback.Pick(qry)
+	}
+
+	routingKey, err := qry.GetRoutingKey()
+	if err != nil {
+		return t.fallback.Pick(qry)
+	}
+	if routingKey == nil {
+		return t.fallback.Pick(qry)
+	}
+
+	var host *HostInfo
+
+	t.mu.RLock()
+	// TODO retrieve a list of hosts based on the replication strategy
+	host = t.tokenRing.GetHostForPartitionKey(routingKey)
+	t.mu.RUnlock()
+
+	if host == nil {
+		return t.fallback.Pick(qry)
+	}
+
+	// scope these variables for the same lifetime as the iterator function
+	var (
+		hostReturned bool
+		fallbackIter NextHost
+	)
+	return func() *HostInfo {
+		if !hostReturned {
+			hostReturned = true
+			return host
+		}
+
+		// fallback
+		if fallbackIter == nil {
+			fallbackIter = t.fallback.Pick(qry)
+		}
+
+		fallbackHost := fallbackIter()
+
+		// filter the token aware selected hosts from the fallback hosts
+		if fallbackHost == host {
+			fallbackHost = fallbackIter()
+		}
+
+		return fallbackHost
+	}
+}
+
+//ConnSelectionPolicy is an interface for selecting an
+//appropriate connection for executing a query
+type ConnSelectionPolicy interface {
+	SetConns(conns []*Conn)
+	Pick(*Query) *Conn
+}
+
+type roundRobinConnPolicy struct {
+	conns []*Conn
+	pos   uint32
+	mu    sync.RWMutex
+}
+
+func NewRoundRobinConnPolicy() ConnSelectionPolicy {
+	return &roundRobinConnPolicy{}
+}
+
+func (r *roundRobinConnPolicy) SetConns(conns []*Conn) {
+	r.mu.Lock()
+	r.conns = conns
+	r.mu.Unlock()
+}
+
+func (r *roundRobinConnPolicy) Pick(qry *Query) *Conn {
+	pos := atomic.AddUint32(&r.pos, 1)
+	var conn *Conn
+	r.mu.RLock()
+	if len(r.conns) > 0 {
+		conn = r.conns[pos%uint32(len(r.conns))]
+	}
+	r.mu.RUnlock()
+	return conn
+}

+ 125 - 0
policies_test.go

@@ -0,0 +1,125 @@
+// Copyright (c) 2015 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gocql
+
+import "testing"
+
+// Tests of the round-robin host selection policy implementation
+func TestRoundRobinHostPolicy(t *testing.T) {
+	policy := NewRoundRobinHostPolicy()
+
+	hosts := []HostInfo{
+		HostInfo{HostId: "0"},
+		HostInfo{HostId: "1"},
+	}
+
+	policy.SetHosts(hosts)
+
+	// the first host selected is actually at [1], but this is ok for RR
+	// interleaved iteration should always increment the host
+	iterA := policy.Pick(nil)
+	if actual := iterA(); actual != &hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.HostId)
+	}
+	iterB := policy.Pick(nil)
+	if actual := iterB(); actual != &hosts[0] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.HostId)
+	}
+	if actual := iterB(); actual != &hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.HostId)
+	}
+	if actual := iterA(); actual != &hosts[0] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.HostId)
+	}
+
+	iterC := policy.Pick(nil)
+	if actual := iterC(); actual != &hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.HostId)
+	}
+	if actual := iterC(); actual != &hosts[0] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.HostId)
+	}
+}
+
+// Tests of the token-aware host selection policy implementation with a
+// round-robin host selection policy fallback.
+func TestTokenAwareHostPolicy(t *testing.T) {
+	policy := NewTokenAwareHostPolicy(NewRoundRobinHostPolicy())
+
+	query := &Query{}
+
+	iter := policy.Pick(nil)
+	if iter == nil {
+		t.Fatal("host iterator was nil")
+	}
+	actual := iter()
+	if actual != nil {
+		t.Fatalf("expected nil from iterator, but was %v", actual)
+	}
+
+	// set the hosts
+	hosts := []HostInfo{
+		HostInfo{Peer: "0", Tokens: []string{"00"}},
+		HostInfo{Peer: "1", Tokens: []string{"25"}},
+		HostInfo{Peer: "2", Tokens: []string{"50"}},
+		HostInfo{Peer: "3", Tokens: []string{"75"}},
+	}
+	policy.SetHosts(hosts)
+
+	// the token ring is not setup without the partitioner, but the fallback
+	// should work
+	if actual := policy.Pick(nil)(); actual.Peer != "1" {
+		t.Errorf("Expected peer 1 but was %s", actual.Peer)
+	}
+
+	query.RoutingKey([]byte("30"))
+	if actual := policy.Pick(query)(); actual.Peer != "2" {
+		t.Errorf("Expected peer 2 but was %s", actual.Peer)
+	}
+
+	policy.SetPartitioner("OrderedPartitioner")
+
+	// now the token ring is configured
+	query.RoutingKey([]byte("20"))
+	iter = policy.Pick(query)
+	if actual := iter(); actual.Peer != "1" {
+		t.Errorf("Expected peer 1 but was %s", actual.Peer)
+	}
+	// rest are round robin
+	if actual := iter(); actual.Peer != "3" {
+		t.Errorf("Expected peer 3 but was %s", actual.Peer)
+	}
+	if actual := iter(); actual.Peer != "0" {
+		t.Errorf("Expected peer 0 but was %s", actual.Peer)
+	}
+	if actual := iter(); actual.Peer != "2" {
+		t.Errorf("Expected peer 2 but was %s", actual.Peer)
+	}
+}
+
+// Tests of the round-robin connection selection policy implementation
+func TestRoundRobinConnPolicy(t *testing.T) {
+	policy := NewRoundRobinConnPolicy()
+
+	conn0 := &Conn{}
+	conn1 := &Conn{}
+	conn := []*Conn{
+		conn0,
+		conn1,
+	}
+
+	policy.SetConns(conn)
+
+	// the first conn selected is actually at [1], but this is ok for RR
+	if actual := policy.Pick(nil); actual != conn1 {
+		t.Error("Expected conn1")
+	}
+	if actual := policy.Pick(nil); actual != conn0 {
+		t.Error("Expected conn0")
+	}
+	if actual := policy.Pick(nil); actual != conn1 {
+		t.Error("Expected conn1")
+	}
+}

+ 41 - 14
token.go

@@ -17,6 +17,7 @@ import (
 
 // a token partitioner
 type partitioner interface {
+	Name() string
 	Hash([]byte) token
 	ParseString(string) token
 }
@@ -31,6 +32,10 @@ type token interface {
 type murmur3Partitioner struct{}
 type murmur3Token int64
 
+func (p murmur3Partitioner) Name() string {
+	return "Murmur3Partitioner"
+}
+
 func (p murmur3Partitioner) Hash(partitionKey []byte) token {
 	h1 := murmur3H1(partitionKey)
 	return murmur3Token(int64(h1))
@@ -183,30 +188,38 @@ func (m murmur3Token) Less(token token) bool {
 }
 
 // order preserving partitioner and token
-type orderPreservingPartitioner struct{}
-type orderPreservingToken []byte
+type orderedPartitioner struct{}
+type orderedToken []byte
+
+func (p orderedPartitioner) Name() string {
+	return "OrderedPartitioner"
+}
 
-func (p orderPreservingPartitioner) Hash(partitionKey []byte) token {
+func (p orderedPartitioner) Hash(partitionKey []byte) token {
 	// the partition key is the token
-	return orderPreservingToken(partitionKey)
+	return orderedToken(partitionKey)
 }
 
-func (p orderPreservingPartitioner) ParseString(str string) token {
-	return orderPreservingToken([]byte(str))
+func (p orderedPartitioner) ParseString(str string) token {
+	return orderedToken([]byte(str))
 }
 
-func (o orderPreservingToken) String() string {
+func (o orderedToken) String() string {
 	return string([]byte(o))
 }
 
-func (o orderPreservingToken) Less(token token) bool {
-	return -1 == bytes.Compare(o, token.(orderPreservingToken))
+func (o orderedToken) Less(token token) bool {
+	return -1 == bytes.Compare(o, token.(orderedToken))
 }
 
 // random partitioner and token
 type randomPartitioner struct{}
 type randomToken big.Int
 
+func (r randomPartitioner) Name() string {
+	return "RandomPartitioner"
+}
+
 func (p randomPartitioner) Hash(partitionKey []byte) token {
 	hash := md5.New()
 	sum := hash.Sum(partitionKey)
@@ -239,7 +252,7 @@ type tokenRing struct {
 	hosts       []*HostInfo
 }
 
-func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
+func newTokenRing(partitioner string, hosts []HostInfo) (*tokenRing, error) {
 	tokenRing := &tokenRing{
 		tokens: []token{},
 		hosts:  []*HostInfo{},
@@ -248,14 +261,15 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
 	if strings.HasSuffix(partitioner, "Murmur3Partitioner") {
 		tokenRing.partitioner = murmur3Partitioner{}
 	} else if strings.HasSuffix(partitioner, "OrderedPartitioner") {
-		tokenRing.partitioner = orderPreservingPartitioner{}
+		tokenRing.partitioner = orderedPartitioner{}
 	} else if strings.HasSuffix(partitioner, "RandomPartitioner") {
 		tokenRing.partitioner = randomPartitioner{}
 	} else {
 		return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
 	}
 
-	for _, host := range hosts {
+	for i := range hosts {
+		host := &hosts[i]
 		for _, strToken := range host.Tokens {
 			token := tokenRing.partitioner.ParseString(strToken)
 			tokenRing.tokens = append(tokenRing.tokens, token)
@@ -282,8 +296,13 @@ func (t *tokenRing) Swap(i, j int) {
 }
 
 func (t *tokenRing) String() string {
+
 	buf := &bytes.Buffer{}
-	buf.WriteString("TokenRing={")
+	buf.WriteString("TokenRing(")
+	if t.partitioner != nil {
+		buf.WriteString(t.partitioner.Name())
+	}
+	buf.WriteString("){")
 	sep := ""
 	for i := range t.tokens {
 		buf.WriteString(sep)
@@ -300,12 +319,20 @@ func (t *tokenRing) String() string {
 }
 
 func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) *HostInfo {
+	if t == nil {
+		return nil
+	}
+
 	token := t.partitioner.Hash(partitionKey)
 	return t.GetHostForToken(token)
 }
 
 func (t *tokenRing) GetHostForToken(token token) *HostInfo {
-	// find the primary repica
+	if t == nil {
+		return nil
+	}
+
+	// find the primary replica
 	ringIndex := sort.Search(
 		len(t.tokens),
 		func(i int) bool {

+ 315 - 30
token_test.go

@@ -5,23 +5,56 @@
 package gocql
 
 import (
+	"bytes"
 	"math/big"
+	"sort"
 	"strconv"
 	"testing"
 )
 
+// Test the implementation of murmur3
 func TestMurmur3H1(t *testing.T) {
-	assertMurmur3H1(t, []byte{}, 0x000000000000000)
-	assertMurmur3H1(t, []byte{0}, 0x4610abe56eff5cb5)
-	assertMurmur3H1(t, []byte{0, 1}, 0x7cb3f5c58dab264c)
-	assertMurmur3H1(t, []byte{0, 1, 2}, 0xb872a12fef53e6be)
-	assertMurmur3H1(t, []byte{0, 1, 2, 3}, 0xe1c594ae0ddfaf10)
+	// these examples are based on adding a index number to a sample string in
+	// a loop. The expected values were generated by the java datastax murmur3
+	// implementation. The number of examples here of increasing lengths ensure
+	// test coverage of all tail-length branches in the murmur3 algorithm
+	seriesExpected := [...]uint64{
+		0x0000000000000000, // ""
+		0x2ac9debed546a380, // "0"
+		0x649e4eaa7fc1708e, // "01"
+		0xce68f60d7c353bdb, // "012"
+		0x0f95757ce7f38254, // "0123"
+		0x0f04e459497f3fc1, // "01234"
+		0x88c0a92586be0a27, // "012345"
+		0x13eb9fb82606f7a6, // "0123456"
+		0x8236039b7387354d, // "01234567"
+		0x4c1e87519fe738ba, // "012345678"
+		0x3f9652ac3effeb24, // "0123456789"
+		0x3f33760ded9006c6, // "01234567890"
+		0xaed70a6631854cb1, // "012345678901"
+		0x8a299a8f8e0e2da7, // "0123456789012"
+		0x624b675c779249a6, // "01234567890123"
+		0xa4b203bb1d90b9a3, // "012345678901234"
+		0xa3293ad698ecb99a, // "0123456789012345"
+		0xbc740023dbd50048, // "01234567890123456"
+		0x3fe5ab9837d25cdd, // "012345678901234567"
+		0x2d0338c1ca87d132, // "0123456789012345678"
+	}
+	sample := ""
+	for i, expected := range seriesExpected {
+		assertMurmur3H1(t, []byte(sample), expected)
+
+		sample = sample + strconv.Itoa(i%10)
+	}
+
+	// Here are some test examples from other driver implementations
 	assertMurmur3H1(t, []byte("hello"), 0xcbd8a7b341bd9b02)
 	assertMurmur3H1(t, []byte("hello, world"), 0x342fac623a5ebc8e)
 	assertMurmur3H1(t, []byte("19 Jan 2038 at 3:14:07 AM"), 0xb89e5988b737affc)
 	assertMurmur3H1(t, []byte("The quick brown fox jumps over the lazy dog."), 0xcd99481f9ee902c9)
 }
 
+// helper function for testing the murmur3 implementation
 func assertMurmur3H1(t *testing.T, data []byte, expected uint64) {
 	actual := murmur3H1(data)
 	if actual != expected {
@@ -29,6 +62,7 @@ func assertMurmur3H1(t *testing.T, data []byte, expected uint64) {
 	}
 }
 
+// Benchmark of the performance of the murmur3 implementation
 func BenchmarkMurmur3H1(b *testing.B) {
 	var h1 uint64
 	var data [1024]byte
@@ -42,6 +76,7 @@ func BenchmarkMurmur3H1(b *testing.B) {
 	}
 }
 
+// Tests of the murmur3Patitioner
 func TestMurmur3Partitioner(t *testing.T) {
 	token := murmur3Partitioner{}.ParseString("-1053604476080545076")
 
@@ -58,6 +93,7 @@ func TestMurmur3Partitioner(t *testing.T) {
 	}
 }
 
+// Tests of the murmur3Token
 func TestMurmur3Token(t *testing.T) {
 	if murmur3Token(42).Less(murmur3Token(42)) {
 		t.Errorf("Expected Less to return false, but was true")
@@ -70,38 +106,66 @@ func TestMurmur3Token(t *testing.T) {
 	}
 }
 
-func TestOrderPreservingPartitioner(t *testing.T) {
+// Tests of the orderedPartitioner
+func TestOrderedPartitioner(t *testing.T) {
 	// at least verify that the partitioner
 	// doesn't return nil
+	p := orderedPartitioner{}
 	pk, _ := marshalInt(nil, 1)
-	token := orderPreservingPartitioner{}.Hash(pk)
+	token := p.Hash(pk)
 	if token == nil {
 		t.Fatal("token was nil")
 	}
+
+	str := token.String()
+	parsedToken := p.ParseString(str)
+
+	if !bytes.Equal([]byte(token.(orderedToken)), []byte(parsedToken.(orderedToken))) {
+		t.Errorf("Failed to convert to and from a string %s expected %x but was %x",
+			str,
+			[]byte(token.(orderedToken)),
+			[]byte(parsedToken.(orderedToken)),
+		)
+	}
 }
 
-func TestOrderPreservingToken(t *testing.T) {
-	if orderPreservingToken([]byte{0, 0, 4, 2}).Less(orderPreservingToken([]byte{0, 0, 4, 2})) {
+// Tests of the orderedToken
+func TestOrderedToken(t *testing.T) {
+	if orderedToken([]byte{0, 0, 4, 2}).Less(orderedToken([]byte{0, 0, 4, 2})) {
 		t.Errorf("Expected Less to return false, but was true")
 	}
-	if !orderPreservingToken([]byte{0, 0, 3}).Less(orderPreservingToken([]byte{0, 0, 4, 2})) {
+	if !orderedToken([]byte{0, 0, 3}).Less(orderedToken([]byte{0, 0, 4, 2})) {
 		t.Errorf("Expected Less to return true, but was false")
 	}
-	if orderPreservingToken([]byte{0, 0, 4, 2}).Less(orderPreservingToken([]byte{0, 0, 3})) {
+	if orderedToken([]byte{0, 0, 4, 2}).Less(orderedToken([]byte{0, 0, 3})) {
 		t.Errorf("Expected Less to return false, but was true")
 	}
 }
 
+// Tests of the randomPartitioner
 func TestRandomPartitioner(t *testing.T) {
 	// at least verify that the partitioner
 	// doesn't return nil
+	p := randomPartitioner{}
 	pk, _ := marshalInt(nil, 1)
-	token := randomPartitioner{}.Hash(pk)
+	token := p.Hash(pk)
 	if token == nil {
 		t.Fatal("token was nil")
 	}
+
+	str := token.String()
+	parsedToken := p.ParseString(str)
+
+	if (*big.Int)(token.(*randomToken)).Cmp((*big.Int)(parsedToken.(*randomToken))) != 0 {
+		t.Errorf("Failed to convert to and from a string %s expected %v but was %v",
+			str,
+			token,
+			parsedToken,
+		)
+	}
 }
 
+// Tests of the randomToken
 func TestRandomToken(t *testing.T) {
 	if ((*randomToken)(big.NewInt(42))).Less((*randomToken)(big.NewInt(42))) {
 		t.Errorf("Expected Less to return false, but was true")
@@ -124,66 +188,287 @@ func (i intToken) Less(token token) bool {
 	return i < token.(intToken)
 }
 
+// Test of the token ring implementation based on example at the start of this
+// page of documentation:
+// http://www.datastax.com/docs/0.8/cluster_architecture/partitioning
 func TestIntTokenRing(t *testing.T) {
-	// test based on example at the start of this page of documentation:
-	// http://www.datastax.com/docs/0.8/cluster_architecture/partitioning
 	host0 := &HostInfo{}
 	host25 := &HostInfo{}
 	host50 := &HostInfo{}
 	host75 := &HostInfo{}
-	tokenRing := &tokenRing{
+	ring := &tokenRing{
 		partitioner: nil,
+		// these tokens and hosts are out of order to test sorting
 		tokens: []token{
 			intToken(0),
-			intToken(25),
 			intToken(50),
 			intToken(75),
+			intToken(25),
 		},
 		hosts: []*HostInfo{
 			host0,
-			host25,
 			host50,
 			host75,
+			host25,
 		},
 	}
 
-	if tokenRing.GetHostForToken(intToken(0)) != host0 {
+	sort.Sort(ring)
+
+	if ring.GetHostForToken(intToken(0)) != host0 {
 		t.Error("Expected host 0 for token 0")
 	}
-	if tokenRing.GetHostForToken(intToken(1)) != host25 {
+	if ring.GetHostForToken(intToken(1)) != host25 {
 		t.Error("Expected host 25 for token 1")
 	}
-	if tokenRing.GetHostForToken(intToken(24)) != host25 {
+	if ring.GetHostForToken(intToken(24)) != host25 {
 		t.Error("Expected host 25 for token 24")
 	}
-	if tokenRing.GetHostForToken(intToken(25)) != host25 {
+	if ring.GetHostForToken(intToken(25)) != host25 {
 		t.Error("Expected host 25 for token 25")
 	}
-	if tokenRing.GetHostForToken(intToken(26)) != host50 {
+	if ring.GetHostForToken(intToken(26)) != host50 {
 		t.Error("Expected host 50 for token 26")
 	}
-	if tokenRing.GetHostForToken(intToken(49)) != host50 {
+	if ring.GetHostForToken(intToken(49)) != host50 {
 		t.Error("Expected host 50 for token 49")
 	}
-	if tokenRing.GetHostForToken(intToken(50)) != host50 {
+	if ring.GetHostForToken(intToken(50)) != host50 {
 		t.Error("Expected host 50 for token 50")
 	}
-	if tokenRing.GetHostForToken(intToken(51)) != host75 {
+	if ring.GetHostForToken(intToken(51)) != host75 {
 		t.Error("Expected host 75 for token 51")
 	}
-	if tokenRing.GetHostForToken(intToken(74)) != host75 {
+	if ring.GetHostForToken(intToken(74)) != host75 {
 		t.Error("Expected host 75 for token 74")
 	}
-	if tokenRing.GetHostForToken(intToken(75)) != host75 {
+	if ring.GetHostForToken(intToken(75)) != host75 {
 		t.Error("Expected host 75 for token 75")
 	}
-	if tokenRing.GetHostForToken(intToken(76)) != host0 {
+	if ring.GetHostForToken(intToken(76)) != host0 {
 		t.Error("Expected host 0 for token 76")
 	}
-	if tokenRing.GetHostForToken(intToken(99)) != host0 {
+	if ring.GetHostForToken(intToken(99)) != host0 {
 		t.Error("Expected host 0 for token 99")
 	}
-	if tokenRing.GetHostForToken(intToken(100)) != host0 {
+	if ring.GetHostForToken(intToken(100)) != host0 {
 		t.Error("Expected host 0 for token 100")
 	}
 }
+
+// Test for the behavior of a nil pointer to tokenRing
+func TestNilTokenRing(t *testing.T) {
+	var ring *tokenRing = nil
+
+	if ring.GetHostForToken(nil) != nil {
+		t.Error("Expected nil for nil token ring")
+	}
+	if ring.GetHostForPartitionKey(nil) != nil {
+		t.Error("Expected nil for nil token ring")
+	}
+}
+
+// Test of the recognition of the partitioner class
+func TestUnknownTokenRing(t *testing.T) {
+	_, err := newTokenRing("UnknownPartitioner", nil)
+	if err == nil {
+		t.Error("Expected error for unknown partitioner value, but was nil")
+	}
+}
+
+// Test of the tokenRing with the Murmur3Partitioner
+func TestMurmur3TokenRing(t *testing.T) {
+	// Note, strings are parsed directly to int64, they are not murmur3 hashed
+	var hosts []HostInfo = []HostInfo{
+		HostInfo{
+			Peer:   "0",
+			Tokens: []string{"0"},
+		},
+		HostInfo{
+			Peer:   "1",
+			Tokens: []string{"25"},
+		},
+		HostInfo{
+			Peer:   "2",
+			Tokens: []string{"50"},
+		},
+		HostInfo{
+			Peer:   "3",
+			Tokens: []string{"75"},
+		},
+	}
+	ring, err := newTokenRing("Murmur3Partitioner", hosts)
+	if err != nil {
+		t.Fatalf("Failed to create token ring due to error: %v", err)
+	}
+
+	p := murmur3Partitioner{}
+
+	var actual *HostInfo
+	actual = ring.GetHostForToken(p.ParseString("0"))
+	if actual.Peer != "0" {
+		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("25"))
+	if actual.Peer != "1" {
+		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("50"))
+	if actual.Peer != "2" {
+		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("75"))
+	if actual.Peer != "3" {
+		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("12"))
+	if actual.Peer != "1" {
+		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
+	if actual.Peer != "0" {
+		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer)
+	}
+}
+
+// Test of the tokenRing with the OrderedPartitioner
+func TestOrderedTokenRing(t *testing.T) {
+	// Tokens here more or less are similar layout to the int tokens above due
+	// to each numeric character translating to a consistently offset byte.
+	var hosts []HostInfo = []HostInfo{
+		HostInfo{
+			Peer: "0",
+			Tokens: []string{
+				"00",
+			},
+		},
+		HostInfo{
+			Peer: "1",
+			Tokens: []string{
+				"25",
+			},
+		},
+		HostInfo{
+			Peer: "2",
+			Tokens: []string{
+				"50",
+			},
+		},
+		HostInfo{
+			Peer: "3",
+			Tokens: []string{
+				"75",
+			},
+		},
+	}
+	ring, err := newTokenRing("OrderedPartitioner", hosts)
+	if err != nil {
+		t.Fatalf("Failed to create token ring due to error: %v", err)
+	}
+
+	p := orderedPartitioner{}
+
+	var actual *HostInfo
+	actual = ring.GetHostForToken(p.ParseString("0"))
+	if actual.Peer != "0" {
+		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("25"))
+	if actual.Peer != "1" {
+		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("50"))
+	if actual.Peer != "2" {
+		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("75"))
+	if actual.Peer != "3" {
+		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("12"))
+	if actual.Peer != "1" {
+		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
+	if actual.Peer != "1" {
+		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer)
+	}
+}
+
+// Test of the tokenRing with the RandomPartitioner
+func TestRandomTokenRing(t *testing.T) {
+	// String tokens are parsed into big.Int in base 10
+	var hosts []HostInfo = []HostInfo{
+		HostInfo{
+			Peer: "0",
+			Tokens: []string{
+				"00",
+			},
+		},
+		HostInfo{
+			Peer: "1",
+			Tokens: []string{
+				"25",
+			},
+		},
+		HostInfo{
+			Peer: "2",
+			Tokens: []string{
+				"50",
+			},
+		},
+		HostInfo{
+			Peer: "3",
+			Tokens: []string{
+				"75",
+			},
+		},
+	}
+	ring, err := newTokenRing("RandomPartitioner", hosts)
+	if err != nil {
+		t.Fatalf("Failed to create token ring due to error: %v", err)
+	}
+
+	p := randomPartitioner{}
+
+	var actual *HostInfo
+	actual = ring.GetHostForToken(p.ParseString("0"))
+	if actual.Peer != "0" {
+		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("25"))
+	if actual.Peer != "1" {
+		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("50"))
+	if actual.Peer != "2" {
+		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("75"))
+	if actual.Peer != "3" {
+		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("12"))
+	if actual.Peer != "1" {
+		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer)
+	}
+
+	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
+	if actual.Peer != "0" {
+		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer)
+	}
+}