Browse Source

Merge pull request #809 from Zariel/host-peer-ip

change HostInfo.Peer to be an IP
Chris Bannister 9 years ago
parent
commit
9395dd748f
20 changed files with 368 additions and 333 deletions
  1. 21 8
      cassandra_test.go
  2. 15 1
      cluster.go
  3. 12 2
      conn.go
  4. 10 4
      conn_test.go
  5. 19 15
      connectionpool.go
  6. 50 22
      control.go
  7. 31 0
      control_test.go
  8. 21 27
      events.go
  9. 1 1
      events_ccm_test.go
  10. 12 4
      filters.go
  11. 17 14
      filters_test.go
  12. 9 14
      host_source.go
  13. 33 26
      policies.go
  14. 24 23
      policies_test.go
  15. 1 1
      query_executor.go
  16. 28 11
      ring.go
  17. 7 4
      ring_test.go
  18. 15 22
      session.go
  19. 1 1
      token.go
  20. 41 133
      token_test.go

+ 21 - 8
cassandra_test.go

@@ -62,11 +62,15 @@ func TestRingDiscovery(t *testing.T) {
 	}
 
 	session.pool.mu.RLock()
+	defer session.pool.mu.RUnlock()
 	size := len(session.pool.hostConnPools)
-	session.pool.mu.RUnlock()
 
 	if *clusterSize != size {
-		t.Fatalf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
+		for p, pool := range session.pool.hostConnPools {
+			t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.Peer().String())
+
+		}
+		t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
 	}
 }
 
@@ -573,7 +577,7 @@ func TestReconnection(t *testing.T) {
 	defer session.Close()
 
 	h := session.ring.allHosts()[0]
-	session.handleNodeDown(net.ParseIP(h.Peer()), h.Port())
+	session.handleNodeDown(h.Peer(), h.Port())
 
 	if h.State() != NodeDown {
 		t.Fatal("Host should be NodeDown but not.")
@@ -2477,17 +2481,26 @@ func TestSchemaReset(t *testing.T) {
 }
 
 func TestCreateSession_DontSwallowError(t *testing.T) {
+	t.Skip("This test is bad, and the resultant error from cassandra changes between versions")
 	cluster := createCluster()
-	cluster.ProtoVersion = 100
+	cluster.ProtoVersion = 0x100
 	session, err := cluster.CreateSession()
 	if err == nil {
 		session.Close()
 
 		t.Fatal("expected to get an error for unsupported protocol")
 	}
-	// TODO: we should get a distinct error type here which include the underlying
-	// cassandra error about the protocol version, for now check this here.
-	if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
-		t.Fatalf(`expcted to get error "unsupported protocol version" got: %q`, err)
+
+	if flagCassVersion.Major < 3 {
+		// TODO: we should get a distinct error type here which include the underlying
+		// cassandra error about the protocol version, for now check this here.
+		if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
+			t.Fatalf(`expcted to get error "unsupported protocol version" got: %q`, err)
+		}
+	} else {
+		if !strings.Contains(err.Error(), "unsupported response version") {
+			t.Fatalf(`expcted to get error "unsupported response version" got: %q`, err)
+		}
 	}
+
 }

+ 15 - 1
cluster.go

@@ -27,7 +27,13 @@ func (p PoolConfig) buildPool(session *Session) *policyConnPool {
 // behavior to fit the most common use cases. Applications that require a
 // different setup must implement their own cluster.
 type ClusterConfig struct {
-	Hosts             []string          // addresses for the initial connections
+	// addresses for the initial connections. It is recomended to use the value set in
+	// the Cassandra config for broadcast_address or listen_address, an IP address not
+	// a domain name. This is because events from Cassandra will use the configured IP
+	// address, which is used to index connected hosts. If the domain name specified
+	// resolves to more than 1 IP address then the driver may connect multiple times to
+	// the same host, and will not mark the node being down or up from events.
+	Hosts             []string
 	CQLVersion        string            // CQL version (default: 3.0.0)
 	ProtoVersion      int               // version of the native protocol (default: 2)
 	Timeout           time.Duration     // connection timeout (default: 600ms)
@@ -100,6 +106,14 @@ type ClusterConfig struct {
 }
 
 // NewCluster generates a new config for the default cluster implementation.
+//
+// The supplied hosts are used to initially connect to the cluster then the rest of
+// the ring will be automatically discovered. It is recomended to use the value set in
+// the Cassandra config for broadcast_address or listen_address, an IP address not
+// a domain name. This is because events from Cassandra will use the configured IP
+// address, which is used to index connected hosts. If the domain name specified
+// resolves to more than 1 IP address then the driver may connect multiple times to
+// the same host, and will not mark the node being down or up from events.
 func NewCluster(hosts ...string) *ClusterConfig {
 	cfg := &ClusterConfig{
 		Hosts:                  hosts,

+ 12 - 2
conn.go

@@ -152,8 +152,15 @@ type Conn struct {
 }
 
 // Connect establishes a connection to a Cassandra node.
-func Connect(host *HostInfo, addr string, cfg *ConnConfig,
-	errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
+	// TODO(zariel): remove these
+	if host == nil {
+		panic("host is nil")
+	} else if len(host.Peer()) == 0 {
+		panic("host missing peer ip address")
+	} else if host.Port() == 0 {
+		panic("host missing port")
+	}
 
 	var (
 		err  error
@@ -164,6 +171,9 @@ func Connect(host *HostInfo, addr string, cfg *ConnConfig,
 		Timeout: cfg.Timeout,
 	}
 
+	// TODO(zariel): handle ipv6 zone
+	addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
+
 	if cfg.tlsConfig != nil {
 		// the TLS config is safe to be reused by connections but it must not
 		// be modified after being used.

+ 10 - 4
conn_test.go

@@ -473,8 +473,7 @@ func TestStream0(t *testing.T) {
 		}
 	})
 
-	host := &HostInfo{peer: srv.Address}
-	conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -509,8 +508,7 @@ func TestConnClosedBlocked(t *testing.T) {
 		t.Log(err)
 	})
 
-	host := &HostInfo{peer: srv.Address}
-	conn, err := Connect(host, srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -637,6 +635,14 @@ type TestServer struct {
 	closed bool
 }
 
+func (srv *TestServer) host() *HostInfo {
+	host, err := hostInfo(srv.Address, 9042)
+	if err != nil {
+		srv.t.Fatal(err)
+	}
+	return host
+}
+
 func (srv *TestServer) closeWatch() {
 	<-srv.ctx.Done()
 

+ 19 - 15
connectionpool.go

@@ -130,9 +130,10 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 			// don't create a connection pool for a down host
 			continue
 		}
-		if _, exists := p.hostConnPools[host.Peer()]; exists {
+		ip := host.Peer().String()
+		if _, exists := p.hostConnPools[ip]; exists {
 			// still have this host, so don't remove it
-			delete(toRemove, host.Peer())
+			delete(toRemove, ip)
 			continue
 		}
 
@@ -155,7 +156,7 @@ func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 		createCount--
 		if pool.Size() > 0 {
 			// add pool onyl if there a connections available
-			p.hostConnPools[pool.host.Peer()] = pool
+			p.hostConnPools[string(pool.host.Peer())] = pool
 		}
 	}
 
@@ -177,9 +178,10 @@ func (p *policyConnPool) Size() int {
 	return count
 }
 
-func (p *policyConnPool) getPool(addr string) (pool *hostConnPool, ok bool) {
+func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
+	ip := host.Peer().String()
 	p.mu.RLock()
-	pool, ok = p.hostConnPools[addr]
+	pool, ok = p.hostConnPools[ip]
 	p.mu.RUnlock()
 	return
 }
@@ -196,8 +198,9 @@ func (p *policyConnPool) Close() {
 }
 
 func (p *policyConnPool) addHost(host *HostInfo) {
+	ip := host.Peer().String()
 	p.mu.Lock()
-	pool, ok := p.hostConnPools[host.Peer()]
+	pool, ok := p.hostConnPools[ip]
 	if !ok {
 		pool = newHostConnPool(
 			p.session,
@@ -207,22 +210,23 @@ func (p *policyConnPool) addHost(host *HostInfo) {
 			p.keyspace,
 		)
 
-		p.hostConnPools[host.Peer()] = pool
+		p.hostConnPools[ip] = pool
 	}
 	p.mu.Unlock()
 
 	pool.fill()
 }
 
-func (p *policyConnPool) removeHost(addr string) {
+func (p *policyConnPool) removeHost(ip net.IP) {
+	k := ip.String()
 	p.mu.Lock()
-	pool, ok := p.hostConnPools[addr]
+	pool, ok := p.hostConnPools[k]
 	if !ok {
 		p.mu.Unlock()
 		return
 	}
 
-	delete(p.hostConnPools, addr)
+	delete(p.hostConnPools, k)
 	p.mu.Unlock()
 
 	go pool.Close()
@@ -234,10 +238,10 @@ func (p *policyConnPool) hostUp(host *HostInfo) {
 	p.addHost(host)
 }
 
-func (p *policyConnPool) hostDown(addr string) {
+func (p *policyConnPool) hostDown(ip net.IP) {
 	// TODO(zariel): mark host as down so we can try to connect to it later, for
 	// now just treat it has removed.
-	p.removeHost(addr)
+	p.removeHost(ip)
 }
 
 // hostConnPool is a connection pool for a single host.
@@ -272,7 +276,7 @@ func newHostConnPool(session *Session, host *HostInfo, port, size int,
 		session:  session,
 		host:     host,
 		port:     port,
-		addr:     JoinHostPort(host.Peer(), port),
+		addr:     (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String(),
 		size:     size,
 		keyspace: keyspace,
 		conns:    make([]*Conn, 0, size),
@@ -396,7 +400,7 @@ func (pool *hostConnPool) fill() {
 
 			// this is calle with the connetion pool mutex held, this call will
 			// then recursivly try to lock it again. FIXME
-			go pool.session.handleNodeDown(net.ParseIP(pool.host.Peer()), pool.port)
+			go pool.session.handleNodeDown(pool.host.Peer(), pool.port)
 			return
 		}
 
@@ -477,7 +481,7 @@ func (pool *hostConnPool) connect() (err error) {
 	// try to connect
 	var conn *Conn
 	for i := 0; i < maxAttempts; i++ {
-		conn, err = pool.session.connect(pool.addr, pool, pool.host)
+		conn, err = pool.session.connect(pool.host, pool)
 		if err == nil {
 			break
 		}

+ 50 - 22
control.go

@@ -4,13 +4,14 @@ import (
 	crand "crypto/rand"
 	"errors"
 	"fmt"
-	"golang.org/x/net/context"
 	"log"
 	"math/rand"
 	"net"
 	"strconv"
 	"sync/atomic"
 	"time"
+
+	"golang.org/x/net/context"
 )
 
 var (
@@ -89,6 +90,8 @@ func (c *controlConn) heartBeat() {
 	}
 }
 
+var hostLookupPreferV4 = false
+
 func hostInfo(addr string, defaultPort int) (*HostInfo, error) {
 	var port int
 	host, portStr, err := net.SplitHostPort(addr)
@@ -102,10 +105,37 @@ func hostInfo(addr string, defaultPort int) (*HostInfo, error) {
 		}
 	}
 
-	return &HostInfo{peer: host, port: port}, nil
+	ip := net.ParseIP(host)
+	if ip == nil {
+		ips, err := net.LookupIP(host)
+		if err != nil {
+			return nil, err
+		} else if len(ips) == 0 {
+			return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
+		}
+
+		if hostLookupPreferV4 {
+			for _, v := range ips {
+				if v4 := v.To4(); v4 != nil {
+					ip = v4
+					break
+				}
+			}
+			if ip == nil {
+				ip = ips[0]
+			}
+		} else {
+			// TODO(zariel): should we check that we can connect to any of the ips?
+			ip = ips[0]
+		}
+
+	}
+
+	return &HostInfo{peer: ip, port: port}, nil
 }
 
 func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
+	// TODO: accept a []*HostInfo
 	perm := randr.Perm(len(endpoints))
 	shuffled := make([]string, len(endpoints))
 
@@ -130,7 +160,7 @@ func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
 		}
 
 		hostInfo, _ := c.session.ring.addHostIfMissing(host)
-		conn, err = c.session.connect(addr, c, hostInfo)
+		conn, err = c.session.connect(hostInfo, c)
 		if err == nil {
 			return conn, err
 		}
@@ -229,22 +259,21 @@ func (c *controlConn) reconnect(refreshring bool) {
 	// TODO: simplify this function, use session.ring to get hosts instead of the
 	// connection pool
 
-	addr := c.addr()
+	var host *HostInfo
 	oldConn := c.conn.Load().(*Conn)
 	if oldConn != nil {
+		host = oldConn.host
 		oldConn.Close()
 	}
 
 	var newConn *Conn
-	if addr != "" {
+	if host != nil {
 		// try to connect to the old host
-		conn, err := c.session.connect(addr, c, oldConn.host)
+		conn, err := c.session.connect(host, c)
 		if err != nil {
 			// host is dead
 			// TODO: this is replicated in a few places
-			ip, portStr, _ := net.SplitHostPort(addr)
-			port, _ := strconv.Atoi(portStr)
-			c.session.handleNodeDown(net.ParseIP(ip), port)
+			c.session.handleNodeDown(host.Peer(), host.Port())
 		} else {
 			newConn = conn
 		}
@@ -260,7 +289,7 @@ func (c *controlConn) reconnect(refreshring bool) {
 		}
 
 		var err error
-		newConn, err = c.session.connect(host.Peer(), c, host)
+		newConn, err = c.session.connect(host, c)
 		if err != nil {
 			// TODO: add log handler for things like this
 			return
@@ -350,29 +379,28 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 	return
 }
 
-func (c *controlConn) fetchHostInfo(addr net.IP, port int) (*HostInfo, error) {
+func (c *controlConn) fetchHostInfo(ip net.IP, port int) (*HostInfo, error) {
 	// TODO(zariel): we should probably move this into host_source or atleast
 	// share code with it.
-	hostname, _, err := net.SplitHostPort(c.addr())
-	if err != nil {
-		return nil, fmt.Errorf("unable to fetch host info, invalid conn addr: %q: %v", c.addr(), err)
+	localHost := c.host()
+	if localHost == nil {
+		return nil, errors.New("unable to fetch host info, invalid conn host")
 	}
 
-	isLocal := hostname == addr.String()
+	isLocal := localHost.Peer().Equal(ip)
 
 	var fn func(*HostInfo) error
 
+	// TODO(zariel): fetch preferred_ip address (is it >3.x only?)
 	if isLocal {
 		fn = func(host *HostInfo) error {
-			// TODO(zariel): should we fetch rpc_address from here?
 			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.local WHERE key='local'")
 			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
 			return iter.Close()
 		}
 	} else {
 		fn = func(host *HostInfo) error {
-			// TODO(zariel): should we fetch rpc_address from here?
-			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.peers WHERE peer=?", addr)
+			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.peers WHERE peer=?", ip)
 			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
 			return iter.Close()
 		}
@@ -380,12 +408,12 @@ func (c *controlConn) fetchHostInfo(addr net.IP, port int) (*HostInfo, error) {
 
 	host := &HostInfo{
 		port: port,
+		peer: ip,
 	}
 
 	if err := fn(host); err != nil {
 		return nil, err
 	}
-	host.peer = addr.String()
 
 	return host, nil
 }
@@ -396,12 +424,12 @@ func (c *controlConn) awaitSchemaAgreement() error {
 	}).err
 }
 
-func (c *controlConn) addr() string {
+func (c *controlConn) host() *HostInfo {
 	conn := c.conn.Load().(*Conn)
 	if conn == nil {
-		return ""
+		return nil
 	}
-	return conn.addr
+	return conn.host
 }
 
 func (c *controlConn) close() {

+ 31 - 0
control_test.go

@@ -0,0 +1,31 @@
+package gocql
+
+import (
+	"net"
+	"testing"
+)
+
+func TestHostInfo_Lookup(t *testing.T) {
+	hostLookupPreferV4 = true
+	defer func() { hostLookupPreferV4 = false }()
+
+	tests := [...]struct {
+		addr string
+		ip   net.IP
+	}{
+		{"127.0.0.1", net.IPv4(127, 0, 0, 1)},
+		{"localhost", net.IPv4(127, 0, 0, 1)}, // TODO: this may be host dependant
+	}
+
+	for i, test := range tests {
+		host, err := hostInfo(test.addr, 1)
+		if err != nil {
+			t.Errorf("%d: %v", i, err)
+			continue
+		}
+
+		if !host.peer.Equal(test.ip) {
+			t.Errorf("expected ip %v got %v for addr %q", test.ip, host.peer, test.addr)
+		}
+	}
+}

+ 21 - 27
events.go

@@ -171,25 +171,21 @@ func (s *Session) handleNodeEvent(frames []frame) {
 	}
 }
 
-func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
-	// TODO(zariel): need to be able to filter discovered nodes
-
+func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
 	var hostInfo *HostInfo
 	if s.control != nil && !s.cfg.IgnorePeerAddr {
 		var err error
-		hostInfo, err = s.control.fetchHostInfo(host, port)
+		hostInfo, err = s.control.fetchHostInfo(ip, port)
 		if err != nil {
-			log.Printf("gocql: events: unable to fetch host info for %v: %v\n", host, err)
+			log.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
 			return
 		}
-
 	} else {
-		hostInfo = &HostInfo{peer: host.String(), port: port, state: NodeUp}
+		hostInfo = &HostInfo{peer: ip, port: port}
 	}
 
-	addr := host.String()
-	if s.cfg.IgnorePeerAddr && hostInfo.Peer() != addr {
-		hostInfo.setPeer(addr)
+	if s.cfg.IgnorePeerAddr && hostInfo.Peer().Equal(ip) {
+		hostInfo.setPeer(ip)
 	}
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(hostInfo) {
@@ -217,11 +213,9 @@ func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
 
 func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	// we remove all nodes but only add ones which pass the filter
-	addr := ip.String()
-
-	host := s.ring.getHost(addr)
+	host := s.ring.getHost(ip)
 	if host == nil {
-		host = &HostInfo{peer: addr}
+		host = &HostInfo{peer: ip, port: port}
 	}
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -229,9 +223,9 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
 	}
 
 	host.setState(NodeDown)
-	s.policy.RemoveHost(addr)
-	s.pool.removeHost(addr)
-	s.ring.removeHost(addr)
+	s.policy.RemoveHost(host)
+	s.pool.removeHost(ip)
+	s.ring.removeHost(ip)
 
 	if !s.cfg.IgnorePeerAddr {
 		s.hostSource.refreshRing()
@@ -242,11 +236,12 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 	if gocqlDebug {
 		log.Printf("gocql: Session.handleNodeUp: %s:%d\n", ip.String(), port)
 	}
-	addr := ip.String()
-	host := s.ring.getHost(addr)
+
+	host := s.ring.getHost(ip)
 	if host != nil {
-		if s.cfg.IgnorePeerAddr && host.Peer() != addr {
-			host.setPeer(addr)
+		if s.cfg.IgnorePeerAddr && host.Peer().Equal(ip) {
+			// TODO: how can this ever be true?
+			host.setPeer(ip)
 		}
 
 		if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -257,7 +252,6 @@ func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
 			time.Sleep(t)
 		}
 
-		host.setPort(port)
 		s.pool.hostUp(host)
 		s.policy.HostUp(host)
 		host.setState(NodeUp)
@@ -271,10 +265,10 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	if gocqlDebug {
 		log.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
 	}
-	addr := ip.String()
-	host := s.ring.getHost(addr)
+
+	host := s.ring.getHost(ip)
 	if host == nil {
-		host = &HostInfo{peer: addr}
+		host = &HostInfo{peer: ip, port: port}
 	}
 
 	if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
@@ -282,6 +276,6 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
 	}
 
 	host.setState(NodeDown)
-	s.policy.HostDown(addr)
-	s.pool.hostDown(addr)
+	s.policy.HostDown(host)
+	s.pool.hostDown(ip)
 }

+ 1 - 1
events_ccm_test.go

@@ -1,4 +1,4 @@
-// +build ccm
+// +build ccm, ignore
 
 package gocql
 

+ 12 - 4
filters.go

@@ -1,5 +1,7 @@
 package gocql
 
+import "fmt"
+
 // HostFilter interface is used when a host is discovered via server sent events.
 type HostFilter interface {
 	// Called when a new host is discovered, returning true will cause the host
@@ -38,12 +40,18 @@ func DataCentreHostFilter(dataCentre string) HostFilter {
 // WhiteListHostFilter filters incoming hosts by checking that their address is
 // in the initial hosts whitelist.
 func WhiteListHostFilter(hosts ...string) HostFilter {
-	m := make(map[string]bool, len(hosts))
-	for _, host := range hosts {
-		m[host] = true
+	hostInfos, err := addrsToHosts(hosts, 9042)
+	if err != nil {
+		// dont want to panic here, but rather not break the API
+		panic(fmt.Errorf("unable to lookup host info from address: %v", err))
+	}
+
+	m := make(map[string]bool, len(hostInfos))
+	for _, host := range hostInfos {
+		m[string(host.peer)] = true
 	}
 
 	return HostFilterFunc(func(host *HostInfo) bool {
-		return m[host.Peer()]
+		return m[string(host.Peer())]
 	})
 }

+ 17 - 14
filters_test.go

@@ -1,16 +1,19 @@
 package gocql
 
-import "testing"
+import (
+	"net"
+	"testing"
+)
 
 func TestFilter_WhiteList(t *testing.T) {
-	f := WhiteListHostFilter("addr1", "addr2")
+	f := WhiteListHostFilter("127.0.0.1", "127.0.0.2")
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 	}{
-		{"addr1", true},
-		{"addr2", true},
-		{"addr3", false},
+		{net.ParseIP("127.0.0.1"), true},
+		{net.ParseIP("127.0.0.2"), true},
+		{net.ParseIP("127.0.0.3"), false},
 	}
 
 	for i, test := range tests {
@@ -27,12 +30,12 @@ func TestFilter_WhiteList(t *testing.T) {
 func TestFilter_AllowAll(t *testing.T) {
 	f := AcceptAllFilter()
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 	}{
-		{"addr1", true},
-		{"addr2", true},
-		{"addr3", true},
+		{net.ParseIP("127.0.0.1"), true},
+		{net.ParseIP("127.0.0.2"), true},
+		{net.ParseIP("127.0.0.3"), true},
 	}
 
 	for i, test := range tests {
@@ -49,12 +52,12 @@ func TestFilter_AllowAll(t *testing.T) {
 func TestFilter_DenyAll(t *testing.T) {
 	f := DenyAllFilter()
 	tests := [...]struct {
-		addr   string
+		addr   net.IP
 		accept bool
 	}{
-		{"addr1", false},
-		{"addr2", false},
-		{"addr3", false},
+		{net.ParseIP("127.0.0.1"), false},
+		{net.ParseIP("127.0.0.2"), false},
+		{net.ParseIP("127.0.0.3"), false},
 	}
 
 	for i, test := range tests {

+ 9 - 14
host_source.go

@@ -100,7 +100,7 @@ type HostInfo struct {
 	// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
 	// that we are thread safe use a mutex to access all fields.
 	mu         sync.RWMutex
-	peer       string
+	peer       net.IP
 	port       int
 	dataCenter string
 	rack       string
@@ -116,16 +116,16 @@ func (h *HostInfo) Equal(host *HostInfo) bool {
 	host.mu.RLock()
 	defer host.mu.RUnlock()
 
-	return h.peer == host.peer && h.hostId == host.hostId
+	return h.peer.Equal(host.peer)
 }
 
-func (h *HostInfo) Peer() string {
+func (h *HostInfo) Peer() net.IP {
 	h.mu.RLock()
 	defer h.mu.RUnlock()
 	return h.peer
 }
 
-func (h *HostInfo) setPeer(peer string) *HostInfo {
+func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
 	h.mu.Lock()
 	defer h.mu.Unlock()
 	h.peer = peer
@@ -314,7 +314,11 @@ func (r *ringDescriber) GetHosts() (hosts []*HostInfo, partitioner string, err e
 			return nil, "", err
 		}
 	} else {
-		iter := r.session.control.query(legacyLocalQuery)
+		iter := r.session.control.withConn(func(c *Conn) *Iter {
+			localHost = c.host
+			return c.query(legacyLocalQuery)
+		})
+
 		if iter == nil {
 			return r.prevHosts, r.prevPartitioner, nil
 		}
@@ -324,15 +328,6 @@ func (r *ringDescriber) GetHosts() (hosts []*HostInfo, partitioner string, err e
 		if err = iter.Close(); err != nil {
 			return nil, "", err
 		}
-
-		addr, _, err := net.SplitHostPort(r.session.control.addr())
-		if err != nil {
-			// this should not happen, ever, as this is the address that was dialed by conn, here
-			// a panic makes sense, please report a bug if it occurs.
-			panic(err)
-		}
-
-		localHost.peer = addr
 	}
 
 	localHost.port = r.session.cfg.Port

+ 33 - 26
policies.go

@@ -7,6 +7,7 @@ package gocql
 import (
 	"fmt"
 	"log"
+	"net"
 	"sync"
 	"sync/atomic"
 
@@ -90,7 +91,7 @@ func (c *cowHostList) update(host *HostInfo) {
 	c.mu.Unlock()
 }
 
-func (c *cowHostList) remove(addr string) bool {
+func (c *cowHostList) remove(ip net.IP) bool {
 	c.mu.Lock()
 	l := c.get()
 	size := len(l)
@@ -102,7 +103,7 @@ func (c *cowHostList) remove(addr string) bool {
 	found := false
 	newL := make([]*HostInfo, 0, size)
 	for i := 0; i < len(l); i++ {
-		if l[i].Peer() != addr {
+		if !l[i].Peer().Equal(ip) {
 			newL = append(newL, l[i])
 		} else {
 			found = true
@@ -161,9 +162,9 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 
 type HostStateNotifier interface {
 	AddHost(host *HostInfo)
-	RemoveHost(addr string)
+	RemoveHost(host *HostInfo)
 	HostUp(host *HostInfo)
-	HostDown(addr string)
+	HostDown(host *HostInfo)
 }
 
 // HostSelectionPolicy is an interface for selecting
@@ -235,16 +236,16 @@ func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
 	r.hosts.add(host)
 }
 
-func (r *roundRobinHostPolicy) RemoveHost(addr string) {
-	r.hosts.remove(addr)
+func (r *roundRobinHostPolicy) RemoveHost(host *HostInfo) {
+	r.hosts.remove(host.Peer())
 }
 
 func (r *roundRobinHostPolicy) HostUp(host *HostInfo) {
 	r.AddHost(host)
 }
 
-func (r *roundRobinHostPolicy) HostDown(addr string) {
-	r.RemoveHost(addr)
+func (r *roundRobinHostPolicy) HostDown(host *HostInfo) {
+	r.RemoveHost(host)
 }
 
 // TokenAwareHostPolicy is a token aware host selection policy, where hosts are
@@ -278,9 +279,9 @@ func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
 	t.resetTokenRing()
 }
 
-func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
-	t.hosts.remove(addr)
-	t.fallback.RemoveHost(addr)
+func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
+	t.hosts.remove(host.Peer())
+	t.fallback.RemoveHost(host)
 
 	t.resetTokenRing()
 }
@@ -289,8 +290,8 @@ func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
 	t.AddHost(host)
 }
 
-func (t *tokenAwareHostPolicy) HostDown(addr string) {
-	t.RemoveHost(addr)
+func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
+	t.RemoveHost(host)
 }
 
 func (t *tokenAwareHostPolicy) resetTokenRing() {
@@ -393,8 +394,9 @@ func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
 	hostMap := make(map[string]*HostInfo, len(hosts))
 
 	for i, host := range hosts {
-		peers[i] = host.Peer()
-		hostMap[host.Peer()] = host
+		ip := host.Peer().String()
+		peers[i] = ip
+		hostMap[ip] = host
 	}
 
 	r.mu.Lock()
@@ -404,15 +406,17 @@ func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
 }
 
 func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	defer r.mu.Unlock()
 
 	// If the host addr is present and isn't nil return
-	if h, ok := r.hostMap[host.Peer()]; ok && h != nil{
+	if h, ok := r.hostMap[ip]; ok && h != nil {
 		return
 	}
 	// otherwise, add the host to the map
-	r.hostMap[host.Peer()] = host
+	r.hostMap[ip] = host
 	// and construct a new peer list to give to the HostPool
 	hosts := make([]string, 0, len(r.hostMap))
 	for addr := range r.hostMap {
@@ -420,21 +424,22 @@ func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
 	}
 
 	r.hp.SetHosts(hosts)
-
 }
 
-func (r *hostPoolHostPolicy) RemoveHost(addr string) {
+func (r *hostPoolHostPolicy) RemoveHost(host *HostInfo) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	defer r.mu.Unlock()
 
-	if _, ok := r.hostMap[addr]; !ok {
+	if _, ok := r.hostMap[ip]; !ok {
 		return
 	}
 
-	delete(r.hostMap, addr)
+	delete(r.hostMap, ip)
 	hosts := make([]string, 0, len(r.hostMap))
-	for addr := range r.hostMap {
-		hosts = append(hosts, addr)
+	for _, host := range r.hostMap {
+		hosts = append(hosts, host.Peer().String())
 	}
 
 	r.hp.SetHosts(hosts)
@@ -444,8 +449,8 @@ func (r *hostPoolHostPolicy) HostUp(host *HostInfo) {
 	r.AddHost(host)
 }
 
-func (r *hostPoolHostPolicy) HostDown(addr string) {
-	r.RemoveHost(addr)
+func (r *hostPoolHostPolicy) HostDown(host *HostInfo) {
+	r.RemoveHost(host)
 }
 
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
@@ -488,10 +493,12 @@ func (host selectedHostPoolHost) Info() *HostInfo {
 }
 
 func (host selectedHostPoolHost) Mark(err error) {
+	ip := host.info.Peer().String()
+
 	host.policy.mu.RLock()
 	defer host.policy.mu.RUnlock()
 
-	if _, ok := host.policy.hostMap[host.info.Peer()]; !ok {
+	if _, ok := host.policy.hostMap[ip]; !ok {
 		// host was removed between pick and mark
 		return
 	}

+ 24 - 23
policies_test.go

@@ -6,6 +6,7 @@ package gocql
 
 import (
 	"fmt"
+	"net"
 	"testing"
 
 	"github.com/hailocab/go-hostpool"
@@ -16,8 +17,8 @@ func TestRoundRobinHostPolicy(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
 	hosts := [...]*HostInfo{
-		{hostId: "0"},
-		{hostId: "1"},
+		{hostId: "0", peer: net.IPv4(0, 0, 0, 1)},
+		{hostId: "1", peer: net.IPv4(0, 0, 0, 2)},
 	}
 
 	for _, host := range hosts {
@@ -67,10 +68,10 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 	// set the hosts
 	hosts := [...]*HostInfo{
-		{peer: "0", tokens: []string{"00"}},
-		{peer: "1", tokens: []string{"25"}},
-		{peer: "2", tokens: []string{"50"}},
-		{peer: "3", tokens: []string{"75"}},
+		{peer: net.IPv4(10, 0, 0, 1), tokens: []string{"00"}},
+		{peer: net.IPv4(10, 0, 0, 2), tokens: []string{"25"}},
+		{peer: net.IPv4(10, 0, 0, 3), tokens: []string{"50"}},
+		{peer: net.IPv4(10, 0, 0, 4), tokens: []string{"75"}},
 	}
 	for _, host := range hosts {
 		policy.AddHost(host)
@@ -78,12 +79,12 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
-	if actual := policy.Pick(nil)(); actual.Info().Peer() != "0" {
+	if actual := policy.Pick(nil)(); !actual.Info().Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 	}
 
 	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().Peer() != "1" {
+	if actual := policy.Pick(query)(); !actual.Info().Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 	}
 
@@ -92,17 +93,17 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 	// now the token ring is configured
 	query.RoutingKey([]byte("20"))
 	iter = policy.Pick(query)
-	if actual := iter(); actual.Info().Peer() != "1" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 	}
 	// rest are round robin
-	if actual := iter(); actual.Info().Peer() != "2" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[2].peer) {
 		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer())
 	}
-	if actual := iter(); actual.Info().Peer() != "3" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[3].peer) {
 		t.Errorf("Expected peer 3 but was %s", actual.Info().Peer())
 	}
-	if actual := iter(); actual.Info().Peer() != "0" {
+	if actual := iter(); !actual.Info().Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 	}
 }
@@ -112,8 +113,8 @@ func TestHostPoolHostPolicy(t *testing.T) {
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 
 	hosts := []*HostInfo{
-		{hostId: "0", peer: "0"},
-		{hostId: "1", peer: "1"},
+		{hostId: "0", peer: net.IPv4(10, 0, 0, 0)},
+		{hostId: "1", peer: net.IPv4(10, 0, 0, 1)},
 	}
 
 	// Using set host to control the ordering of the hosts as calling "AddHost" iterates the map
@@ -177,10 +178,10 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 	hosts := [...]*HostInfo{
-		{peer: "0", tokens: []string{"00"}},
-		{peer: "1", tokens: []string{"25"}},
-		{peer: "2", tokens: []string{"50"}},
-		{peer: "3", tokens: []string{"75"}},
+		{peer: net.IPv4(10, 0, 0, 0), tokens: []string{"00"}},
+		{peer: net.IPv4(10, 0, 0, 1), tokens: []string{"25"}},
+		{peer: net.IPv4(10, 0, 0, 2), tokens: []string{"50"}},
+		{peer: net.IPv4(10, 0, 0, 3), tokens: []string{"75"}},
 	}
 	for _, host := range hosts {
 		policy.AddHost(host)
@@ -196,13 +197,13 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 		t.Fatal("got nil host")
 	} else if v := next.Info(); v == nil {
 		t.Fatal("got nil HostInfo")
-	} else if v.Peer() != "1" {
+	} else if !v.Peer().Equal(hosts[1].peer) {
 		t.Fatalf("expected peer 1 got %v", v.Peer())
 	}
 
 	// Empty the hosts to trigger the panic when using the fallback.
 	for _, host := range hosts {
-		policy.RemoveHost(host.Peer())
+		policy.RemoveHost(host)
 	}
 
 	next = iter()
@@ -217,7 +218,7 @@ func TestTokenAwareNilHostInfo(t *testing.T) {
 func TestCOWList_Add(t *testing.T) {
 	var cow cowHostList
 
-	toAdd := [...]string{"peer1", "peer2", "peer3"}
+	toAdd := [...]net.IP{net.IPv4(0, 0, 0, 0), net.IPv4(1, 0, 0, 0), net.IPv4(2, 0, 0, 0)}
 
 	for _, addr := range toAdd {
 		if !cow.add(&HostInfo{peer: addr}) {
@@ -232,11 +233,11 @@ func TestCOWList_Add(t *testing.T) {
 
 	set := make(map[string]bool)
 	for _, host := range hosts {
-		set[host.Peer()] = true
+		set[string(host.Peer())] = true
 	}
 
 	for _, addr := range toAdd {
-		if !set[addr] {
+		if !set[string(addr)] {
 			t.Errorf("addr was not in the host list: %q", addr)
 		}
 	}

+ 1 - 1
query_executor.go

@@ -28,7 +28,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 			continue
 		}
 
-		pool, ok := q.pool.getPool(host.Peer())
+		pool, ok := q.pool.getPool(host)
 		if !ok {
 			continue
 		}

+ 28 - 11
ring.go

@@ -1,6 +1,7 @@
 package gocql
 
 import (
+	"net"
 	"sync"
 	"sync/atomic"
 )
@@ -34,9 +35,9 @@ func (r *ring) rrHost() *HostInfo {
 	return r.hostList[pos%len(r.hostList)]
 }
 
-func (r *ring) getHost(addr string) *HostInfo {
+func (r *ring) getHost(ip net.IP) *HostInfo {
 	r.mu.RLock()
-	host := r.hosts[addr]
+	host := r.hosts[ip.String()]
 	r.mu.RUnlock()
 	return host
 }
@@ -52,42 +53,58 @@ func (r *ring) allHosts() []*HostInfo {
 }
 
 func (r *ring) addHost(host *HostInfo) bool {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 	}
 
-	addr := host.Peer()
-	_, ok := r.hosts[addr]
-	r.hosts[addr] = host
+	_, ok := r.hosts[ip]
+	if !ok {
+		r.hostList = append(r.hostList, host)
+	}
+
+	r.hosts[ip] = host
 	r.mu.Unlock()
 	return ok
 }
 
 func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
+	ip := host.Peer().String()
+
 	r.mu.Lock()
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 	}
 
-	addr := host.Peer()
-	existing, ok := r.hosts[addr]
+	existing, ok := r.hosts[ip]
 	if !ok {
-		r.hosts[addr] = host
+		r.hosts[ip] = host
 		existing = host
+		r.hostList = append(r.hostList, host)
 	}
 	r.mu.Unlock()
 	return existing, ok
 }
 
-func (r *ring) removeHost(addr string) bool {
+func (r *ring) removeHost(ip net.IP) bool {
 	r.mu.Lock()
 	if r.hosts == nil {
 		r.hosts = make(map[string]*HostInfo)
 	}
 
-	_, ok := r.hosts[addr]
-	delete(r.hosts, addr)
+	k := ip.String()
+	_, ok := r.hosts[k]
+	if ok {
+		for i, host := range r.hostList {
+			if host.Peer().Equal(ip) {
+				r.hostList = append(r.hostList[:i], r.hostList[i+1:]...)
+				break
+			}
+		}
+	}
+	delete(r.hosts, k)
 	r.mu.Unlock()
 	return ok
 }

+ 7 - 4
ring_test.go

@@ -1,11 +1,14 @@
 package gocql
 
-import "testing"
+import (
+	"net"
+	"testing"
+)
 
 func TestRing_AddHostIfMissing_Missing(t *testing.T) {
 	ring := &ring{}
 
-	host := &HostInfo{peer: "test1"}
+	host := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 	h1, ok := ring.addHostIfMissing(host)
 	if ok {
 		t.Fatal("host was reported as already existing")
@@ -19,10 +22,10 @@ func TestRing_AddHostIfMissing_Missing(t *testing.T) {
 func TestRing_AddHostIfMissing_Existing(t *testing.T) {
 	ring := &ring{}
 
-	host := &HostInfo{peer: "test1"}
+	host := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 	ring.addHostIfMissing(host)
 
-	h2 := &HostInfo{peer: "test1"}
+	h2 := &HostInfo{peer: net.IPv4(1, 1, 1, 1)}
 
 	h1, ok := ring.addHostIfMissing(h2)
 	if !ok {

+ 15 - 22
session.go

@@ -11,8 +11,6 @@ import (
 	"fmt"
 	"io"
 	"log"
-	"net"
-	"strconv"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -81,18 +79,12 @@ var queryPool = &sync.Pool{
 func addrsToHosts(addrs []string, defaultPort int) ([]*HostInfo, error) {
 	hosts := make([]*HostInfo, len(addrs))
 	for i, hostport := range addrs {
-		// TODO: remove duplication
-		addr, portStr, err := net.SplitHostPort(JoinHostPort(hostport, defaultPort))
+		host, err := hostInfo(hostport, defaultPort)
 		if err != nil {
-			return nil, fmt.Errorf("NewSession: unable to parse hostport of addr %q: %v", hostport, err)
-		}
-
-		port, err := strconv.Atoi(portStr)
-		if err != nil {
-			return nil, fmt.Errorf("NewSession: invalid port for hostport of addr %q: %v", hostport, err)
+			return nil, err
 		}
 
-		hosts[i] = &HostInfo{peer: addr, port: port, state: NodeUp}
+		hosts[i] = host
 	}
 
 	return hosts, nil
@@ -156,7 +148,6 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		localHasRPCAddr, _ := checkSystemLocal(s.control)
 		s.hostSource.localHasRpcAddr = localHasRPCAddr
 
-		var err error
 		if cfg.DisableInitialHostLookup {
 			// TODO: we could look at system.local to get token and other metadata
 			// in this case.
@@ -165,22 +156,23 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 			hosts, _, err = s.hostSource.GetHosts()
 		}
 
-		if err != nil {
-			s.Close()
-			return nil, fmt.Errorf("gocql: unable to create session: %v", err)
-		}
 	} else {
 		// we dont get host info
 		hosts, err = addrsToHosts(cfg.Hosts, cfg.Port)
 	}
 
+	if err != nil {
+		s.Close()
+		return nil, fmt.Errorf("gocql: unable to create session: %v", err)
+	}
+
 	for _, host := range hosts {
 		if s.cfg.HostFilter == nil || s.cfg.HostFilter.Accept(host) {
 			if existingHost, ok := s.ring.addHostIfMissing(host); ok {
 				existingHost.update(host)
 			}
 
-			s.handleNodeUp(net.ParseIP(host.Peer()), host.Port(), false)
+			s.handleNodeUp(host.Peer(), host.Port(), false)
 		}
 	}
 
@@ -203,6 +195,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 	// connection is disable, we really have no choice, so we just make our
 	// best guess...
 	if !cfg.disableControlConn && cfg.DisableInitialHostLookup {
+		// TODO(zariel): we dont need to do this twice
 		newer, _ := checkSystemSchema(s.control)
 		s.useSystemSchema = newer
 	} else {
@@ -225,7 +218,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
 			if gocqlDebug {
 				buf := bytes.NewBufferString("Session.ring:")
 				for _, h := range hosts {
-					buf.WriteString("[" + h.Peer() + ":" + h.State().String() + "]")
+					buf.WriteString("[" + h.Peer().String() + ":" + h.State().String() + "]")
 				}
 				log.Println(buf.String())
 			}
@@ -234,7 +227,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
 				if h.IsUp() {
 					continue
 				}
-				s.handleNodeUp(net.ParseIP(h.Peer()), h.Port(), true)
+				s.handleNodeUp(h.Peer(), h.Port(), true)
 			}
 		case <-s.quit:
 			return
@@ -409,7 +402,7 @@ func (s *Session) getConn() *Conn {
 			continue
 		}
 
-		pool, ok := s.pool.getPool(host.Peer())
+		pool, ok := s.pool.getPool(host)
 		if !ok {
 			continue
 		}
@@ -628,8 +621,8 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 	return applied, iter, iter.err
 }
 
-func (s *Session) connect(addr string, errorHandler ConnErrorHandler, host *HostInfo) (*Conn, error) {
-	return Connect(host, addr, s.connCfg, errorHandler, s)
+func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
+	return Connect(host, s.connCfg, errorHandler, s)
 }
 
 // Query represents a CQL statement that can be executed.

+ 1 - 1
token.go

@@ -184,7 +184,7 @@ func (t *tokenRing) String() string {
 		buf.WriteString("]")
 		buf.WriteString(t.tokens[i].String())
 		buf.WriteString(":")
-		buf.WriteString(t.hosts[i].Peer())
+		buf.WriteString(t.hosts[i].Peer().String())
 	}
 	buf.WriteString("\n}")
 	return string(buf.Bytes())

+ 41 - 133
token_test.go

@@ -6,7 +6,9 @@ package gocql
 
 import (
 	"bytes"
+	"fmt"
 	"math/big"
+	"net"
 	"sort"
 	"strconv"
 	"testing"
@@ -226,27 +228,23 @@ func TestUnknownTokenRing(t *testing.T) {
 	}
 }
 
+func hostsForTests(n int) []*HostInfo {
+	hosts := make([]*HostInfo, n)
+	for i := 0; i < n; i++ {
+		host := &HostInfo{
+			peer:   net.IPv4(1, 1, 1, byte(n)),
+			tokens: []string{fmt.Sprintf("%d", n)},
+		}
+
+		hosts[i] = host
+	}
+	return hosts
+}
+
 // Test of the tokenRing with the Murmur3Partitioner
 func TestMurmur3TokenRing(t *testing.T) {
 	// Note, strings are parsed directly to int64, they are not murmur3 hashed
-	hosts := []*HostInfo{
-		{
-			peer:   "0",
-			tokens: []string{"0"},
-		},
-		{
-			peer:   "1",
-			tokens: []string{"25"},
-		},
-		{
-			peer:   "2",
-			tokens: []string{"50"},
-		},
-		{
-			peer:   "3",
-			tokens: []string{"75"},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("Murmur3Partitioner", hosts)
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -254,34 +252,20 @@ func TestMurmur3TokenRing(t *testing.T) {
 
 	p := murmur3Partitioner{}
 
-	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual := ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	actual := ring.GetHostForToken(p.ParseString("12"))
+	if !actual.Peer().Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "0" {
+	if !actual.Peer().Equal(hosts[0].peer) {
 		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }
@@ -290,32 +274,7 @@ func TestMurmur3TokenRing(t *testing.T) {
 func TestOrderedTokenRing(t *testing.T) {
 	// Tokens here more or less are similar layout to the int tokens above due
 	// to each numeric character translating to a consistently offset byte.
-	hosts := []*HostInfo{
-		{
-			peer: "0",
-			tokens: []string{
-				"00",
-			},
-		},
-		{
-			peer: "1",
-			tokens: []string{
-				"25",
-			},
-		},
-		{
-			peer: "2",
-			tokens: []string{
-				"50",
-			},
-		},
-		{
-			peer: "3",
-			tokens: []string{
-				"75",
-			},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("OrderedPartitioner", hosts)
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -324,33 +283,20 @@ func TestOrderedTokenRing(t *testing.T) {
 	p := orderedPartitioner{}
 
 	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual = ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }
@@ -358,32 +304,7 @@ func TestOrderedTokenRing(t *testing.T) {
 // Test of the tokenRing with the RandomPartitioner
 func TestRandomTokenRing(t *testing.T) {
 	// String tokens are parsed into big.Int in base 10
-	hosts := []*HostInfo{
-		{
-			peer: "0",
-			tokens: []string{
-				"00",
-			},
-		},
-		{
-			peer: "1",
-			tokens: []string{
-				"25",
-			},
-		},
-		{
-			peer: "2",
-			tokens: []string{
-				"50",
-			},
-		},
-		{
-			peer: "3",
-			tokens: []string{
-				"75",
-			},
-		},
-	}
+	hosts := hostsForTests(4)
 	ring, err := newTokenRing("RandomPartitioner", hosts)
 	if err != nil {
 		t.Fatalf("Failed to create token ring due to error: %v", err)
@@ -392,33 +313,20 @@ func TestRandomTokenRing(t *testing.T) {
 	p := randomPartitioner{}
 
 	var actual *HostInfo
-	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer() != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer() != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
-	}
-
-	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer() != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
+	for _, host := range hosts {
+		actual = ring.GetHostForToken(p.ParseString(host.tokens[0]))
+		if !actual.Peer().Equal(host.peer) {
+			t.Errorf("Expected peer %v for token %q, but was %v", host.peer, host.tokens[0], actual.peer)
+		}
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer() != "1" {
+	if !actual.peer.Equal(hosts[1].peer) {
 		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer() != "0" {
-		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
+	if !actual.peer.Equal(hosts[0].peer) {
+		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }