Browse Source

change HostInfo.Peer to be an IP

When Cassandra returns us a hosts info the peer address is defined as an
inet both at protocol for events and the schema for the peer
information.

Previously we stored this as a string, and used it to
connect to hosts and also to index hosts by. This is different to what
use for user supplied address endpoints, we keep the potentially DNS
name as the peer address. This means that we can end up having duplicate
host pools, duplicate host info in the ring.

Fix this by making everything rely on a hosts address being an IP
address instead of either a DNS name or an IP.
Chris Bannister 9 years ago
parent
commit
e0a2f2ca85
19 changed files with 367 additions and 332 deletions
  1. 21 8
      cassandra_test.go
  2. 15 1
      cluster.go
  3. 12 2
      conn.go
  4. 10 4
      conn_test.go
  5. 19 15
      connectionpool.go
  6. 50 22
      control.go
  7. 31 0
      control_test.go
  8. 21 27
      events.go
  9. 12 4
      filters.go
  10. 17 14
      filters_test.go
  11. 9 14
      host_source.go
  12. 33 26
      policies.go
  13. 24 23
      policies_test.go
  14. 1 1
      query_executor.go
  15. 28 11
      ring.go
  16. 7 4
      ring_test.go
  17. 15 22
      session.go
  18. 1 1
      token.go
  19. 41 133
      token_test.go

+ 21 - 8
cassandra_test.go

@@ -62,11 +62,15 @@ func TestRingDiscovery(t *testing.T) {
 	}
 	}
 
 
 	session.pool.mu.RLock()
 	session.pool.mu.RLock()
+	defer session.pool.mu.RUnlock()
 	size := len(session.pool.hostConnPools)
 	size := len(session.pool.hostConnPools)
-	session.pool.mu.RUnlock()
 
 
 	if *clusterSize != size {
 	if *clusterSize != size {
-		t.Fatalf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
+		for p, pool := range session.pool.hostConnPools {
+			t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.Peer().String())
+
+		}
+		t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
 	}
 	}
 }
 }
 
 
@@ -573,7 +577,7 @@ func TestReconnection(t *testing.T) {
 	defer session.Close()
 	defer session.Close()
 
 
 	h := session.ring.allHosts()[0]
 	h := session.ring.allHosts()[0]
-	session.handleNodeDown(net.ParseIP(h.Peer()), h.Port())
+	session.handleNodeDown(h.Peer(), h.Port())
 
 
 	if h.State() != NodeDown {
 	if h.State() != NodeDown {
 		t.Fatal("Host should be NodeDown but not.")
 		t.Fatal("Host should be NodeDown but not.")
@@ -2477,17 +2481,26 @@ func TestSchemaReset(t *testing.T) {
 }
 }
 
 
 func TestCreateSession_DontSwallowError(t *testing.T) {
 func TestCreateSession_DontSwallowError(t *testing.T) {
+	t.Skip("This test is bad, and the resultant error from cassandra changes between versions")
 	cluster := createCluster()
 	cluster := createCluster()
-	cluster.ProtoVersion = 100
+	cluster.ProtoVersion = 0x100
 	session, err := cluster.CreateSession()
 	session, err := cluster.CreateSession()
 	if err == nil {
 	if err == nil {
 		session.Close()
 		session.Close()
 
 
 		t.Fatal("expected to get an error for unsupported protocol")
 		t.Fatal("expected to get an error for unsupported protocol")
 	}
 	}
-	// TODO: we should get a distinct error type here which include the underlying
-	// cassandra error about the protocol version, for now check this here.
-	if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
-		t.Fatalf(`expcted to get error "unsupported protocol version" got: %q`, err)
+
+	if flagCassVersion.Major < 3 {
+		// TODO: we should get a distinct error type here which include the underlying
+		// cassandra error about the protocol version, for now check this here.
+		if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
+			t.Fatalf(`expcted to get error "unsupported protocol version" got: %q`, err)
+		}
+	} else {
+		if !strings.Contains(err.Error(), "unsupported response version") {
+			t.Fatalf(`expcted to get error "unsupported response version" got: %q`, err)
+		}
 	}
 	}
+
 }
 }

+ 15 - 1
cluster.go

@@ -27,7 +27,13 @@ func (p PoolConfig) buildPool(session *Session) *policyConnPool {
 // behavior to fit the most common use cases. Applications that require a
 // behavior to fit the most common use cases. Applications that require a
 // different setup must implement their own cluster.
 // different setup must implement their own cluster.
 type ClusterConfig struct {
 type ClusterConfig struct {
-	Hosts             []string          // addresses for the initial connections
+	// addresses for the initial connections. It is recomended to use the value set in
+	// the Cassandra config for broadcast_address or listen_address, an IP address not
+	// a domain name. This is because events from Cassandra will use the configured IP
+	// address, which is used to index connected hosts. If the domain name specified
+	// resolves to more than 1 IP address then the driver may connect multiple times to
+	// the same host, and will not mark the node being down or up from events.
+	Hosts             []string
 	CQLVersion        string            // CQL version (default: 3.0.0)
 	CQLVersion        string            // CQL version (default: 3.0.0)
 	ProtoVersion      int               // version of the native protocol (default: 2)
 	ProtoVersion      int               // version of the native protocol (default: 2)
 	Timeout           time.Duration     // connection timeout (default: 600ms)
 	Timeout           time.Duration     // connection timeout (default: 600ms)
@@ -100,6 +106,14 @@ type ClusterConfig struct {
 }
 }
 
 
 // NewCluster generates a new config for the default cluster implementation.
 // NewCluster generates a new config for the default cluster implementation.
+//
+// The supplied hosts are used to initially connect to the cluster then the rest of
+// the ring will be automatically discovered. It is recomended to use the value set in
+// the Cassandra config for broadcast_address or listen_address, an IP address not
+// a domain name. This is because events from Cassandra will use the configured IP
+// address, which is used to index connected hosts. If the domain name specified
+// resolves to more than 1 IP address then the driver may connect multiple times to
+// the same host, and will not mark the node being down or up from events.
 func NewCluster(hosts ...string) *ClusterConfig {
 func NewCluster(hosts ...string) *ClusterConfig {
 	cfg := &ClusterConfig{
 	cfg := &ClusterConfig{
 		Hosts:                  hosts,
 		Hosts:                  hosts,

+ 12 - 2
conn.go

@@ -152,8 +152,15 @@ type Conn struct {
 }
 }
 
 
 // Connect establishes a connection to a Cassandra node.
 // Connect establishes a connection to a Cassandra node.
-func Connect(host *HostInfo, addr string, cfg *ConnConfig,
-	errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+	// TODO(zariel): remove these
+	if host == nil {
+		panic("host is nil")
+	} else if len(host.Peer()) == 0 {
+		panic("host missing peer ip address")
+	} else if host.Port() == 0 {
+		panic("host missing port")
+	}
 
 
 	var (
 	var (
 		err  error
 		err  error
@@ -164,6 +171,9 @@ func Connect(host *HostInfo, addr string, cfg *ConnConfig,
 		Timeout: cfg.Timeout,
 		Timeout: cfg.Timeout,
 	}
 	}
 
 
+	// TODO(zariel): handle ipv6 zone
+	addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
+
 	if cfg.tlsConfig != nil {
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not
 		// the TLS config is safe to be reused by connections but it must not
 		// be modified after being used.
 		// be modified after being used.

+ 10 - 4
conn_test.go

@@ -473,8 +473,7 @@ func TestStream0(t *testing.T) {
 		}
 		}
 	})
 	})
 
 
-	host := &HostInfo{peer: srv.Address}
-	conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -509,8 +508,7 @@ func TestConnClosedBlocked(t *testing.T) {
 		t.Log(err)
 		t.Log(err)
 	})
 	})
 
 
-	host := &HostInfo{peer: srv.Address}
-	conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -637,6 +635,14 @@ type TestServer struct {
 	closed bool
 	closed bool
 }
 }
 
 
+func (srv *TestServer) host() *HostInfo {
+	host, err := hostInfo(srv.Address, 9042)
+	if err != nil {
+		srv.t.Fatal(err)
+	}
+	return host
+}
+
 func (srv *TestServer) closeWatch() {
 func (srv *TestServer) closeWatch() {
 	<-srv.ctx.Done()
 	<-srv.ctx.Done()
 
 

+ 19 - 15
connectionpool.go

@@ -130,9 +130,10 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 			// don't create a connection pool for a down host
 			// don't create a connection pool for a down host
 			continue
 			continue
 		}
 		}
-		if _, exists := p.hostConnPools[host.Peer()]; exists {
+		ip := host.Peer().String()
+		if _, exists := p.hostConnPools[ip]; exists {
 			// still have this host, so don't remove it
 			// still have this host, so don't remove it
-			delete(toRemove, host.Peer())
+			delete(toRemove, ip)
 			continue
 			continue
 		}
 		}
 
 
@@ -155,7 +156,7 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 		createCount--
 		createCount--
 		if pool.Size() > 0 {
 		if pool.Size() > 0 {
 			// add pool onyl if there a connections available
 			// add pool onyl if there a connections available
-			p.hostConnPools[pool.host.Peer()] = pool
+			p.hostConnPools[string(pool.host.Peer())] = pool
 		}
 		}
 	}
 	}
 
 
@@ -177,9 +178,10 @@ func (p *policyConnPool) Size() int {
 	return count
 	return count
 }
 }
 
 
-func (p *policyConnPool) getPool(addr string) (pool *hostConnPool, ok bool) {
+func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
+	ip := host.Peer().String()
 	p.mu.RLock()
 	p.mu.RLock()
-	pool, ok = p.hostConnPools[addr]
+	pool, ok = p.hostConnPools[ip]
 	p.mu.RUnlock()
 	p.mu.RUnlock()
 	return
 	return
 }
 }
@@ -196,8 +198,9 @@ func (p *policyConnPool) Close() {
 }
 }
 
 
 func (p *policyConnPool) addHost(host *HostInfo) {
 func (p *policyConnPool) addHost(host *HostInfo) {
+	ip := host.Peer().String()
 	p.mu.Lock()
 	p.mu.Lock()
-	pool, ok := p.hostConnPools[host.Peer()]
+	pool, ok := p.hostConnPools[ip]
 	if !ok {
 	if !ok {
 		pool = newHostConnPool(
 		pool = newHostConnPool(
 			p.session,
 			p.session,
@@ -207,22 +210,23 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 			p.keyspace,
 			p.keyspace,
 		)
 		)
 
 
-		p.hostConnPools[host.Peer()] = pool
+		p.hostConnPools[ip] = pool
 	}
 	}
 	p.mu.Unlock()
 	p.mu.Unlock()
 
 
 	pool.fill()
 	pool.fill()
 }
 }
 
 
-func (p *policyConnPool) removeHost(addr string) {
+func (p *policyConnPool) removeHost(ip net.IP) {
+	k := ip.String()
 	p.mu.Lock()
 	p.mu.Lock()
-	pool, ok := p.hostConnPools[addr]
+	pool, ok := p.hostConnPools[k]
 	if !ok {
 	if !ok {
 		p.mu.Unlock()
 		p.mu.Unlock()
 		return
 		return
 	}
 	}
 
 
-	delete(p.hostConnPools, addr)
+	delete(p.hostConnPools, k)
 	p.mu.Unlock()
 	p.mu.Unlock()
 
 
 	go pool.Close()
 	go pool.Close()
@@ -234,10 +238,10 @@ func (p *policyConnPool) hostUp(host *HostInfo) {
 	p.addHost(host)
 	p.addHost(host)
 }
 }
 
 
-func (p *policyConnPool) hostDown(addr string) {
+func (p *policyConnPool) hostDown(ip net.IP) {
 	// TODO(zariel): mark host as down so we can try to connect to it later, for
 	// TODO(zariel): mark host as down so we can try to connect to it later, for
 	// now just treat it has removed.
 	// now just treat it has removed.
-	p.removeHost(addr)
+	p.removeHost(ip)
 }
 }
 
 
 // hostConnPool is a connection pool for a single host.
 // hostConnPool is a connection pool for a single host.
@@ -272,7 +276,7 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int,
 		session:  session,
 		session:  session,
 		host:     host,
 		host:     host,
 		port:     port,
 		port:     port,
-		addr:     JoinHostPort(host.Peer(), port),
+		addr:     (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String(),
 		size:     size,
 		size:     size,
 		keyspace: keyspace,
 		keyspace: keyspace,
 		conns:    make([]*Conn, 0, size),
 		conns:    make([]*Conn, 0, size),
@@ -396,7 +400,7 @@ func (pool *hostConnPool) fill() {
 
 
 			// this is calle with the connetion pool mutex held, this call will
 			// this is calle with the connetion pool mutex held, this call will
 			// then recursivly try to lock it again. FIXME
 			// then recursivly try to lock it again. FIXME
-			go pool.session.handleNodeDown(net.ParseIP(pool.host.Peer()), pool.port)
+			go pool.session.handleNodeDown(pool.host.Peer(), pool.port)
 			return
 			return
 		}
 		}
 
 
@@ -477,7 +481,7 @@ func (pool *hostConnPool) connect() (err error) {
 	// try to connect
 	// try to connect
 	var conn *Conn
 	var conn *Conn
 	for i := 0; i < maxAttempts; i++ {
 	for i := 0; i < maxAttempts; i++ {
-		conn, err = pool.session.connect(pool.addr, pool, pool.host)
+		conn, err = pool.session.connect(pool.host, pool)
 		if err == nil {
 		if err == nil {
 			break
 			break
 		}
 		}

+ 50 - 22
control.go

@@ -4,13 +4,14 @@ import (
 	crand "crypto/rand"
 	crand "crypto/rand"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"golang.org/x/net/context"
 	"log"
 	"log"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
+
+	"golang.org/x/net/context"
 )
 )
 
 
 var (
 var (
@@ -89,6 +90,8 @@ func (c *controlConn) heartBeat() {
 	}
 	}
 }
 }
 
 
+var hostLookupPreferV4 = false
+
 func hostInfo(addr string, defaultPort int) (*HostInfo, error) {
 func hostInfo(addr string, defaultPort int) (*HostInfo, error) {
 	var port int
 	var port int
 	host, portStr, err := net.SplitHostPort(addr)
 	host, portStr, err := net.SplitHostPort(addr)
@@ -102,10 +105,37 @@ func hostInfo(addr string, defaultPort int) (*HostInfo, error) {
 		}
 		}
 	}
 	}
 
 
-	return &HostInfo{peer: host, port: port}, nil
+	ip := net.ParseIP(host)
+	if ip == nil {
+		ips, err := net.LookupIP(host)
+		if err != nil {
+			return nil, err
+		} else if len(ips) == 0 {
+			return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
+		}
+
+		if hostLookupPreferV4 {
+			for _, v := range ips {
+				if v4 := v.To4(); v4 != nil {
+					ip = v4
+					break
+				}
+			}
+			if ip == nil {
+				ip = ips[0]
+			}
+		} else {
+			// TODO(zariel): should we check that we can connect to any of the ips?
+			ip = ips[0]
+		}
+
+	}
+
+	return &HostInfo{peer: ip, port: port}, nil
 }
 }
 
 
 func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
 func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
+	// TODO: accept a []*HostInfo
 	perm := randr.Perm(len(endpoints))
 	perm := randr.Perm(len(endpoints))
 	shuffled := make([]string, len(endpoints))
 	shuffled := make([]string, len(endpoints))
 
 
@@ -130,7 +160,7 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
 		}
 		}
 
 
 		hostInfo, _ := c.session.ring.addHostIfMissing(host)
 		hostInfo, _ := c.session.ring.addHostIfMissing(host)
-		conn, err = c.session.connect(addr, c, hostInfo)
+		conn, err = c.session.connect(hostInfo, c)
 		if err == nil {
 		if err == nil {
 			return conn, err
 			return conn, err
 		}
 		}
@@ -229,22 +259,21 @@ func (c *controlConn) reconnect(refreshring bool) {
 	// TODO: simplify this function, use session.ring to get hosts instead of the
 	// TODO: simplify this function, use session.ring to get hosts instead of the
 	// connection pool
 	// connection pool
 
 
-	addr := c.addr()
+	var host *HostInfo
 	oldConn := c.conn.Load().(*Conn)
 	oldConn := c.conn.Load().(*Conn)
 	if oldConn != nil {
 	if oldConn != nil {
+		host = oldConn.host
 		oldConn.Close()
 		oldConn.Close()
 	}
 	}
 
 
 	var newConn *Conn
 	var newConn *Conn
-	if addr != "" {
+	if host != nil {
 		// try to connect to the old host
 		// try to connect to the old host
-		conn, err := c.session.connect(addr, c, oldConn.host)
+		conn, err := c.session.connect(host, c)
 		if err != nil {
 		if err != nil {
 			// host is dead
 			// host is dead
 			// TODO: this is replicated in a few places
 			// TODO: this is replicated in a few places
-			ip, portStr, _ := net.SplitHostPort(addr)
-			port, _ := strconv.Atoi(portStr)
-			c.session.handleNodeDown(net.ParseIP(ip), port)
+			c.session.handleNodeDown(host.Peer(), host.Port())
 		} else {
 		} else {
 			newConn = conn
 			newConn = conn
 		}
 		}
@@ -260,7 +289,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 		}
 		}
 
 
 		var err error
 		var err error
-		newConn, err = c.session.connect(host.Peer(), c, host)
+		newConn, err = c.session.connect(host, c)
 		if err != nil {
 		if err != nil {
 			// TODO: add log handler for things like this
 			// TODO: add log handler for things like this
 			return
 			return
@@ -350,29 +379,28 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 	return
 	return
 }
 }
 
 
-func (c *controlConn) fetchHostInfo(addr net.IP, port int) (*HostInfo, error) {
+func (c *controlConn) fetchHostInfo(ip net.IP, port int) (*HostInfo, error) {
 	// TODO(zariel): we should probably move this into host_source or atleast
 	// TODO(zariel): we should probably move this into host_source or atleast
 	// share code with it.
 	// share code with it.
-	hostname, _, err := net.SplitHostPort(c.addr())
-	if err != nil {
-		return nil, fmt.Errorf("unable to fetch host info, invalid conn addr: %q: %v", c.addr(), err)
+	localHost := c.host()
+	if localHost == nil {
+		return nil, errors.New("unable to fetch host info, invalid conn host")
 	}
 	}
 
 
-	isLocal := hostname == addr.String()
+	isLocal := localHost.Peer().Equal(ip)
 
 
 	var fn func(*HostInfo) error
 	var fn func(*HostInfo) error
 
 
+	// TODO(zariel): fetch preferred_ip address (is it >3.x only?)
 	if isLocal {
 	if isLocal {
 		fn = func(host *HostInfo) error {
 		fn = func(host *HostInfo) error {
-			// TODO(zariel): should we fetch rpc_address from here?
 			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.local WHERE key='local'")
 			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.local WHERE key='local'")
 			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
 			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
 			return iter.Close()
 			return iter.Close()
 		}
 		}
 	} else {
 	} else {
 		fn = func(host *HostInfo) error {
 		fn = func(host *HostInfo) error {
-			// TODO(zariel): should we fetch rpc_address from here?
-			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.peers WHERE peer=?", addr)
+			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.peers WHERE peer=?", ip)
 			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
 			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
 			return iter.Close()
 			return iter.Close()
 		}
 		}
@@ -380,12 +408,12 @@ func (c *controlConn) fetchHostInfo(addr net.IP, port int) (*HostInfo, error) {
 
 
 	host := &HostInfo{
 	host := &HostInfo{
 		port: port,
 		port: port,
+		peer: ip,
 	}
 	}
 
 
 	if err := fn(host); err != nil {
 	if err := fn(host); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	host.peer = addr.String()
 
 
 	return host, nil
 	return host, nil
 }
 }
@@ -396,12 +424,12 @@ func (c *controlConn) awaitSchemaAgreement() error {
 	}).err
 	}).err
 }
 }
 
 
-func (c *controlConn) addr() string {
+func (c *controlConn) host() *HostInfo {
 	conn := c.conn.Load().(*Conn)
 	conn := c.conn.Load().(*Conn)
 	if conn == nil {
 	if conn == nil {
-		return ""
+		return nil
 	}
 	}
-	return conn.addr
+	return conn.host
 }
 }
 
 
 func (c *controlConn) close() {
 func (c *controlConn) close() {

+ 31 - 0
control_test.go

@@ -0,0 +1,31 @@
+package gocql
+
+import (
+	"net"
+	"testing"
+)
+
+func TestHostInfo_Lookup(t *testing.T) {
+	hostLookupPreferV4 = true
+	defer func() { hostLookupPreferV4 = false }()
+
+	tests := [...]struct {
+		addr string
+		ip   net.IP
+	}{
+		{"127.0.0.1", net.IPv4(127, 0, 0, 1)},
+		{"localhost", net.IPv4(127, 0, 0, 1)}, // TODO: this may be host dependant
+	}
+
+	for i, test := range tests {
+		host, err := hostInfo(test.addr, 1)
+		if err != nil {
+			t.Errorf("%d: %v", i, err)
+			continue
+		}
+
+		if !host.peer.Equal(test.ip) {
+			t.Errorf("expected ip %v got %v for addr %q", test.ip, host.peer, test.addr)
+		}
+	}
+}

+ 21 - 27
events.go

@@ -171,25 +171,21 @@ func (s *Session) handleNodeEvent(frames []frame) {
 	}
 	}
 }
 }
 
 
-func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
-	// TODO(zariel): need to be able to filter discovered nodes
-
+func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	var hostInfo *HostInfo
 	var hostInfo *HostInfo
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
 		var err error
 		var err error
-		hostInfo, err = s.control.fetchHostInfo(host, port)
+		hostInfo, err = s.control.fetchHostInfo(ip, port)
 		if err != nil {
 		if err != nil {
-			log.Printf("gocql: events: unable to fetch host info for %v: %v\n", host, err)
+			log.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
 			return
 			return
 		}
 		}
-
 	} else {
 	} else {
-		hostInfo = &HostInfo{peer: host.String(), port: port, state: NodeUp}
+		hostInfo = &HostInfo{peer: ip, port: port}
 	}
 	}
 
 
-	addr := host.String()
-	if s.cfg.IgnorePeerAddr && hostInfo.Peer() != addr {
-		hostInfo.setPeer(addr)
+	if s.cfg.IgnorePeerAddr && hostInfo.Peer().Equal(ip) {
+		hostInfo.setPeer(ip)
 	}
 	}
 
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(hostInfo) {
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(hostInfo) {
@@ -217,11 +213,9 @@ func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
 
 
 func (s *Session) handleRemovedNode(ip net.IP, port int) {
 func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	// we remove all nodes but only add ones which pass the filter
 	// we remove all nodes but only add ones which pass the filter
-	addr := ip.String()
-
-	host := s.ring.getHost(addr)
+	host := s.ring.getHost(ip)
 	if host == nil {
 	if host == nil {
-		host = &HostInfo{peer: addr}
+		host = &HostInfo{peer: ip, port: port}
 	}
 	}
 
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -229,9 +223,9 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	}
 	}
 
 
 	host.setState(NodeDown)
 	host.setState(NodeDown)
-	s.policy.RemoveHost(addr)
-	s.pool.removeHost(addr)
-	s.ring.removeHost(addr)
+	s.policy.RemoveHost(host)
+	s.pool.removeHost(ip)
+	s.ring.removeHost(ip)
 
 
 	if !s.cfg.IgnorePeerAddr {
 	if !s.cfg.IgnorePeerAddr {
 		s.hostSource.refreshRing()
 		s.hostSource.refreshRing()
@@ -242,11 +236,12 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 	if gocqlDebug {
 	if gocqlDebug {
 		log.Printf("gocql: Session.handleNodeUp: %s:%d\n", ip.String(), port)
 		log.Printf("gocql: Session.handleNodeUp: %s:%d\n", ip.String(), port)
 	}
 	}
-	addr := ip.String()
-	host := s.ring.getHost(addr)
+
+	host := s.ring.getHost(ip)
 	if host != nil {
 	if host != nil {
-		if s.cfg.IgnorePeerAddr && host.Peer() != addr {
-			host.setPeer(addr)
+		if s.cfg.IgnorePeerAddr && host.Peer().Equal(ip) {
+			// TODO: how can this ever be true?
+			host.setPeer(ip)
 		}
 		}
 
 
 		if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
 		if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -257,7 +252,6 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 			time.Sleep(t)
 			time.Sleep(t)
 		}
 		}
 
 
-		host.setPort(port)
 		s.pool.hostUp(host)
 		s.pool.hostUp(host)
 		s.policy.HostUp(host)
 		s.policy.HostUp(host)
 		host.setState(NodeUp)
 		host.setState(NodeUp)
@@ -271,10 +265,10 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	if gocqlDebug {
 	if gocqlDebug {
 		log.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
 		log.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
 	}
 	}
-	addr := ip.String()
-	host := s.ring.getHost(addr)
+
+	host := s.ring.getHost(ip)
 	if host == nil {
 	if host == nil {
-		host = &HostInfo{peer: addr}
+		host = &HostInfo{peer: ip, port: port}
 	}
 	}
 
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -282,6 +276,6 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	}
 	}
 
 
 	host.setState(NodeDown)
 	host.setState(NodeDown)
-	s.policy.HostDown(addr)
-	s.pool.hostDown(addr)
+	s.policy.HostDown(host)
+	s.pool.hostDown(ip)
 }
 }

+ 12 - 4
filters.go

@@ -1,5 +1,7 @@
 package gocql
 package gocql
 
 
+import "fmt"
+
 // HostFilter interface is used when a host is discovered via server sent events.
 // HostFilter interface is used when a host is discovered via server sent events.
 type HostFilter interface {
 type HostFilter interface {
 	// Called when a new host is discovered, returning true will cause the host
 	// Called when a new host is discovered, returning true will cause the host
@@ -38,12 +40,18 @@ func DataCentreHostFilter(dataCentre string) HostFilter {
 // WhiteListHostFilter filters incoming hosts by checking that their address is
 // WhiteListHostFilter filters incoming hosts by checking that their address is
 // in the initial hosts whitelist.
 // in the initial hosts whitelist.
 func WhiteListHostFilter(hosts ...string) HostFilter {
 func WhiteListHostFilter(hosts ...string) HostFilter {
-	m := make(map[string]bool, len(hosts))
-	for _, host := range hosts {
-		m[host] = true
+	hostInfos, err := addrsToHosts(hosts, 9042)
+	if err != nil {
+		// dont want to panic here, but rather not break the API
+		panic(fmt.Errorf("unable to lookup host info from address: %v", err))
+	}
+
+	m := make(map[string]bool, len(hostInfos))
+	for _, host := range hostInfos {
+		m[string(host.peer)] = true
 	}
 	}
 
 
 	return HostFilterFunc(func(host *HostInfo) bool {
 	return HostFilterFunc(func(host *HostInfo) bool {
-		return m[host.Peer()]
+		return m[string(host.Peer())]
 	})
 	})
 }
 }

+ 17 - 14
filters_test.go

@@ -1,16 +1,19 @@
 package gocql
 package gocql
 
 
-import "testing"
+import (
+	"net"
+	"testing"
+)
 
 
 func TestFilter_WhiteList(t *testing.T) {
 func TestFilter_WhiteList(t *testing.T) {
-	f := WhiteListHostFilter("addr1", "addr2")
+	f := WhiteListHostFilter("127.0.0.1", "127.0.0.2")
 	tests := [...]struct {
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 		accept bool
 	}{
 	}{
-		{"addr1", true},
-		{"addr2", true},
-		{"addr3", false},
+		{net.ParseIP("127.0.0.1"), true},
+		{net.ParseIP("127.0.0.2"), true},
+		{net.ParseIP("127.0.0.3"), false},
 	}
 	}
 
 
 	for i, test := range tests {
 	for i, test := range tests {
@@ -27,12 +30,12 @@ func TestFilter_WhiteList(t *testing.T) {
 func TestFilter_AllowAll(t *testing.T) {
 func TestFilter_AllowAll(t *testing.T) {
 	f := AcceptAllFilter()
 	f := AcceptAllFilter()
 	tests := [...]struct {
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 		accept bool
 	}{
 	}{
-		{"addr1", true},
-		{"addr2", true},
-		{"addr3", true},
+		{net.ParseIP("127.0.0.1"), true},
+		{net.ParseIP("127.0.0.2"), true},
+		{net.ParseIP("127.0.0.3"), true},
 	}
 	}
 
 
 	for i, test := range tests {
 	for i, test := range tests {
@@ -49,12 +52,12 @@ func TestFilter_AllowAll(t *testing.T) {
 func TestFilter_DenyAll(t *testing.T) {
 func TestFilter_DenyAll(t *testing.T) {
 	f := DenyAllFilter()
 	f := DenyAllFilter()
 	tests := [...]struct {
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 		accept bool
 	}{
 	}{
-		{"addr1", false},
-		{"addr2", false},
-		{"addr3", false},
+		{net.ParseIP("127.0.0.1"), false},
+		{net.ParseIP("127.0.0.2"), false},
+		{net.ParseIP("127.0.0.3"), false},
 	}
 	}
 
 
 	for i, test := range tests {
 	for i, test := range tests {

+ 9 - 14
host_source.go

@@ -100,7 +100,7 @@ type HostInfo struct {
 	// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
 	// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
 	// that we are thread safe use a mutex to access all fields.
 	// that we are thread safe use a mutex to access all fields.
 	mu         sync.RWMutex
 	mu         sync.RWMutex
-	peer       string
+	peer       net.IP
 	port       int
 	port       int
 	dataCenter string
 	dataCenter string
 	rack       string
 	rack       string
@@ -116,16 +116,16 @@ func (h *HostInfo) Equal(host *HostInfo) bool {
 	host.mu.RLock()
 	host.mu.RLock()
 	defer host.mu.RUnlock()
 	defer host.mu.RUnlock()
 
 
-	return h.peer == host.peer && h.hostId == host.hostId
+	return h.peer.Equal(host.peer)
 }
 }
 
 
-func (h *HostInfo) Peer() string {
+func (h *HostInfo) Peer() net.IP {
 	h.mu.RLock()
 	h.mu.RLock()
 	defer h.mu.RUnlock()
 	defer h.mu.RUnlock()
 	return h.peer
 	return h.peer
 }
 }
 
 
-func (h *HostInfo) setPeer(peer string) *HostInfo {
+func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
 	h.mu.Lock()
 	h.mu.Lock()
 	defer h.mu.Unlock()
 	defer h.mu.Unlock()
 	h.peer = peer
 	h.peer = peer
@@ -314,7 +314,11 @@ func (r *ringDescriber) GetHosts() (hosts []*HostInfo, partitioner string, err e
 			return nil, "", err
 			return nil, "", err
 		}
 		}
 	} else {
 	} else {
-		iter := r.session.control.query(legacyLocalQuery)
+		iter := r.session.control.withConn(func(c *Conn) *Iter {
+			localHost = c.host
+			return c.query(legacyLocalQuery)
+		})
+
 		if iter == nil {
 		if iter == nil {
 			return r.prevHosts, r.prevPartitioner, nil
 			return r.prevHosts, r.prevPartitioner, nil
 		}
 		}
@@ -324,15 +328,6 @@ func (r *ringDescriber) GetHosts() (hosts []*HostInfo, partitioner string, err e
 		if err = iter.Close(); err != nil {
 		if err = iter.Close(); err != nil {
 			return nil, "", err
 			return nil, "", err
 		}
 		}
-
-		addr, _, err := net.SplitHostPort(r.session.control.addr())
-		if err != nil {
-			// this should not happen, ever, as this is the address that was dialed by conn, here
-			// a panic makes sense, please report a bug if it occurs.
-			panic(err)
-		}
-
-		localHost.peer = addr
 	}
 	}
 
 
 	localHost.port = r.session.cfg.Port
 	localHost.port = r.session.cfg.Port

+ 33 - 26
policies.go

@@ -7,6 +7,7 @@ package gocql
 import (
 import (
 	"fmt"
 	"fmt"
 	"log"
 	"log"
+	"net"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 
 
@@ -90,7 +91,7 @@ func (c *cowHostList) update(host *HostInfo) {
 	c.mu.Unlock()
 	c.mu.Unlock()
 }
 }
 
 
-func (c *cowHostList) remove(addr string) bool {
+func (c *cowHostList) remove(ip net.IP) bool {
 	c.mu.Lock()
 	c.mu.Lock()
 	l := c.get()
 	l := c.get()
 	size := len(l)
 	size := len(l)
@@ -102,7 +103,7 @@ func (c *cowHostList) remove(addr string) bool {
 	found := false
 	found := false
 	newL := make([]*HostInfo, 0, size)
 	newL := make([]*HostInfo, 0, size)
 	for i := 0; i < len(l); i++ {
 	for i := 0; i < len(l); i++ {
-		if l[i].Peer() != addr {
+		if !l[i].Peer().Equal(ip) {
 			newL = append(newL, l[i])
 			newL = append(newL, l[i])
 		} else {
 		} else {
 			found = true
 			found = true
@@ -161,9 +162,9 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 
 
 type HostStateNotifier interface {
 type HostStateNotifier interface {
 	AddHost(host *HostInfo)
 	AddHost(host *HostInfo)
-	RemoveHost(addr string)
+	RemoveHost(host *HostInfo)
 	HostUp(host *HostInfo)
 	HostUp(host *HostInfo)
-	HostDown(addr string)
+	HostDown(host *HostInfo)
 }
 }
 
 
 // HostSelectionPolicy is an interface for selecting
 // HostSelectionPolicy is an interface for selecting
@@ -235,16 +236,16 @@ func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
 	r.hosts.add(host)
 	r.hosts.add(host)
 }
 }
 
 
-func (r *roundRobinHostPolicy) RemoveHost(addr string) {
-	r.hosts.remove(addr)
+func (r *roundRobinHostPolicy) RemoveHost(host *HostInfo) {
+	r.hosts.remove(host.Peer())
 }
 }
 
 
 func (r *roundRobinHostPolicy) HostUp(host *HostInfo) {
 func (r *roundRobinHostPolicy) HostUp(host *HostInfo) {
 	r.AddHost(host)
 	r.AddHost(host)
 }
 }
 
 
-func (r *roundRobinHostPolicy) HostDown(addr string) {
-	r.RemoveHost(addr)
+func (r *roundRobinHostPolicy) HostDown(host *HostInfo) {
+	r.RemoveHost(host)
 }
 }
 
 
 // TokenAwareHostPolicy is a token aware host selection policy, where hosts are
 // TokenAwareHostPolicy is a token aware host selection policy, where hosts are
@@ -278,9 +279,9 @@ func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
 	t.resetTokenRing()
 	t.resetTokenRing()
 }
 }
 
 
-func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
-	t.hosts.remove(addr)
-	t.fallback.RemoveHost(addr)
+func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
+	t.hosts.remove(host.Peer())
+	t.fallback.RemoveHost(host)
 
 
 	t.resetTokenRing()
 	t.resetTokenRing()
 }
 }
@@ -289,8 +290,8 @@ func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
 	t.AddHost(host)
 	t.AddHost(host)
 }
 }
 
 
-func (t *tokenAwareHostPolicy) HostDown(addr string) {
-	t.RemoveHost(addr)
+func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
+	t.RemoveHost(host)
 }
 }
 
 
 func (t *tokenAwareHostPolicy) resetTokenRing() {
 func (t *tokenAwareHostPolicy) resetTokenRing() {
@@ -393,8 +394,9 @@ func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
 	hostMap := make(map[string]*HostInfo, len(hosts))
 	hostMap := make(map[string]*HostInfo, len(hosts))
 
 
 	for i, host := range hosts {
 	for i, host := range hosts {
-		peers[i] = host.Peer()
-		hostMap[host.Peer()] = host
+		ip := host.Peer().String()
+		peers[i] = ip
+		hostMap[ip] = host
 	}
 	}
 
 
 	r.mu.Lock()
 	r.mu.Lock()
@@ -404,15 +406,17 @@ func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
 }
 }
 
 
 func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
 func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	r.mu.Lock()
 	defer r.mu.Unlock()
 	defer r.mu.Unlock()
 
 
 	// If the host addr is present and isn't nil return
 	// If the host addr is present and isn't nil return
-	if h, ok := r.hostMap[host.Peer()]; ok && h != nil{
+	if h, ok := r.hostMap[ip]; ok && h != nil {
 		return
 		return
 	}
 	}
 	// otherwise, add the host to the map
 	// otherwise, add the host to the map
-	r.hostMap[host.Peer()] = host
+	r.hostMap[ip] = host
 	// and construct a new peer list to give to the HostPool
 	// and construct a new peer list to give to the HostPool
 	hosts := make([]string, 0, len(r.hostMap))
 	hosts := make([]string, 0, len(r.hostMap))
 	for addr := range r.hostMap {
 	for addr := range r.hostMap {
@@ -420,21 +424,22 @@ func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
 	}
 	}
 
 
 	r.hp.SetHosts(hosts)
 	r.hp.SetHosts(hosts)
-
 }
 }
 
 
-func (r *hostPoolHostPolicy) RemoveHost(addr string) {
+func (r *hostPoolHostPolicy) RemoveHost(host *HostInfo) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	r.mu.Lock()
 	defer r.mu.Unlock()
 	defer r.mu.Unlock()
 
 
-	if _, ok := r.hostMap[addr]; !ok {
+	if _, ok := r.hostMap[ip]; !ok {
 		return
 		return
 	}
 	}
 
 
-	delete(r.hostMap, addr)
+	delete(r.hostMap, ip)
 	hosts := make([]string, 0, len(r.hostMap))
 	hosts := make([]string, 0, len(r.hostMap))
-	for addr := range r.hostMap {
-		hosts = append(hosts, addr)
+	for _, host := range r.hostMap {
+		hosts = append(hosts, host.Peer().String())
 	}
 	}
 
 
 	r.hp.SetHosts(hosts)
 	r.hp.SetHosts(hosts)
@@ -444,8 +449,8 @@ func (r *hostPoolHostPolicy) HostUp(host *HostInfo) {
 	r.AddHost(host)
 	r.AddHost(host)
 }
 }
 
 
-func (r *hostPoolHostPolicy) HostDown(addr string) {
-	r.RemoveHost(addr)
+func (r *hostPoolHostPolicy) HostDown(host *HostInfo) {
+	r.RemoveHost(host)
 }
 }
 
 
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
@@ -488,10 +493,12 @@ func (host selectedHostPoolHost) Info() *HostInfo {
 }
 }
 
 
 func (host selectedHostPoolHost) Mark(err error) {
 func (host selectedHostPoolHost) Mark(err error) {
+	ip := host.info.Peer().String()
+
 	host.policy.mu.RLock()
 	host.policy.mu.RLock()
 	defer host.policy.mu.RUnlock()
 	defer host.policy.mu.RUnlock()
 
 
-	if _, ok := host.policy.hostMap[host.info.Peer()]; !ok {
+	if _, ok := host.policy.hostMap[ip]; !ok {
 		// host was removed between pick and mark
 		// host was removed between pick and mark
 		return
 		return
 	}
 	}

+ 24 - 23
policies_test.go

@@ -6,6 +6,7 @@ package gocql
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"net"
 	"testing"
 	"testing"
 
 
 	"github.com/hailocab/go-hostpool"
 	"github.com/hailocab/go-hostpool"
@@ -16,8 +17,8 @@ func TestRoundRobinHostPolicy(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 	policy := RoundRobinHostPolicy()
 
 
 	hosts := [...]*HostInfo{
 	hosts := [...]*HostInfo{
-		{hostId: "0"},
-		{hostId: "1"},
+		{hostId: "0", peer: net.IPv4(0, 0, 0, 1)},
+		{hostId: "1", peer: net.IPv4(0, 0, 0, 2)},
 	}
 	}
 
 
 	for _, host := range hosts {
 	for _, host := range hosts {
@@ -67,10 +68,10 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 
 	// set the hosts
 	// set the hosts
 	hosts := [...]*HostInfo{
 	hosts := [...]*HostInfo{
-		{peer: "0", tokens: []string{"00"}},
-		{peer: "1", tokens: []string{"25"}},
-		{peer: "2", tokens: []string{"50"}},
-		{peer: "3", tokens: []string{"75"}},
+		{peer: net.IPv4(10, 0, 0, 1), tokens: []string{"00"}},
+		{peer: net.IPv4(10, 0, 0, 2), tokens: []string{"25"}},
+		{peer: net.IPv4(10, 0, 0, 3), tokens: []string{"50"}},
+		{peer: net.IPv4(10, 0, 0, 4), tokens: []string{"75"}},
 	}
 	}
 	for _, host := range hosts {
 	for _, host := range hosts {
 		policy.AddHost(host)
 		policy.AddHost(host)
@@ -78,12 +79,12 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
 	// should work
-	if actual := policy.Pick(nil)(); actual.Info().Peer() != "0" {
+	if actual := policy.Pick(nil)(); !actual.Info().Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 	}
 	}
 
 
 	query.RoutingKey([]byte("30"))
 	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().Peer() != "1" {
+	if actual := policy.Pick(query)(); !actual.Info().Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 	}
 	}
 
 
@@ -92,17 +93,17 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 	// now the token ring is configured
 	// now the token ring is configured
 	query.RoutingKey([]byte("20"))
 	query.RoutingKey([]byte("20"))
 	iter = policy.Pick(query)
 	iter = policy.Pick(query)
-	if actual := iter(); actual.Info().Peer() != "1" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 	}
 	}
 	// rest are round robin
 	// rest are round robin
-	if actual := iter(); actual.Info().Peer() != "2" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[2].peer) {
 		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer())
 		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer())
 	}
 	}
-	if actual := iter(); actual.Info().Peer() != "3" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[3].peer) {
 		t.Errorf("Expected peer 3 but was %s", actual.Info().Peer())
 		t.Errorf("Expected peer 3 but was %s", actual.Info().Peer())
 	}
 	}
-	if actual := iter(); actual.Info().Peer() != "0" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 	}
 	}
 }
 }
@@ -112,8 +113,8 @@ func TestHostPoolHostPolicy(t *testing.T) {
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 
 
 	hosts := []*HostInfo{
 	hosts := []*HostInfo{
-		{hostId: "0", peer: "0"},
-		{hostId: "1", peer: "1"},
+		{hostId: "0", peer: net.IPv4(10, 0, 0, 0)},
+		{hostId: "1", peer: net.IPv4(10, 0, 0, 1)},
 	}
 	}
 
 
 	// Using set host to control the ordering of the hosts as calling "AddHost" iterates the map
 	// Using set host to control the ordering of the hosts as calling "AddHost" iterates the map
@@ -177,10 +178,10 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 
 	hosts := [...]*HostInfo{
 	hosts := [...]*HostInfo{
-		{peer: "0", tokens: []string{"00"}},
-		{peer: "1", tokens: []string{"25"}},
-		{peer: "2", tokens: []string{"50"}},
-		{peer: "3", tokens: []string{"75"}},
+		{peer: net.IPv4(10, 0, 0, 0), tokens: []string{"00"}},
+		{peer: net.IPv4(10, 0, 0, 1), tokens: []string{"25"}},
+		{peer: net.IPv4(10, 0, 0, 2), tokens: []string{"50"}},
+		{peer: net.IPv4(10, 0, 0, 3), tokens: []string{"75"}},
 	}
 	}
 	for _, host := range hosts {
 	for _, host := range hosts {
 		policy.AddHost(host)
 		policy.AddHost(host)
@@ -196,13 +197,13 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 		t.Fatal("got nil host")
 		t.Fatal("got nil host")
 	} else if v := next.Info(); v == nil {
 	} else if v := next.Info(); v == nil {
 		t.Fatal("got nil HostInfo")
 		t.Fatal("got nil HostInfo")
-	} else if v.Peer() != "1" {
+	} else if !v.Peer().Equal(hosts[1].peer) {
 		t.Fatalf("expected peer 1 got %v", v.Peer())
 		t.Fatalf("expected peer 1 got %v", v.Peer())
 	}
 	}
 
 
 	// Empty the hosts to trigger the panic when using the fallback.
 	// Empty the hosts to trigger the panic when using the fallback.
 	for _, host := range hosts {
 	for _, host := range hosts {
-		policy.RemoveHost(host.Peer())
+		policy.RemoveHost(host)
 	}
 	}
 
 
 	next = iter()
 	next = iter()
@@ -217,7 +218,7 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 func TestCOWList_Add(t *testing.T) {
 func TestCOWList_Add(t *testing.T) {
 	var cow cowHostList
 	var cow cowHostList
 
 
-	toAdd := [...]string{"peer1", "peer2", "peer3"}
+	toAdd := [...]net.IP{net.IPv4(0, 0, 0, 0), net.IPv4(1, 0, 0, 0), net.IPv4(2, 0, 0, 0)}
 
 
 	for _, addr := range toAdd {
 	for _, addr := range toAdd {
 		if !cow.add(&HostInfo{peer: addr}) {
 		if !cow.add(&HostInfo{peer: addr}) {
@@ -232,11 +233,11 @@ func TestCOWList_Add(t *testing.T) {
 
 
 	set := make(map[string]bool)
 	set := make(map[string]bool)
 	for _, host := range hosts {
 	for _, host := range hosts {
-		set[host.Peer()] = true
+		set[string(host.Peer())] = true
 	}
 	}
 
 
 	for _, addr := range toAdd {
 	for _, addr := range toAdd {
-		if !set[addr] {
+		if !set[string(addr)] {
 			t.Errorf("addr was not in the host list: %q", addr)
 			t.Errorf("addr was not in the host list: %q", addr)
 		}
 		}
 	}
 	}

+ 1 - 1
query_executor.go

@@ -28,7 +28,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 			continue
 			continue
 		}
 		}
 
 
-		pool, ok := q.pool.getPool(host.Peer())
+		pool, ok := q.pool.getPool(host)
 		if !ok {
 		if !ok {
 			continue
 			continue
 		}
 		}

+ 28 - 11
ring.go

@@ -1,6 +1,7 @@
 package gocql
 package gocql
 
 
 import (
 import (
+	"net"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 )
 )
@@ -34,9 +35,9 @@ func (r *ring) rrHost() *HostInfo {
 	return r.hostList[pos%len(r.hostList)]
 	return r.hostList[pos%len(r.hostList)]
 }
 }
 
 
-func (r *ring) getHost(addr string) *HostInfo {
+func (r *ring) getHost(ip net.IP) *HostInfo {
 	r.mu.RLock()
 	r.mu.RLock()
-	host := r.hosts[addr]
+	host := r.hosts[ip.String()]
 	r.mu.RUnlock()
 	r.mu.RUnlock()
 	return host
 	return host
 }
 }
@@ -52,42 +53,58 @@ func (r *ring) allHosts() []*HostInfo {
 }
 }
 
 
 func (r *ring) addHost(host *HostInfo) bool {
 func (r *ring) addHost(host *HostInfo) bool {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	r.mu.Lock()
 	if r.hosts == nil {
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 		r.hosts = make(map[string]*HostInfo)
 	}
 	}
 
 
-	addr := host.Peer()
-	_, ok := r.hosts[addr]
-	r.hosts[addr] = host
+	_, ok := r.hosts[ip]
+	if !ok {
+		r.hostList = append(r.hostList, host)
+	}
+
+	r.hosts[ip] = host
 	r.mu.Unlock()
 	r.mu.Unlock()
 	return ok
 	return ok
 }
 }
 
 
 func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
 func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	r.mu.Lock()
 	if r.hosts == nil {
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 		r.hosts = make(map[string]*HostInfo)
 	}
 	}
 
 
-	addr := host.Peer()
-	existing, ok := r.hosts[addr]
+	existing, ok := r.hosts[ip]
 	if !ok {
 	if !ok {
-		r.hosts[addr] = host
+		r.hosts[ip] = host
 		existing = host
 		existing = host
+		r.hostList = append(r.hostList, host)
 	}
 	}
 	r.mu.Unlock()
 	r.mu.Unlock()
 	return existing, ok
 	return existing, ok
 }
 }
 
 
-func (r *ring) removeHost(addr string) bool {
+func (r *ring) removeHost(ip net.IP) bool {
 	r.mu.Lock()
 	r.mu.Lock()
 	if r.hosts == nil {
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 		r.hosts = make(map[string]*HostInfo)
 	}
 	}
 
 
-	_, ok := r.hosts[addr]
-	delete(r.hosts, addr)
+	k := ip.String()
+	_, ok := r.hosts[k]
+	if ok {
+		for i, host := range r.hostList {
+			if host.Peer().Equal(ip) {
+				r.hostList = append(r.hostList[:i], r.hostList[i+1:]...)
+				break
+			}
+		}
+	}
+	delete(r.hosts, k)
 	r.mu.Unlock()
 	r.mu.Unlock()
 	return ok
 	return ok
 }
 }

+ 7 - 4
ring_test.go

@@ -1,11 +1,14 @@
 package gocql
 package gocql
 
 
-import "testing"
+import (
+	"net"
+	"testing"
+)
 
 
 func TestRing_AddHostIfMissing_Missing(t *testing.T) {
 func TestRing_AddHostIfMissing_Missing(t *testing.T) {
 	ring := &ring{}
 	ring := &ring{}
 
 
-	host := &HostInfo{peer: "test1"}
+	host := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 	h1, ok := ring.addHostIfMissing(host)
 	h1, ok := ring.addHostIfMissing(host)
 	if ok {
 	if ok {
 		t.Fatal("host was reported as already existing")
 		t.Fatal("host was reported as already existing")
@@ -19,10 +22,10 @@ func TestRing_AddHostIfMissing_Missing(t *testing.T) {
 func TestRing_AddHostIfMissing_Existing(t *testing.T) {
 func TestRing_AddHostIfMissing_Existing(t *testing.T) {
 	ring := &ring{}
 	ring := &ring{}
 
 
-	host := &HostInfo{peer: "test1"}
+	host := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 	ring.addHostIfMissing(host)
 	ring.addHostIfMissing(host)
 
 
-	h2 := &HostInfo{peer: "test1"}
+	h2 := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 
 
 	h1, ok := ring.addHostIfMissing(h2)
 	h1, ok := ring.addHostIfMissing(h2)
 	if !ok {
 	if !ok {

+ 15 - 22
session.go

@@ -11,8 +11,6 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log"
 	"log"
-	"net"
-	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
@@ -81,18 +79,12 @@ var queryPool = &sync.Pool{
 func addrsToHosts(addrs []string, defaultPort int) ([]*HostInfo, error) {
 func addrsToHosts(addrs []string, defaultPort int) ([]*HostInfo, error) {
 	hosts := make([]*HostInfo, len(addrs))
 	hosts := make([]*HostInfo, len(addrs))
 	for i, hostport := range addrs {
 	for i, hostport := range addrs {
-		// TODO: remove duplication
-		addr, portStr, err := net.SplitHostPort(JoinHostPort(hostport, defaultPort))
+		host, err := hostInfo(hostport, defaultPort)
 		if err != nil {
 		if err != nil {
-			return nil, fmt.Errorf("NewSession: unable to parse hostport of addr %q: %v", hostport, err)
-		}
-
-		port, err := strconv.Atoi(portStr)
-		if err != nil {
-			return nil, fmt.Errorf("NewSession: invalid port for hostport of addr %q: %v", hostport, err)
+			return nil, err
 		}
 		}
 
 
-		hosts[i] = &HostInfo{peer: addr, port: port, state: NodeUp}
+		hosts[i] = host
 	}
 	}
 
 
 	return hosts, nil
 	return hosts, nil
@@ -156,7 +148,6 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		localHasRPCAddr, _ := checkSystemLocal(s.control)
 		localHasRPCAddr, _ := checkSystemLocal(s.control)
 		s.hostSource.localHasRpcAddr = localHasRPCAddr
 		s.hostSource.localHasRpcAddr = localHasRPCAddr
 
 
-		var err error
 		if cfg.DisableInitialHostLookup {
 		if cfg.DisableInitialHostLookup {
 			// TODO: we could look at system.local to get token and other metadata
 			// TODO: we could look at system.local to get token and other metadata
 			// in this case.
 			// in this case.
@@ -165,22 +156,23 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 			hosts, _, err = s.hostSource.GetHosts()
 			hosts, _, err = s.hostSource.GetHosts()
 		}
 		}
 
 
-		if err != nil {
-			s.Close()
-			return nil, fmt.Errorf("gocql: unable to create session: %v", err)
-		}
 	} else {
 	} else {
 		// we dont get host info
 		// we dont get host info
 		hosts, err = addrsToHosts(cfg.Hosts, cfg.Port)
 		hosts, err = addrsToHosts(cfg.Hosts, cfg.Port)
 	}
 	}
 
 
+	if err != nil {
+		s.Close()
+		return nil, fmt.Errorf("gocql: unable to create session: %v", err)
+	}
+
 	for _, host := range hosts {
 	for _, host := range hosts {
 		if s.cfg.HostFilter == nil || s.cfg.HostFilter.Accept(host) {
 		if s.cfg.HostFilter == nil || s.cfg.HostFilter.Accept(host) {
 			if existingHost, ok := s.ring.addHostIfMissing(host); ok {
 			if existingHost, ok := s.ring.addHostIfMissing(host); ok {
 				existingHost.update(host)
 				existingHost.update(host)
 			}
 			}
 
 
-			s.handleNodeUp(net.ParseIP(host.Peer()), host.Port(), false)
+			s.handleNodeUp(host.Peer(), host.Port(), false)
 		}
 		}
 	}
 	}
 
 
@@ -203,6 +195,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 	// connection is disable, we really have no choice, so we just make our
 	// connection is disable, we really have no choice, so we just make our
 	// best guess...
 	// best guess...
 	if !cfg.disableControlConn && cfg.DisableInitialHostLookup {
 	if !cfg.disableControlConn && cfg.DisableInitialHostLookup {
+		// TODO(zariel): we dont need to do this twice
 		newer, _ := checkSystemSchema(s.control)
 		newer, _ := checkSystemSchema(s.control)
 		s.useSystemSchema = newer
 		s.useSystemSchema = newer
 	} else {
 	} else {
@@ -225,7 +218,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
 			if gocqlDebug {
 			if gocqlDebug {
 				buf := bytes.NewBufferString("Session.ring:")
 				buf := bytes.NewBufferString("Session.ring:")
 				for _, h := range hosts {
 				for _, h := range hosts {
-					buf.WriteString("[" + h.Peer() + ":" + h.State().String() + "]")
+					buf.WriteString("[" + h.Peer().String() + ":" + h.State().String() + "]")
 				}
 				}
 				log.Println(buf.String())
 				log.Println(buf.String())
 			}
 			}
@@ -234,7 +227,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
 				if h.IsUp() {
 				if h.IsUp() {
 					continue
 					continue
 				}
 				}
-				s.handleNodeUp(net.ParseIP(h.Peer()), h.Port(), true)
+				s.handleNodeUp(h.Peer(), h.Port(), true)
 			}
 			}
 		case <-s.quit:
 		case <-s.quit:
 			return
 			return
@@ -409,7 +402,7 @@ func (s *Session) getConn() *Conn {
 			continue
 			continue
 		}
 		}
 
 
-		pool, ok := s.pool.getPool(host.Peer())
+		pool, ok := s.pool.getPool(host)
 		if !ok {
 		if !ok {
 			continue
 			continue
 		}
 		}
@@ -628,8 +621,8 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 	return applied, iter, iter.err
 	return applied, iter, iter.err
 }
 }
 
 
-func (s *Session) connect(addr string, errorHandler ConnErrorHandler, host *HostInfo) (*Conn, error) {
-	return Connect(host, addr, s.connCfg, errorHandler, s)
+func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
+	return Connect(host, s.connCfg, errorHandler, s)
 }
 }
 
 
 // Query represents a CQL statement that can be executed.
 // Query represents a CQL statement that can be executed.

+ 1 - 1
token.go

@@ -184,7 +184,7 @@ func (t *tokenRing) String() string {
 		buf.WriteString("]")
 		buf.WriteString("]")
 		buf.WriteString(t.tokens[i].String())
 		buf.WriteString(t.tokens[i].String())
 		buf.WriteString(":")
 		buf.WriteString(":")
-		buf.WriteString(t.hosts[i].Peer())
+		buf.WriteString(t.hosts[i].Peer().String())
 	}
 	}
 	buf.WriteString("\n}")
 	buf.WriteString("\n}")
 	return string(buf.Bytes())
 	return string(buf.Bytes())

+ 41 - 133
token_test.go

@@ -6,7 +6,9 @@ package gocql
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"fmt"
 	"math/big"
 	"math/big"
+	"net"
 	"sort"
 	"sort"
 	"strconv"
 	"strconv"
 	"testing"
 	"testing"
@@ -226,27 +228,23 @@ func TestUnknownTokenRing(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func hostsForTests(n int) []*HostInfo {
+	hosts := make([]*HostInfo, n)
+	for i := 0; i < n; i++ {
+		host := &HostInfo{
+			peer:   net.IPv4(1, 1, 1, byte(n)),
+			tokens: []string{fmt.Sprintf("%d", n)},
+		}
+
+		hosts[i] = host
+	}
+	return hosts
+}
+
 // Test of the tokenRing with the Murmur3Partitioner
 // Test of the tokenRing with the Murmur3Partitioner
 func TestMurmur3TokenRing(t *testing.T) {
 func TestMurmur3TokenRing(t *testing.T) {
 	// Note, strings are parsed directly to int64, they are not murmur3 hashed
 	// Note, strings are parsed directly to int64, they are not murmur3 hashed
-	hosts := []*HostInfo{
-		{
-			peer:   "0",
-			tokens: []string{"0"},
-		},
-		{
-			peer:   "1",
-			tokens: []string{"25"},
-		},
-		{
-			peer:   "2",
-			tokens: []string{"50"},
-		},
-		{
-			peer:   "3",
-			tokens: []string{"75"},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("Murmur3Partitioner", hosts)
 	ring, err := newTokenRing("Murmur3Partitioner", hosts)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -254,34 +252,20 @@ func TestMurmur3TokenRing(t *testing.T) {
 
 
 	p := murmur3Partitioner{}
 	p := murmur3Partitioner{}
 
 
-	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual := ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 	}
 
 
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	actual := ring.GetHostForToken(p.ParseString("12"))
+	if !actual.Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 	}
 
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "0" {
+	if !actual.Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
 		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 	}
 }
 }
@@ -290,32 +274,7 @@ func TestMurmur3TokenRing(t *testing.T) {
 func TestOrderedTokenRing(t *testing.T) {
 func TestOrderedTokenRing(t *testing.T) {
 	// Tokens here more or less are similar layout to the int tokens above due
 	// Tokens here more or less are similar layout to the int tokens above due
 	// to each numeric character translating to a consistently offset byte.
 	// to each numeric character translating to a consistently offset byte.
-	hosts := []*HostInfo{
-		{
-			peer: "0",
-			tokens: []string{
-				"00",
-			},
-		},
-		{
-			peer: "1",
-			tokens: []string{
-				"25",
-			},
-		},
-		{
-			peer: "2",
-			tokens: []string{
-				"50",
-			},
-		},
-		{
-			peer: "3",
-			tokens: []string{
-				"75",
-			},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("OrderedPartitioner", hosts)
 	ring, err := newTokenRing("OrderedPartitioner", hosts)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -324,33 +283,20 @@ func TestOrderedTokenRing(t *testing.T) {
 	p := orderedPartitioner{}
 	p := orderedPartitioner{}
 
 
 	var actual *HostInfo
 	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual = ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 	}
 
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 	}
 
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer())
 		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 	}
 }
 }
@@ -358,32 +304,7 @@ func TestOrderedTokenRing(t *testing.T) {
 // Test of the tokenRing with the RandomPartitioner
 // Test of the tokenRing with the RandomPartitioner
 func TestRandomTokenRing(t *testing.T) {
 func TestRandomTokenRing(t *testing.T) {
 	// String tokens are parsed into big.Int in base 10
 	// String tokens are parsed into big.Int in base 10
-	hosts := []*HostInfo{
-		{
-			peer: "0",
-			tokens: []string{
-				"00",
-			},
-		},
-		{
-			peer: "1",
-			tokens: []string{
-				"25",
-			},
-		},
-		{
-			peer: "2",
-			tokens: []string{
-				"50",
-			},
-		},
-		{
-			peer: "3",
-			tokens: []string{
-				"75",
-			},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("RandomPartitioner", hosts)
 	ring, err := newTokenRing("RandomPartitioner", hosts)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -392,33 +313,20 @@ func TestRandomTokenRing(t *testing.T) {
 	p := randomPartitioner{}
 	p := randomPartitioner{}
 
 
 	var actual *HostInfo
 	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual = ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 	}
 
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 	}
 
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
+	if !actual.peer.Equal(hosts[0].peer) {
+		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 	}
 }
 }