Browse Source

Merge pull request #527 from Zariel/hostinfo-nil

policies: ensure NextHost is nil
Chris Bannister 10 years ago
parent
commit
f8da9dfadd
2 changed files with 31 additions and 8 deletions
  1. 6 8
      policies.go
  2. 25 0
      policies_test.go

+ 6 - 8
policies.go

@@ -94,7 +94,7 @@ func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {
 func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 	// i is used to limit the number of attempts to find a host
 	// to the number of hosts known to this policy
-	var i uint32
+	var i int
 	return func() SelectedHost {
 		r.mu.RLock()
 		defer r.mu.RUnlock()
@@ -102,14 +102,14 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 			return nil
 		}
 
-		var host *HostInfo
 		// always increment pos to evenly distribute traffic in case of
 		// failures
 		pos := atomic.AddUint32(&r.pos, 1)
-		if int(i) < len(r.hosts) {
-			host = &r.hosts[(pos)%uint32(len(r.hosts))]
-			i++
+		if i >= len(r.hosts) {
+			return nil
 		}
+		host := &r.hosts[(pos)%uint32(len(r.hosts))]
+		i++
 		return selectedRoundRobinHost{host}
 	}
 }
@@ -201,11 +201,9 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
 		return t.fallback.Pick(qry)
 	}
 
-	var host *HostInfo
-
 	t.mu.RLock()
 	// TODO retrieve a list of hosts based on the replication strategy
-	host = t.tokenRing.GetHostForPartitionKey(routingKey)
+	host := t.tokenRing.GetHostForPartitionKey(routingKey)
 	t.mu.RUnlock()
 
 	if host == nil {

+ 25 - 0
policies_test.go

@@ -167,3 +167,28 @@ func TestRoundRobinConnPolicy(t *testing.T) {
 		t.Error("Expected conn1")
 	}
 }
+
+func TestRoundRobinNilHostInfo(t *testing.T) {
+	policy := RoundRobinHostPolicy()
+
+	host := HostInfo{HostId: "host-1"}
+	policy.SetHosts([]HostInfo{host})
+
+	iter := policy.Pick(nil)
+	next := iter()
+	if next == nil {
+		t.Fatal("got nil host")
+	} else if v := next.Info(); v == nil {
+		t.Fatal("got nil HostInfo")
+	} else if v.HostId != host.HostId {
+		t.Fatalf("expected host %v got %v", host, *v)
+	}
+
+	next = iter()
+	if next != nil {
+		t.Errorf("expected to get nil host got %+v", next)
+		if next.Info() == nil {
+			t.Fatalf("HostInfo is nil")
+		}
+	}
+}