Prechádzať zdrojové kódy

policy: roundrobbin should try all up hosts (#1351)

Instead of keep a global position to iterate around a ring of hosts just
shuffle all the hosts and return hosts which are up.

Fix finding the correct token in the token ring for host selection.

Simplify the tests so that they only test things they are inteded to
test, let other tests verify that DCAwareRR works for example.
Chris Bannister 6 rokov pred
rodič
commit
7b17705d75
3 zmenil súbory, kde vykonal 135 pridanie a 268 odobranie
  1. 56 45
      policies.go
  2. 68 222
      policies_test.go
  3. 11 1
      topology.go

+ 56 - 45
policies.go

@@ -6,6 +6,8 @@ package gocql
 
 import (
 	"context"
+	crand "crypto/rand"
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"math"
@@ -334,8 +336,6 @@ func RoundRobinHostPolicy() HostSelectionPolicy {
 
 type roundRobinHostPolicy struct {
 	hosts cowHostList
-	pos   uint32
-	mu    sync.RWMutex
 }
 
 func (r *roundRobinHostPolicy) IsLocal(*HostInfo) bool              { return true }
@@ -344,25 +344,16 @@ func (r *roundRobinHostPolicy) SetPartitioner(partitioner string)   {}
 func (r *roundRobinHostPolicy) Init(*Session)                       {}
 
 func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost {
-	// i is used to limit the number of attempts to find a host
-	// to the number of hosts known to this policy
-	var i int
-	return func() SelectedHost {
-		hosts := r.hosts.get()
-		if len(hosts) == 0 {
-			return nil
-		}
+	src := r.hosts.get()
+	hosts := make([]*HostInfo, len(src))
+	copy(hosts, src)
 
-		// always increment pos to evenly distribute traffic in case of
-		// failures
-		pos := atomic.AddUint32(&r.pos, 1) - 1
-		if i >= len(hosts) {
-			return nil
-		}
-		host := hosts[(pos)%uint32(len(hosts))]
-		i++
-		return (*selectedHost)(host)
-	}
+	rand := rand.New(randSource())
+	rand.Shuffle(len(hosts), func(i, j int) {
+		hosts[i], hosts[j] = hosts[j], hosts[i]
+	})
+
+	return roundRobbin(hosts)
 }
 
 func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
@@ -585,8 +576,8 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 
 	token := meta.tokenRing.partitioner.Hash(routingKey)
 	ht := meta.replicas[qry.Keyspace()].replicasFor(token)
-	var replicas []*HostInfo
 
+	var replicas []*HostInfo
 	if ht == nil {
 		host, _ := meta.tokenRing.GetHostForToken(token)
 		replicas = []*HostInfo{host}
@@ -792,8 +783,6 @@ func (host selectedHostPoolHost) Mark(err error) {
 
 type dcAwareRR struct {
 	local       string
-	pos         uint32
-	mu          sync.RWMutex
 	localHosts  cowHostList
 	remoteHosts cowHostList
 }
@@ -814,7 +803,7 @@ func (d *dcAwareRR) IsLocal(host *HostInfo) bool {
 }
 
 func (d *dcAwareRR) AddHost(host *HostInfo) {
-	if host.DataCenter() == d.local {
+	if d.IsLocal(host) {
 		d.localHosts.add(host)
 	} else {
 		d.remoteHosts.add(host)
@@ -822,7 +811,7 @@ func (d *dcAwareRR) AddHost(host *HostInfo) {
 }
 
 func (d *dcAwareRR) RemoveHost(host *HostInfo) {
-	if host.DataCenter() == d.local {
+	if d.IsLocal(host) {
 		d.localHosts.remove(host.ConnectAddress())
 	} else {
 		d.remoteHosts.remove(host.ConnectAddress())
@@ -832,33 +821,55 @@ func (d *dcAwareRR) RemoveHost(host *HostInfo) {
 func (d *dcAwareRR) HostUp(host *HostInfo)   { d.AddHost(host) }
 func (d *dcAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) }
 
-func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
+var randSeed int64
+
+func init() {
+	p := make([]byte, 8)
+	if _, err := crand.Read(p); err != nil {
+		panic(err)
+	}
+	randSeed = int64(binary.BigEndian.Uint64(p))
+}
+
+func randSource() rand.Source {
+	return rand.NewSource(atomic.AddInt64(&randSeed, 1))
+}
+
+func roundRobbin(hosts []*HostInfo) NextHost {
 	var i int
 	return func() SelectedHost {
-		var hosts []*HostInfo
-		localHosts := d.localHosts.get()
-		remoteHosts := d.remoteHosts.get()
-		if len(localHosts) != 0 {
-			hosts = localHosts
-		} else {
-			hosts = remoteHosts
-		}
-		if len(hosts) == 0 {
-			return nil
-		}
+		for i < len(hosts) {
+			h := hosts[i]
+			i++
 
-		// always increment pos to evenly distribute traffic in case of
-		// failures
-		pos := atomic.AddUint32(&d.pos, 1) - 1
-		if i >= len(localHosts)+len(remoteHosts) {
-			return nil
+			if h.IsUp() {
+				return (*selectedHost)(h)
+			}
 		}
-		host := hosts[(pos)%uint32(len(hosts))]
-		i++
-		return (*selectedHost)(host)
+
+		return nil
 	}
 }
 
+func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
+	local := d.localHosts.get()
+	remote := d.remoteHosts.get()
+	hosts := make([]*HostInfo, len(local)+len(remote))
+	n := copy(hosts, local)
+	copy(hosts[n:], remote)
+
+	// TODO: use random chose-2 but that will require plumbing information
+	// about connection/host load to here
+	r := rand.New(randSource())
+	for _, l := range [][]*HostInfo{local, remote} {
+		r.Shuffle(len(l), func(i, j int) {
+			l[i], l[j] = l[j], l[i]
+		})
+	}
+
+	return roundRobbin(hosts)
+}
+
 // ConvictionPolicy interface is used by gocql to determine if a host should be
 // marked as DOWN based on the error and host info
 type ConvictionPolicy interface {

+ 68 - 222
policies_test.go

@@ -15,7 +15,7 @@ import (
 )
 
 // Tests of the round-robin host selection policy implementation
-func TestHostPolicy_RoundRobin(t *testing.T) {
+func TestRoundRobbin(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
 	hosts := [...]*HostInfo{
@@ -27,34 +27,23 @@ func TestHostPolicy_RoundRobin(t *testing.T) {
 		policy.AddHost(host)
 	}
 
-	// interleaved iteration should always increment the host
-	iterA := policy.Pick(nil)
-	if actual := iterA(); actual.Info() != hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
-	}
-	iterB := policy.Pick(nil)
-	if actual := iterB(); actual.Info() != hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
-	}
-	if actual := iterB(); actual.Info() != hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
-	}
-	if actual := iterA(); actual.Info() != hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
-	}
-
-	iterC := policy.Pick(nil)
-	if actual := iterC(); actual.Info() != hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
+	got := make(map[string]bool)
+	it := policy.Pick(nil)
+	for h := it(); h != nil; h = it() {
+		id := h.Info().hostId
+		if got[id] {
+			t.Fatalf("got duplicate host: %v", id)
+		}
+		got[id] = true
 	}
-	if actual := iterC(); actual.Info() != hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
+	if len(got) != len(hosts) {
+		t.Fatalf("expected %d hosts got %d", len(hosts), len(got))
 	}
 }
 
 // Tests of the token-aware host selection policy implementation with a
 // round-robin host selection policy fallback.
-func TestHostPolicy_TokenAware(t *testing.T) {
+func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) {
 	const keyspace = "myKeyspace"
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 	policyInternal := policy.(*tokenAwareHostPolicy)
@@ -77,26 +66,15 @@ func TestHostPolicy_TokenAware(t *testing.T) {
 
 	// set the hosts
 	hosts := [...]*HostInfo{
-		{connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"00"}},
-		{connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"25"}},
-		{connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"50"}},
-		{connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"75"}},
+		{hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"00"}},
+		{hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"25"}},
+		{hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"50"}},
+		{hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"75"}},
 	}
-	for _, host := range hosts {
+	for _, host := range &hosts {
 		policy.AddHost(host)
 	}
 
-	// the token ring is not setup without the partitioner, but the fallback
-	// should work
-	if actual := policy.Pick(nil)(); !actual.Info().ConnectAddress().Equal(hosts[0].ConnectAddress()) {
-		t.Errorf("Expected peer 0 but was %s", actual.Info().ConnectAddress())
-	}
-
-	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); !actual.Info().ConnectAddress().Equal(hosts[1].ConnectAddress()) {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().ConnectAddress())
-	}
-
 	policy.SetPartitioner("OrderedPartitioner")
 
 	policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) {
@@ -128,19 +106,8 @@ func TestHostPolicy_TokenAware(t *testing.T) {
 	// now the token ring is configured
 	query.RoutingKey([]byte("20"))
 	iter = policy.Pick(query)
-	if actual := iter(); !actual.Info().ConnectAddress().Equal(hosts[1].ConnectAddress()) {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().ConnectAddress())
-	}
-	// rest are round robin
-	if actual := iter(); !actual.Info().ConnectAddress().Equal(hosts[2].ConnectAddress()) {
-		t.Errorf("Expected peer 2 but was %s", actual.Info().ConnectAddress())
-	}
-	if actual := iter(); !actual.Info().ConnectAddress().Equal(hosts[3].ConnectAddress()) {
-		t.Errorf("Expected peer 3 but was %s", actual.Info().ConnectAddress())
-	}
-	if actual := iter(); !actual.Info().ConnectAddress().Equal(hosts[0].ConnectAddress()) {
-		t.Errorf("Expected peer 0 but was %s", actual.Info().ConnectAddress())
-	}
+	iterCheck(t, iter, "0")
+	iterCheck(t, iter, "1")
 }
 
 // Tests of the host pool host selection policy implementation
@@ -427,36 +394,34 @@ func TestHostPolicy_DCAwareRR(t *testing.T) {
 		p.AddHost(host)
 	}
 
-	// interleaved iteration should always increment the host
-	iterA := p.Pick(nil)
-	if actual := iterA(); actual.Info() != hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
-	}
-	iterB := p.Pick(nil)
-	if actual := iterB(); actual.Info() != hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
-	}
-	if actual := iterB(); actual.Info() != hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
-	}
-	if actual := iterA(); actual.Info() != hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
-	}
-	iterC := p.Pick(nil)
-	if actual := iterC(); actual.Info() != hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
-	}
-	p.RemoveHost(hosts[0])
-	if actual := iterC(); actual.Info() != hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
+	got := make(map[string]bool, len(hosts))
+	var dcs []string
+
+	it := p.Pick(nil)
+	for h := it(); h != nil; h = it() {
+		id := h.Info().hostId
+		dc := h.Info().dataCenter
+
+		if got[id] {
+			t.Fatalf("got duplicate host %s", id)
+		}
+		got[id] = true
+		dcs = append(dcs, dc)
 	}
-	p.RemoveHost(hosts[1])
-	iterD := p.Pick(nil)
-	if actual := iterD(); actual.Info() != hosts[2] {
-		t.Errorf("Expected hosts[2] but was hosts[%s]", actual.Info().HostID())
+
+	if len(got) != len(hosts) {
+		t.Fatalf("expected %d hosts got %d", len(hosts), len(got))
 	}
-	if actual := iterD(); actual.Info() != hosts[3] {
-		t.Errorf("Expected hosts[3] but was hosts[%s]", actual.Info().HostID())
+
+	var remote bool
+	for _, dc := range dcs {
+		if dc == "local" {
+			if remote {
+				t.Fatalf("got local dc after remote: %v", dcs)
+			}
+		} else {
+			remote = true
+		}
 	}
 
 }
@@ -464,7 +429,7 @@ func TestHostPolicy_DCAwareRR(t *testing.T) {
 // Tests of the token-aware host selection policy implementation with a
 // DC aware round-robin host selection policy fallback
 // with {"class": "NetworkTopologyStrategy", "a": 1, "b": 1, "c": 1} replication.
-func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) {
+func TestHostPolicy_TokenAware(t *testing.T) {
 	const keyspace = "myKeyspace"
 	policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local"))
 	policyInternal := policy.(*tokenAwareHostPolicy)
@@ -506,13 +471,13 @@ func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) {
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
-	if actual := policy.Pick(nil)(); actual.Info().HostID() != "1" {
-		t.Fatalf("Expected host 1 but was %s", actual.Info().HostID())
+	if actual := policy.Pick(nil)(); actual == nil {
+		t.Fatal("expected to get host from fallback got nil")
 	}
 
 	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().HostID() != "4" {
-		t.Fatalf("Expected peer 4 but was %s", actual.Info().HostID())
+	if actual := policy.Pick(query)(); actual == nil {
+		t.Fatal("expected to get host from fallback got nil")
 	}
 
 	policy.SetPartitioner("OrderedPartitioner")
@@ -558,25 +523,23 @@ func TestHostPolicy_TokenAware_DCAwareRR(t *testing.T) {
 	iter = policy.Pick(query)
 	// first should be host with matching token from the local DC
 	iterCheck(t, iter, "4")
-	// rest are according DCAwareRR from local DC only, starting with 7 as the fallback was used twice above
-	iterCheck(t, iter, "7")
-	iterCheck(t, iter, "10")
-	iterCheck(t, iter, "1")
+	// next are in non deterministic order
 }
 
 // Tests of the token-aware host selection policy implementation with a
 // DC aware round-robin host selection policy fallback
 // with {"class": "NetworkTopologyStrategy", "a": 2, "b": 2, "c": 2} replication.
-func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
-	policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local"))
+func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) {
+	const keyspace = "myKeyspace"
+	policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local"), NonLocalReplicasFallback())
 	policyInternal := policy.(*tokenAwareHostPolicy)
-	policyInternal.getKeyspaceName = func() string { return "myKeyspace" }
+	policyInternal.getKeyspaceName = func() string { return keyspace }
 	policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) {
 		return nil, errors.New("not initialized")
 	}
 
 	query := &Query{}
-	query.getKeyspace = func() string { return "myKeyspace" }
+	query.getKeyspace = func() string { return keyspace }
 
 	iter := policy.Pick(nil)
 	if iter == nil {
@@ -592,12 +555,12 @@ func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
 		{hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"05"}, dataCenter: "remote1"},
 		{hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: "local"},
 		{hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "remote2"},
-		{hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"},
-		{hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: "local"},
-		{hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote2"},
-		{hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "remote1"},
-		{hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: "local"},
-		{hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote2"},
+		{hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"}, // 1
+		{hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: "local"},   // 2
+		{hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote2"}, // 3
+		{hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "remote1"}, // 4
+		{hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: "local"},   // 5
+		{hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote2"}, // 6
 		{hostId: "9", connectAddress: net.IPv4(10, 0, 0, 10), tokens: []string{"50"}, dataCenter: "remote1"},
 		{hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: "local"},
 		{hostId: "11", connectAddress: net.IPv4(10, 0, 0, 12), tokens: []string{"60"}, dataCenter: "remote2"},
@@ -606,29 +569,14 @@ func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
 		policy.AddHost(host)
 	}
 
-	// the token ring is not setup without the partitioner, but the fallback
-	// should work
-	if actual := policy.Pick(nil)(); actual.Info().HostID() != "1" {
-		t.Errorf("Expected host 1 but was %s", actual.Info().HostID())
-	}
-
-	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().HostID() != "4" {
-		t.Errorf("Expected peer 4 but was %s", actual.Info().HostID())
-	}
-
-	// advance the index in round-robin so that the next expected value does not overlap with the one selected by token.
-	policy.Pick(query)()
-	policy.Pick(query)()
-
 	policy.SetPartitioner("OrderedPartitioner")
 
 	policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) {
-		if keyspaceName != "myKeyspace" {
+		if keyspaceName != keyspace {
 			return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName)
 		}
 		return &KeyspaceMetadata{
-			Name:          "myKeyspace",
+			Name:          keyspace,
 			StrategyClass: "NetworkTopologyStrategy",
 			StrategyOptions: map[string]interface{}{
 				"class":   "NetworkTopologyStrategy",
@@ -638,12 +586,12 @@ func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
 			},
 		}, nil
 	}
-	policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: "myKeyspace"})
+	policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace})
 
 	// The NetworkTopologyStrategy above should generate the following replicas.
 	// It's handy to have as reference here.
 	assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{
-		"myKeyspace": {
+		keyspace: {
 			{orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3], hosts[4], hosts[5]}},
 			{orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4], hosts[5], hosts[6]}},
 			{orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5], hosts[6], hosts[7]}},
@@ -665,111 +613,9 @@ func TestHostPolicy_TokenAware_DCAwareRR2(t *testing.T) {
 	// first should be hosts with matching token from the local DC
 	iterCheck(t, iter, "4")
 	iterCheck(t, iter, "7")
-	// rest are according DCAwareRR from local DC only, starting with 7 as the fallback was used twice above
-	iterCheck(t, iter, "1")
-	iterCheck(t, iter, "10")
-}
-
-// Tests of the token-aware host selection policy implementation with a
-// DC aware round-robin host selection policy fallback with NonLocalReplicasFallback option enabled.
-func TestHostPolicy_TokenAware_DCAwareRR_NonLocalFallback(t *testing.T) {
-	const (
-		keyspace = "myKeyspace"
-		localDC  = "local"
-	)
-	policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy(localDC), NonLocalReplicasFallback())
-	policyInternal := policy.(*tokenAwareHostPolicy)
-	policyInternal.getKeyspaceName = func() string { return keyspace }
-	policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) {
-		return nil, errors.New("not initialized")
-	}
-
-	query := &Query{}
-	query.getKeyspace = func() string { return keyspace }
-
-	iter := policy.Pick(nil)
-	if iter == nil {
-		t.Fatal("host iterator was nil")
-	}
-	actual := iter()
-	if actual != nil {
-		t.Fatalf("expected nil from iterator, but was %v", actual)
-	}
-
-	// set the hosts
-	hosts := [...]*HostInfo{
-		{hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"05"}, dataCenter: "remote1"},
-		{hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: localDC},
-		{hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "remote2"}, // 1
-		{hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"}, // 2
-		{hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: localDC},   // 3
-		{hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote2"},
-		{hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "remote1"},
-		{hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: localDC},
-		{hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote2"},
-		{hostId: "9", connectAddress: net.IPv4(10, 0, 0, 10), tokens: []string{"50"}, dataCenter: "remote1"},
-		{hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: localDC},
-		{hostId: "11", connectAddress: net.IPv4(10, 0, 0, 12), tokens: []string{"60"}, dataCenter: "remote2"},
-	}
-	for _, host := range hosts {
-		policy.AddHost(host)
-	}
-
-	// the token ring is not setup without the partitioner, but the fallback
-	// should work
-	iterCheck(t, policy.Pick(nil), "1")
-
-	query.RoutingKey([]byte("30"))
-	iterCheck(t, policy.Pick(query), "4")
-
-	policy.SetPartitioner("OrderedPartitioner")
-
-	policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) {
-		if keyspaceName != keyspace {
-			return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName)
-		}
-		return &KeyspaceMetadata{
-			Name:          keyspace,
-			StrategyClass: "NetworkTopologyStrategy",
-			StrategyOptions: map[string]interface{}{
-				"class":   "NetworkTopologyStrategy",
-				localDC:   1,
-				"remote1": 1,
-				"remote2": 1,
-			},
-		}, nil
-	}
-	policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace})
-
-	// The NetworkTopologyStrategy above should generate the following replicas.
-	// It's handy to have as reference here.
-	assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{
-		keyspace: {
-			{orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2]}},
-			{orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3]}},
-			{orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4]}},
-			{orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5]}},
-			{orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6]}},
-			{orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7]}},
-			{orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8]}},
-			{orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9]}},
-			{orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10]}},
-			{orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11]}},
-			{orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0]}},
-			{orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1]}},
-		},
-	}, policyInternal.getMetadataReadOnly().replicas)
-
-	// now the token ring is configured
-	query.RoutingKey([]byte("18"))
-	iter = policy.Pick(query)
-	// first should be host with matching token from the local DC
-	iterCheck(t, iter, "4")
 	// rest should be hosts with matching token from remote DCs
 	iterCheck(t, iter, "3")
 	iterCheck(t, iter, "5")
-	// rest are according DCAwareRR from local DC only, starting with 7 as the fallback was used twice above
-	iterCheck(t, iter, "7")
-	iterCheck(t, iter, "10")
-	iterCheck(t, iter, "1")
+	iterCheck(t, iter, "6")
+	iterCheck(t, iter, "8")
 }

+ 11 - 1
topology.go

@@ -26,9 +26,20 @@ func (h tokenRingReplicas) replicasFor(t token) *hostTokens {
 	p := sort.Search(len(h), func(i int) bool {
 		return !h[i].token.Less(t)
 	})
+
+	// TODO: simplify this
+	if p < len(h) && h[p].token == t {
+		return &h[p]
+	}
+
+	p--
+
 	if p >= len(h) {
 		// rollover
 		p = 0
+	} else if p < 0 {
+		// rollunder
+		p = len(h) - 1
 	}
 
 	return &h[p]
@@ -100,7 +111,6 @@ func (s *simpleStrategy) replicaMap(tokenRing *tokenRing) tokenRingReplicas {
 		seen := make(map[*HostInfo]bool)
 
 		for j := 0; j < len(tokens) && len(replicas) < s.rf; j++ {
-			// TODO: need to ensure we dont add the same hosts twice
 			h := tokens[(i+j)%len(tokens)]
 			if !seen[h.host] {
 				replicas = append(replicas, h.host)