Переглянути джерело

TokenAwarePolicy: use token replicas per placement strategy (#1039)

* TokenAwarePolicy: use token replicas per placement strategy

Add support for finding token replicas based on the placement stategy
for each keyspace.

Update host policies to recieve keyspace change updates and the ability
to set the session once they are created.

* rf is a string, parse it and provide helpful error messages

* fix panic when loading ks meta

* fix vet
Chris Bannister 8 роки тому
батько
коміт
3688d5fd49
14 змінених файлів з 652 додано та 156 видалено
  1. 8 5
      batch_test.go
  2. 1 0
      cassandra_test.go
  3. 1 0
      common_test.go
  4. 1 1
      conn.go
  5. 15 9
      events.go
  6. 40 2
      frame.go
  7. 134 53
      policies.go
  8. 6 6
      policies_test.go
  9. 1 0
      query_executor.go
  10. 32 20
      session.go
  11. 29 36
      token.go
  12. 13 24
      token_test.go
  13. 208 0
      topology.go
  14. 163 0
      topology_test.go

+ 8 - 5
batch_test.go

@@ -9,12 +9,15 @@ import (
 
 func TestBatch_Errors(t *testing.T) {
 	if *flagProto == 1 {
-		t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
 	}
 
 	session := createSession(t)
 	defer session.Close()
 
+	if session.cfg.ProtoVersion < protoVersion2 {
+		t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
+	}
+
 	if err := createTable(session, `CREATE TABLE gocql_test.batch_errors (id int primary key, val inet)`); err != nil {
 		t.Fatal(err)
 	}
@@ -27,13 +30,13 @@ func TestBatch_Errors(t *testing.T) {
 }
 
 func TestBatch_WithTimestamp(t *testing.T) {
-	if *flagProto < protoVersion3 {
-		t.Skip("Batch timestamps are only available on protocol >= 3")
-	}
-
 	session := createSession(t)
 	defer session.Close()
 
+	if session.cfg.ProtoVersion < protoVersion3 {
+		t.Skip("Batch timestamps are only available on protocol >= 3")
+	}
+
 	if err := createTable(session, `CREATE TABLE gocql_test.batch_ts (id int primary key, val text)`); err != nil {
 		t.Fatal(err)
 	}

+ 1 - 0
cassandra_test.go

@@ -83,6 +83,7 @@ func TestEmptyHosts(t *testing.T) {
 }
 
 func TestInvalidPeerEntry(t *testing.T) {
+	t.Skip("dont mutate system tables, rewrite this to test what we mean to test")
 	session := createSession(t)
 
 	// rack, release_version, schema_version, tokens are all null

+ 1 - 0
common_test.go

@@ -94,6 +94,7 @@ func createCluster() *ClusterConfig {
 }
 
 func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
+	// TODO: tb.Helper()
 	c := *cluster
 	c.Keyspace = "system"
 	c.Timeout = 30 * time.Second

+ 1 - 1
conn.go

@@ -895,7 +895,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		return iter
 	case *resultKeyspaceFrame:
 		return &Iter{framer: framer}
-	case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction:
+	case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType:
 		iter := &Iter{framer: framer}
 		if err := c.awaitSchemaAgreement(); err != nil {
 			// TODO: should have this behind a flag

+ 15 - 9
events.go

@@ -80,7 +80,6 @@ func (e *eventDebouncer) debounce(frame frame) {
 }
 
 func (s *Session) handleEvent(framer *framer) {
-	// TODO(zariel): need to debounce events frames, and possible also events
 	defer framerPool.Put(framer)
 
 	frame, err := framer.parseFrame()
@@ -94,9 +93,10 @@ func (s *Session) handleEvent(framer *framer) {
 		Logger.Printf("gocql: handling frame: %v\n", frame)
 	}
 
-	// TODO: handle medatadata events
 	switch f := frame.(type) {
-	case *schemaChangeKeyspace, *schemaChangeFunction, *schemaChangeTable:
+	case *schemaChangeKeyspace, *schemaChangeFunction,
+		*schemaChangeTable, *schemaChangeAggregate, *schemaChangeType:
+
 		s.schemaEvents.debounce(frame)
 	case *topologyChangeEventFrame, *statusChangeEventFrame:
 		s.nodeEvents.debounce(frame)
@@ -106,22 +106,28 @@ func (s *Session) handleEvent(framer *framer) {
 }
 
 func (s *Session) handleSchemaEvent(frames []frame) {
-	s.mu.RLock()
-	defer s.mu.RUnlock()
-
-	if s.schemaDescriber == nil {
-		return
-	}
+	// TODO: debounce events
 	for _, frame := range frames {
 		switch f := frame.(type) {
 		case *schemaChangeKeyspace:
 			s.schemaDescriber.clearSchema(f.keyspace)
+			s.handleKeyspaceChange(f.keyspace, f.change)
 		case *schemaChangeTable:
 			s.schemaDescriber.clearSchema(f.keyspace)
+		case *schemaChangeAggregate:
+			s.schemaDescriber.clearSchema(f.keyspace)
+		case *schemaChangeFunction:
+			s.schemaDescriber.clearSchema(f.keyspace)
+		case *schemaChangeType:
+			s.schemaDescriber.clearSchema(f.keyspace)
 		}
 	}
 }
 
+func (s *Session) handleKeyspaceChange(keyspace, change string) {
+	s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change})
+}
+
 func (s *Session) handleNodeEvent(frames []frame) {
 	type nodeEvent struct {
 		change string

+ 40 - 2
frame.go

@@ -1112,6 +1112,14 @@ func (f schemaChangeTable) String() string {
 	return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.keyspace, f.object)
 }
 
+type schemaChangeType struct {
+	frameHeader
+
+	change   string
+	keyspace string
+	object   string
+}
+
 type schemaChangeFunction struct {
 	frameHeader
 
@@ -1121,6 +1129,15 @@ type schemaChangeFunction struct {
 	args     []string
 }
 
+type schemaChangeAggregate struct {
+	frameHeader
+
+	change   string
+	keyspace string
+	name     string
+	args     []string
+}
+
 func (f *framer) parseResultSchemaChange() frame {
 	if f.proto <= protoVersion2 {
 		change := f.readString()
@@ -1156,7 +1173,7 @@ func (f *framer) parseResultSchemaChange() frame {
 			frame.keyspace = f.readString()
 
 			return frame
-		case "TABLE", "TYPE":
+		case "TABLE":
 			frame := &schemaChangeTable{
 				frameHeader: *f.header,
 				change:      change,
@@ -1166,7 +1183,17 @@ func (f *framer) parseResultSchemaChange() frame {
 			frame.object = f.readString()
 
 			return frame
-		case "FUNCTION", "AGGREGATE":
+		case "TYPE":
+			frame := &schemaChangeType{
+				frameHeader: *f.header,
+				change:      change,
+			}
+
+			frame.keyspace = f.readString()
+			frame.object = f.readString()
+
+			return frame
+		case "FUNCTION":
 			frame := &schemaChangeFunction{
 				frameHeader: *f.header,
 				change:      change,
@@ -1176,6 +1203,17 @@ func (f *framer) parseResultSchemaChange() frame {
 			frame.name = f.readString()
 			frame.args = f.readStringList()
 
+			return frame
+		case "AGGREGATE":
+			frame := &schemaChangeAggregate{
+				frameHeader: *f.header,
+				change:      change,
+			}
+
+			frame.keyspace = f.readString()
+			frame.name = f.readString()
+			frame.args = f.readStringList()
+
 			return frame
 		default:
 			panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change))

+ 134 - 53
policies.go

@@ -200,11 +200,18 @@ type HostStateNotifier interface {
 	HostDown(host *HostInfo)
 }
 
+type KeyspaceUpdateEvent struct {
+	Keyspace string
+	Change   string
+}
+
 // HostSelectionPolicy is an interface for selecting
 // the most appropriate host to execute a given query.
 type HostSelectionPolicy interface {
 	HostStateNotifier
 	SetPartitioner
+	KeyspaceChanged(KeyspaceUpdateEvent)
+	Init(*Session)
 	//Pick returns an iteration function over selected hosts
 	Pick(ExecutableQuery) NextHost
 }
@@ -239,9 +246,9 @@ type roundRobinHostPolicy struct {
 	mu    sync.RWMutex
 }
 
-func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {
-	// noop
-}
+func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
+func (r *roundRobinHostPolicy) SetPartitioner(partitioner string)   {}
+func (r *roundRobinHostPolicy) Init(*Session)                       {}
 
 func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	// i is used to limit the number of attempts to find a host
@@ -281,19 +288,69 @@ func (r *roundRobinHostPolicy) HostDown(host *HostInfo) {
 	r.RemoveHost(host)
 }
 
+func ShuffleReplicas() func(*tokenAwareHostPolicy) {
+	return func(t *tokenAwareHostPolicy) {
+		t.shuffleReplicas = true
+	}
+}
+
 // TokenAwareHostPolicy is a token aware host selection policy, where hosts are
 // selected based on the partition key, so queries are sent to the host which
 // owns the partition. Fallback is used when routing information is not available.
-func TokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
-	return &tokenAwareHostPolicy{fallback: fallback}
+func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAwareHostPolicy)) HostSelectionPolicy {
+	p := &tokenAwareHostPolicy{fallback: fallback}
+	for _, opt := range opts {
+		opt(p)
+	}
+	return p
+}
+
+type keyspaceMeta struct {
+	replicas map[string]map[token][]*HostInfo
 }
 
 type tokenAwareHostPolicy struct {
 	hosts       cowHostList
 	mu          sync.RWMutex
 	partitioner string
-	tokenRing   *tokenRing
 	fallback    HostSelectionPolicy
+	session     *Session
+
+	tokenRing atomic.Value // *tokenRing
+	keyspaces atomic.Value // *keyspaceMeta
+
+	shuffleReplicas bool
+}
+
+func (t *tokenAwareHostPolicy) Init(s *Session) {
+	t.session = s
+}
+
+func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) {
+	meta, _ := t.keyspaces.Load().(*keyspaceMeta)
+	// TODO: avoid recaulating things which havnt changed
+	newMeta := &keyspaceMeta{
+		replicas: make(map[string]map[token][]*HostInfo, len(meta.replicas)),
+	}
+
+	ks, err := t.session.KeyspaceMetadata(update.Keyspace)
+	if err == nil {
+		strat := getStrategy(ks)
+		tr := t.tokenRing.Load().(*tokenRing)
+		if tr != nil {
+			newMeta.replicas[update.Keyspace] = strat.replicaMap(t.hosts.get(), tr.tokens)
+		}
+	}
+
+	if meta != nil {
+		for ks, replicas := range meta.replicas {
+			if ks != update.Keyspace {
+				newMeta.replicas[ks] = replicas
+			}
+		}
+	}
+
+	t.keyspaces.Store(newMeta)
 }
 
 func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
@@ -304,31 +361,34 @@ func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
 		t.fallback.SetPartitioner(partitioner)
 		t.partitioner = partitioner
 
-		t.resetTokenRing()
+		t.resetTokenRing(partitioner)
 	}
 }
 
 func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
-	t.mu.Lock()
-	defer t.mu.Unlock()
-
 	t.hosts.add(host)
 	t.fallback.AddHost(host)
 
-	t.resetTokenRing()
+	t.mu.RLock()
+	partitioner := t.partitioner
+	t.mu.RUnlock()
+	t.resetTokenRing(partitioner)
 }
 
 func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
-	t.mu.Lock()
-	defer t.mu.Unlock()
-
 	t.hosts.remove(host.ConnectAddress())
 	t.fallback.RemoveHost(host)
 
-	t.resetTokenRing()
+	t.mu.RLock()
+	partitioner := t.partitioner
+	t.mu.RUnlock()
+	t.resetTokenRing(partitioner)
 }
 
 func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
+	// TODO: need to avoid doing all the work on AddHost on hostup/down
+	// because it now expensive to calculate the replica map for each
+	// token
 	t.AddHost(host)
 }
 
@@ -336,22 +396,31 @@ func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
 	t.RemoveHost(host)
 }
 
-func (t *tokenAwareHostPolicy) resetTokenRing() {
-	if t.partitioner == "" {
+func (t *tokenAwareHostPolicy) resetTokenRing(partitioner string) {
+	if partitioner == "" {
 		// partitioner not yet set
 		return
 	}
 
 	// create a new token ring
 	hosts := t.hosts.get()
-	tokenRing, err := newTokenRing(t.partitioner, hosts)
+	tokenRing, err := newTokenRing(partitioner, hosts)
 	if err != nil {
 		Logger.Printf("Unable to update the token ring due to error: %s", err)
 		return
 	}
 
 	// replace the token ring
-	t.tokenRing = tokenRing
+	t.tokenRing.Store(tokenRing)
+}
+
+func (t *tokenAwareHostPolicy) getReplicas(keyspace string, token token) ([]*HostInfo, bool) {
+	meta, _ := t.keyspaces.Load().(*keyspaceMeta)
+	if meta == nil {
+		return nil, false
+	}
+	tokens, ok := meta.replicas[keyspace][token]
+	return tokens, ok
 }
 
 func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
@@ -362,45 +431,62 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	routingKey, err := qry.GetRoutingKey()
 	if err != nil {
 		return t.fallback.Pick(qry)
+	} else if routingKey == nil {
+		return t.fallback.Pick(qry)
 	}
-	if routingKey == nil {
+
+	tr, _ := t.tokenRing.Load().(*tokenRing)
+	if tr == nil {
 		return t.fallback.Pick(qry)
 	}
 
-	t.mu.RLock()
-	// TODO retrieve a list of hosts based on the replication strategy
-	host := t.tokenRing.GetHostForPartitionKey(routingKey)
-	t.mu.RUnlock()
+	token := tr.partitioner.Hash(routingKey)
+	primaryEndpoint := tr.GetHostForToken(token)
 
-	if host == nil {
+	if primaryEndpoint == nil || token == nil {
 		return t.fallback.Pick(qry)
 	}
 
-	// scope these variables for the same lifetime as the iterator function
+	replicas, ok := t.getReplicas(qry.Keyspace(), token)
+	if !ok {
+		replicas = []*HostInfo{primaryEndpoint}
+	} else if t.shuffleReplicas {
+		replicas = shuffleHosts(replicas)
+	}
+
 	var (
-		hostReturned bool
 		fallbackIter NextHost
+		i            int
 	)
 
+	used := make(map[*HostInfo]bool)
 	return func() SelectedHost {
-		if !hostReturned {
-			hostReturned = true
-			return (*selectedHost)(host)
+		for i < len(replicas) {
+			h := replicas[i]
+			i++
+
+			if !h.IsUp() {
+				// TODO: need a way to handle host distance, as we may want to not
+				// use hosts in specific DC's
+				continue
+			}
+			used[h] = true
+
+			return (*selectedHost)(h)
 		}
 
-		// fallback
 		if fallbackIter == nil {
+			// fallback
 			fallbackIter = t.fallback.Pick(qry)
 		}
 
-		fallbackHost := fallbackIter()
-
 		// filter the token aware selected hosts from the fallback hosts
-		if fallbackHost != nil && fallbackHost.Info() == host {
-			fallbackHost = fallbackIter()
+		for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() {
+			if !used[fallbackHost.Info()] {
+				return fallbackHost
+			}
 		}
-
-		return fallbackHost
+		return nil
 	}
 }
 
@@ -428,6 +514,10 @@ type hostPoolHostPolicy struct {
 	hostMap map[string]*HostInfo
 }
 
+func (r *hostPoolHostPolicy) Init(*Session)                       {}
+func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {}
+func (r *hostPoolHostPolicy) SetPartitioner(string)               {}
+
 func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
 	peers := make([]string, len(hosts))
 	hostMap := make(map[string]*HostInfo, len(hosts))
@@ -492,10 +582,6 @@ func (r *hostPoolHostPolicy) HostDown(host *HostInfo) {
 	r.RemoveHost(host)
 }
 
-func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
-	// noop
-}
-
 func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	return func() SelectedHost {
 		r.mu.RLock()
@@ -557,11 +643,13 @@ type dcAwareRR struct {
 // return hosts which are in the local datacentre before returning hosts in all
 // other datercentres
 func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy {
-	return &dcAwareRR{
-		local: localDC,
-	}
+	return &dcAwareRR{local: localDC}
 }
 
+func (r *dcAwareRR) Init(*Session)                       {}
+func (r *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {}
+func (d *dcAwareRR) SetPartitioner(p string)             {}
+
 func (d *dcAwareRR) AddHost(host *HostInfo) {
 	if host.DataCenter() == d.local {
 		d.localHosts.add(host)
@@ -578,15 +666,8 @@ func (d *dcAwareRR) RemoveHost(host *HostInfo) {
 	}
 }
 
-func (d *dcAwareRR) HostUp(host *HostInfo) {
-	d.AddHost(host)
-}
-
-func (d *dcAwareRR) HostDown(host *HostInfo) {
-	d.RemoveHost(host)
-}
-
-func (d *dcAwareRR) SetPartitioner(p string) {}
+func (d *dcAwareRR) HostUp(host *HostInfo)   { d.AddHost(host) }
+func (d *dcAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) }
 
 func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
 	var i int

+ 6 - 6
policies_test.go

@@ -14,7 +14,7 @@ import (
 )
 
 // Tests of the round-robin host selection policy implementation
-func TestRoundRobinHostPolicy(t *testing.T) {
+func TestHostPolicy_RoundRobin(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
 	hosts := [...]*HostInfo{
@@ -53,7 +53,7 @@ func TestRoundRobinHostPolicy(t *testing.T) {
 
 // Tests of the token-aware host selection policy implementation with a
 // round-robin host selection policy fallback.
-func TestTokenAwareHostPolicy(t *testing.T) {
+func TestHostPolicy_TokenAware(t *testing.T) {
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 	query := &Query{}
@@ -110,7 +110,7 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 }
 
 // Tests of the host pool host selection policy implementation
-func TestHostPoolHostPolicy(t *testing.T) {
+func TestHostPolicy_HostPool(t *testing.T) {
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 
 	hosts := []*HostInfo{
@@ -150,7 +150,7 @@ func TestHostPoolHostPolicy(t *testing.T) {
 	actualD.Mark(nil)
 }
 
-func TestRoundRobinNilHostInfo(t *testing.T) {
+func TestHostPolicy_RoundRobin_NilHostInfo(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
 	host := &HostInfo{hostId: "host-1"}
@@ -175,7 +175,7 @@ func TestRoundRobinNilHostInfo(t *testing.T) {
 	}
 }
 
-func TestTokenAwareNilHostInfo(t *testing.T) {
+func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) {
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 	hosts := [...]*HostInfo{
@@ -302,7 +302,7 @@ func TestExponentialBackoffPolicy(t *testing.T) {
 	}
 }
 
-func TestDCAwareRR(t *testing.T) {
+func TestHostPolicy_DCAwareRR(t *testing.T) {
 	p := DCAwareRoundRobinPolicy("local")
 
 	hosts := [...]*HostInfo{

+ 1 - 0
query_executor.go

@@ -9,6 +9,7 @@ type ExecutableQuery interface {
 	attempt(time.Duration)
 	retryPolicy() RetryPolicy
 	GetRoutingKey() ([]byte, error)
+	Keyspace() string
 	RetryableQuery
 }
 

+ 32 - 20
session.go

@@ -112,14 +112,14 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		quit:     make(chan struct{}),
 	}
 
+	s.schemaDescriber = newSchemaDescriber(s)
+
 	s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent)
 	s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent)
 
 	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
-	s.hostSource = &ringDescriber{
-		session: s,
-	}
+	s.hostSource = &ringDescriber{session: s}
 
 	if cfg.PoolConfig.HostSelectionPolicy == nil {
 		cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
@@ -127,6 +127,8 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 	s.pool = cfg.PoolConfig.buildPool(s)
 
 	s.policy = cfg.PoolConfig.HostSelectionPolicy
+	s.policy.Init(s)
+
 	s.executor = &queryExecutor{
 		pool:   s.pool,
 		policy: cfg.PoolConfig.HostSelectionPolicy,
@@ -409,25 +411,15 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) {
 	// fail fast
 	if s.Closed() {
 		return nil, ErrSessionClosed
-	}
-
-	if keyspace == "" {
+	} else if keyspace == "" {
 		return nil, ErrNoKeyspace
 	}
 
-	s.mu.Lock()
-	// lazy-init schemaDescriber
-	if s.schemaDescriber == nil {
-		s.schemaDescriber = newSchemaDescriber(s)
-	}
-	s.mu.Unlock()
-
 	return s.schemaDescriber.getSchema(keyspace)
 }
 
 func (s *Session) getConn() *Conn {
 	hosts := s.ring.allHosts()
-	var conn *Conn
 	for _, host := range hosts {
 		if !host.IsUp() {
 			continue
@@ -436,10 +428,7 @@ func (s *Session) getConn() *Conn {
 		pool, ok := s.pool.getPool(host)
 		if !ok {
 			continue
-		}
-
-		conn = pool.Pick()
-		if conn != nil {
+		} else if conn := pool.Pick(); conn != nil {
 			return conn
 		}
 	}
@@ -780,6 +769,16 @@ func (q *Query) retryPolicy() RetryPolicy {
 	return q.rt
 }
 
+// Keyspace returns the keyspace the query will be executed against.
+func (q *Query) Keyspace() string {
+	if q.session == nil {
+		return ""
+	}
+	// TODO(chbannis): this should be parsed from the query or we should let
+	// this be set by users.
+	return q.session.cfg.Keyspace
+}
+
 // GetRoutingKey gets the routing key to use for routing this query. If
 // a routing key has not been explicitly set, then the routing key will
 // be constructed if possible using the keyspace's schema and the query
@@ -1341,9 +1340,12 @@ type Batch struct {
 	defaultTimestamp      bool
 	defaultTimestampValue int64
 	context               context.Context
+	keyspace              string
 }
 
 // NewBatch creates a new batch operation without defaults from the cluster
+//
+// Depreicated: use session.NewBatch instead
 func NewBatch(typ BatchType) *Batch {
 	return &Batch{Type: typ}
 }
@@ -1351,12 +1353,22 @@ func NewBatch(typ BatchType) *Batch {
 // NewBatch creates a new batch operation using defaults defined in the cluster
 func (s *Session) NewBatch(typ BatchType) *Batch {
 	s.mu.RLock()
-	batch := &Batch{Type: typ, rt: s.cfg.RetryPolicy, serialCons: s.cfg.SerialConsistency,
-		Cons: s.cons, defaultTimestamp: s.cfg.DefaultTimestamp}
+	batch := &Batch{
+		Type:             typ,
+		rt:               s.cfg.RetryPolicy,
+		serialCons:       s.cfg.SerialConsistency,
+		Cons:             s.cons,
+		defaultTimestamp: s.cfg.DefaultTimestamp,
+		keyspace:         s.cfg.Keyspace,
+	}
 	s.mu.RUnlock()
 	return batch
 }
 
+func (b *Batch) Keyspace() string {
+	return b.keyspace
+}
+
 // Attempts returns the number of attempts made to execute the batch.
 func (b *Batch) Attempts() int {
 	return b.attempts

+ 29 - 36
token.go

@@ -58,7 +58,7 @@ func (m murmur3Token) Less(token token) bool {
 
 // order preserving partitioner and token
 type orderedPartitioner struct{}
-type orderedToken []byte
+type orderedToken string
 
 func (p orderedPartitioner) Name() string {
 	return "OrderedPartitioner"
@@ -70,15 +70,15 @@ func (p orderedPartitioner) Hash(partitionKey []byte) token {
 }
 
 func (p orderedPartitioner) ParseString(str string) token {
-	return orderedToken([]byte(str))
+	return orderedToken(str)
 }
 
 func (o orderedToken) String() string {
-	return string([]byte(o))
+	return string(o)
 }
 
 func (o orderedToken) Less(token token) bool {
-	return -1 == bytes.Compare(o, token.(orderedToken))
+	return o < token.(orderedToken)
 }
 
 // random partitioner and token
@@ -118,18 +118,23 @@ func (r *randomToken) Less(token token) bool {
 	return -1 == (*big.Int)(r).Cmp((*big.Int)(token.(*randomToken)))
 }
 
+type hostToken struct {
+	token token
+	host  *HostInfo
+}
+
+func (ht hostToken) String() string {
+	return fmt.Sprintf("{token=%v host=%v}", ht.token, ht.host.HostID())
+}
+
 // a data structure for organizing the relationship between tokens and hosts
 type tokenRing struct {
 	partitioner partitioner
-	tokens      []token
-	hosts       []*HostInfo
+	tokens      []hostToken
 }
 
 func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
-	tokenRing := &tokenRing{
-		tokens: []token{},
-		hosts:  []*HostInfo{},
-	}
+	tokenRing := &tokenRing{}
 
 	if strings.HasSuffix(partitioner, "Murmur3Partitioner") {
 		tokenRing.partitioner = murmur3Partitioner{}
@@ -144,8 +149,7 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
 	for _, host := range hosts {
 		for _, strToken := range host.Tokens() {
 			token := tokenRing.partitioner.ParseString(strToken)
-			tokenRing.tokens = append(tokenRing.tokens, token)
-			tokenRing.hosts = append(tokenRing.hosts, host)
+			tokenRing.tokens = append(tokenRing.tokens, hostToken{token, host})
 		}
 	}
 
@@ -159,16 +163,14 @@ func (t *tokenRing) Len() int {
 }
 
 func (t *tokenRing) Less(i, j int) bool {
-	return t.tokens[i].Less(t.tokens[j])
+	return t.tokens[i].token.Less(t.tokens[j].token)
 }
 
 func (t *tokenRing) Swap(i, j int) {
-	t.tokens[i], t.hosts[i], t.tokens[j], t.hosts[j] =
-		t.tokens[j], t.hosts[j], t.tokens[i], t.hosts[i]
+	t.tokens[i], t.tokens[j] = t.tokens[j], t.tokens[i]
 }
 
 func (t *tokenRing) String() string {
-
 	buf := &bytes.Buffer{}
 	buf.WriteString("TokenRing(")
 	if t.partitioner != nil {
@@ -176,15 +178,15 @@ func (t *tokenRing) String() string {
 	}
 	buf.WriteString("){")
 	sep := ""
-	for i := range t.tokens {
+	for i, th := range t.tokens {
 		buf.WriteString(sep)
 		sep = ","
 		buf.WriteString("\n\t[")
 		buf.WriteString(strconv.Itoa(i))
 		buf.WriteString("]")
-		buf.WriteString(t.tokens[i].String())
+		buf.WriteString(th.token.String())
 		buf.WriteString(":")
-		buf.WriteString(t.hosts[i].ConnectAddress().String())
+		buf.WriteString(th.host.ConnectAddress().String())
 	}
 	buf.WriteString("\n}")
 	return string(buf.Bytes())
@@ -200,28 +202,19 @@ func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) *HostInfo {
 }
 
 func (t *tokenRing) GetHostForToken(token token) *HostInfo {
-	if t == nil {
-		return nil
-	}
-
-	l := len(t.tokens)
-	// no host tokens, no available hosts
-	if l == 0 {
+	if t == nil || len(t.tokens) == 0 {
 		return nil
 	}
 
 	// find the primary replica
-	ringIndex := sort.Search(
-		l,
-		func(i int) bool {
-			return !t.tokens[i].Less(token)
-		},
-	)
-
-	if ringIndex == l {
+	ringIndex := sort.Search(len(t.tokens), func(i int) bool {
+		return !t.tokens[i].token.Less(token)
+	})
+
+	if ringIndex == len(t.tokens) {
 		// wrap around to the first in the ring
 		ringIndex = 0
 	}
-	host := t.hosts[ringIndex]
-	return host
+
+	return t.tokens[ringIndex].host
 }

+ 13 - 24
token_test.go

@@ -132,18 +132,13 @@ func TestRandomToken(t *testing.T) {
 
 type intToken int
 
-func (i intToken) String() string {
-	return strconv.Itoa(int(i))
-}
-
-func (i intToken) Less(token token) bool {
-	return i < token.(intToken)
-}
+func (i intToken) String() string        { return strconv.Itoa(int(i)) }
+func (i intToken) Less(token token) bool { return i < token.(intToken) }
 
 // Test of the token ring implementation based on example at the start of this
 // page of documentation:
 // http://www.datastax.com/docs/0.8/cluster_architecture/partitioning
-func TestIntTokenRing(t *testing.T) {
+func TestTokenRing_Int(t *testing.T) {
 	host0 := &HostInfo{}
 	host25 := &HostInfo{}
 	host50 := &HostInfo{}
@@ -151,17 +146,11 @@ func TestIntTokenRing(t *testing.T) {
 	ring := &tokenRing{
 		partitioner: nil,
 		// these tokens and hosts are out of order to test sorting
-		tokens: []token{
-			intToken(0),
-			intToken(50),
-			intToken(75),
-			intToken(25),
-		},
-		hosts: []*HostInfo{
-			host0,
-			host50,
-			host75,
-			host25,
+		tokens: []hostToken{
+			{intToken(0), host0},
+			{intToken(50), host50},
+			{intToken(75), host75},
+			{intToken(25), host25},
 		},
 	}
 
@@ -209,7 +198,7 @@ func TestIntTokenRing(t *testing.T) {
 }
 
 // Test for the behavior of a nil pointer to tokenRing
-func TestNilTokenRing(t *testing.T) {
+func TestTokenRing_Nil(t *testing.T) {
 	var ring *tokenRing = nil
 
 	if ring.GetHostForToken(nil) != nil {
@@ -221,7 +210,7 @@ func TestNilTokenRing(t *testing.T) {
 }
 
 // Test of the recognition of the partitioner class
-func TestUnknownTokenRing(t *testing.T) {
+func TestTokenRing_UnknownPartition(t *testing.T) {
 	_, err := newTokenRing("UnknownPartitioner", nil)
 	if err == nil {
 		t.Error("Expected error for unknown partitioner value, but was nil")
@@ -242,7 +231,7 @@ func hostsForTests(n int) []*HostInfo {
 }
 
 // Test of the tokenRing with the Murmur3Partitioner
-func TestMurmur3TokenRing(t *testing.T) {
+func TestTokenRing_Murmur3(t *testing.T) {
 	// Note, strings are parsed directly to int64, they are not murmur3 hashed
 	hosts := hostsForTests(4)
 	ring, err := newTokenRing("Murmur3Partitioner", hosts)
@@ -272,7 +261,7 @@ func TestMurmur3TokenRing(t *testing.T) {
 }
 
 // Test of the tokenRing with the OrderedPartitioner
-func TestOrderedTokenRing(t *testing.T) {
+func TestTokenRing_Ordered(t *testing.T) {
 	// Tokens here more or less are similar layout to the int tokens above due
 	// to each numeric character translating to a consistently offset byte.
 	hosts := hostsForTests(4)
@@ -304,7 +293,7 @@ func TestOrderedTokenRing(t *testing.T) {
 }
 
 // Test of the tokenRing with the RandomPartitioner
-func TestRandomTokenRing(t *testing.T) {
+func TestTokenRing_Random(t *testing.T) {
 	// String tokens are parsed into big.Int in base 10
 	hosts := hostsForTests(4)
 	ring, err := newTokenRing("RandomPartitioner", hosts)

+ 208 - 0
topology.go

@@ -0,0 +1,208 @@
+package gocql
+
+import (
+	"fmt"
+	"strconv"
+	"strings"
+)
+
+type placementStrategy interface {
+	replicaMap(hosts []*HostInfo, tokens []hostToken) map[token][]*HostInfo
+	replicationFactor(dc string) int
+}
+
+func getReplicationFactorFromOpts(keyspace string, val interface{}) int {
+	// TODO: dont really want to panic here, but is better
+	// than spamming
+	switch v := val.(type) {
+	case int:
+		if v <= 0 {
+			panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", v, keyspace))
+		}
+		return v
+	case string:
+		n, err := strconv.Atoi(v)
+		if err != nil {
+			panic(fmt.Sprintf("invalid replication_factor. Is the %q keyspace configured correctly? %v", keyspace, err))
+		} else if n <= 0 {
+			panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", n, keyspace))
+		}
+		return n
+	default:
+		panic(fmt.Sprintf("unkown replication_factor type %T", v))
+	}
+}
+
+func getStrategy(ks *KeyspaceMetadata) placementStrategy {
+	switch {
+	case strings.Contains(ks.StrategyClass, "SimpleStrategy"):
+		return &simpleStrategy{rf: getReplicationFactorFromOpts(ks.Name, ks.StrategyOptions["replication_factor"])}
+	case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"):
+		dcs := make(map[string]int)
+		for dc, rf := range ks.StrategyOptions {
+			dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf)
+		}
+		return &networkTopology{dcs: dcs}
+	default:
+		// TODO: handle unknown replicas and just return the primary host for a token
+		panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass))
+	}
+}
+
+type simpleStrategy struct {
+	rf int
+}
+
+func (s *simpleStrategy) replicationFactor(dc string) int {
+	return s.rf
+}
+
+func (s *simpleStrategy) replicaMap(_ []*HostInfo, tokens []hostToken) map[token][]*HostInfo {
+	tokenRing := make(map[token][]*HostInfo, len(tokens))
+
+	for i, th := range tokens {
+		replicas := make([]*HostInfo, 0, s.rf)
+		for j := 0; j < len(tokens) && len(replicas) < s.rf; j++ {
+			// TODO: need to ensure we dont add the same hosts twice
+			h := tokens[(i+j)%len(tokens)]
+			replicas = append(replicas, h.host)
+		}
+		tokenRing[th.token] = replicas
+	}
+
+	return tokenRing
+}
+
+type networkTopology struct {
+	dcs map[string]int
+}
+
+func (n *networkTopology) replicationFactor(dc string) int {
+	return n.dcs[dc]
+}
+
+func (n *networkTopology) haveRF(replicaCounts map[string]int) bool {
+	if len(replicaCounts) != len(n.dcs) {
+		return false
+	}
+
+	for dc, rf := range n.dcs {
+		if rf != replicaCounts[dc] {
+			return false
+		}
+	}
+
+	return true
+}
+
+func (n *networkTopology) replicaMap(hosts []*HostInfo, tokens []hostToken) map[token][]*HostInfo {
+	dcRacks := make(map[string]map[string]struct{})
+
+	for _, h := range hosts {
+		dc := h.DataCenter()
+		rack := h.Rack()
+
+		racks, ok := dcRacks[dc]
+		if !ok {
+			racks = make(map[string]struct{})
+			dcRacks[dc] = racks
+		}
+		racks[rack] = struct{}{}
+	}
+
+	tokenRing := make(map[token][]*HostInfo, len(tokens))
+
+	var totalRF int
+	for _, rf := range n.dcs {
+		totalRF += rf
+	}
+
+	for i, th := range tokens {
+		// number of replicas per dc
+		// TODO: recycle these
+		replicasInDC := make(map[string]int, len(n.dcs))
+		// dc -> racks
+		seenDCRacks := make(map[string]map[string]struct{}, len(n.dcs))
+		// skipped hosts in a dc
+		skipped := make(map[string][]*HostInfo, len(n.dcs))
+
+		replicas := make([]*HostInfo, 0, totalRF)
+		for j := 0; j < len(tokens) && !n.haveRF(replicasInDC); j++ {
+			// TODO: ensure we dont add the same host twice
+			h := tokens[(i+j)%len(tokens)].host
+
+			dc := h.DataCenter()
+			rack := h.Rack()
+
+			rf, ok := n.dcs[dc]
+			if !ok {
+				// skip this DC, dont know about it
+				continue
+			} else if replicasInDC[dc] >= rf {
+				if replicasInDC[dc] > rf {
+					panic(fmt.Sprintf("replica overflow. rf=%d have=%d in dc %q", rf, replicasInDC[dc], dc))
+				}
+
+				// have enough replicas in this DC
+				continue
+			} else if _, ok := dcRacks[dc][rack]; !ok {
+				// dont know about this rack
+				continue
+			} else if len(replicas) >= totalRF {
+				if replicasInDC[dc] > rf {
+					panic(fmt.Sprintf("replica overflow. total rf=%d have=%d", totalRF, len(replicas)))
+				}
+
+				// we now have enough replicas
+				break
+			}
+
+			racks := seenDCRacks[dc]
+			if _, ok := racks[rack]; ok && len(racks) == len(dcRacks[dc]) {
+				// we have been through all the racks and dont have RF yet, add this
+				replicas = append(replicas, h)
+				replicasInDC[dc]++
+			} else if !ok {
+				if racks == nil {
+					racks = make(map[string]struct{}, 1)
+					seenDCRacks[dc] = racks
+				}
+
+				// new rack
+				racks[rack] = struct{}{}
+				replicas = append(replicas, h)
+				replicasInDC[dc]++
+
+				if len(racks) == len(dcRacks[dc]) {
+					// if we have been through all the racks, drain the rest of the skipped
+					// hosts until we have RF. The next iteration will skip in the block
+					// above
+					skippedHosts := skipped[dc]
+					var k int
+					for ; k < len(skippedHosts) && replicasInDC[dc] < rf; k++ {
+						sh := skippedHosts[k]
+						replicas = append(replicas, sh)
+						replicasInDC[dc]++
+					}
+					skipped[dc] = skippedHosts[k:]
+				}
+			} else {
+				// already seen this rack, keep hold of this host incase
+				// we dont get enough for rf
+				skipped[dc] = append(skipped[dc], h)
+			}
+		}
+
+		if len(replicas) == 0 || replicas[0] != th.host {
+			panic("first replica is not the primary replica for the token")
+		}
+
+		tokenRing[th.token] = replicas
+	}
+
+	if len(tokenRing) != len(tokens) {
+		panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(tokenRing), len(tokens)))
+	}
+
+	return tokenRing
+}

+ 163 - 0
topology_test.go

@@ -0,0 +1,163 @@
+package gocql
+
+import (
+	"fmt"
+	"sort"
+	"testing"
+)
+
+func TestPlacementStrategy_SimpleStrategy(t *testing.T) {
+	host0 := &HostInfo{hostId: "0"}
+	host25 := &HostInfo{hostId: "25"}
+	host50 := &HostInfo{hostId: "50"}
+	host75 := &HostInfo{hostId: "75"}
+
+	tokenRing := []hostToken{
+		{intToken(0), host0},
+		{intToken(25), host25},
+		{intToken(50), host50},
+		{intToken(75), host75},
+	}
+
+	hosts := []*HostInfo{host0, host25, host50, host75}
+
+	strat := &simpleStrategy{rf: 2}
+	tokenReplicas := strat.replicaMap(hosts, tokenRing)
+	if len(tokenReplicas) != len(tokenRing) {
+		t.Fatalf("expected replica map to have %d items but has %d", len(tokenRing), len(tokenReplicas))
+	}
+
+	for token, replicas := range tokenReplicas {
+		if len(replicas) != strat.rf {
+			t.Errorf("expected to have %d replicas got %d for token=%v", strat.rf, len(replicas), token)
+		}
+	}
+
+	for i, token := range tokenRing {
+		replicas, ok := tokenReplicas[token.token]
+		if !ok {
+			t.Errorf("token %v not in replica map", token)
+		}
+
+		for j, replica := range replicas {
+			exp := tokenRing[(i+j)%len(tokenRing)].host
+			if exp != replica {
+				t.Errorf("expected host %v to be a replica of %v got %v", exp, token, replica)
+			}
+		}
+	}
+}
+
+func TestPlacementStrategy_NetworkStrategy(t *testing.T) {
+	var (
+		hosts  []*HostInfo
+		tokens []hostToken
+	)
+
+	const (
+		totalDCs   = 3
+		racksPerDC = 3
+		hostsPerDC = 5
+	)
+
+	dcRing := make(map[string][]hostToken, totalDCs)
+	for i := 0; i < totalDCs; i++ {
+		var dcTokens []hostToken
+		dc := fmt.Sprintf("dc%d", i+1)
+
+		for j := 0; j < hostsPerDC; j++ {
+			rack := fmt.Sprintf("rack%d", (j%racksPerDC)+1)
+
+			h := &HostInfo{hostId: fmt.Sprintf("%s:%s:%d", dc, rack, j), dataCenter: dc, rack: rack}
+
+			token := hostToken{
+				token: orderedToken([]byte(h.hostId)),
+				host:  h,
+			}
+
+			tokens = append(tokens, token)
+			dcTokens = append(dcTokens, token)
+
+			hosts = append(hosts, h)
+		}
+
+		sort.Sort(&tokenRing{tokens: dcTokens})
+		dcRing[dc] = dcTokens
+	}
+
+	if len(tokens) != hostsPerDC*totalDCs {
+		t.Fatalf("expected %d tokens in the ring got %d", hostsPerDC*totalDCs, len(tokens))
+	}
+	sort.Sort(&tokenRing{tokens: tokens})
+
+	strat := &networkTopology{
+		dcs: map[string]int{
+			"dc1": 1,
+			"dc2": 2,
+			"dc3": 3,
+		},
+	}
+
+	var expReplicas int
+	for _, rf := range strat.dcs {
+		expReplicas += rf
+	}
+
+	tokenReplicas := strat.replicaMap(hosts, tokens)
+	if len(tokenReplicas) != len(tokens) {
+		t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas))
+	}
+
+	for token, replicas := range tokenReplicas {
+		if len(replicas) != expReplicas {
+			t.Fatalf("expected to have %d replicas got %d for token=%v", expReplicas, len(replicas), token)
+		}
+	}
+
+	for dc, rf := range strat.dcs {
+		dcTokens := dcRing[dc]
+		for i, th := range dcTokens {
+			token := th.token
+			allReplicas, ok := tokenReplicas[token]
+			if !ok {
+				t.Fatalf("token %v not in replica map", token)
+			}
+
+			var replicas []*HostInfo
+			for _, replica := range allReplicas {
+				if replica.dataCenter == dc {
+					replicas = append(replicas, replica)
+				}
+			}
+
+			if len(replicas) != rf {
+				t.Fatalf("expected %d replicas in dc %q got %d", rf, dc, len(replicas))
+			}
+
+			var lastRack string
+			for j, replica := range replicas {
+				// expected is in the next rack
+				var exp *HostInfo
+				if lastRack == "" {
+					// primary, first replica
+					exp = dcTokens[(i+j)%len(dcTokens)].host
+				} else {
+					for k := 0; k < len(dcTokens); k++ {
+						// walk around the ring from i + j to find the next host the
+						// next rack
+						p := (i + j + k) % len(dcTokens)
+						h := dcTokens[p].host
+						if h.rack != lastRack {
+							exp = h
+							break
+						}
+					}
+					if exp.rack == lastRack {
+						panic("no more racks")
+					}
+				}
+				lastRack = replica.rack
+			}
+		}
+	}
+}