浏览代码

policy: tokenAwareHostPolicy host selection bug fix (#1261)

* policy: tokenAwareHostPolicy host selection bug fix

Bug:

`tokenAwareHostPolicy` always selects a primary replica of a token.
When using `tokenAwareHostPolicy` backed by `dcAwareRR`, for tokens
with primary replica in a different DC, this results in selecting
a RANDOM host in a given DC.

This is caused by two issues in `tokenAwareHostPolicy::getReplicas` func.

* `tokenAwareHostPolicy::meta` is always nil unless a keyspace changed
  event it received
* even when metadata is available it uses the calculated token value of
  a partition key to find replicas, however, the replicas are stored
  in a map where the key is the primary replica token aka the `end_token`
  (nodetool describering)

Solution:

This commit extends `tokenRing::GetHostForPartitionKey` to return
the endToken that is later used in the `Pick` function, and forces
a call to `KeyspaceChanged` when session is initialised.

Signed-off-by: Michał Matczuk <michal@scylladb.com>

* policy: tokenAwareHostPolicy update keyspace metadata on host add or remove events

Currently when host is added or removed tokenAwareHostPolicy uses old
replica information. This is fixed by extracting the `KeyspaceChanged`
function body as `updateKeyspaceMetadata` and calling after the token
ring is updated.

Signed-off-by: Michał Matczuk <michal@scylladb.com>

* policy: tokenAwareHostPolicy type assertion bug fix

Fixes inappropriate type assertion

```
--- FAIL: TestTokenAwareConnPool (0.64s)
panic: interface conversion: interface {} is nil, not *gocql.tokenRing [recovered]
	panic: interface conversion: interface {} is nil, not *gocql.tokenRing
goroutine 4058 [running]:
testing.tRunner.func1(0xc42010c0f0)
	/home/travis/.gimme/versions/go1.10.linux.amd64/src/testing/testing.go:742 +0x567
panic(0x886740, 0xc420318cc0)
	/home/travis/.gimme/versions/go1.10.linux.amd64/src/runtime/panic.go:505 +0x24a
github.com/gocql/gocql.(*tokenAwareHostPolicy).updateKeyspaceMetadata(0xc4210f1d80, 0x8dd7aa, 0xa)
	/home/travis/gopath/src/github.com/gocql/gocql/policies.go:445 +0x55c
github.com/gocql/gocql.(*tokenAwareHostPolicy).AddHost(0xc4210f1d80, 0xc420668580)
	/home/travis/gopath/src/github.com/gocql/gocql/policies.go:478 +0xcc
github.com/gocql/gocql.(*Session).addNewNode(0xc420f40380, 0xc420668580)
	/home/travis/gopath/src/github.com/gocql/gocql/events.go:188 +0xee
github.com/gocql/gocql.(*Session).handleNodeUp(0xc420f40380, 0xc420f00630, 0x10, 0x10, 0x2352, 0x0)
	/home/travis/gopath/src/github.com/gocql/gocql/events.go:273 +0x2e9
github.com/gocql/gocql.(*controlConn).setupConn(0xc42113c080, 0xc420782400, 0x3, 0x4)
	/home/travis/gopath/src/github.com/gocql/gocql/control.go:289 +0x1fc
github.com/gocql/gocql.(*controlConn).connect(0xc42113c080, 0xc42113e820, 0x3, 0x4, 0xc42113e820, 0x3)
	/home/travis/gopath/src/github.com/gocql/gocql/control.go:252 +0x218
github.com/gocql/gocql.(*Session).init(0xc420f40380, 0xc420c338f0, 0x0)
	/home/travis/gopath/src/github.com/gocql/gocql/session.go:197 +0x80e
github.com/gocql/gocql.NewSession(0xc420094960, 0x3, 0x3, 0x8dc0cd, 0x5, 0x4, 0xdf8475800, 0x23c34600, 0x2352, 0x8dd7aa, ...)
	/home/travis/gopath/src/github.com/gocql/gocql/session.go:160 +0x1138
github.com/gocql/gocql.(*ClusterConfig).CreateSession(0xc421122000, 0xc420566de0, 0x0, 0xc420046800)
	/home/travis/gopath/src/github.com/gocql/gocql/cluster.go:185 +0x8d
github.com/gocql/gocql.createSessionFromCluster(0xc421122000, 0x9421a0, 0xc42010c0f0, 0x0)
	/home/travis/gopath/src/github.com/gocql/gocql/common_test.go:136 +0x116
github.com/gocql/gocql.TestTokenAwareConnPool(0xc42010c0f0)
	/home/travis/gopath/src/github.com/gocql/gocql/cassandra_test.go:2520 +0x195
testing.tRunner(0xc42010c0f0, 0x8f98a0)
	/home/travis/.gimme/versions/go1.10.linux.amd64/src/testing/testing.go:777 +0x16e
created by testing.(*T).Run
	/home/travis/.gimme/versions/go1.10.linux.amd64/src/testing/testing.go:824 +0x565
exit status 2
```

Signed-off-by: Michał Matczuk <michal@scylladb.com>
Michal Matczuk 6 年之前
父节点
当前提交
ec4793573d
共有 5 个文件被更改,包括 68 次插入53 次删除
  1. 29 22
      policies.go
  2. 6 0
      session.go
  3. 7 7
      token.go
  4. 24 24
      token_test.go
  5. 2 0
      topology.go

+ 29 - 22
policies.go

@@ -424,6 +424,10 @@ func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
 }
 
 func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) {
+	t.updateKeyspaceMetadata(update.Keyspace)
+}
+
+func (t *tokenAwareHostPolicy) updateKeyspaceMetadata(keyspace string) {
 	meta, _ := t.keyspaces.Load().(*keyspaceMeta)
 	var size = 1
 	if meta != nil {
@@ -434,18 +438,20 @@ func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) {
 		replicas: make(map[string]map[token][]*HostInfo, size),
 	}
 
-	ks, err := t.session.KeyspaceMetadata(update.Keyspace)
+	ks, err := t.session.KeyspaceMetadata(keyspace)
 	if err == nil {
 		strat := getStrategy(ks)
-		tr := t.tokenRing.Load().(*tokenRing)
-		if tr != nil {
-			newMeta.replicas[update.Keyspace] = strat.replicaMap(t.hosts.get(), tr.tokens)
+		if strat != nil {
+			tr, _ := t.tokenRing.Load().(*tokenRing)
+			if tr != nil {
+				newMeta.replicas[keyspace] = strat.replicaMap(t.hosts.get(), tr.tokens)
+			}
 		}
 	}
 
 	if meta != nil {
 		for ks, replicas := range meta.replicas {
-			if ks != update.Keyspace {
+			if ks != keyspace {
 				newMeta.replicas[ks] = replicas
 			}
 		}
@@ -467,6 +473,20 @@ func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
 }
 
 func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
+	t.HostUp(host)
+	if t.session != nil { // disable for unit tests
+		t.updateKeyspaceMetadata(t.session.cfg.Keyspace)
+	}
+}
+
+func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
+	t.HostDown(host)
+	if t.session != nil { // disable for unit tests
+		t.updateKeyspaceMetadata(t.session.cfg.Keyspace)
+	}
+}
+
+func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
 	t.hosts.add(host)
 	t.fallback.AddHost(host)
 
@@ -476,7 +496,7 @@ func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
 	t.resetTokenRing(partitioner)
 }
 
-func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
+func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
 	t.hosts.remove(host.ConnectAddress())
 	t.fallback.RemoveHost(host)
 
@@ -486,17 +506,6 @@ func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
 	t.resetTokenRing(partitioner)
 }
 
-func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
-	// TODO: need to avoid doing all the work on AddHost on hostup/down
-	// because it now expensive to calculate the replica map for each
-	// token
-	t.AddHost(host)
-}
-
-func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
-	t.RemoveHost(host)
-}
-
 func (t *tokenAwareHostPolicy) resetTokenRing(partitioner string) {
 	if partitioner == "" {
 		// partitioner not yet set
@@ -520,8 +529,8 @@ func (t *tokenAwareHostPolicy) getReplicas(keyspace string, token token) ([]*Hos
 	if meta == nil {
 		return nil, false
 	}
-	tokens, ok := meta.replicas[keyspace][token]
-	return tokens, ok
+	replicas, ok := meta.replicas[keyspace][token]
+	return replicas, ok
 }
 
 func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
@@ -541,9 +550,7 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
 		return t.fallback.Pick(qry)
 	}
 
-	token := tr.partitioner.Hash(routingKey)
-	primaryEndpoint := tr.GetHostForToken(token)
-
+	primaryEndpoint, token := tr.GetHostForPartitionKey(routingKey)
 	if primaryEndpoint == nil || token == nil {
 		return t.fallback.Pick(qry)
 	}

+ 6 - 0
session.go

@@ -250,6 +250,12 @@ func (s *Session) init() error {
 		return ErrNoConnectionsStarted
 	}
 
+	// Invoke KeyspaceChanged to let the policy cache the session keyspace
+	// parameters. This is used by tokenAwareHostPolicy to discover replicas.
+	if !s.cfg.disableControlConn && s.cfg.Keyspace != "" {
+		s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: s.cfg.Keyspace})
+	}
+
 	return nil
 }
 

+ 7 - 7
token.go

@@ -192,18 +192,17 @@ func (t *tokenRing) String() string {
 	return string(buf.Bytes())
 }
 
-func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) *HostInfo {
+func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) (host *HostInfo, endToken token) {
 	if t == nil {
-		return nil
+		return nil, nil
 	}
 
-	token := t.partitioner.Hash(partitionKey)
-	return t.GetHostForToken(token)
+	return t.GetHostForToken(t.partitioner.Hash(partitionKey))
 }
 
-func (t *tokenRing) GetHostForToken(token token) *HostInfo {
+func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token) {
 	if t == nil || len(t.tokens) == 0 {
-		return nil
+		return nil, nil
 	}
 
 	// find the primary replica
@@ -216,5 +215,6 @@ func (t *tokenRing) GetHostForToken(token token) *HostInfo {
 		ringIndex = 0
 	}
 
-	return t.tokens[ringIndex].host
+	v := t.tokens[ringIndex]
+	return v.host, v.token
 }

+ 24 - 24
token_test.go

@@ -156,43 +156,43 @@ func TestTokenRing_Int(t *testing.T) {
 
 	sort.Sort(ring)
 
-	if ring.GetHostForToken(intToken(0)) != host0 {
+	if host, endToken := ring.GetHostForToken(intToken(0)); host != host0 || endToken != intToken(0) {
 		t.Error("Expected host 0 for token 0")
 	}
-	if ring.GetHostForToken(intToken(1)) != host25 {
+	if host, endToken := ring.GetHostForToken(intToken(1)); host != host25 || endToken != intToken(25) {
 		t.Error("Expected host 25 for token 1")
 	}
-	if ring.GetHostForToken(intToken(24)) != host25 {
+	if host, endToken := ring.GetHostForToken(intToken(24)); host != host25 || endToken != intToken(25) {
 		t.Error("Expected host 25 for token 24")
 	}
-	if ring.GetHostForToken(intToken(25)) != host25 {
+	if host, endToken := ring.GetHostForToken(intToken(25)); host != host25 || endToken != intToken(25) {
 		t.Error("Expected host 25 for token 25")
 	}
-	if ring.GetHostForToken(intToken(26)) != host50 {
+	if host, endToken := ring.GetHostForToken(intToken(26)); host != host50 || endToken != intToken(50) {
 		t.Error("Expected host 50 for token 26")
 	}
-	if ring.GetHostForToken(intToken(49)) != host50 {
+	if host, endToken := ring.GetHostForToken(intToken(49)); host != host50 || endToken != intToken(50) {
 		t.Error("Expected host 50 for token 49")
 	}
-	if ring.GetHostForToken(intToken(50)) != host50 {
+	if host, endToken := ring.GetHostForToken(intToken(50)); host != host50 || endToken != intToken(50) {
 		t.Error("Expected host 50 for token 50")
 	}
-	if ring.GetHostForToken(intToken(51)) != host75 {
+	if host, endToken := ring.GetHostForToken(intToken(51)); host != host75 || endToken != intToken(75) {
 		t.Error("Expected host 75 for token 51")
 	}
-	if ring.GetHostForToken(intToken(74)) != host75 {
+	if host, endToken := ring.GetHostForToken(intToken(74)); host != host75 || endToken != intToken(75) {
 		t.Error("Expected host 75 for token 74")
 	}
-	if ring.GetHostForToken(intToken(75)) != host75 {
+	if host, endToken := ring.GetHostForToken(intToken(75)); host != host75 || endToken != intToken(75) {
 		t.Error("Expected host 75 for token 75")
 	}
-	if ring.GetHostForToken(intToken(76)) != host0 {
+	if host, endToken := ring.GetHostForToken(intToken(76)); host != host0 || endToken != intToken(0) {
 		t.Error("Expected host 0 for token 76")
 	}
-	if ring.GetHostForToken(intToken(99)) != host0 {
+	if host, endToken := ring.GetHostForToken(intToken(99)); host != host0 || endToken != intToken(0) {
 		t.Error("Expected host 0 for token 99")
 	}
-	if ring.GetHostForToken(intToken(100)) != host0 {
+	if host, endToken := ring.GetHostForToken(intToken(100)); host != host0 || endToken != intToken(0) {
 		t.Error("Expected host 0 for token 100")
 	}
 }
@@ -201,10 +201,10 @@ func TestTokenRing_Int(t *testing.T) {
 func TestTokenRing_Nil(t *testing.T) {
 	var ring *tokenRing = nil
 
-	if ring.GetHostForToken(nil) != nil {
+	if host, endToken := ring.GetHostForToken(nil); host != nil || endToken != nil {
 		t.Error("Expected nil for nil token ring")
 	}
-	if ring.GetHostForPartitionKey(nil) != nil {
+	if host, endToken := ring.GetHostForPartitionKey(nil); host != nil || endToken != nil {
 		t.Error("Expected nil for nil token ring")
 	}
 }
@@ -242,19 +242,19 @@ func TestTokenRing_Murmur3(t *testing.T) {
 	p := murmur3Partitioner{}
 
 	for _, host := range hosts {
-		actual := ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0]))
 		if !actual.ConnectAddress().Equal(host.ConnectAddress()) {
 			t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(),
 				host.tokens[0], actual.ConnectAddress())
 		}
 	}
 
-	actual := ring.GetHostForToken(p.ParseString("12"))
+	actual, _ := ring.GetHostForToken(p.ParseString("12"))
 	if !actual.ConnectAddress().Equal(hosts[1].ConnectAddress()) {
 		t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress())
 	}
 
-	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
+	actual, _ = ring.GetHostForToken(p.ParseString("24324545443332"))
 	if !actual.ConnectAddress().Equal(hosts[0].ConnectAddress()) {
 		t.Errorf("Expected address 0 for token \"24324545443332\", but was %s", actual.ConnectAddress())
 	}
@@ -274,19 +274,19 @@ func TestTokenRing_Ordered(t *testing.T) {
 
 	var actual *HostInfo
 	for _, host := range hosts {
-		actual = ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0]))
 		if !actual.ConnectAddress().Equal(host.ConnectAddress()) {
 			t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(),
 				host.tokens[0], actual.ConnectAddress())
 		}
 	}
 
-	actual = ring.GetHostForToken(p.ParseString("12"))
+	actual, _ = ring.GetHostForToken(p.ParseString("12"))
 	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress())
 	}
 
-	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
+	actual, _ = ring.GetHostForToken(p.ParseString("24324545443332"))
 	if !actual.ConnectAddress().Equal(hosts[1].ConnectAddress()) {
 		t.Errorf("Expected address 1 for token \"24324545443332\", but was %s", actual.ConnectAddress())
 	}
@@ -305,19 +305,19 @@ func TestTokenRing_Random(t *testing.T) {
 
 	var actual *HostInfo
 	for _, host := range hosts {
-		actual = ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0]))
 		if !actual.ConnectAddress().Equal(host.ConnectAddress()) {
 			t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(),
 				host.tokens[0], actual.ConnectAddress())
 		}
 	}
 
-	actual = ring.GetHostForToken(p.ParseString("12"))
+	actual, _ = ring.GetHostForToken(p.ParseString("12"))
 	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress())
 	}
 
-	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
+	actual, _ = ring.GetHostForToken(p.ParseString("24324545443332"))
 	if !actual.ConnectAddress().Equal(hosts[0].ConnectAddress()) {
 		t.Errorf("Expected address 1 for token \"24324545443332\", but was %s", actual.ConnectAddress())
 	}

+ 2 - 0
topology.go

@@ -47,6 +47,8 @@ func getStrategy(ks *KeyspaceMetadata) placementStrategy {
 			dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf)
 		}
 		return &networkTopology{dcs: dcs}
+	case strings.Contains(ks.StrategyClass, "LocalStrategy"):
+		return nil
 	default:
 		// TODO: handle unknown replicas and just return the primary host for a token
 		panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass))