Переглянути джерело

topology: replace token map with binary search (#1350)

We are not going to have every token in a map, instead binary search to
find the closest host which owns the token. This remove the need to
first binary search the tokenRing to find an actual token to lookup in
the map.

Fix SimpleStrategy not placing replicas on unique hosts.

Dont return duplicate hosts from fallbacks from TokenAwarePolicy

go fmt ./...
Chris Bannister 6 роки тому
батько
коміт
73f35ffa6d
12 змінених файлів з 299 додано та 293 видалено
  1. 10 5
      common_test.go
  2. 2 2
      conn_test.go
  3. 6 6
      control.go
  4. 2 2
      host_source.go
  5. 8 8
      marshal.go
  6. 28 28
      marshal_test.go
  7. 27 28
      policies.go
  8. 118 161
      policies_test.go
  9. 11 11
      session.go
  10. 8 5
      token.go
  11. 59 17
      topology.go
  12. 20 20
      topology_test.go

+ 10 - 5
common_test.go

@@ -224,31 +224,36 @@ func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
 }
 
 func assertTrue(t *testing.T, description string, value bool) {
+	t.Helper()
 	if !value {
-		t.Errorf("expected %s to be true", description)
+		t.Fatalf("expected %s to be true", description)
 	}
 }
 
 func assertEqual(t *testing.T, description string, expected, actual interface{}) {
+	t.Helper()
 	if expected != actual {
-		t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
+		t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
 	}
 }
 
 func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) {
+	t.Helper()
 	if !reflect.DeepEqual(expected, actual) {
-		t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
+		t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
 	}
 }
 
 func assertNil(t *testing.T, description string, actual interface{}) {
+	t.Helper()
 	if actual != nil {
-		t.Errorf("expected %s to be (nil) but was (%+v) instead", description, actual)
+		t.Fatalf("expected %s to be (nil) but was (%+v) instead", description, actual)
 	}
 }
 
 func assertNotNil(t *testing.T, description string, actual interface{}) {
+	t.Helper()
 	if actual == nil {
-		t.Errorf("expected %s not to be (nil)", description)
+		t.Fatalf("expected %s not to be (nil)", description)
 	}
 }

+ 2 - 2
conn_test.go

@@ -47,8 +47,8 @@ func TestApprove(t *testing.T) {
 
 func TestJoinHostPort(t *testing.T) {
 	tests := map[string]string{
-		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
-		"127.0.0.1:1":                                 JoinHostPort("127.0.0.1:1", 9142),
+		"127.0.0.1:0": JoinHostPort("127.0.0.1", 0),
+		"127.0.0.1:1": JoinHostPort("127.0.0.1:1", 9142),
 		"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:0": JoinHostPort("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 0),
 		"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1": JoinHostPort("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1", 9142),
 	}

+ 6 - 6
control.go

@@ -149,14 +149,14 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
 }
 
 func shuffleHosts(hosts []*HostInfo) []*HostInfo {
-	mutRandr.Lock()
-	perm := randr.Perm(len(hosts))
-	mutRandr.Unlock()
 	shuffled := make([]*HostInfo, len(hosts))
+	copy(shuffled, hosts)
 
-	for i, host := range hosts {
-		shuffled[perm[i]] = host
-	}
+	mutRandr.Lock()
+	randr.Shuffle(len(hosts), func(i, j int) {
+		shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
+	})
+	mutRandr.Unlock()
 
 	return shuffled
 }

+ 2 - 2
host_source.go

@@ -110,7 +110,7 @@ type HostInfo struct {
 	// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
 	// that we are thread safe use a mutex to access all fields.
 	mu               sync.RWMutex
-	hostname 		 string
+	hostname         string
 	peer             net.IP
 	broadcastAddress net.IP
 	listenAddress    net.IP
@@ -128,7 +128,7 @@ type HostInfo struct {
 	clusterName      string
 	version          cassVersion
 	state            nodeState
-	schemaVersion	 string
+	schemaVersion    string
 	tokens           []string
 }
 

+ 8 - 8
marshal.go

@@ -1262,15 +1262,15 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error {
 		*v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC)
 		return nil
 	case *string:
-                if len(data) == 0 {
-                        *v = ""
-                        return nil
-                }
-                var origin uint32 = 1 << 31
-                var current uint32 = binary.BigEndian.Uint32(data)
-                timestamp := (int64(current) - int64(origin)) * 86400000
+		if len(data) == 0 {
+			*v = ""
+			return nil
+		}
+		var origin uint32 = 1 << 31
+		var current uint32 = binary.BigEndian.Uint32(data)
+		timestamp := (int64(current) - int64(origin)) * 86400000
 		*v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC).Format("2006-01-02")
-                return nil
+		return nil
 	}
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }

+ 28 - 28
marshal_test.go

@@ -1554,58 +1554,58 @@ func TestReadCollectionSize(t *testing.T) {
 	}
 
 	tests := []struct {
-		name string
-		info CollectionType
-		data []byte
-		isError bool
+		name         string
+		info         CollectionType
+		data         []byte
+		isError      bool
 		expectedSize int
 	}{
 		{
-			name: "short read 0 proto 2",
-			info: listV2,
-			data: []byte{},
+			name:    "short read 0 proto 2",
+			info:    listV2,
+			data:    []byte{},
 			isError: true,
 		},
 		{
-			name: "short read 1 proto 2",
-			info: listV2,
-			data: []byte{0x01},
+			name:    "short read 1 proto 2",
+			info:    listV2,
+			data:    []byte{0x01},
 			isError: true,
 		},
 		{
-			name: "good read proto 2",
-			info: listV2,
-			data: []byte{0x01, 0x38},
+			name:         "good read proto 2",
+			info:         listV2,
+			data:         []byte{0x01, 0x38},
 			expectedSize: 0x0138,
 		},
 		{
-			name: "short read 0 proto 3",
-			info: listV3,
-			data: []byte{},
+			name:    "short read 0 proto 3",
+			info:    listV3,
+			data:    []byte{},
 			isError: true,
 		},
 		{
-			name: "short read 1 proto 3",
-			info: listV3,
-			data: []byte{0x01},
+			name:    "short read 1 proto 3",
+			info:    listV3,
+			data:    []byte{0x01},
 			isError: true,
 		},
 		{
-			name: "short read 2 proto 3",
-			info: listV3,
-			data: []byte{0x01, 0x38},
+			name:    "short read 2 proto 3",
+			info:    listV3,
+			data:    []byte{0x01, 0x38},
 			isError: true,
 		},
 		{
-			name: "short read 3 proto 3",
-			info: listV3,
-			data: []byte{0x01, 0x38, 0x42},
+			name:    "short read 3 proto 3",
+			info:    listV3,
+			data:    []byte{0x01, 0x38, 0x42},
 			isError: true,
 		},
 		{
-			name: "good read proto 3",
-			info: listV3,
-			data: []byte{0x01, 0x38, 0x42, 0x22},
+			name:         "good read proto 3",
+			info:         listV3,
+			data:         []byte{0x01, 0x38, 0x42, 0x22},
 			expectedSize: 0x01384222,
 		},
 	}

+ 27 - 28
policies.go

@@ -416,14 +416,14 @@ func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAware
 // and the pointer in clusterMeta updated to point to the new value.
 type clusterMeta struct {
 	// replicas is map[keyspace]map[token]hosts
-	replicas map[string]map[token][]*HostInfo
+	replicas  map[string]tokenRingReplicas
 	tokenRing *tokenRing
 }
 
 type tokenAwareHostPolicy struct {
-	fallback    HostSelectionPolicy
+	fallback            HostSelectionPolicy
 	getKeyspaceMetadata func(keyspace string) (*KeyspaceMetadata, error)
-	getKeyspaceName func() string
+	getKeyspaceName     func() string
 
 	shuffleReplicas          bool
 	nonLocalReplicasFallback bool
@@ -438,7 +438,7 @@ type tokenAwareHostPolicy struct {
 
 func (t *tokenAwareHostPolicy) Init(s *Session) {
 	t.getKeyspaceMetadata = s.KeyspaceMetadata
-	t.getKeyspaceName = func() string {return s.cfg.Keyspace}
+	t.getKeyspaceName = func() string { return s.cfg.Keyspace }
 }
 
 func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
@@ -457,15 +457,14 @@ func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) {
 // It must be called with t.mu mutex locked.
 // meta must not be nil and it's replicas field will be updated.
 func (t *tokenAwareHostPolicy) updateReplicas(meta *clusterMeta, keyspace string) {
-	newReplicas := make(map[string]map[token][]*HostInfo, len(meta.replicas))
+	newReplicas := make(map[string]tokenRingReplicas, len(meta.replicas))
 
 	ks, err := t.getKeyspaceMetadata(keyspace)
 	if err == nil {
 		strat := getStrategy(ks)
 		if strat != nil {
 			if meta != nil && meta.tokenRing != nil {
-				hosts := t.hosts.get()
-				newReplicas[keyspace] = strat.replicaMap(hosts, meta.tokenRing.tokens)
+				newReplicas[keyspace] = strat.replicaMap(meta.tokenRing)
 			}
 		}
 	}
@@ -567,14 +566,6 @@ func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo) {
 	m.tokenRing = tokenRing
 }
 
-func (m *clusterMeta) getReplicas(keyspace string, token token) ([]*HostInfo, bool) {
-	if m.replicas == nil {
-		return nil, false
-	}
-	replicas, ok := m.replicas[keyspace][token]
-	return replicas, ok
-}
-
 func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 	if qry == nil {
 		return t.fallback.Pick(qry)
@@ -592,22 +583,23 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 		return t.fallback.Pick(qry)
 	}
 
-	primaryEndpoint, token := meta.tokenRing.GetHostForPartitionKey(routingKey)
-	if primaryEndpoint == nil || token == nil {
-		return t.fallback.Pick(qry)
-	}
+	token := meta.tokenRing.partitioner.Hash(routingKey)
+	ht := meta.replicas[qry.Keyspace()].replicasFor(token)
+	var replicas []*HostInfo
 
-	replicas, ok := meta.getReplicas(qry.Keyspace(), token)
-	if !ok {
-		replicas = []*HostInfo{primaryEndpoint}
+	if ht == nil {
+		host, _ := meta.tokenRing.GetHostForToken(token)
+		replicas = []*HostInfo{host}
 	} else if t.shuffleReplicas {
 		replicas = shuffleHosts(replicas)
+	} else {
+		replicas = ht.hosts
 	}
 
 	var (
 		fallbackIter NextHost
-		i            int
-		j            int
+		i, j         int
+		remote       []*HostInfo
 	)
 
 	used := make(map[*HostInfo]bool, len(replicas))
@@ -616,18 +608,23 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 			h := replicas[i]
 			i++
 
-			if h.IsUp() && t.fallback.IsLocal(h) {
+			if !t.fallback.IsLocal(h) {
+				remote = append(remote, h)
+				continue
+			}
+
+			if h.IsUp() {
 				used[h] = true
 				return (*selectedHost)(h)
 			}
 		}
 
 		if t.nonLocalReplicasFallback {
-			for j < len(replicas) {
-				h := replicas[j]
+			for j < len(remote) {
+				h := remote[j]
 				j++
 
-				if h.IsUp() && !t.fallback.IsLocal(h) {
+				if h.IsUp() {
 					used[h] = true
 					return (*selectedHost)(h)
 				}
@@ -642,9 +639,11 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 		// filter the token aware selected hosts from the fallback hosts
 		for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() {
 			if !used[fallbackHost.Info()] {
+				used[fallbackHost.Info()] = true
 				return fallbackHost
 			}
 		}
+
 		return nil
 	}
 }

+ 118 - 161
policies_test.go

@@ -55,15 +55,16 @@ func TestHostPolicy_RoundRobin(t *testing.T) {
 // Tests of the token-aware host selection policy implementation with a
 // round-robin host selection policy fallback.
 func TestHostPolicy_TokenAware(t *testing.T) {
+	const keyspace = "myKeyspace"
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 	policyInternal := policy.(*tokenAwareHostPolicy)
-	policyInternal.getKeyspaceName = func() string {return "myKeyspace"}
+	policyInternal.getKeyspaceName = func() string { return keyspace }
 	policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) {
 		return nil, errors.New("not initalized")
 	}
 
 	query := &Query{}
-	query.getKeyspace = func() string{return "myKeyspace"}
+	query.getKeyspace = func() string { return keyspace }
 
 	iter := policy.Pick(nil)
 	if iter == nil {
@@ -99,28 +100,28 @@ func TestHostPolicy_TokenAware(t *testing.T) {
 	policy.SetPartitioner("OrderedPartitioner")
 
 	policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) {
-		if keyspaceName != "myKeyspace" {
+		if keyspaceName != keyspace {
 			return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName)
 		}
 		return &KeyspaceMetadata{
-			Name: "myKeyspace",
+			Name:          keyspace,
 			StrategyClass: "SimpleStrategy",
-			StrategyOptions: map[string]interface{} {
-				"class": "SimpleStrategy",
+			StrategyOptions: map[string]interface{}{
+				"class":              "SimpleStrategy",
 				"replication_factor": 2,
 			},
 		}, nil
 	}
-	policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: "myKeyspace"})
+	policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace})
 
 	// The SimpleStrategy above should generate the following replicas.
 	// It's handy to have as reference here.
-	assertDeepEqual(t, "replicas", map[string]map[token][]*HostInfo{
+	assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{
 		"myKeyspace": {
-			orderedToken("00"): {hosts[0], hosts[1]},
-			orderedToken("25"): {hosts[1], hosts[2]},
-			orderedToken("50"): {hosts[2], hosts[3]},
-			orderedToken("75"): {hosts[3], hosts[0]},
+			{orderedToken("00"), []*HostInfo{hosts[0], hosts[1]}},
+			{orderedToken("25"), []*HostInfo{hosts[1], hosts[2]}},
+			{orderedToken("50"), []*HostInfo{hosts[2], hosts[3]}},
+			{orderedToken("75"), []*HostInfo{hosts[3], hosts[0]}},
 		},
 	}, policyInternal.getMetadataReadOnly().replicas)
 
@@ -211,7 +212,7 @@ func TestHostPolicy_RoundRobin_NilHostInfo(t *testing.T) {
 func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) {
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 	policyInternal := policy.(*tokenAwareHostPolicy)
-	policyInternal.getKeyspaceName = func() string {return "myKeyspace"}
+	policyInternal.getKeyspaceName = func() string { return "myKeyspace" }
 	policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) {
 		return nil, errors.New("not initialized")
 	}
@@ -228,7 +229,7 @@ func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) {
 	policy.SetPartitioner("OrderedPartitioner")
 
 	query := &Query{}
-	query.getKeyspace = func() string {return "myKeyspace"}
+	query.getKeyspace = func() string { return "myKeyspace" }
 	query.RoutingKey([]byte("20"))
 
 	iter := policy.Pick(query)
@@ -400,6 +401,18 @@ func TestDowngradingConsistencyRetryPolicy(t *testing.T) {
 	}
 }
 
+func iterCheck(t *testing.T, iter NextHost, hostID string) {
+	t.Helper()
+
+	host := iter()
+	if host == nil || host.Info() == nil {
+		t.Fatalf("expected hostID %s got nil", hostID)
+	}
+	if host.Info().HostID() != hostID {
+		t.Fatalf("Expected peer %s but was %s", hostID, host.Info().HostID())
+	}
+}
+
 func TestHostPolicy_DCAwareRR(t *testing.T) {
 	p := DCAwareRoundRobinPolicy("local")
 
@@ -448,20 +461,20 @@ func TestHostPolicy_DCAwareRR(t *testing.T) {
 
 }
 
-
 // Tests of the token-aware host selection policy implementation with a
 // DC aware round-robin host selection policy fallback
 // with {"class": "NetworkTopologyStrategy", "a": 1, "b": 1, "c": 1} replication.
 func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) {
+	const keyspace = "myKeyspace"
 	policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local"))
 	policyInternal := policy.(*tokenAwareHostPolicy)
-	policyInternal.getKeyspaceName = func() string {return "myKeyspace"}
+	policyInternal.getKeyspaceName = func() string { return keyspace }
 	policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) {
 		return nil, errors.New("not initialized")
 	}
 
 	query := &Query{}
-	query.getKeyspace = func() string {return "myKeyspace"}
+	query.getKeyspace = func() string { return keyspace }
 
 	iter := policy.Pick(nil)
 	if iter == nil {
@@ -494,27 +507,26 @@ func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) {
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
 	if actual := policy.Pick(nil)(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected host 1 but was %s", actual.Info().HostID())
+		t.Fatalf("Expected host 1 but was %s", actual.Info().HostID())
 	}
 
 	query.RoutingKey([]byte("30"))
 	if actual := policy.Pick(query)(); actual.Info().HostID() != "4" {
-		t.Errorf("Expected peer 4 but was %s", actual.Info().HostID())
+		t.Fatalf("Expected peer 4 but was %s", actual.Info().HostID())
 	}
 
 	policy.SetPartitioner("OrderedPartitioner")
 
-
 	policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) {
-		if keyspaceName != "myKeyspace" {
+		if keyspaceName != keyspace {
 			return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName)
 		}
 		return &KeyspaceMetadata{
-			Name: "myKeyspace",
+			Name:          keyspace,
 			StrategyClass: "NetworkTopologyStrategy",
-			StrategyOptions: map[string]interface{} {
-				"class": "NetworkTopologyStrategy",
-				"local": 1,
+			StrategyOptions: map[string]interface{}{
+				"class":   "NetworkTopologyStrategy",
+				"local":   1,
 				"remote1": 1,
 				"remote2": 1,
 			},
@@ -524,20 +536,20 @@ func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) {
 
 	// The NetworkTopologyStrategy above should generate the following replicas.
 	// It's handy to have as reference here.
-	assertDeepEqual(t, "replicas", map[string]map[token][]*HostInfo{
+	assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{
 		"myKeyspace": {
-			orderedToken("05"): {hosts[0], hosts[1], hosts[2]},
-			orderedToken("10"): {hosts[1], hosts[2], hosts[3]},
-			orderedToken("15"): {hosts[2], hosts[3], hosts[4]},
-			orderedToken("20"): {hosts[3], hosts[4], hosts[5]},
-			orderedToken("25"): {hosts[4], hosts[5], hosts[6]},
-			orderedToken("30"): {hosts[5], hosts[6], hosts[7]},
-			orderedToken("35"): {hosts[6], hosts[7], hosts[8]},
-			orderedToken("40"): {hosts[7], hosts[8], hosts[9]},
-			orderedToken("45"): {hosts[8], hosts[9], hosts[10]},
-			orderedToken("50"): {hosts[9], hosts[10], hosts[11]},
-			orderedToken("55"): {hosts[10], hosts[11], hosts[0]},
-			orderedToken("60"): {hosts[11], hosts[0], hosts[1]},
+			{orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2]}},
+			{orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3]}},
+			{orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4]}},
+			{orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5]}},
+			{orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6]}},
+			{orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7]}},
+			{orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8]}},
+			{orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9]}},
+			{orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10]}},
+			{orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11]}},
+			{orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0]}},
+			{orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1]}},
 		},
 	}, policyInternal.getMetadataReadOnly().replicas)
 
@@ -545,29 +557,11 @@ func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) {
 	query.RoutingKey([]byte("23"))
 	iter = policy.Pick(query)
 	// first should be host with matching token from the local DC
-	if actual := iter(); actual.Info().HostID() != "4" {
-		t.Errorf("Expected peer 4 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, iter, "4")
 	// rest are according DCAwareRR from local DC only, starting with 7 as the fallback was used twice above
-	if actual := iter(); actual.Info().HostID() != "7" {
-		t.Errorf("Expected peer 7 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "10" {
-		t.Errorf("Expected peer 10 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().HostID())
-	}
-	// and it starts to repeat now without host 4...
-	if actual := iter(); actual.Info().HostID() != "7" {
-		t.Errorf("Expected peer 7 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "10" {
-		t.Errorf("Expected peer 10 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, iter, "7")
+	iterCheck(t, iter, "10")
+	iterCheck(t, iter, "1")
 }
 
 // Tests of the token-aware host selection policy implementation with a
@@ -576,13 +570,13 @@ func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) {
 func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
 	policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local"))
 	policyInternal := policy.(*tokenAwareHostPolicy)
-	policyInternal.getKeyspaceName = func() string {return "myKeyspace"}
+	policyInternal.getKeyspaceName = func() string { return "myKeyspace" }
 	policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) {
 		return nil, errors.New("not initialized")
 	}
 
 	query := &Query{}
-	query.getKeyspace = func() string {return "myKeyspace"}
+	query.getKeyspace = func() string { return "myKeyspace" }
 
 	iter := policy.Pick(nil)
 	if iter == nil {
@@ -634,11 +628,11 @@ func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
 			return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName)
 		}
 		return &KeyspaceMetadata{
-			Name: "myKeyspace",
+			Name:          "myKeyspace",
 			StrategyClass: "NetworkTopologyStrategy",
-			StrategyOptions: map[string]interface{} {
-				"class": "NetworkTopologyStrategy",
-				"local": 2,
+			StrategyOptions: map[string]interface{}{
+				"class":   "NetworkTopologyStrategy",
+				"local":   2,
 				"remote1": 2,
 				"remote2": 2,
 			},
@@ -648,20 +642,20 @@ func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
 
 	// The NetworkTopologyStrategy above should generate the following replicas.
 	// It's handy to have as reference here.
-	assertDeepEqual(t, "replicas", map[string]map[token][]*HostInfo{
+	assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{
 		"myKeyspace": {
-			orderedToken("05"): {hosts[0], hosts[1], hosts[2], hosts[3], hosts[4], hosts[5]},
-			orderedToken("10"): {hosts[1], hosts[2], hosts[3], hosts[4], hosts[5], hosts[6]},
-			orderedToken("15"): {hosts[2], hosts[3], hosts[4], hosts[5], hosts[6], hosts[7]},
-			orderedToken("20"): {hosts[3], hosts[4], hosts[5], hosts[6], hosts[7], hosts[8]},
-			orderedToken("25"): {hosts[4], hosts[5], hosts[6], hosts[7], hosts[8], hosts[9]},
-			orderedToken("30"): {hosts[5], hosts[6], hosts[7], hosts[8], hosts[9], hosts[10]},
-			orderedToken("35"): {hosts[6], hosts[7], hosts[8], hosts[9], hosts[10], hosts[11]},
-			orderedToken("40"): {hosts[7], hosts[8], hosts[9], hosts[10], hosts[11], hosts[0]},
-			orderedToken("45"): {hosts[8], hosts[9], hosts[10], hosts[11], hosts[0], hosts[1]},
-			orderedToken("50"): {hosts[9], hosts[10], hosts[11], hosts[0], hosts[1], hosts[2]},
-			orderedToken("55"): {hosts[10], hosts[11], hosts[0], hosts[1], hosts[2], hosts[3]},
-			orderedToken("60"): {hosts[11], hosts[0], hosts[1], hosts[2], hosts[3], hosts[4]},
+			{orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3], hosts[4], hosts[5]}},
+			{orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4], hosts[5], hosts[6]}},
+			{orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5], hosts[6], hosts[7]}},
+			{orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5], hosts[6], hosts[7], hosts[8]}},
+			{orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6], hosts[7], hosts[8], hosts[9]}},
+			{orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7], hosts[8], hosts[9], hosts[10]}},
+			{orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8], hosts[9], hosts[10], hosts[11]}},
+			{orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9], hosts[10], hosts[11], hosts[0]}},
+			{orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10], hosts[11], hosts[0], hosts[1]}},
+			{orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11], hosts[0], hosts[1], hosts[2]}},
+			{orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0], hosts[1], hosts[2], hosts[3]}},
+			{orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1], hosts[2], hosts[3], hosts[4]}},
 		},
 	}, policyInternal.getMetadataReadOnly().replicas)
 
@@ -669,40 +663,29 @@ func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
 	query.RoutingKey([]byte("23"))
 	iter = policy.Pick(query)
 	// first should be hosts with matching token from the local DC
-	if actual := iter(); actual.Info().HostID() != "4" {
-		t.Errorf("Expected peer 4 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "7" {
-		t.Errorf("Expected peer 7 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, iter, "4")
+	iterCheck(t, iter, "7")
 	// rest are according DCAwareRR from local DC only, starting with 7 as the fallback was used twice above
-	if actual := iter(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "10" {
-		t.Errorf("Expected peer 10 but was %s", actual.Info().HostID())
-	}
-	// and it starts to repeat now without host 4 and 7...
-	if actual := iter(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "10" {
-		t.Errorf("Expected peer 10 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, iter, "1")
+	iterCheck(t, iter, "10")
 }
 
 // Tests of the token-aware host selection policy implementation with a
 // DC aware round-robin host selection policy fallback with NonLocalReplicasFallback option enabled.
 func TestHostPolicy_TokenAware_DCAwareRR_NonLocalFallback(t *testing.T) {
-	policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local"), NonLocalReplicasFallback())
+	const (
+		keyspace = "myKeyspace"
+		localDC  = "local"
+	)
+	policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy(localDC), NonLocalReplicasFallback())
 	policyInternal := policy.(*tokenAwareHostPolicy)
-	policyInternal.getKeyspaceName = func() string {return "myKeyspace"}
+	policyInternal.getKeyspaceName = func() string { return keyspace }
 	policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) {
 		return nil, errors.New("not initialized")
 	}
 
 	query := &Query{}
-	query.getKeyspace = func() string {return "myKeyspace"}
+	query.getKeyspace = func() string { return keyspace }
 
 	iter := policy.Pick(nil)
 	if iter == nil {
@@ -716,16 +699,16 @@ func TestHostPolicy_TokenAware_DCAwareRR_NonLocalFallback(t *testing.T) {
 	// set the hosts
 	hosts := [...]*HostInfo{
 		{hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"05"}, dataCenter: "remote1"},
-		{hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: "local"},
-		{hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "remote2"},
-		{hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"},
-		{hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: "local"},
+		{hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: localDC},
+		{hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "remote2"}, // 1
+		{hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"}, // 2
+		{hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: localDC},   // 3
 		{hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote2"},
 		{hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "remote1"},
-		{hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: "local"},
+		{hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: localDC},
 		{hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote2"},
 		{hostId: "9", connectAddress: net.IPv4(10, 0, 0, 10), tokens: []string{"50"}, dataCenter: "remote1"},
-		{hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: "local"},
+		{hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: localDC},
 		{hostId: "11", connectAddress: net.IPv4(10, 0, 0, 12), tokens: []string{"60"}, dataCenter: "remote2"},
 	}
 	for _, host := range hosts {
@@ -734,50 +717,46 @@ func TestHostPolicy_TokenAware_DCAwareRR_NonLocalFallback(t *testing.T) {
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
-	if actual := policy.Pick(nil)(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected host 1 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, policy.Pick(nil), "1")
 
 	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().HostID() != "4" {
-		t.Errorf("Expected peer 4 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, policy.Pick(query), "4")
 
 	policy.SetPartitioner("OrderedPartitioner")
 
 	policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) {
-		if keyspaceName != "myKeyspace" {
+		if keyspaceName != keyspace {
 			return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName)
 		}
 		return &KeyspaceMetadata{
-			Name: "myKeyspace",
+			Name:          keyspace,
 			StrategyClass: "NetworkTopologyStrategy",
-			StrategyOptions: map[string]interface{} {
-				"class": "NetworkTopologyStrategy",
-				"local": 1,
+			StrategyOptions: map[string]interface{}{
+				"class":   "NetworkTopologyStrategy",
+				localDC:   1,
 				"remote1": 1,
 				"remote2": 1,
 			},
 		}, nil
 	}
-	policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: "myKeyspace"})
+	policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace})
 
 	// The NetworkTopologyStrategy above should generate the following replicas.
 	// It's handy to have as reference here.
-	assertDeepEqual(t, "replicas", map[string]map[token][]*HostInfo{
-		"myKeyspace": {
-			orderedToken("05"): {hosts[0], hosts[1], hosts[2]},
-			orderedToken("10"): {hosts[1], hosts[2], hosts[3]},
-			orderedToken("15"): {hosts[2], hosts[3], hosts[4]},
-			orderedToken("20"): {hosts[3], hosts[4], hosts[5]},
-			orderedToken("25"): {hosts[4], hosts[5], hosts[6]},
-			orderedToken("30"): {hosts[5], hosts[6], hosts[7]},
-			orderedToken("35"): {hosts[6], hosts[7], hosts[8]},
-			orderedToken("40"): {hosts[7], hosts[8], hosts[9]},
-			orderedToken("45"): {hosts[8], hosts[9], hosts[10]},
-			orderedToken("50"): {hosts[9], hosts[10], hosts[11]},
-			orderedToken("55"): {hosts[10], hosts[11], hosts[0]},
-			orderedToken("60"): {hosts[11], hosts[0], hosts[1]},
+	assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{
+		keyspace: {
+			{orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2]}},
+			{orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3]}},
+			{orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4]}},
+			{orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5]}},
+			{orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6]}},
+			{orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7]}},
+			{orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8]}},
+			{orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9]}},
+			{orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10]}},
+			{orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11]}},
+			{orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0]}},
+			{orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1]}},
 		},
 	}, policyInternal.getMetadataReadOnly().replicas)
 
@@ -785,34 +764,12 @@ func TestHostPolicy_TokenAware_DCAwareRR_NonLocalFallback(t *testing.T) {
 	query.RoutingKey([]byte("18"))
 	iter = policy.Pick(query)
 	// first should be host with matching token from the local DC
-	if actual := iter(); actual.Info().HostID() != "4" {
-		t.Errorf("Expected peer 4 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, iter, "4")
 	// rest should be hosts with matching token from remote DCs
-	if actual := iter(); actual.Info().HostID() != "3" {
-		t.Errorf("Expected peer 3 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "5" {
-		t.Errorf("Expected peer 5 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, iter, "3")
+	iterCheck(t, iter, "5")
 	// rest are according DCAwareRR from local DC only, starting with 7 as the fallback was used twice above
-	if actual := iter(); actual.Info().HostID() != "7" {
-		t.Errorf("Expected peer 7 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "10" {
-		t.Errorf("Expected peer 10 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().HostID())
-	}
-	// and it starts to repeat now without host 4...
-	if actual := iter(); actual.Info().HostID() != "7" {
-		t.Errorf("Expected peer 7 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "10" {
-		t.Errorf("Expected peer 10 but was %s", actual.Info().HostID())
-	}
-	if actual := iter(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().HostID())
-	}
+	iterCheck(t, iter, "7")
+	iterCheck(t, iter, "10")
+	iterCheck(t, iter, "1")
 }

+ 11 - 11
session.go

@@ -722,7 +722,7 @@ func (qm *queryMetrics) latency() int64 {
 	qm.l.Lock()
 	var (
 		attempts int
-		latency int64
+		latency  int64
 	)
 	for _, metric := range qm.m {
 		attempts += metric.Attempts
@@ -1509,16 +1509,16 @@ func NewBatch(typ BatchType) *Batch {
 func (s *Session) NewBatch(typ BatchType) *Batch {
 	s.mu.RLock()
 	batch := &Batch{
-		Type:              typ,
-		rt:                s.cfg.RetryPolicy,
-		serialCons:        s.cfg.SerialConsistency,
-		observer:          s.batchObserver,
-		session:           s,
-		Cons:              s.cons,
-		defaultTimestamp:  s.cfg.DefaultTimestamp,
-		keyspace:          s.cfg.Keyspace,
-		metrics:           &queryMetrics{m: make(map[string]*hostMetrics)},
-		spec:              &NonSpeculativeExecution{},
+		Type:             typ,
+		rt:               s.cfg.RetryPolicy,
+		serialCons:       s.cfg.SerialConsistency,
+		observer:         s.batchObserver,
+		session:          s,
+		Cons:             s.cons,
+		defaultTimestamp: s.cfg.DefaultTimestamp,
+		keyspace:         s.cfg.Keyspace,
+		metrics:          &queryMetrics{m: make(map[string]*hostMetrics)},
+		spec:             &NonSpeculativeExecution{},
 	}
 
 	s.mu.RUnlock()

+ 8 - 5
token.go

@@ -131,10 +131,13 @@ func (ht hostToken) String() string {
 type tokenRing struct {
 	partitioner partitioner
 	tokens      []hostToken
+	hosts       []*HostInfo
 }
 
 func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
-	tokenRing := &tokenRing{}
+	tokenRing := &tokenRing{
+		hosts: hosts,
+	}
 
 	if strings.HasSuffix(partitioner, "Murmur3Partitioner") {
 		tokenRing.partitioner = murmur3Partitioner{}
@@ -206,15 +209,15 @@ func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token
 	}
 
 	// find the primary replica
-	ringIndex := sort.Search(len(t.tokens), func(i int) bool {
+	p := sort.Search(len(t.tokens), func(i int) bool {
 		return !t.tokens[i].token.Less(token)
 	})
 
-	if ringIndex == len(t.tokens) {
+	if p == len(t.tokens) {
 		// wrap around to the first in the ring
-		ringIndex = 0
+		p = 0
 	}
 
-	v := t.tokens[ringIndex]
+	v := t.tokens[p]
 	return v.host, v.token
 }

+ 59 - 17
topology.go

@@ -2,12 +2,40 @@ package gocql
 
 import (
 	"fmt"
+	"sort"
 	"strconv"
 	"strings"
 )
 
+type hostTokens struct {
+	token token
+	hosts []*HostInfo
+}
+
+type tokenRingReplicas []hostTokens
+
+func (h tokenRingReplicas) Less(i, j int) bool { return h[i].token.Less(h[j].token) }
+func (h tokenRingReplicas) Len() int           { return len(h) }
+func (h tokenRingReplicas) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
+
+func (h tokenRingReplicas) replicasFor(t token) *hostTokens {
+	if len(h) == 0 {
+		return nil
+	}
+
+	p := sort.Search(len(h), func(i int) bool {
+		return !h[i].token.Less(t)
+	})
+	if p >= len(h) {
+		// rollover
+		p = 0
+	}
+
+	return &h[p]
+}
+
 type placementStrategy interface {
-	replicaMap(hosts []*HostInfo, tokens []hostToken) map[token][]*HostInfo
+	replicaMap(tokenRing *tokenRing) tokenRingReplicas
 	replicationFactor(dc string) int
 }
 
@@ -63,20 +91,29 @@ func (s *simpleStrategy) replicationFactor(dc string) int {
 	return s.rf
 }
 
-func (s *simpleStrategy) replicaMap(_ []*HostInfo, tokens []hostToken) map[token][]*HostInfo {
-	tokenRing := make(map[token][]*HostInfo, len(tokens))
+func (s *simpleStrategy) replicaMap(tokenRing *tokenRing) tokenRingReplicas {
+	tokens := tokenRing.tokens
+	ring := make(tokenRingReplicas, len(tokens))
 
 	for i, th := range tokens {
 		replicas := make([]*HostInfo, 0, s.rf)
+		seen := make(map[*HostInfo]bool)
+
 		for j := 0; j < len(tokens) && len(replicas) < s.rf; j++ {
 			// TODO: need to ensure we dont add the same hosts twice
 			h := tokens[(i+j)%len(tokens)]
-			replicas = append(replicas, h.host)
+			if !seen[h.host] {
+				replicas = append(replicas, h.host)
+				seen[h.host] = true
+			}
 		}
-		tokenRing[th.token] = replicas
+
+		ring[i] = hostTokens{th.token, replicas}
 	}
 
-	return tokenRing
+	sort.Sort(ring)
+
+	return ring
 }
 
 type networkTopology struct {
@@ -101,10 +138,10 @@ func (n *networkTopology) haveRF(replicaCounts map[string]int) bool {
 	return true
 }
 
-func (n *networkTopology) replicaMap(hosts []*HostInfo, tokens []hostToken) map[token][]*HostInfo {
-	dcRacks := make(map[string]map[string]struct{})
+func (n *networkTopology) replicaMap(tokenRing *tokenRing) tokenRingReplicas {
+	dcRacks := make(map[string]map[string]struct{}, len(n.dcs))
 
-	for _, h := range hosts {
+	for _, h := range tokenRing.hosts {
 		dc := h.DataCenter()
 		rack := h.Rack()
 
@@ -116,14 +153,15 @@ func (n *networkTopology) replicaMap(hosts []*HostInfo, tokens []hostToken) map[
 		racks[rack] = struct{}{}
 	}
 
-	tokenRing := make(map[token][]*HostInfo, len(tokens))
+	tokens := tokenRing.tokens
+	replicaRing := make(tokenRingReplicas, len(tokens))
 
 	var totalRF int
 	for _, rf := range n.dcs {
 		totalRF += rf
 	}
 
-	for i, th := range tokens {
+	for i, th := range tokenRing.tokens {
 		// number of replicas per dc
 		// TODO: recycle these
 		replicasInDC := make(map[string]int, len(n.dcs))
@@ -199,16 +237,20 @@ func (n *networkTopology) replicaMap(hosts []*HostInfo, tokens []hostToken) map[
 			}
 		}
 
-		if len(replicas) == 0 || replicas[0] != th.host {
-			panic("first replica is not the primary replica for the token")
+		if len(replicas) == 0 {
+			panic(fmt.Sprintf("no replicas for token: %v", th.token))
+		} else if !replicas[0].Equal(th.host) {
+			panic(fmt.Sprintf("first replica is not the primary replica for the token: expected %v got %v", replicas[0].ConnectAddress(), th.host.ConnectAddress()))
 		}
 
-		tokenRing[th.token] = replicas
+		replicaRing[i] = hostTokens{th.token, replicas}
 	}
 
-	if len(tokenRing) != len(tokens) {
-		panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(tokenRing), len(tokens)))
+	if len(replicaRing) != len(tokens) {
+		panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(replicaRing), len(tokens)))
 	}
 
-	return tokenRing
+	sort.Sort(replicaRing)
+
+	return replicaRing
 }

+ 20 - 20
topology_test.go

@@ -12,7 +12,7 @@ func TestPlacementStrategy_SimpleStrategy(t *testing.T) {
 	host50 := &HostInfo{hostId: "50"}
 	host75 := &HostInfo{hostId: "75"}
 
-	tokenRing := []hostToken{
+	tokens := []hostToken{
 		{intToken(0), host0},
 		{intToken(25), host25},
 		{intToken(50), host50},
@@ -22,27 +22,27 @@ func TestPlacementStrategy_SimpleStrategy(t *testing.T) {
 	hosts := []*HostInfo{host0, host25, host50, host75}
 
 	strat := &simpleStrategy{rf: 2}
-	tokenReplicas := strat.replicaMap(hosts, tokenRing)
-	if len(tokenReplicas) != len(tokenRing) {
-		t.Fatalf("expected replica map to have %d items but has %d", len(tokenRing), len(tokenReplicas))
+	tokenReplicas := strat.replicaMap(&tokenRing{hosts: hosts, tokens: tokens})
+	if len(tokenReplicas) != len(tokens) {
+		t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas))
 	}
 
-	for token, replicas := range tokenReplicas {
-		if len(replicas) != strat.rf {
-			t.Errorf("expected to have %d replicas got %d for token=%v", strat.rf, len(replicas), token)
+	for _, replicas := range tokenReplicas {
+		if len(replicas.hosts) != strat.rf {
+			t.Errorf("expected to have %d replicas got %d for token=%v", strat.rf, len(replicas.hosts), replicas.token)
 		}
 	}
 
-	for i, token := range tokenRing {
-		replicas, ok := tokenReplicas[token.token]
-		if !ok {
-			t.Errorf("token %v not in replica map", token)
+	for i, token := range tokens {
+		ht := tokenReplicas.replicasFor(token.token)
+		if ht.token != token.token {
+			t.Errorf("token %v not in replica map: %v", token, ht.hosts)
 		}
 
-		for j, replica := range replicas {
-			exp := tokenRing[(i+j)%len(tokenRing)].host
+		for j, replica := range ht.hosts {
+			exp := tokens[(i+j)%len(tokens)].host
 			if exp != replica {
-				t.Errorf("expected host %v to be a replica of %v got %v", exp, token, replica)
+				t.Errorf("expected host %v to be a replica of %v got %v", exp.hostId, token, replica.hostId)
 			}
 		}
 	}
@@ -103,14 +103,14 @@ func TestPlacementStrategy_NetworkStrategy(t *testing.T) {
 		expReplicas += rf
 	}
 
-	tokenReplicas := strat.replicaMap(hosts, tokens)
+	tokenReplicas := strat.replicaMap(&tokenRing{hosts: hosts, tokens: tokens})
 	if len(tokenReplicas) != len(tokens) {
 		t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas))
 	}
 
 	for token, replicas := range tokenReplicas {
-		if len(replicas) != expReplicas {
-			t.Fatalf("expected to have %d replicas got %d for token=%v", expReplicas, len(replicas), token)
+		if len(replicas.hosts) != expReplicas {
+			t.Fatalf("expected to have %d replicas got %d for token=%v", expReplicas, len(replicas.hosts), token)
 		}
 	}
 
@@ -118,13 +118,13 @@ func TestPlacementStrategy_NetworkStrategy(t *testing.T) {
 		dcTokens := dcRing[dc]
 		for i, th := range dcTokens {
 			token := th.token
-			allReplicas, ok := tokenReplicas[token]
-			if !ok {
+			allReplicas := tokenReplicas.replicasFor(token)
+			if allReplicas.token != token {
 				t.Fatalf("token %v not in replica map", token)
 			}
 
 			var replicas []*HostInfo
-			for _, replica := range allReplicas {
+			for _, replica := range allReplicas.hosts {
 				if replica.dataCenter == dc {
 					replicas = append(replicas, replica)
 				}