Browse Source

change HostInfo.Peer to be an IP

When Cassandra returns us a hosts info the peer address is defined as an
inet both at protocol for events and the schema for the peer
information.

Previously we stored this as a string, and used it to
connect to hosts and also to index hosts by. This is different to what
use for user supplied address endpoints, we keep the potentially DNS
name as the peer address. This means that we can end up having duplicate
host pools, duplicate host info in the ring.

Fix this by making everything rely on a hosts address being an IP
address instead of either a DNS name or an IP.
Chris Bannister 9 năm trước cách đây
mục cha
commit
e0a2f2ca85
19 tập tin đã thay đổi với 367 bổ sung332 xóa
  1. 21 8
      cassandra_test.go
  2. 15 1
      cluster.go
  3. 12 2
      conn.go
  4. 10 4
      conn_test.go
  5. 19 15
      connectionpool.go
  6. 50 22
      control.go
  7. 31 0
      control_test.go
  8. 21 27
      events.go
  9. 12 4
      filters.go
  10. 17 14
      filters_test.go
  11. 9 14
      host_source.go
  12. 33 26
      policies.go
  13. 24 23
      policies_test.go
  14. 1 1
      query_executor.go
  15. 28 11
      ring.go
  16. 7 4
      ring_test.go
  17. 15 22
      session.go
  18. 1 1
      token.go
  19. 41 133
      token_test.go

+ 21 - 8
cassandra_test.go

@@ -62,11 +62,15 @@ func TestRingDiscovery(t *testing.T) {
 	}
 
 	session.pool.mu.RLock()
+	defer session.pool.mu.RUnlock()
 	size := len(session.pool.hostConnPools)
-	session.pool.mu.RUnlock()
 
 	if *clusterSize != size {
-		t.Fatalf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
+		for p, pool := range session.pool.hostConnPools {
+			t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.Peer().String())
+
+		}
+		t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
 	}
 }
 
@@ -573,7 +577,7 @@ func TestReconnection(t *testing.T) {
 	defer session.Close()
 
 	h := session.ring.allHosts()[0]
-	session.handleNodeDown(net.ParseIP(h.Peer()), h.Port())
+	session.handleNodeDown(h.Peer(), h.Port())
 
 	if h.State() != NodeDown {
 		t.Fatal("Host should be NodeDown but not.")
@@ -2477,17 +2481,26 @@ func TestSchemaReset(t *testing.T) {
 }
 
 func TestCreateSession_DontSwallowError(t *testing.T) {
+	t.Skip("This test is bad, and the resultant error from cassandra changes between versions")
 	cluster := createCluster()
-	cluster.ProtoVersion = 100
+	cluster.ProtoVersion = 0x100
 	session, err := cluster.CreateSession()
 	if err == nil {
 		session.Close()
 
 		t.Fatal("expected to get an error for unsupported protocol")
 	}
-	// TODO: we should get a distinct error type here which include the underlying
-	// cassandra error about the protocol version, for now check this here.
-	if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
-		t.Fatalf(`expcted to get error "unsupported protocol version" got: %q`, err)
+
+	if flagCassVersion.Major < 3 {
+		// TODO: we should get a distinct error type here which include the underlying
+		// cassandra error about the protocol version, for now check this here.
+		if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
+			t.Fatalf(`expcted to get error "unsupported protocol version" got: %q`, err)
+		}
+	} else {
+		if !strings.Contains(err.Error(), "unsupported response version") {
+			t.Fatalf(`expcted to get error "unsupported response version" got: %q`, err)
+		}
 	}
+
 }

+ 15 - 1
cluster.go

@@ -27,7 +27,13 @@ func (p PoolConfig) buildPool(session *Session) *policyConnPool {
 // behavior to fit the most common use cases. Applications that require a
 // different setup must implement their own cluster.
 type ClusterConfig struct {
-	Hosts             []string          // addresses for the initial connections
+	// addresses for the initial connections. It is recomended to use the value set in
+	// the Cassandra config for broadcast_address or listen_address, an IP address not
+	// a domain name. This is because events from Cassandra will use the configured IP
+	// address, which is used to index connected hosts. If the domain name specified
+	// resolves to more than 1 IP address then the driver may connect multiple times to
+	// the same host, and will not mark the node being down or up from events.
+	Hosts             []string
 	CQLVersion        string            // CQL version (default: 3.0.0)
 	ProtoVersion      int               // version of the native protocol (default: 2)
 	Timeout           time.Duration     // connection timeout (default: 600ms)
@@ -100,6 +106,14 @@ type ClusterConfig struct {
 }
 
 // NewCluster generates a new config for the default cluster implementation.
+//
+// The supplied hosts are used to initially connect to the cluster then the rest of
+// the ring will be automatically discovered. It is recomended to use the value set in
+// the Cassandra config for broadcast_address or listen_address, an IP address not
+// a domain name. This is because events from Cassandra will use the configured IP
+// address, which is used to index connected hosts. If the domain name specified
+// resolves to more than 1 IP address then the driver may connect multiple times to
+// the same host, and will not mark the node being down or up from events.
 func NewCluster(hosts ...string) *ClusterConfig {
 	cfg := &ClusterConfig{
 		Hosts:                  hosts,

+ 12 - 2
conn.go

@@ -152,8 +152,15 @@ type Conn struct {
 }
 
 // Connect establishes a connection to a Cassandra node.
-func Connect(host *HostInfo, addr string, cfg *ConnConfig,
-	errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+	// TODO(zariel): remove these
+	if host == nil {
+		panic("host is nil")
+	} else if len(host.Peer()) == 0 {
+		panic("host missing peer ip address")
+	} else if host.Port() == 0 {
+		panic("host missing port")
+	}
 
 	var (
 		err  error
@@ -164,6 +171,9 @@ func Connect(host *HostInfo, addr string, cfg *ConnConfig,
 		Timeout: cfg.Timeout,
 	}
 
+	// TODO(zariel): handle ipv6 zone
+	addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
+
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not
 		// be modified after being used.

+ 10 - 4
conn_test.go

@@ -473,8 +473,7 @@ func TestStream0(t *testing.T) {
 		}
 	})
 
-	host := &HostInfo{peer: srv.Address}
-	conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -509,8 +508,7 @@ func TestConnClosedBlocked(t *testing.T) {
 		t.Log(err)
 	})
 
-	host := &HostInfo{peer: srv.Address}
-	conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -637,6 +635,14 @@ type TestServer struct {
 	closed bool
 }
 
+func (srv *TestServer) host() *HostInfo {
+	host, err := hostInfo(srv.Address, 9042)
+	if err != nil {
+		srv.t.Fatal(err)
+	}
+	return host
+}
+
 func (srv *TestServer) closeWatch() {
 	<-srv.ctx.Done()
 

+ 19 - 15
connectionpool.go

@@ -130,9 +130,10 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 			// don't create a connection pool for a down host
 			continue
 		}
-		if _, exists := p.hostConnPools[host.Peer()]; exists {
+		ip := host.Peer().String()
+		if _, exists := p.hostConnPools[ip]; exists {
 			// still have this host, so don't remove it
-			delete(toRemove, host.Peer())
+			delete(toRemove, ip)
 			continue
 		}
 
@@ -155,7 +156,7 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 		createCount--
 		if pool.Size() > 0 {
 			// add pool onyl if there a connections available
-			p.hostConnPools[pool.host.Peer()] = pool
+			p.hostConnPools[string(pool.host.Peer())] = pool
 		}
 	}
 
@@ -177,9 +178,10 @@ func (p *policyConnPool) Size() int {
 	return count
 }
 
-func (p *policyConnPool) getPool(addr string) (pool *hostConnPool, ok bool) {
+func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
+	ip := host.Peer().String()
 	p.mu.RLock()
-	pool, ok = p.hostConnPools[addr]
+	pool, ok = p.hostConnPools[ip]
 	p.mu.RUnlock()
 	return
 }
@@ -196,8 +198,9 @@ func (p *policyConnPool) Close() {
 }
 
 func (p *policyConnPool) addHost(host *HostInfo) {
+	ip := host.Peer().String()
 	p.mu.Lock()
-	pool, ok := p.hostConnPools[host.Peer()]
+	pool, ok := p.hostConnPools[ip]
 	if !ok {
 		pool = newHostConnPool(
 			p.session,
@@ -207,22 +210,23 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 			p.keyspace,
 		)
 
-		p.hostConnPools[host.Peer()] = pool
+		p.hostConnPools[ip] = pool
 	}
 	p.mu.Unlock()
 
 	pool.fill()
 }
 
-func (p *policyConnPool) removeHost(addr string) {
+func (p *policyConnPool) removeHost(ip net.IP) {
+	k := ip.String()
 	p.mu.Lock()
-	pool, ok := p.hostConnPools[addr]
+	pool, ok := p.hostConnPools[k]
 	if !ok {
 		p.mu.Unlock()
 		return
 	}
 
-	delete(p.hostConnPools, addr)
+	delete(p.hostConnPools, k)
 	p.mu.Unlock()
 
 	go pool.Close()
@@ -234,10 +238,10 @@ func (p *policyConnPool) hostUp(host *HostInfo) {
 	p.addHost(host)
 }
 
-func (p *policyConnPool) hostDown(addr string) {
+func (p *policyConnPool) hostDown(ip net.IP) {
 	// TODO(zariel): mark host as down so we can try to connect to it later, for
 	// now just treat it has removed.
-	p.removeHost(addr)
+	p.removeHost(ip)
 }
 
 // hostConnPool is a connection pool for a single host.
@@ -272,7 +276,7 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int,
 		session:  session,
 		host:     host,
 		port:     port,
-		addr:     JoinHostPort(host.Peer(), port),
+		addr:     (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String(),
 		size:     size,
 		keyspace: keyspace,
 		conns:    make([]*Conn, 0, size),
@@ -396,7 +400,7 @@ func (pool *hostConnPool) fill() {
 
 			// this is calle with the connetion pool mutex held, this call will
 			// then recursivly try to lock it again. FIXME
-			go pool.session.handleNodeDown(net.ParseIP(pool.host.Peer()), pool.port)
+			go pool.session.handleNodeDown(pool.host.Peer(), pool.port)
 			return
 		}
 
@@ -477,7 +481,7 @@ func (pool *hostConnPool) connect() (err error) {
 	// try to connect
 	var conn *Conn
 	for i := 0; i < maxAttempts; i++ {
-		conn, err = pool.session.connect(pool.addr, pool, pool.host)
+		conn, err = pool.session.connect(pool.host, pool)
 		if err == nil {
 			break
 		}

+ 50 - 22
control.go

@@ -4,13 +4,14 @@ import (
 	crand "crypto/rand"
 	"errors"
 	"fmt"
-	"golang.org/x/net/context"
 	"log"
 	"math/rand"
 	"net"
 	"strconv"
 	"sync/atomic"
 	"time"
+
+	"golang.org/x/net/context"
 )
 
 var (
@@ -89,6 +90,8 @@ func (c *controlConn) heartBeat() {
 	}
 }
 
+var hostLookupPreferV4 = false
+
 func hostInfo(addr string, defaultPort int) (*HostInfo, error) {
 	var port int
 	host, portStr, err := net.SplitHostPort(addr)
@@ -102,10 +105,37 @@ func hostInfo(addr string, defaultPort int) (*HostInfo, error) {
 		}
 	}
 
-	return &HostInfo{peer: host, port: port}, nil
+	ip := net.ParseIP(host)
+	if ip == nil {
+		ips, err := net.LookupIP(host)
+		if err != nil {
+			return nil, err
+		} else if len(ips) == 0 {
+			return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
+		}
+
+		if hostLookupPreferV4 {
+			for _, v := range ips {
+				if v4 := v.To4(); v4 != nil {
+					ip = v4
+					break
+				}
+			}
+			if ip == nil {
+				ip = ips[0]
+			}
+		} else {
+			// TODO(zariel): should we check that we can connect to any of the ips?
+			ip = ips[0]
+		}
+
+	}
+
+	return &HostInfo{peer: ip, port: port}, nil
 }
 
 func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
+	// TODO: accept a []*HostInfo
 	perm := randr.Perm(len(endpoints))
 	shuffled := make([]string, len(endpoints))
 
@@ -130,7 +160,7 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
 		}
 
 		hostInfo, _ := c.session.ring.addHostIfMissing(host)
-		conn, err = c.session.connect(addr, c, hostInfo)
+		conn, err = c.session.connect(hostInfo, c)
 		if err == nil {
 			return conn, err
 		}
@@ -229,22 +259,21 @@ func (c *controlConn) reconnect(refreshring bool) {
 	// TODO: simplify this function, use session.ring to get hosts instead of the
 	// connection pool
 
-	addr := c.addr()
+	var host *HostInfo
 	oldConn := c.conn.Load().(*Conn)
 	if oldConn != nil {
+		host = oldConn.host
 		oldConn.Close()
 	}
 
 	var newConn *Conn
-	if addr != "" {
+	if host != nil {
 		// try to connect to the old host
-		conn, err := c.session.connect(addr, c, oldConn.host)
+		conn, err := c.session.connect(host, c)
 		if err != nil {
 			// host is dead
 			// TODO: this is replicated in a few places
-			ip, portStr, _ := net.SplitHostPort(addr)
-			port, _ := strconv.Atoi(portStr)
-			c.session.handleNodeDown(net.ParseIP(ip), port)
+			c.session.handleNodeDown(host.Peer(), host.Port())
 		} else {
 			newConn = conn
 		}
@@ -260,7 +289,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 		}
 
 		var err error
-		newConn, err = c.session.connect(host.Peer(), c, host)
+		newConn, err = c.session.connect(host, c)
 		if err != nil {
 			// TODO: add log handler for things like this
 			return
@@ -350,29 +379,28 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 	return
 }
 
-func (c *controlConn) fetchHostInfo(addr net.IP, port int) (*HostInfo, error) {
+func (c *controlConn) fetchHostInfo(ip net.IP, port int) (*HostInfo, error) {
 	// TODO(zariel): we should probably move this into host_source or atleast
 	// share code with it.
-	hostname, _, err := net.SplitHostPort(c.addr())
-	if err != nil {
-		return nil, fmt.Errorf("unable to fetch host info, invalid conn addr: %q: %v", c.addr(), err)
+	localHost := c.host()
+	if localHost == nil {
+		return nil, errors.New("unable to fetch host info, invalid conn host")
 	}
 
-	isLocal := hostname == addr.String()
+	isLocal := localHost.Peer().Equal(ip)
 
 	var fn func(*HostInfo) error
 
+	// TODO(zariel): fetch preferred_ip address (is it >3.x only?)
 	if isLocal {
 		fn = func(host *HostInfo) error {
-			// TODO(zariel): should we fetch rpc_address from here?
 			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.local WHERE key='local'")
 			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
 			return iter.Close()
 		}
 	} else {
 		fn = func(host *HostInfo) error {
-			// TODO(zariel): should we fetch rpc_address from here?
-			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.peers WHERE peer=?", addr)
+			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.peers WHERE peer=?", ip)
 			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
 			return iter.Close()
 		}
@@ -380,12 +408,12 @@ func (c *controlConn) fetchHostInfo(addr net.IP, port int) (*HostInfo, error) {
 
 	host := &HostInfo{
 		port: port,
+		peer: ip,
 	}
 
 	if err := fn(host); err != nil {
 		return nil, err
 	}
-	host.peer = addr.String()
 
 	return host, nil
 }
@@ -396,12 +424,12 @@ func (c *controlConn) awaitSchemaAgreement() error {
 	}).err
 }
 
-func (c *controlConn) addr() string {
+func (c *controlConn) host() *HostInfo {
 	conn := c.conn.Load().(*Conn)
 	if conn == nil {
-		return ""
+		return nil
 	}
-	return conn.addr
+	return conn.host
 }
 
 func (c *controlConn) close() {

+ 31 - 0
control_test.go

@@ -0,0 +1,31 @@
+package gocql
+
+import (
+	"net"
+	"testing"
+)
+
+func TestHostInfo_Lookup(t *testing.T) {
+	hostLookupPreferV4 = true
+	defer func() { hostLookupPreferV4 = false }()
+
+	tests := [...]struct {
+		addr string
+		ip   net.IP
+	}{
+		{"127.0.0.1", net.IPv4(127, 0, 0, 1)},
+		{"localhost", net.IPv4(127, 0, 0, 1)}, // TODO: this may be host dependant
+	}
+
+	for i, test := range tests {
+		host, err := hostInfo(test.addr, 1)
+		if err != nil {
+			t.Errorf("%d: %v", i, err)
+			continue
+		}
+
+		if !host.peer.Equal(test.ip) {
+			t.Errorf("expected ip %v got %v for addr %q", test.ip, host.peer, test.addr)
+		}
+	}
+}

+ 21 - 27
events.go

@@ -171,25 +171,21 @@ func (s *Session) handleNodeEvent(frames []frame) {
 	}
 }
 
-func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
-	// TODO(zariel): need to be able to filter discovered nodes
-
+func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	var hostInfo *HostInfo
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
 		var err error
-		hostInfo, err = s.control.fetchHostInfo(host, port)
+		hostInfo, err = s.control.fetchHostInfo(ip, port)
 		if err != nil {
-			log.Printf("gocql: events: unable to fetch host info for %v: %v\n", host, err)
+			log.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
 			return
 		}
-
 	} else {
-		hostInfo = &HostInfo{peer: host.String(), port: port, state: NodeUp}
+		hostInfo = &HostInfo{peer: ip, port: port}
 	}
 
-	addr := host.String()
-	if s.cfg.IgnorePeerAddr && hostInfo.Peer() != addr {
-		hostInfo.setPeer(addr)
+	if s.cfg.IgnorePeerAddr && hostInfo.Peer().Equal(ip) {
+		hostInfo.setPeer(ip)
 	}
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(hostInfo) {
@@ -217,11 +213,9 @@ func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
 
 func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	// we remove all nodes but only add ones which pass the filter
-	addr := ip.String()
-
-	host := s.ring.getHost(addr)
+	host := s.ring.getHost(ip)
 	if host == nil {
-		host = &HostInfo{peer: addr}
+		host = &HostInfo{peer: ip, port: port}
 	}
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -229,9 +223,9 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	}
 
 	host.setState(NodeDown)
-	s.policy.RemoveHost(addr)
-	s.pool.removeHost(addr)
-	s.ring.removeHost(addr)
+	s.policy.RemoveHost(host)
+	s.pool.removeHost(ip)
+	s.ring.removeHost(ip)
 
 	if !s.cfg.IgnorePeerAddr {
 		s.hostSource.refreshRing()
@@ -242,11 +236,12 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 	if gocqlDebug {
 		log.Printf("gocql: Session.handleNodeUp: %s:%d\n", ip.String(), port)
 	}
-	addr := ip.String()
-	host := s.ring.getHost(addr)
+
+	host := s.ring.getHost(ip)
 	if host != nil {
-		if s.cfg.IgnorePeerAddr && host.Peer() != addr {
-			host.setPeer(addr)
+		if s.cfg.IgnorePeerAddr && host.Peer().Equal(ip) {
+			// TODO: how can this ever be true?
+			host.setPeer(ip)
 		}
 
 		if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -257,7 +252,6 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 			time.Sleep(t)
 		}
 
-		host.setPort(port)
 		s.pool.hostUp(host)
 		s.policy.HostUp(host)
 		host.setState(NodeUp)
@@ -271,10 +265,10 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	if gocqlDebug {
 		log.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
 	}
-	addr := ip.String()
-	host := s.ring.getHost(addr)
+
+	host := s.ring.getHost(ip)
 	if host == nil {
-		host = &HostInfo{peer: addr}
+		host = &HostInfo{peer: ip, port: port}
 	}
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -282,6 +276,6 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	}
 
 	host.setState(NodeDown)
-	s.policy.HostDown(addr)
-	s.pool.hostDown(addr)
+	s.policy.HostDown(host)
+	s.pool.hostDown(ip)
 }

+ 12 - 4
filters.go

@@ -1,5 +1,7 @@
 package gocql
 
+import "fmt"
+
 // HostFilter interface is used when a host is discovered via server sent events.
 type HostFilter interface {
 	// Called when a new host is discovered, returning true will cause the host
@@ -38,12 +40,18 @@ func DataCentreHostFilter(dataCentre string) HostFilter {
 // WhiteListHostFilter filters incoming hosts by checking that their address is
 // in the initial hosts whitelist.
 func WhiteListHostFilter(hosts ...string) HostFilter {
-	m := make(map[string]bool, len(hosts))
-	for _, host := range hosts {
-		m[host] = true
+	hostInfos, err := addrsToHosts(hosts, 9042)
+	if err != nil {
+		// dont want to panic here, but rather not break the API
+		panic(fmt.Errorf("unable to lookup host info from address: %v", err))
+	}
+
+	m := make(map[string]bool, len(hostInfos))
+	for _, host := range hostInfos {
+		m[string(host.peer)] = true
 	}
 
 	return HostFilterFunc(func(host *HostInfo) bool {
-		return m[host.Peer()]
+		return m[string(host.Peer())]
 	})
 }

+ 17 - 14
filters_test.go

@@ -1,16 +1,19 @@
 package gocql
 
-import "testing"
+import (
+	"net"
+	"testing"
+)
 
 func TestFilter_WhiteList(t *testing.T) {
-	f := WhiteListHostFilter("addr1", "addr2")
+	f := WhiteListHostFilter("127.0.0.1", "127.0.0.2")
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 	}{
-		{"addr1", true},
-		{"addr2", true},
-		{"addr3", false},
+		{net.ParseIP("127.0.0.1"), true},
+		{net.ParseIP("127.0.0.2"), true},
+		{net.ParseIP("127.0.0.3"), false},
 	}
 
 	for i, test := range tests {
@@ -27,12 +30,12 @@ func TestFilter_WhiteList(t *testing.T) {
 func TestFilter_AllowAll(t *testing.T) {
 	f := AcceptAllFilter()
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 	}{
-		{"addr1", true},
-		{"addr2", true},
-		{"addr3", true},
+		{net.ParseIP("127.0.0.1"), true},
+		{net.ParseIP("127.0.0.2"), true},
+		{net.ParseIP("127.0.0.3"), true},
 	}
 
 	for i, test := range tests {
@@ -49,12 +52,12 @@ func TestFilter_AllowAll(t *testing.T) {
 func TestFilter_DenyAll(t *testing.T) {
 	f := DenyAllFilter()
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 	}{
-		{"addr1", false},
-		{"addr2", false},
-		{"addr3", false},
+		{net.ParseIP("127.0.0.1"), false},
+		{net.ParseIP("127.0.0.2"), false},
+		{net.ParseIP("127.0.0.3"), false},
 	}
 
 	for i, test := range tests {

+ 9 - 14
host_source.go

@@ -100,7 +100,7 @@ type HostInfo struct {
 	// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
 	// that we are thread safe use a mutex to access all fields.
 	mu         sync.RWMutex
-	peer       string
+	peer       net.IP
 	port       int
 	dataCenter string
 	rack       string
@@ -116,16 +116,16 @@ func (h *HostInfo) Equal(host *HostInfo) bool {
 	host.mu.RLock()
 	defer host.mu.RUnlock()
 
-	return h.peer == host.peer && h.hostId == host.hostId
+	return h.peer.Equal(host.peer)
 }
 
-func (h *HostInfo) Peer() string {
+func (h *HostInfo) Peer() net.IP {
 	h.mu.RLock()
 	defer h.mu.RUnlock()
 	return h.peer
 }
 
-func (h *HostInfo) setPeer(peer string) *HostInfo {
+func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
 	h.mu.Lock()
 	defer h.mu.Unlock()
 	h.peer = peer
@@ -314,7 +314,11 @@ func (r *ringDescriber) GetHosts() (hosts []*HostInfo, partitioner string, err e
 			return nil, "", err
 		}
 	} else {
-		iter := r.session.control.query(legacyLocalQuery)
+		iter := r.session.control.withConn(func(c *Conn) *Iter {
+			localHost = c.host
+			return c.query(legacyLocalQuery)
+		})
+
 		if iter == nil {
 			return r.prevHosts, r.prevPartitioner, nil
 		}
@@ -324,15 +328,6 @@ func (r *ringDescriber) GetHosts() (hosts []*HostInfo, partitioner string, err e
 		if err = iter.Close(); err != nil {
 			return nil, "", err
 		}
-
-		addr, _, err := net.SplitHostPort(r.session.control.addr())
-		if err != nil {
-			// this should not happen, ever, as this is the address that was dialed by conn, here
-			// a panic makes sense, please report a bug if it occurs.
-			panic(err)
-		}
-
-		localHost.peer = addr
 	}
 
 	localHost.port = r.session.cfg.Port

+ 33 - 26
policies.go

@@ -7,6 +7,7 @@ package gocql
 import (
 	"fmt"
 	"log"
+	"net"
 	"sync"
 	"sync/atomic"
 
@@ -90,7 +91,7 @@ func (c *cowHostList) update(host *HostInfo) {
 	c.mu.Unlock()
 }
 
-func (c *cowHostList) remove(addr string) bool {
+func (c *cowHostList) remove(ip net.IP) bool {
 	c.mu.Lock()
 	l := c.get()
 	size := len(l)
@@ -102,7 +103,7 @@ func (c *cowHostList) remove(addr string) bool {
 	found := false
 	newL := make([]*HostInfo, 0, size)
 	for i := 0; i < len(l); i++ {
-		if l[i].Peer() != addr {
+		if !l[i].Peer().Equal(ip) {
 			newL = append(newL, l[i])
 		} else {
 			found = true
@@ -161,9 +162,9 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 
 type HostStateNotifier interface {
 	AddHost(host *HostInfo)
-	RemoveHost(addr string)
+	RemoveHost(host *HostInfo)
 	HostUp(host *HostInfo)
-	HostDown(addr string)
+	HostDown(host *HostInfo)
 }
 
 // HostSelectionPolicy is an interface for selecting
@@ -235,16 +236,16 @@ func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
 	r.hosts.add(host)
 }
 
-func (r *roundRobinHostPolicy) RemoveHost(addr string) {
-	r.hosts.remove(addr)
+func (r *roundRobinHostPolicy) RemoveHost(host *HostInfo) {
+	r.hosts.remove(host.Peer())
 }
 
 func (r *roundRobinHostPolicy) HostUp(host *HostInfo) {
 	r.AddHost(host)
 }
 
-func (r *roundRobinHostPolicy) HostDown(addr string) {
-	r.RemoveHost(addr)
+func (r *roundRobinHostPolicy) HostDown(host *HostInfo) {
+	r.RemoveHost(host)
 }
 
 // TokenAwareHostPolicy is a token aware host selection policy, where hosts are
@@ -278,9 +279,9 @@ func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
 	t.resetTokenRing()
 }
 
-func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
-	t.hosts.remove(addr)
-	t.fallback.RemoveHost(addr)
+func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
+	t.hosts.remove(host.Peer())
+	t.fallback.RemoveHost(host)
 
 	t.resetTokenRing()
 }
@@ -289,8 +290,8 @@ func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
 	t.AddHost(host)
 }
 
-func (t *tokenAwareHostPolicy) HostDown(addr string) {
-	t.RemoveHost(addr)
+func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
+	t.RemoveHost(host)
 }
 
 func (t *tokenAwareHostPolicy) resetTokenRing() {
@@ -393,8 +394,9 @@ func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
 	hostMap := make(map[string]*HostInfo, len(hosts))
 
 	for i, host := range hosts {
-		peers[i] = host.Peer()
-		hostMap[host.Peer()] = host
+		ip := host.Peer().String()
+		peers[i] = ip
+		hostMap[ip] = host
 	}
 
 	r.mu.Lock()
@@ -404,15 +406,17 @@ func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
 }
 
 func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	defer r.mu.Unlock()
 
 	// If the host addr is present and isn't nil return
-	if h, ok := r.hostMap[host.Peer()]; ok && h != nil{
+	if h, ok := r.hostMap[ip]; ok && h != nil {
 		return
 	}
 	// otherwise, add the host to the map
-	r.hostMap[host.Peer()] = host
+	r.hostMap[ip] = host
 	// and construct a new peer list to give to the HostPool
 	hosts := make([]string, 0, len(r.hostMap))
 	for addr := range r.hostMap {
@@ -420,21 +424,22 @@ func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
 	}
 
 	r.hp.SetHosts(hosts)
-
 }
 
-func (r *hostPoolHostPolicy) RemoveHost(addr string) {
+func (r *hostPoolHostPolicy) RemoveHost(host *HostInfo) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	defer r.mu.Unlock()
 
-	if _, ok := r.hostMap[addr]; !ok {
+	if _, ok := r.hostMap[ip]; !ok {
 		return
 	}
 
-	delete(r.hostMap, addr)
+	delete(r.hostMap, ip)
 	hosts := make([]string, 0, len(r.hostMap))
-	for addr := range r.hostMap {
-		hosts = append(hosts, addr)
+	for _, host := range r.hostMap {
+		hosts = append(hosts, host.Peer().String())
 	}
 
 	r.hp.SetHosts(hosts)
@@ -444,8 +449,8 @@ func (r *hostPoolHostPolicy) HostUp(host *HostInfo) {
 	r.AddHost(host)
 }
 
-func (r *hostPoolHostPolicy) HostDown(addr string) {
-	r.RemoveHost(addr)
+func (r *hostPoolHostPolicy) HostDown(host *HostInfo) {
+	r.RemoveHost(host)
 }
 
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
@@ -488,10 +493,12 @@ func (host selectedHostPoolHost) Info() *HostInfo {
 }
 
 func (host selectedHostPoolHost) Mark(err error) {
+	ip := host.info.Peer().String()
+
 	host.policy.mu.RLock()
 	defer host.policy.mu.RUnlock()
 
-	if _, ok := host.policy.hostMap[host.info.Peer()]; !ok {
+	if _, ok := host.policy.hostMap[ip]; !ok {
 		// host was removed between pick and mark
 		return
 	}

+ 24 - 23
policies_test.go

@@ -6,6 +6,7 @@ package gocql
 
 import (
 	"fmt"
+	"net"
 	"testing"
 
 	"github.com/hailocab/go-hostpool"
@@ -16,8 +17,8 @@ func TestRoundRobinHostPolicy(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
 	hosts := [...]*HostInfo{
-		{hostId: "0"},
-		{hostId: "1"},
+		{hostId: "0", peer: net.IPv4(0, 0, 0, 1)},
+		{hostId: "1", peer: net.IPv4(0, 0, 0, 2)},
 	}
 
 	for _, host := range hosts {
@@ -67,10 +68,10 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 	// set the hosts
 	hosts := [...]*HostInfo{
-		{peer: "0", tokens: []string{"00"}},
-		{peer: "1", tokens: []string{"25"}},
-		{peer: "2", tokens: []string{"50"}},
-		{peer: "3", tokens: []string{"75"}},
+		{peer: net.IPv4(10, 0, 0, 1), tokens: []string{"00"}},
+		{peer: net.IPv4(10, 0, 0, 2), tokens: []string{"25"}},
+		{peer: net.IPv4(10, 0, 0, 3), tokens: []string{"50"}},
+		{peer: net.IPv4(10, 0, 0, 4), tokens: []string{"75"}},
 	}
 	for _, host := range hosts {
 		policy.AddHost(host)
@@ -78,12 +79,12 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
-	if actual := policy.Pick(nil)(); actual.Info().Peer() != "0" {
+	if actual := policy.Pick(nil)(); !actual.Info().Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 	}
 
 	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().Peer() != "1" {
+	if actual := policy.Pick(query)(); !actual.Info().Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 	}
 
@@ -92,17 +93,17 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 	// now the token ring is configured
 	query.RoutingKey([]byte("20"))
 	iter = policy.Pick(query)
-	if actual := iter(); actual.Info().Peer() != "1" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 	}
 	// rest are round robin
-	if actual := iter(); actual.Info().Peer() != "2" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[2].peer) {
 		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer())
 	}
-	if actual := iter(); actual.Info().Peer() != "3" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[3].peer) {
 		t.Errorf("Expected peer 3 but was %s", actual.Info().Peer())
 	}
-	if actual := iter(); actual.Info().Peer() != "0" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 	}
 }
@@ -112,8 +113,8 @@ func TestHostPoolHostPolicy(t *testing.T) {
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 
 	hosts := []*HostInfo{
-		{hostId: "0", peer: "0"},
-		{hostId: "1", peer: "1"},
+		{hostId: "0", peer: net.IPv4(10, 0, 0, 0)},
+		{hostId: "1", peer: net.IPv4(10, 0, 0, 1)},
 	}
 
 	// Using set host to control the ordering of the hosts as calling "AddHost" iterates the map
@@ -177,10 +178,10 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 	hosts := [...]*HostInfo{
-		{peer: "0", tokens: []string{"00"}},
-		{peer: "1", tokens: []string{"25"}},
-		{peer: "2", tokens: []string{"50"}},
-		{peer: "3", tokens: []string{"75"}},
+		{peer: net.IPv4(10, 0, 0, 0), tokens: []string{"00"}},
+		{peer: net.IPv4(10, 0, 0, 1), tokens: []string{"25"}},
+		{peer: net.IPv4(10, 0, 0, 2), tokens: []string{"50"}},
+		{peer: net.IPv4(10, 0, 0, 3), tokens: []string{"75"}},
 	}
 	for _, host := range hosts {
 		policy.AddHost(host)
@@ -196,13 +197,13 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 		t.Fatal("got nil host")
 	} else if v := next.Info(); v == nil {
 		t.Fatal("got nil HostInfo")
-	} else if v.Peer() != "1" {
+	} else if !v.Peer().Equal(hosts[1].peer) {
 		t.Fatalf("expected peer 1 got %v", v.Peer())
 	}
 
 	// Empty the hosts to trigger the panic when using the fallback.
 	for _, host := range hosts {
-		policy.RemoveHost(host.Peer())
+		policy.RemoveHost(host)
 	}
 
 	next = iter()
@@ -217,7 +218,7 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 func TestCOWList_Add(t *testing.T) {
 	var cow cowHostList
 
-	toAdd := [...]string{"peer1", "peer2", "peer3"}
+	toAdd := [...]net.IP{net.IPv4(0, 0, 0, 0), net.IPv4(1, 0, 0, 0), net.IPv4(2, 0, 0, 0)}
 
 	for _, addr := range toAdd {
 		if !cow.add(&HostInfo{peer: addr}) {
@@ -232,11 +233,11 @@ func TestCOWList_Add(t *testing.T) {
 
 	set := make(map[string]bool)
 	for _, host := range hosts {
-		set[host.Peer()] = true
+		set[string(host.Peer())] = true
 	}
 
 	for _, addr := range toAdd {
-		if !set[addr] {
+		if !set[string(addr)] {
 			t.Errorf("addr was not in the host list: %q", addr)
 		}
 	}

+ 1 - 1
query_executor.go

@@ -28,7 +28,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 			continue
 		}
 
-		pool, ok := q.pool.getPool(host.Peer())
+		pool, ok := q.pool.getPool(host)
 		if !ok {
 			continue
 		}

+ 28 - 11
ring.go

@@ -1,6 +1,7 @@
 package gocql
 
 import (
+	"net"
 	"sync"
 	"sync/atomic"
 )
@@ -34,9 +35,9 @@ func (r *ring) rrHost() *HostInfo {
 	return r.hostList[pos%len(r.hostList)]
 }
 
-func (r *ring) getHost(addr string) *HostInfo {
+func (r *ring) getHost(ip net.IP) *HostInfo {
 	r.mu.RLock()
-	host := r.hosts[addr]
+	host := r.hosts[ip.String()]
 	r.mu.RUnlock()
 	return host
 }
@@ -52,42 +53,58 @@ func (r *ring) allHosts() []*HostInfo {
 }
 
 func (r *ring) addHost(host *HostInfo) bool {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 	}
 
-	addr := host.Peer()
-	_, ok := r.hosts[addr]
-	r.hosts[addr] = host
+	_, ok := r.hosts[ip]
+	if !ok {
+		r.hostList = append(r.hostList, host)
+	}
+
+	r.hosts[ip] = host
 	r.mu.Unlock()
 	return ok
 }
 
 func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 	}
 
-	addr := host.Peer()
-	existing, ok := r.hosts[addr]
+	existing, ok := r.hosts[ip]
 	if !ok {
-		r.hosts[addr] = host
+		r.hosts[ip] = host
 		existing = host
+		r.hostList = append(r.hostList, host)
 	}
 	r.mu.Unlock()
 	return existing, ok
 }
 
-func (r *ring) removeHost(addr string) bool {
+func (r *ring) removeHost(ip net.IP) bool {
 	r.mu.Lock()
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 	}
 
-	_, ok := r.hosts[addr]
-	delete(r.hosts, addr)
+	k := ip.String()
+	_, ok := r.hosts[k]
+	if ok {
+		for i, host := range r.hostList {
+			if host.Peer().Equal(ip) {
+				r.hostList = append(r.hostList[:i], r.hostList[i+1:]...)
+				break
+			}
+		}
+	}
+	delete(r.hosts, k)
 	r.mu.Unlock()
 	return ok
 }

+ 7 - 4
ring_test.go

@@ -1,11 +1,14 @@
 package gocql
 
-import "testing"
+import (
+	"net"
+	"testing"
+)
 
 func TestRing_AddHostIfMissing_Missing(t *testing.T) {
 	ring := &ring{}
 
-	host := &HostInfo{peer: "test1"}
+	host := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 	h1, ok := ring.addHostIfMissing(host)
 	if ok {
 		t.Fatal("host was reported as already existing")
@@ -19,10 +22,10 @@ func TestRing_AddHostIfMissing_Missing(t *testing.T) {
 func TestRing_AddHostIfMissing_Existing(t *testing.T) {
 	ring := &ring{}
 
-	host := &HostInfo{peer: "test1"}
+	host := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 	ring.addHostIfMissing(host)
 
-	h2 := &HostInfo{peer: "test1"}
+	h2 := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 
 	h1, ok := ring.addHostIfMissing(h2)
 	if !ok {

+ 15 - 22
session.go

@@ -11,8 +11,6 @@ import (
 	"fmt"
 	"io"
 	"log"
-	"net"
-	"strconv"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -81,18 +79,12 @@ var queryPool = &sync.Pool{
 func addrsToHosts(addrs []string, defaultPort int) ([]*HostInfo, error) {
 	hosts := make([]*HostInfo, len(addrs))
 	for i, hostport := range addrs {
-		// TODO: remove duplication
-		addr, portStr, err := net.SplitHostPort(JoinHostPort(hostport, defaultPort))
+		host, err := hostInfo(hostport, defaultPort)
 		if err != nil {
-			return nil, fmt.Errorf("NewSession: unable to parse hostport of addr %q: %v", hostport, err)
-		}
-
-		port, err := strconv.Atoi(portStr)
-		if err != nil {
-			return nil, fmt.Errorf("NewSession: invalid port for hostport of addr %q: %v", hostport, err)
+			return nil, err
 		}
 
-		hosts[i] = &HostInfo{peer: addr, port: port, state: NodeUp}
+		hosts[i] = host
 	}
 
 	return hosts, nil
@@ -156,7 +148,6 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		localHasRPCAddr, _ := checkSystemLocal(s.control)
 		s.hostSource.localHasRpcAddr = localHasRPCAddr
 
-		var err error
 		if cfg.DisableInitialHostLookup {
 			// TODO: we could look at system.local to get token and other metadata
 			// in this case.
@@ -165,22 +156,23 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 			hosts, _, err = s.hostSource.GetHosts()
 		}
 
-		if err != nil {
-			s.Close()
-			return nil, fmt.Errorf("gocql: unable to create session: %v", err)
-		}
 	} else {
 		// we dont get host info
 		hosts, err = addrsToHosts(cfg.Hosts, cfg.Port)
 	}
 
+	if err != nil {
+		s.Close()
+		return nil, fmt.Errorf("gocql: unable to create session: %v", err)
+	}
+
 	for _, host := range hosts {
 		if s.cfg.HostFilter == nil || s.cfg.HostFilter.Accept(host) {
 			if existingHost, ok := s.ring.addHostIfMissing(host); ok {
 				existingHost.update(host)
 			}
 
-			s.handleNodeUp(net.ParseIP(host.Peer()), host.Port(), false)
+			s.handleNodeUp(host.Peer(), host.Port(), false)
 		}
 	}
 
@@ -203,6 +195,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 	// connection is disable, we really have no choice, so we just make our
 	// best guess...
 	if !cfg.disableControlConn && cfg.DisableInitialHostLookup {
+		// TODO(zariel): we dont need to do this twice
 		newer, _ := checkSystemSchema(s.control)
 		s.useSystemSchema = newer
 	} else {
@@ -225,7 +218,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
 			if gocqlDebug {
 				buf := bytes.NewBufferString("Session.ring:")
 				for _, h := range hosts {
-					buf.WriteString("[" + h.Peer() + ":" + h.State().String() + "]")
+					buf.WriteString("[" + h.Peer().String() + ":" + h.State().String() + "]")
 				}
 				log.Println(buf.String())
 			}
@@ -234,7 +227,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
 				if h.IsUp() {
 					continue
 				}
-				s.handleNodeUp(net.ParseIP(h.Peer()), h.Port(), true)
+				s.handleNodeUp(h.Peer(), h.Port(), true)
 			}
 		case <-s.quit:
 			return
@@ -409,7 +402,7 @@ func (s *Session) getConn() *Conn {
 			continue
 		}
 
-		pool, ok := s.pool.getPool(host.Peer())
+		pool, ok := s.pool.getPool(host)
 		if !ok {
 			continue
 		}
@@ -628,8 +621,8 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 	return applied, iter, iter.err
 }
 
-func (s *Session) connect(addr string, errorHandler ConnErrorHandler, host *HostInfo) (*Conn, error) {
-	return Connect(host, addr, s.connCfg, errorHandler, s)
+func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
+	return Connect(host, s.connCfg, errorHandler, s)
 }
 
 // Query represents a CQL statement that can be executed.

+ 1 - 1
token.go

@@ -184,7 +184,7 @@ func (t *tokenRing) String() string {
 		buf.WriteString("]")
 		buf.WriteString(t.tokens[i].String())
 		buf.WriteString(":")
-		buf.WriteString(t.hosts[i].Peer())
+		buf.WriteString(t.hosts[i].Peer().String())
 	}
 	buf.WriteString("\n}")
 	return string(buf.Bytes())

+ 41 - 133
token_test.go

@@ -6,7 +6,9 @@ package gocql
 
 import (
 	"bytes"
+	"fmt"
 	"math/big"
+	"net"
 	"sort"
 	"strconv"
 	"testing"
@@ -226,27 +228,23 @@ func TestUnknownTokenRing(t *testing.T) {
 	}
 }
 
+func hostsForTests(n int) []*HostInfo {
+	hosts := make([]*HostInfo, n)
+	for i := 0; i < n; i++ {
+		host := &HostInfo{
+			peer:   net.IPv4(1, 1, 1, byte(n)),
+			tokens: []string{fmt.Sprintf("%d", n)},
+		}
+
+		hosts[i] = host
+	}
+	return hosts
+}
+
 // Test of the tokenRing with the Murmur3Partitioner
 func TestMurmur3TokenRing(t *testing.T) {
 	// Note, strings are parsed directly to int64, they are not murmur3 hashed
-	hosts := []*HostInfo{
-		{
-			peer:   "0",
-			tokens: []string{"0"},
-		},
-		{
-			peer:   "1",
-			tokens: []string{"25"},
-		},
-		{
-			peer:   "2",
-			tokens: []string{"50"},
-		},
-		{
-			peer:   "3",
-			tokens: []string{"75"},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("Murmur3Partitioner", hosts)
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -254,34 +252,20 @@ func TestMurmur3TokenRing(t *testing.T) {
 
 	p := murmur3Partitioner{}
 
-	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual := ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	actual := ring.GetHostForToken(p.ParseString("12"))
+	if !actual.Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "0" {
+	if !actual.Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }
@@ -290,32 +274,7 @@ func TestMurmur3TokenRing(t *testing.T) {
 func TestOrderedTokenRing(t *testing.T) {
 	// Tokens here more or less are similar layout to the int tokens above due
 	// to each numeric character translating to a consistently offset byte.
-	hosts := []*HostInfo{
-		{
-			peer: "0",
-			tokens: []string{
-				"00",
-			},
-		},
-		{
-			peer: "1",
-			tokens: []string{
-				"25",
-			},
-		},
-		{
-			peer: "2",
-			tokens: []string{
-				"50",
-			},
-		},
-		{
-			peer: "3",
-			tokens: []string{
-				"75",
-			},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("OrderedPartitioner", hosts)
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -324,33 +283,20 @@ func TestOrderedTokenRing(t *testing.T) {
 	p := orderedPartitioner{}
 
 	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual = ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }
@@ -358,32 +304,7 @@ func TestOrderedTokenRing(t *testing.T) {
 // Test of the tokenRing with the RandomPartitioner
 func TestRandomTokenRing(t *testing.T) {
 	// String tokens are parsed into big.Int in base 10
-	hosts := []*HostInfo{
-		{
-			peer: "0",
-			tokens: []string{
-				"00",
-			},
-		},
-		{
-			peer: "1",
-			tokens: []string{
-				"25",
-			},
-		},
-		{
-			peer: "2",
-			tokens: []string{
-				"50",
-			},
-		},
-		{
-			peer: "3",
-			tokens: []string{
-				"75",
-			},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("RandomPartitioner", hosts)
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -392,33 +313,20 @@ func TestRandomTokenRing(t *testing.T) {
 	p := randomPartitioner{}
 
 	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual = ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
+	if !actual.peer.Equal(hosts[0].peer) {
+		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }