Browse Source

Merge pull request #551 from Zariel/event-frames

Add support for node events for up and down hosts
Chris Bannister 10 năm trước cách đây
mục cha
commit
3b1e36f335
23 tập tin đã thay đổi với 2017 bổ sung555 xóa
  1. 2 125
      cassandra_test.go
  2. 167 0
      ccm_test/ccm.go
  3. 48 0
      ccm_test/ccm_test.go
  4. 24 14
      cluster.go
  5. 134 0
      common_test.go
  6. 19 5
      conn.go
  7. 22 91
      conn_test.go
  8. 127 54
      connectionpool.go
  9. 172 22
      control.go
  10. 230 0
      events.go
  11. 168 0
      events_ccm_test.go
  12. 32 0
      events_test.go
  13. 43 0
      filters.go
  14. 68 1
      frame.go
  15. 242 51
      host_source.go
  16. 6 1
      integration.sh
  17. 200 24
      policies.go
  18. 74 47
      policies_test.go
  19. 75 0
      ring.go
  20. 82 29
      session.go
  21. 3 11
      session_test.go
  22. 4 5
      token.go
  23. 75 75
      token_test.go

+ 2 - 125
cassandra_test.go

@@ -4,10 +4,7 @@ package gocql
 
 import (
 	"bytes"
-	"flag"
-	"fmt"
 	"io"
-	"log"
 	"math"
 	"math/big"
 	"net"
@@ -22,118 +19,6 @@ import (
 	"gopkg.in/inf.v0"
 )
 
-var (
-	flagCluster      = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
-	flagProto        = flag.Int("proto", 2, "protcol version")
-	flagCQL          = flag.String("cql", "3.0.0", "CQL version")
-	flagRF           = flag.Int("rf", 1, "replication factor for test keyspace")
-	clusterSize      = flag.Int("clusterSize", 1, "the expected size of the cluster")
-	flagRetry        = flag.Int("retries", 5, "number of times to retry queries")
-	flagAutoWait     = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
-	flagRunSslTest   = flag.Bool("runssl", false, "Set to true to run ssl test")
-	flagRunAuthTest  = flag.Bool("runauth", false, "Set to true to run authentication test")
-	flagCompressTest = flag.String("compressor", "", "compressor to use")
-	flagTimeout      = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")
-	clusterHosts     []string
-)
-
-func init() {
-	flag.Parse()
-	clusterHosts = strings.Split(*flagCluster, ",")
-	log.SetFlags(log.Lshortfile | log.LstdFlags)
-}
-
-func addSslOptions(cluster *ClusterConfig) *ClusterConfig {
-	if *flagRunSslTest {
-		cluster.SslOpts = &SslOptions{
-			CertPath:               "testdata/pki/gocql.crt",
-			KeyPath:                "testdata/pki/gocql.key",
-			CaPath:                 "testdata/pki/ca.crt",
-			EnableHostVerification: false,
-		}
-	}
-	return cluster
-}
-
-var initOnce sync.Once
-
-func createTable(s *Session, table string) error {
-	if err := s.control.query(table).Close(); err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func createCluster() *ClusterConfig {
-	cluster := NewCluster(clusterHosts...)
-	cluster.ProtoVersion = *flagProto
-	cluster.CQLVersion = *flagCQL
-	cluster.Timeout = *flagTimeout
-	cluster.Consistency = Quorum
-	cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
-	if *flagRetry > 0 {
-		cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
-	}
-
-	switch *flagCompressTest {
-	case "snappy":
-		cluster.Compressor = &SnappyCompressor{}
-	case "":
-	default:
-		panic("invalid compressor: " + *flagCompressTest)
-	}
-
-	cluster = addSslOptions(cluster)
-	return cluster
-}
-
-func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
-	c := *cluster
-	c.Keyspace = "system"
-	c.Timeout = 20 * time.Second
-	session, err := c.CreateSession()
-	if err != nil {
-		tb.Fatal("createSession:", err)
-	}
-
-	err = session.control.query(`DROP KEYSPACE IF EXISTS ` + keyspace).Close()
-	if err != nil {
-		tb.Fatal(err)
-	}
-
-	err = session.control.query(fmt.Sprintf(`CREATE KEYSPACE %s
-	WITH replication = {
-		'class' : 'SimpleStrategy',
-		'replication_factor' : %d
-	}`, keyspace, *flagRF)).Close()
-
-	if err != nil {
-		tb.Fatal(err)
-	}
-}
-
-func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session {
-	// Drop and re-create the keyspace once. Different tests should use their own
-	// individual tables, but can assume that the table does not exist before.
-	initOnce.Do(func() {
-		createKeyspace(tb, cluster, "gocql_test")
-	})
-
-	cluster.Keyspace = "gocql_test"
-	session, err := cluster.CreateSession()
-	if err != nil {
-		tb.Fatal("createSession:", err)
-	}
-
-	return session
-}
-
-func createSession(tb testing.TB) *Session {
-	cluster := createCluster()
-	return createSessionFromCluster(cluster, tb)
-}
-
 // TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections
 func TestAuthentication(t *testing.T) {
 
@@ -165,7 +50,6 @@ func TestAuthentication(t *testing.T) {
 func TestRingDiscovery(t *testing.T) {
 	cluster := createCluster()
 	cluster.Hosts = clusterHosts[:1]
-	cluster.DiscoverHosts = true
 
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
@@ -649,10 +533,6 @@ func TestCreateSessionTimeout(t *testing.T) {
 		session.Close()
 		t.Fatal("expected ErrNoConnectionsStarted, but no error was returned.")
 	}
-
-	if err != ErrNoConnectionsStarted {
-		t.Fatalf("expected ErrNoConnectionsStarted, but received %v", err)
-	}
 }
 
 type FullName struct {
@@ -2001,13 +1881,12 @@ func TestRoutingKey(t *testing.T) {
 func TestTokenAwareConnPool(t *testing.T) {
 	cluster := createCluster()
 	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
-	cluster.DiscoverHosts = true
 
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
 
-	if session.pool.Size() != cluster.NumConns*len(cluster.Hosts) {
-		t.Errorf("Expected pool size %d but was %d", cluster.NumConns*len(cluster.Hosts), session.pool.Size())
+	if expected := cluster.NumConns * len(session.ring.allHosts()); session.pool.Size() != expected {
+		t.Errorf("Expected pool size %d but was %d", expected, session.pool.Size())
 	}
 
 	if err := createTable(session, "CREATE TABLE gocql_test.test_token_aware (id int, data text, PRIMARY KEY (id))"); err != nil {
@@ -2379,9 +2258,7 @@ func TestDiscoverViaProxy(t *testing.T) {
 	proxyAddr := proxy.Addr().String()
 
 	cluster := createCluster()
-	cluster.DiscoverHosts = true
 	cluster.NumConns = 1
-	cluster.Discovery.Sleep = 100 * time.Millisecond
 	// initial host is the proxy address
 	cluster.Hosts = []string{proxyAddr}
 

+ 167 - 0
ccm_test/ccm.go

@@ -0,0 +1,167 @@
+// +build ccm
+
+package ccm
+
+import (
+	"bufio"
+	"bytes"
+	"errors"
+	"fmt"
+	"os/exec"
+	"strings"
+)
+
+func execCmd(args ...string) (*bytes.Buffer, error) {
+	cmd := exec.Command("ccm", args...)
+	stdout := &bytes.Buffer{}
+	cmd.Stdout = stdout
+	cmd.Stderr = &bytes.Buffer{}
+	if err := cmd.Run(); err != nil {
+		return nil, errors.New(cmd.Stderr.(*bytes.Buffer).String())
+	}
+
+	return stdout, nil
+}
+
+func AllUp() error {
+	status, err := Status()
+	if err != nil {
+		return err
+	}
+
+	for _, host := range status {
+		if !host.State.IsUp() {
+			if err := NodeUp(host.Name); err != nil {
+				return err
+			}
+		}
+	}
+
+	return nil
+}
+
+func NodeUp(node string) error {
+	_, err := execCmd(node, "start", "--wait-for-binary-proto", "--wait-other-notice")
+	return err
+}
+
+func NodeDown(node string) error {
+	_, err := execCmd(node, "stop")
+	return err
+}
+
+type Host struct {
+	State NodeState
+	Addr  string
+	Name  string
+}
+
+type NodeState int
+
+func (n NodeState) String() string {
+	if n == NodeStateUp {
+		return "UP"
+	} else if n == NodeStateDown {
+		return "DOWN"
+	} else {
+		return fmt.Sprintf("UNKNOWN_STATE_%d", n)
+	}
+}
+
+func (n NodeState) IsUp() bool {
+	return n == NodeStateUp
+}
+
+const (
+	NodeStateUp NodeState = iota
+	NodeStateDown
+)
+
+func Status() (map[string]Host, error) {
+	// TODO: parse into struct o maniuplate
+	out, err := execCmd("status", "-v")
+	if err != nil {
+		return nil, err
+	}
+
+	const (
+		stateCluster = iota
+		stateCommas
+		stateNode
+		stateOption
+	)
+
+	nodes := make(map[string]Host)
+	// didnt really want to write a full state machine parser
+	state := stateCluster
+	sc := bufio.NewScanner(out)
+
+	var host Host
+
+	for sc.Scan() {
+		switch state {
+		case stateCluster:
+			text := sc.Text()
+			if !strings.HasPrefix(text, "Cluster:") {
+				return nil, fmt.Errorf("expected 'Cluster:' got %q", text)
+			}
+			state = stateCommas
+		case stateCommas:
+			text := sc.Text()
+			if !strings.HasPrefix(text, "-") {
+				return nil, fmt.Errorf("expected commas got %q", text)
+			}
+			state = stateNode
+		case stateNode:
+			// assume nodes start with node
+			text := sc.Text()
+			if !strings.HasPrefix(text, "node") {
+				return nil, fmt.Errorf("expected 'node' got %q", text)
+			}
+			line := strings.Split(text, ":")
+			host.Name = line[0]
+
+			nodeState := strings.TrimSpace(line[1])
+			switch nodeState {
+			case "UP":
+				host.State = NodeStateUp
+			case "DOWN":
+				host.State = NodeStateDown
+			default:
+				return nil, fmt.Errorf("unknown node state from ccm: %q", nodeState)
+			}
+
+			state = stateOption
+		case stateOption:
+			text := sc.Text()
+			if text == "" {
+				state = stateNode
+				nodes[host.Name] = host
+				host = Host{}
+				continue
+			}
+
+			line := strings.Split(strings.TrimSpace(text), "=")
+			k, v := line[0], line[1]
+			if k == "binary" {
+				// could check errors
+				// ('127.0.0.1', 9042)
+				v = v[2:] // (''
+				if i := strings.IndexByte(v, '\''); i < 0 {
+					return nil, fmt.Errorf("invalid binary v=%q", v)
+				} else {
+					host.Addr = v[:i]
+					// dont need port
+				}
+			}
+		default:
+			return nil, fmt.Errorf("unexpected state: %q", state)
+		}
+	}
+
+	if err := sc.Err(); err != nil {
+		return nil, fmt.Errorf("unable to parse ccm status: %v", err)
+	}
+
+	return nodes, nil
+}

+ 48 - 0
ccm_test/ccm_test.go

@@ -0,0 +1,48 @@
+// +build ccm
+
+package ccm
+
+import (
+	"testing"
+)
+
+func TestCCM(t *testing.T) {
+	if err := AllUp(); err != nil {
+		t.Fatal(err)
+	}
+
+	status, err := Status()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if host, ok := status["node1"]; !ok {
+		t.Fatal("node1 not in status list")
+	} else if !host.State.IsUp() {
+		t.Fatal("node1 is not up")
+	}
+
+	NodeDown("node1")
+	status, err = Status()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if host, ok := status["node1"]; !ok {
+		t.Fatal("node1 not in status list")
+	} else if host.State.IsUp() {
+		t.Fatal("node1 is not down")
+	}
+
+	NodeUp("node1")
+	status, err = Status()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if host, ok := status["node1"]; !ok {
+		t.Fatal("node1 not in status list")
+	} else if !host.State.IsUp() {
+		t.Fatal("node1 is not up")
+	}
+}

+ 24 - 14
cluster.go

@@ -40,16 +40,6 @@ func initStmtsLRU(max int) {
 	}
 }
 
-// To enable periodic node discovery enable DiscoverHosts in ClusterConfig
-type DiscoveryConfig struct {
-	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
-	DcFilter string
-	// If not empty will filter all discoverred hosts to a single Rack (default: "")
-	RackFilter string
-	// The interval to check for new hosts (default: 30s)
-	Sleep time.Duration
-}
-
 // PoolConfig configures the connection pool used by the driver, it defaults to
 // using a round robbin host selection policy and a round robbin connection selection
 // policy for each host.
@@ -63,7 +53,7 @@ type PoolConfig struct {
 	ConnSelectionPolicy func() ConnSelectionPolicy
 }
 
-func (p PoolConfig) buildPool(session *Session) (*policyConnPool, error) {
+func (p PoolConfig) buildPool(session *Session) *policyConnPool {
 	hostSelection := p.HostSelectionPolicy
 	if hostSelection == nil {
 		hostSelection = RoundRobinHostPolicy()
@@ -77,6 +67,27 @@ func (p PoolConfig) buildPool(session *Session) (*policyConnPool, error) {
 	return newPolicyConnPool(session, hostSelection, connSelection)
 }
 
+type DiscoveryConfig struct {
+	// If not empty will filter all discoverred hosts to a single Data Centre (default: "")
+	DcFilter string
+	// If not empty will filter all discoverred hosts to a single Rack (default: "")
+	RackFilter string
+	// ignored
+	Sleep time.Duration
+}
+
+func (d DiscoveryConfig) matchFilter(host *HostInfo) bool {
+	if d.DcFilter != "" && d.DcFilter != host.DataCenter() {
+		return false
+	}
+
+	if d.RackFilter != "" && d.RackFilter != host.Rack() {
+		return false
+	}
+
+	return true
+}
+
 // ClusterConfig is a struct to configure the default cluster implementation
 // of gocoql. It has a varity of attributes that can be used to modify the
 // behavior to fit the most common use cases. Applications that requre a
@@ -94,18 +105,18 @@ type ClusterConfig struct {
 	Authenticator     Authenticator     // authenticator (default: nil)
 	RetryPolicy       RetryPolicy       // Default retry policy to use for queries (default: 0)
 	SocketKeepalive   time.Duration     // The keepalive period to use, enabled if > 0 (default: 0)
-	DiscoverHosts     bool              // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
 	MaxPreparedStmts  int               // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
 	MaxRoutingKeyInfo int               // Sets the maximum cache size for query info about statements for each session (default: 1000)
 	PageSize          int               // Default page size to use for created sessions (default: 5000)
 	SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset)
-	Discovery         DiscoveryConfig
 	SslOpts           *SslOptions
 	DefaultTimestamp  bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
 	// PoolConfig configures the underlying connection pool, allowing the
 	// configuration of host selection and connection selection policies.
 	PoolConfig PoolConfig
 
+	Discovery DiscoveryConfig
+
 	// The maximum amount of time to wait for schema agreement in a cluster after
 	// receiving a schema change frame. (deault: 60s)
 	MaxWaitSchemaAgreement time.Duration
@@ -124,7 +135,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Port:                   9042,
 		NumConns:               2,
 		Consistency:            Quorum,
-		DiscoverHosts:          false,
 		MaxPreparedStmts:       defaultMaxPreparedStmts,
 		MaxRoutingKeyInfo:      1000,
 		PageSize:               5000,

+ 134 - 0
common_test.go

@@ -0,0 +1,134 @@
+package gocql
+
+import (
+	"flag"
+	"fmt"
+	"log"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+)
+
+var (
+	flagCluster      = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
+	flagProto        = flag.Int("proto", 2, "protcol version")
+	flagCQL          = flag.String("cql", "3.0.0", "CQL version")
+	flagRF           = flag.Int("rf", 1, "replication factor for test keyspace")
+	clusterSize      = flag.Int("clusterSize", 1, "the expected size of the cluster")
+	flagRetry        = flag.Int("retries", 5, "number of times to retry queries")
+	flagAutoWait     = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
+	flagRunSslTest   = flag.Bool("runssl", false, "Set to true to run ssl test")
+	flagRunAuthTest  = flag.Bool("runauth", false, "Set to true to run authentication test")
+	flagCompressTest = flag.String("compressor", "", "compressor to use")
+	flagTimeout      = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")
+	clusterHosts     []string
+)
+
+func init() {
+	flag.Parse()
+	clusterHosts = strings.Split(*flagCluster, ",")
+	log.SetFlags(log.Lshortfile | log.LstdFlags)
+}
+
+func addSslOptions(cluster *ClusterConfig) *ClusterConfig {
+	if *flagRunSslTest {
+		cluster.SslOpts = &SslOptions{
+			CertPath:               "testdata/pki/gocql.crt",
+			KeyPath:                "testdata/pki/gocql.key",
+			CaPath:                 "testdata/pki/ca.crt",
+			EnableHostVerification: false,
+		}
+	}
+	return cluster
+}
+
+var initOnce sync.Once
+
+func createTable(s *Session, table string) error {
+	if err := s.control.query(table).Close(); err != nil {
+		return err
+	}
+
+	if err := s.control.awaitSchemaAgreement(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func createCluster() *ClusterConfig {
+	cluster := NewCluster(clusterHosts...)
+	cluster.ProtoVersion = *flagProto
+	cluster.CQLVersion = *flagCQL
+	cluster.Timeout = *flagTimeout
+	cluster.Consistency = Quorum
+	cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
+	if *flagRetry > 0 {
+		cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
+	}
+
+	switch *flagCompressTest {
+	case "snappy":
+		cluster.Compressor = &SnappyCompressor{}
+	case "":
+	default:
+		panic("invalid compressor: " + *flagCompressTest)
+	}
+
+	cluster = addSslOptions(cluster)
+	return cluster
+}
+
+func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
+	c := *cluster
+	c.Keyspace = "system"
+	c.Timeout = 20 * time.Second
+	session, err := c.CreateSession()
+	if err != nil {
+		tb.Fatal("createSession:", err)
+	}
+	defer session.Close()
+	defer log.Println("closing keyspace session")
+
+	err = session.control.query(`DROP KEYSPACE IF EXISTS ` + keyspace).Close()
+	if err != nil {
+		tb.Fatal(err)
+	}
+
+	err = session.control.query(fmt.Sprintf(`CREATE KEYSPACE %s
+	WITH replication = {
+		'class' : 'SimpleStrategy',
+		'replication_factor' : %d
+	}`, keyspace, *flagRF)).Close()
+
+	if err != nil {
+		tb.Fatal(err)
+	}
+
+	// lets just be sure
+	if err := session.control.awaitSchemaAgreement(); err != nil {
+		tb.Fatal(err)
+	}
+}
+
+func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session {
+	// Drop and re-create the keyspace once. Different tests should use their own
+	// individual tables, but can assume that the table does not exist before.
+	initOnce.Do(func() {
+		createKeyspace(tb, cluster, "gocql_test")
+	})
+
+	cluster.Keyspace = "gocql_test"
+	session, err := cluster.CreateSession()
+	if err != nil {
+		tb.Fatal("createSession:", err)
+	}
+
+	return session
+}
+
+func createSession(tb testing.TB) *Session {
+	cluster := createCluster()
+	return createSessionFromCluster(cluster, tb)
+}

+ 19 - 5
conn.go

@@ -141,7 +141,6 @@ type Conn struct {
 }
 
 // Connect establishes a connection to a Cassandra node.
-// You must also call the Serve method before you can execute any queries.
 func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
 	var (
 		err  error
@@ -397,7 +396,12 @@ func (c *Conn) recv() error {
 		return fmt.Errorf("gocql: frame header stream is beyond call exepected bounds: %d", head.stream)
 	} else if head.stream == -1 {
 		// TODO: handle cassandra event frames, we shouldnt get any currently
-		return c.discardFrame(head)
+		framer := newFramer(c, c, c.compressor, c.version)
+		if err := framer.readFrame(&head); err != nil {
+			return err
+		}
+		go c.session.handleEvent(framer)
+		return nil
 	} else if head.stream <= 0 {
 		// reserved stream that we dont use, probably due to a protocol error
 		// or a bug in Cassandra, this should be an error, parse it and return.
@@ -739,7 +743,10 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		return &Iter{framer: framer}
 	case *resultSchemaChangeFrame, *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction:
 		iter := &Iter{framer: framer}
-		c.awaitSchemaAgreement()
+		if err := c.awaitSchemaAgreement(); err != nil {
+			// TODO: should have this behind a flag
+			log.Println(err)
+		}
 		// dont return an error from this, might be a good idea to give a warning
 		// though. The impact of this returning an error would be that the cluster
 		// is not consistent with regards to its schema.
@@ -939,11 +946,13 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
 	)
 
+	var versions map[string]struct{}
+
 	endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
 	for time.Now().Before(endDeadline) {
 		iter := c.query(peerSchemas)
 
-		versions := make(map[string]struct{})
+		versions = make(map[string]struct{})
 
 		var schemaVersion string
 		for iter.Scan(&schemaVersion) {
@@ -977,8 +986,13 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 		return
 	}
 
+	schemas := make([]string, 0, len(versions))
+	for schema := range versions {
+		schemas = append(schemas, schema)
+	}
+
 	// not exported
-	return errors.New("gocql: cluster schema versions not consistent")
+	return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
 }
 
 type inflightPrepare struct {

+ 22 - 91
conn_test.go

@@ -50,12 +50,18 @@ func TestJoinHostPort(t *testing.T) {
 	}
 }
 
+func testCluster(addr string, proto protoVersion) *ClusterConfig {
+	cluster := NewCluster(addr)
+	cluster.ProtoVersion = int(proto)
+	cluster.disableControlConn = true
+	return cluster
+}
+
 func TestSimple(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
-	cluster.ProtoVersion = int(defaultProto)
+	cluster := testCluster(srv.Address, defaultProto)
 	db, err := cluster.CreateSession()
 	if err != nil {
 		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
@@ -94,18 +100,19 @@ func TestSSLSimpleNoClientCert(t *testing.T) {
 	}
 }
 
-func createTestSslCluster(hosts string, proto uint8, useClientCert bool) *ClusterConfig {
-	cluster := NewCluster(hosts)
+func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *ClusterConfig {
+	cluster := testCluster(addr, proto)
 	sslOpts := &SslOptions{
 		CaPath:                 "testdata/pki/ca.crt",
 		EnableHostVerification: false,
 	}
+
 	if useClientCert {
 		sslOpts.CertPath = "testdata/pki/gocql.crt"
 		sslOpts.KeyPath = "testdata/pki/gocql.key"
 	}
+
 	cluster.SslOpts = sslOpts
-	cluster.ProtoVersion = int(proto)
 	return cluster
 }
 
@@ -115,28 +122,23 @@ func TestClosed(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
-	cluster.ProtoVersion = int(defaultProto)
-
-	session, err := cluster.CreateSession()
-	defer session.Close()
+	session, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
 	}
 
+	session.Close()
+
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
 		t.Fatalf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
 	}
 }
 
-func newTestSession(addr string, proto uint8) (*Session, error) {
-	cluster := NewCluster(addr)
-	cluster.ProtoVersion = int(proto)
-	return cluster.CreateSession()
+func newTestSession(addr string, proto protoVersion) (*Session, error) {
+	return testCluster(addr, proto).CreateSession()
 }
 
 func TestTimeout(t *testing.T) {
-
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
@@ -197,7 +199,7 @@ func TestStreams_Protocol1(t *testing.T) {
 
 	// TODO: these are more like session tests and should instead operate
 	// on a single Conn
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, protoVersion1)
 	cluster.NumConns = 1
 	cluster.ProtoVersion = 1
 
@@ -229,7 +231,7 @@ func TestStreams_Protocol3(t *testing.T) {
 
 	// TODO: these are more like session tests and should instead operate
 	// on a single Conn
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, protoVersion3)
 	cluster.NumConns = 1
 	cluster.ProtoVersion = 3
 
@@ -275,76 +277,6 @@ func BenchmarkProtocolV3(b *testing.B) {
 	}
 }
 
-func TestRoundRobinConnPoolRoundRobin(t *testing.T) {
-	// create 5 test servers
-	servers := make([]*TestServer, 5)
-	addrs := make([]string, len(servers))
-	for n := 0; n < len(servers); n++ {
-		servers[n] = NewTestServer(t, defaultProto)
-		addrs[n] = servers[n].Address
-		defer servers[n].Stop()
-	}
-
-	// create a new cluster using the policy-based round robin conn pool
-	cluster := NewCluster(addrs...)
-	cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
-	cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy()
-	cluster.disableControlConn = true
-
-	db, err := cluster.CreateSession()
-	if err != nil {
-		t.Fatalf("failed to create a new session: %v", err)
-	}
-
-	// Sleep to allow the pool to fill
-	time.Sleep(100 * time.Millisecond)
-
-	// run concurrent queries against the pool, server usage should
-	// be even
-	var wg sync.WaitGroup
-	wg.Add(5)
-	for n := 0; n < 5; n++ {
-		go func() {
-			defer wg.Done()
-
-			for j := 0; j < 5; j++ {
-				if err := db.Query("void").Exec(); err != nil {
-					t.Errorf("Query failed with error: %v", err)
-					return
-				}
-			}
-		}()
-	}
-	wg.Wait()
-
-	db.Close()
-
-	// wait for the pool to drain
-	time.Sleep(100 * time.Millisecond)
-	size := db.pool.Size()
-	if size != 0 {
-		t.Errorf("connection pool did not drain, still contains %d connections", size)
-	}
-
-	// verify that server usage is even
-	diff := 0
-	for n := 1; n < len(servers); n++ {
-		d := 0
-		if servers[n].nreq > servers[n-1].nreq {
-			d = int(servers[n].nreq - servers[n-1].nreq)
-		} else {
-			d = int(servers[n-1].nreq - servers[n].nreq)
-		}
-		if d > diff {
-			diff = d
-		}
-	}
-
-	if diff > 0 {
-		t.Fatalf("expected 0 difference in usage but was %d", diff)
-	}
-}
-
 // This tests that the policy connection pool handles SSL correctly
 func TestPolicyConnPoolSSL(t *testing.T) {
 	srv := NewSSLTestServer(t, defaultProto)
@@ -356,7 +288,6 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 
 	db, err := cluster.CreateSession()
 	if err != nil {
-		db.Close()
 		t.Fatalf("failed to create new session: %v", err)
 	}
 
@@ -377,7 +308,7 @@ func TestQueryTimeout(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1 * time.Millisecond
@@ -418,7 +349,7 @@ func TestQueryTimeoutReuseStream(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1 * time.Millisecond
@@ -442,7 +373,7 @@ func TestQueryTimeoutClose(t *testing.T) {
 	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	cluster := NewCluster(srv.Address)
+	cluster := testCluster(srv.Address, defaultProto)
 	// Set the timeout arbitrarily low so that the query hits the timeout in a
 	// timely manner.
 	cluster.Timeout = 1000 * time.Millisecond

+ 127 - 54
connectionpool.go

@@ -19,7 +19,7 @@ import (
 
 // interface to implement to receive the host information
 type SetHosts interface {
-	SetHosts(hosts []HostInfo)
+	SetHosts(hosts []*HostInfo)
 }
 
 // interface to implement to receive the partitioner value
@@ -62,25 +62,25 @@ type policyConnPool struct {
 
 	port     int
 	numConns int
-	connCfg  *ConnConfig
 	keyspace string
 
 	mu            sync.RWMutex
 	hostPolicy    HostSelectionPolicy
 	connPolicy    func() ConnSelectionPolicy
 	hostConnPools map[string]*hostConnPool
+
+	endpoints []string
 }
 
-func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
-	connPolicy func() ConnSelectionPolicy) (*policyConnPool, error) {
+func connConfig(session *Session) (*ConnConfig, error) {
+	cfg := session.cfg
 
 	var (
 		err       error
 		tlsConfig *tls.Config
 	)
 
-	cfg := session.cfg
-
+	// TODO(zariel): move tls config setup into session init.
 	if cfg.SslOpts != nil {
 		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
 		if err != nil {
@@ -88,37 +88,38 @@ func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
 		}
 	}
 
+	return &ConnConfig{
+		ProtoVersion:  cfg.ProtoVersion,
+		CQLVersion:    cfg.CQLVersion,
+		Timeout:       cfg.Timeout,
+		Compressor:    cfg.Compressor,
+		Authenticator: cfg.Authenticator,
+		Keepalive:     cfg.SocketKeepalive,
+		tlsConfig:     tlsConfig,
+	}, nil
+}
+
+func newPolicyConnPool(session *Session, hostPolicy HostSelectionPolicy,
+	connPolicy func() ConnSelectionPolicy) *policyConnPool {
+
 	// create the pool
 	pool := &policyConnPool{
-		session:  session,
-		port:     cfg.Port,
-		numConns: cfg.NumConns,
-		connCfg: &ConnConfig{
-			ProtoVersion:  cfg.ProtoVersion,
-			CQLVersion:    cfg.CQLVersion,
-			Timeout:       cfg.Timeout,
-			Compressor:    cfg.Compressor,
-			Authenticator: cfg.Authenticator,
-			Keepalive:     cfg.SocketKeepalive,
-			tlsConfig:     tlsConfig,
-		},
-		keyspace:      cfg.Keyspace,
+		session:       session,
+		port:          session.cfg.Port,
+		numConns:      session.cfg.NumConns,
+		keyspace:      session.cfg.Keyspace,
 		hostPolicy:    hostPolicy,
 		connPolicy:    connPolicy,
 		hostConnPools: map[string]*hostConnPool{},
 	}
 
-	hosts := make([]HostInfo, len(cfg.Hosts))
-	for i, hostAddr := range cfg.Hosts {
-		hosts[i].Peer = hostAddr
-	}
-
-	pool.SetHosts(hosts)
+	pool.endpoints = make([]string, len(session.cfg.Hosts))
+	copy(pool.endpoints, session.cfg.Hosts)
 
-	return pool, nil
+	return pool
 }
 
-func (p *policyConnPool) SetHosts(hosts []HostInfo) {
+func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
 	p.mu.Lock()
 	defer p.mu.Unlock()
 
@@ -129,24 +130,22 @@ func (p *policyConnPool) SetHosts(hosts []HostInfo) {
 
 	// TODO connect to hosts in parallel, but wait for pools to be
 	// created before returning
-
-	for i := range hosts {
-		pool, exists := p.hostConnPools[hosts[i].Peer]
-		if !exists {
+	for _, host := range hosts {
+		pool, exists := p.hostConnPools[host.Peer()]
+		if !exists && host.IsUp() {
 			// create a connection pool for the host
 			pool = newHostConnPool(
 				p.session,
-				hosts[i].Peer,
+				host,
 				p.port,
 				p.numConns,
-				p.connCfg,
 				p.keyspace,
 				p.connPolicy(),
 			)
-			p.hostConnPools[hosts[i].Peer] = pool
+			p.hostConnPools[host.Peer()] = pool
 		} else {
 			// still have this host, so don't remove it
-			delete(toRemove, hosts[i].Peer)
+			delete(toRemove, host.Peer())
 		}
 	}
 
@@ -158,7 +157,6 @@ func (p *policyConnPool) SetHosts(hosts []HostInfo) {
 
 	// update the policy
 	p.hostPolicy.SetHosts(hosts)
-
 }
 
 func (p *policyConnPool) SetPartitioner(partitioner string) {
@@ -194,7 +192,7 @@ func (p *policyConnPool) Pick(qry *Query) (SelectedHost, *Conn) {
 			panic(fmt.Sprintf("policy %T returned no host info: %+v", p.hostPolicy, host))
 		}
 
-		pool, ok := p.hostConnPools[host.Info().Peer]
+		pool, ok := p.hostConnPools[host.Info().Peer()]
 		if !ok {
 			continue
 		}
@@ -209,7 +207,7 @@ func (p *policyConnPool) Close() {
 	defer p.mu.Unlock()
 
 	// remove the hosts from the policy
-	p.hostPolicy.SetHosts([]HostInfo{})
+	p.hostPolicy.SetHosts(nil)
 
 	// close the pools
 	for addr, pool := range p.hostConnPools {
@@ -218,15 +216,69 @@ func (p *policyConnPool) Close() {
 	}
 }
 
+func (p *policyConnPool) addHost(host *HostInfo) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+
+	pool, ok := p.hostConnPools[host.Peer()]
+	if ok {
+		go pool.fill()
+		return
+	}
+
+	pool = newHostConnPool(
+		p.session,
+		host,
+		host.Port(),
+		p.numConns,
+		p.keyspace,
+		p.connPolicy(),
+	)
+
+	p.hostConnPools[host.Peer()] = pool
+
+	// update policy
+	// TODO: policy should not have conns, it should have hosts and return a host
+	// iter which the pool will use to serve conns
+	p.hostPolicy.AddHost(host)
+}
+
+func (p *policyConnPool) removeHost(addr string) {
+	p.hostPolicy.RemoveHost(addr)
+	p.mu.Lock()
+
+	pool, ok := p.hostConnPools[addr]
+	if !ok {
+		p.mu.Unlock()
+		return
+	}
+
+	delete(p.hostConnPools, addr)
+	p.mu.Unlock()
+
+	pool.Close()
+}
+
+func (p *policyConnPool) hostUp(host *HostInfo) {
+	// TODO(zariel): have a set of up hosts and down hosts, we can internally
+	// detect down hosts, then try to reconnect to them.
+	p.addHost(host)
+}
+
+func (p *policyConnPool) hostDown(addr string) {
+	// TODO(zariel): mark host as down so we can try to connect to it later, for
+	// now just treat it has removed.
+	p.removeHost(addr)
+}
+
 // hostConnPool is a connection pool for a single host.
 // Connection selection is based on a provided ConnSelectionPolicy
 type hostConnPool struct {
 	session  *Session
-	host     string
+	host     *HostInfo
 	port     int
 	addr     string
 	size     int
-	connCfg  *ConnConfig
 	keyspace string
 	policy   ConnSelectionPolicy
 	// protection for conns, closed, filling
@@ -236,16 +288,22 @@ type hostConnPool struct {
 	filling bool
 }
 
-func newHostConnPool(session *Session, host string, port, size int, connCfg *ConnConfig,
+func (h *hostConnPool) String() string {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]",
+		h.filling, h.closed, len(h.conns), h.size, h.host)
+}
+
+func newHostConnPool(session *Session, host *HostInfo, port, size int,
 	keyspace string, policy ConnSelectionPolicy) *hostConnPool {
 
 	pool := &hostConnPool{
 		session:  session,
 		host:     host,
 		port:     port,
-		addr:     JoinHostPort(host, port),
+		addr:     JoinHostPort(host.Peer(), port),
 		size:     size,
-		connCfg:  connCfg,
 		keyspace: keyspace,
 		policy:   policy,
 		conns:    make([]*Conn, 0, size),
@@ -267,13 +325,16 @@ func (pool *hostConnPool) Pick(qry *Query) *Conn {
 		return nil
 	}
 
-	empty := len(pool.conns) == 0
+	size := len(pool.conns)
 	pool.mu.RUnlock()
 
-	if empty {
-		// try to fill the empty pool
+	if size < pool.size {
+		// try to fill the pool
 		go pool.fill()
-		return nil
+
+		if size == 0 {
+			return nil
+		}
 	}
 
 	return pool.policy.Pick(qry)
@@ -350,7 +411,11 @@ func (pool *hostConnPool) fill() {
 
 		if err != nil {
 			// probably unreachable host
-			go pool.fillingStopped()
+			pool.fillingStopped()
+
+			// this is calle with the connetion pool mutex held, this call will
+			// then recursivly try to lock it again. FIXME
+			go pool.session.handleNodeDown(net.ParseIP(pool.host.Peer()), pool.port)
 			return
 		}
 
@@ -366,7 +431,7 @@ func (pool *hostConnPool) fill() {
 			fillCount--
 		}
 
-		go pool.fillingStopped()
+		pool.fillingStopped()
 		return
 	}
 
@@ -410,7 +475,7 @@ func (pool *hostConnPool) fillingStopped() {
 // create a new connection to the host and add it to the pool
 func (pool *hostConnPool) connect() error {
 	// try to connect
-	conn, err := Connect(pool.addr, pool.connCfg, pool, pool.session)
+	conn, err := pool.session.connect(pool.addr, pool)
 	if err != nil {
 		return err
 	}
@@ -433,7 +498,11 @@ func (pool *hostConnPool) connect() error {
 	}
 
 	pool.conns = append(pool.conns, conn)
-	pool.policy.SetConns(pool.conns)
+
+	conns := make([]*Conn, len(pool.conns))
+	copy(conns, pool.conns)
+	pool.policy.SetConns(conns)
+
 	return nil
 }
 
@@ -444,6 +513,8 @@ func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
 		return
 	}
 
+	// TODO: track the number of errors per host and detect when a host is dead,
+	// then also have something which can detect when a host comes back.
 	pool.mu.Lock()
 	defer pool.mu.Unlock()
 
@@ -459,7 +530,9 @@ func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
 			pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
 
 			// update the policy
-			pool.policy.SetConns(pool.conns)
+			conns := make([]*Conn, len(pool.conns))
+			copy(conns, pool.conns)
+			pool.policy.SetConns(conns)
 
 			// lost a connection, so fill the pool
 			go pool.fill()
@@ -475,10 +548,10 @@ func (pool *hostConnPool) drain() {
 
 	// empty the pool
 	conns := pool.conns
-	pool.conns = pool.conns[:0]
+	pool.conns = pool.conns[:0:0]
 
 	// update the policy
-	pool.policy.SetConns(pool.conns)
+	pool.policy.SetConns(nil)
 
 	// close the connections
 	for _, conn := range conns {

+ 172 - 22
control.go

@@ -3,22 +3,27 @@ package gocql
 import (
 	"errors"
 	"fmt"
+	"log"
+	"math/rand"
+	"net"
+	"strconv"
+	"sync"
 	"sync/atomic"
 	"time"
 )
 
-// Ensure that the atomic variable is aligned to a 64bit boundary 
+// Ensure that the atomic variable is aligned to a 64bit boundary
 // so that atomic operations can be applied on 32bit architectures.
 type controlConn struct {
-	connecting uint64
+	connecting int64
 
 	session *Session
-
-	conn       atomic.Value
+	conn    atomic.Value
 
 	retry RetryPolicy
 
-	quit chan struct{}
+	closeWg sync.WaitGroup
+	quit    chan struct{}
 }
 
 func createControlConn(session *Session) *controlConn {
@@ -29,12 +34,13 @@ func createControlConn(session *Session) *controlConn {
 	}
 
 	control.conn.Store((*Conn)(nil))
-	go control.heartBeat()
 
 	return control
 }
 
 func (c *controlConn) heartBeat() {
+	defer c.closeWg.Done()
+
 	for {
 		select {
 		case <-c.quit:
@@ -60,12 +66,84 @@ func (c *controlConn) heartBeat() {
 		c.reconnect(true)
 		// time.Sleep(5 * time.Second)
 		continue
+	}
+}
+
+func (c *controlConn) connect(endpoints []string) error {
+	// intial connection attmept, try to connect to each endpoint to get an initial
+	// list of nodes.
+
+	// shuffle endpoints so not all drivers will connect to the same initial
+	// node.
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
+	perm := r.Perm(len(endpoints))
+	shuffled := make([]string, len(endpoints))
+
+	for i, endpoint := range endpoints {
+		shuffled[perm[i]] = endpoint
+	}
+
+	// store that we are not connected so that reconnect wont happen if we error
+	atomic.StoreInt64(&c.connecting, -1)
+
+	var (
+		conn *Conn
+		err  error
+	)
+
+	for _, addr := range shuffled {
+		conn, err = c.session.connect(JoinHostPort(addr, c.session.cfg.Port), c)
+		if err != nil {
+			log.Printf("gocql: unable to control conn dial %v: %v\n", addr, err)
+			continue
+		}
+
+		if err = c.registerEvents(conn); err != nil {
+			conn.Close()
+			continue
+		}
+
+		// we should fetch the initial ring here and update initial host data. So that
+		// when we return from here we have a ring topology ready to go.
+		break
+	}
 
+	if conn == nil {
+		// this is fatal, not going to connect a session
+		return err
 	}
+
+	c.conn.Store(conn)
+	atomic.StoreInt64(&c.connecting, 0)
+
+	c.closeWg.Add(1)
+	go c.heartBeat()
+
+	return nil
+}
+
+func (c *controlConn) registerEvents(conn *Conn) error {
+	framer, err := conn.exec(&writeRegisterFrame{
+		events: []string{"TOPOLOGY_CHANGE", "STATUS_CHANGE", "STATUS_CHANGE"},
+	}, nil)
+	if err != nil {
+		return err
+	}
+
+	frame, err := framer.parseFrame()
+	if err != nil {
+		return err
+	} else if _, ok := frame.(*readyFrame); !ok {
+		return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame)
+	}
+
+	return nil
 }
 
 func (c *controlConn) reconnect(refreshring bool) {
-	if !atomic.CompareAndSwapUint64(&c.connecting, 0, 1) {
+	// TODO: simplify this function, use session.ring to get hosts instead of the
+	// connection pool
+	if !atomic.CompareAndSwapInt64(&c.connecting, 0, 1) {
 		return
 	}
 
@@ -75,38 +153,65 @@ func (c *controlConn) reconnect(refreshring bool) {
 		if success {
 			go func() {
 				time.Sleep(500 * time.Millisecond)
-				atomic.StoreUint64(&c.connecting, 0)
+				atomic.StoreInt64(&c.connecting, 0)
 			}()
 		} else {
-			atomic.StoreUint64(&c.connecting, 0)
+			atomic.StoreInt64(&c.connecting, 0)
 		}
 	}()
 
+	addr := c.addr()
 	oldConn := c.conn.Load().(*Conn)
+	if oldConn != nil {
+		oldConn.Close()
+	}
+
+	var newConn *Conn
+	if addr != "" {
+		// try to connect to the old host
+		conn, err := c.session.connect(addr, c)
+		if err != nil {
+			// host is dead
+			// TODO: this is replicated in a few places
+			ip, portStr, _ := net.SplitHostPort(addr)
+			port, _ := strconv.Atoi(portStr)
+			c.session.handleNodeDown(net.ParseIP(ip), port)
+		} else {
+			newConn = conn
+		}
+	}
 
 	// TODO: should have our own roundrobbin for hosts so that we can try each
 	// in succession and guantee that we get a different host each time.
-	host, conn := c.session.pool.Pick(nil)
-	if conn == nil {
-		return
+	if newConn == nil {
+		_, conn := c.session.pool.Pick(nil)
+		if conn == nil {
+			return
+		}
+
+		if conn == nil {
+			return
+		}
+
+		var err error
+		newConn, err = c.session.connect(conn.addr, c)
+		if err != nil {
+			// TODO: add log handler for things like this
+			return
+		}
 	}
 
-	newConn, err := Connect(conn.addr, conn.cfg, c, c.session)
-	if err != nil {
-		host.Mark(err)
-		// TODO: add log handler for things like this
+	if err := c.registerEvents(newConn); err != nil {
+		// TODO: handle this case better
+		newConn.Close()
+		log.Printf("gocql: control unable to register events: %v\n", err)
 		return
 	}
 
-	host.Mark(nil)
 	c.conn.Store(newConn)
 	success = true
 
-	if oldConn != nil {
-		oldConn.Close()
-	}
-
-	if refreshring && c.session.cfg.DiscoverHosts {
+	if refreshring {
 		c.session.hostSource.refreshRing()
 	}
 }
@@ -179,6 +284,46 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 	return
 }
 
+func (c *controlConn) fetchHostInfo(addr net.IP, port int) (*HostInfo, error) {
+	// TODO(zariel): we should probably move this into host_source or atleast
+	// share code with it.
+	hostname, _, err := net.SplitHostPort(c.addr())
+	if err != nil {
+		return nil, fmt.Errorf("unable to fetch host info, invalid conn addr: %q: %v", c.addr(), err)
+	}
+
+	isLocal := hostname == addr.String()
+
+	var fn func(*HostInfo) error
+
+	if isLocal {
+		fn = func(host *HostInfo) error {
+			// TODO(zariel): should we fetch rpc_address from here?
+			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.local WHERE key='local'")
+			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
+			return iter.Close()
+		}
+	} else {
+		fn = func(host *HostInfo) error {
+			// TODO(zariel): should we fetch rpc_address from here?
+			iter := c.query("SELECT data_center, rack, host_id, tokens, release_version FROM system.peers WHERE peer=?", addr)
+			iter.Scan(&host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version)
+			return iter.Close()
+		}
+	}
+
+	host := &HostInfo{
+		port: port,
+	}
+
+	if err := fn(host); err != nil {
+		return nil, err
+	}
+	host.peer = addr.String()
+
+	return host, nil
+}
+
 func (c *controlConn) awaitSchemaAgreement() error {
 	return c.withConn(func(conn *Conn) *Iter {
 		return &Iter{err: conn.awaitSchemaAgreement()}
@@ -196,6 +341,11 @@ func (c *controlConn) addr() string {
 func (c *controlConn) close() {
 	// TODO: handle more gracefully
 	close(c.quit)
+	c.closeWg.Wait()
+	conn := c.conn.Load().(*Conn)
+	if conn != nil {
+		conn.Close()
+	}
 }
 
 var errNoControl = errors.New("gocql: no control connection available")

+ 230 - 0
events.go

@@ -0,0 +1,230 @@
+package gocql
+
+import (
+	"log"
+	"net"
+	"sync"
+	"time"
+)
+
+type eventDeouncer struct {
+	name   string
+	timer  *time.Timer
+	mu     sync.Mutex
+	events []frame
+
+	callback func([]frame)
+	quit     chan struct{}
+}
+
+func newEventDeouncer(name string, eventHandler func([]frame)) *eventDeouncer {
+	e := &eventDeouncer{
+		name:     name,
+		quit:     make(chan struct{}),
+		timer:    time.NewTimer(eventDebounceTime),
+		callback: eventHandler,
+	}
+	e.timer.Stop()
+	go e.flusher()
+
+	return e
+}
+
+func (e *eventDeouncer) stop() {
+	e.quit <- struct{}{} // sync with flusher
+	close(e.quit)
+}
+
+func (e *eventDeouncer) flusher() {
+	for {
+		select {
+		case <-e.timer.C:
+			e.mu.Lock()
+			e.flush()
+			e.mu.Unlock()
+		case <-e.quit:
+			return
+		}
+	}
+}
+
+const (
+	eventBufferSize   = 1000
+	eventDebounceTime = 1 * time.Second
+)
+
+// flush must be called with mu locked
+func (e *eventDeouncer) flush() {
+	if len(e.events) == 0 {
+		return
+	}
+
+	// if the flush interval is faster than the callback then we will end up calling
+	// the callback multiple times, probably a bad idea. In this case we could drop
+	// frames?
+	go e.callback(e.events)
+	e.events = make([]frame, 0, eventBufferSize)
+}
+
+func (e *eventDeouncer) debounce(frame frame) {
+	e.mu.Lock()
+	e.timer.Reset(eventDebounceTime)
+
+	// TODO: probably need a warning to track if this threshold is too low
+	if len(e.events) < eventBufferSize {
+		e.events = append(e.events, frame)
+	} else {
+		log.Printf("%s: buffer full, dropping event frame: %s", e.name, frame)
+	}
+
+	e.mu.Unlock()
+}
+
+func (s *Session) handleNodeEvent(frames []frame) {
+	type nodeEvent struct {
+		change string
+		host   net.IP
+		port   int
+	}
+
+	events := make(map[string]*nodeEvent)
+
+	for _, frame := range frames {
+		// TODO: can we be sure the order of events in the buffer is correct?
+		switch f := frame.(type) {
+		case *topologyChangeEventFrame:
+			event, ok := events[f.host.String()]
+			if !ok {
+				event = &nodeEvent{change: f.change, host: f.host, port: f.port}
+				events[f.host.String()] = event
+			}
+			event.change = f.change
+
+		case *statusChangeEventFrame:
+			event, ok := events[f.host.String()]
+			if !ok {
+				event = &nodeEvent{change: f.change, host: f.host, port: f.port}
+				events[f.host.String()] = event
+			}
+			event.change = f.change
+		}
+	}
+
+	for _, f := range events {
+		switch f.change {
+		case "NEW_NODE":
+			s.handleNewNode(f.host, f.port, true)
+		case "REMOVED_NODE":
+			s.handleRemovedNode(f.host, f.port)
+		case "MOVED_NODE":
+		// java-driver handles this, not mentioned in the spec
+		// TODO(zariel): refresh token map
+		case "UP":
+			s.handleNodeUp(f.host, f.port, true)
+		case "DOWN":
+			s.handleNodeDown(f.host, f.port)
+		}
+	}
+}
+
+func (s *Session) handleEvent(framer *framer) {
+	// TODO(zariel): need to debounce events frames, and possible also events
+	defer framerPool.Put(framer)
+
+	frame, err := framer.parseFrame()
+	if err != nil {
+		// TODO: logger
+		log.Printf("gocql: unable to parse event frame: %v\n", err)
+		return
+	}
+
+	// TODO: handle medatadata events
+	switch f := frame.(type) {
+	case *schemaChangeKeyspace:
+	case *schemaChangeFunction:
+	case *schemaChangeTable:
+	case *topologyChangeEventFrame, *statusChangeEventFrame:
+		s.nodeEvents.debounce(frame)
+	default:
+		log.Printf("gocql: invalid event frame (%T): %v\n", f, f)
+	}
+
+}
+
+func (s *Session) handleNewNode(host net.IP, port int, waitForBinary bool) {
+	// TODO(zariel): need to be able to filter discovered nodes
+
+	var hostInfo *HostInfo
+	if s.control != nil {
+		var err error
+		hostInfo, err = s.control.fetchHostInfo(host, port)
+		if err != nil {
+			log.Printf("gocql: events: unable to fetch host info for %v: %v\n", host, err)
+			return
+		}
+
+	} else {
+		hostInfo = &HostInfo{peer: host.String(), port: port, state: NodeUp}
+	}
+
+	// TODO: remove this when the host selection policy is more sophisticated
+	if !s.cfg.Discovery.matchFilter(hostInfo) {
+		return
+	}
+
+	if t := hostInfo.Version().nodeUpDelay(); t > 0 && waitForBinary {
+		time.Sleep(t)
+	}
+
+	// should this handle token moving?
+	if existing, ok := s.ring.addHostIfMissing(hostInfo); !ok {
+		existing.update(hostInfo)
+		hostInfo = existing
+	}
+
+	s.pool.addHost(hostInfo)
+
+	if s.control != nil {
+		s.hostSource.refreshRing()
+	}
+}
+
+func (s *Session) handleRemovedNode(ip net.IP, port int) {
+	// we remove all nodes but only add ones which pass the filter
+	addr := ip.String()
+	s.pool.removeHost(addr)
+	s.ring.removeHost(addr)
+
+	s.hostSource.refreshRing()
+}
+
+func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) {
+	addr := ip.String()
+	host := s.ring.getHost(addr)
+	if host != nil {
+		// TODO: remove this when the host selection policy is more sophisticated
+		if !s.cfg.Discovery.matchFilter(host) {
+			return
+		}
+
+		if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary {
+			time.Sleep(t)
+		}
+
+		host.setState(NodeUp)
+		s.pool.hostUp(host)
+		return
+	}
+
+	s.handleNewNode(ip, port, waitForBinary)
+}
+
+func (s *Session) handleNodeDown(ip net.IP, port int) {
+	addr := ip.String()
+	host := s.ring.getHost(addr)
+	if host != nil {
+		host.setState(NodeDown)
+	}
+
+	s.pool.hostDown(addr)
+}

+ 168 - 0
events_ccm_test.go

@@ -0,0 +1,168 @@
+// +build ccm
+
+package gocql
+
+import (
+	"github.com/gocql/gocql/ccm_test"
+	"log"
+	"testing"
+	"time"
+)
+
+func TestEventDiscovery(t *testing.T) {
+	if err := ccm.AllUp(); err != nil {
+		t.Fatal(err)
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	status, err := ccm.Status()
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Logf("status=%+v\n", status)
+
+	session.pool.mu.RLock()
+	poolHosts := session.pool.hostConnPools // TODO: replace with session.ring
+	t.Logf("poolhosts=%+v\n", poolHosts)
+	// check we discovered all the nodes in the ring
+	for _, host := range status {
+		if _, ok := poolHosts[host.Addr]; !ok {
+			t.Errorf("did not discover %q", host.Addr)
+		}
+	}
+	session.pool.mu.RUnlock()
+	if t.Failed() {
+		t.FailNow()
+	}
+}
+
+func TestEventNodeDownControl(t *testing.T) {
+	const targetNode = "node1"
+	t.Log("marking " + targetNode + " as down")
+	if err := ccm.AllUp(); err != nil {
+		t.Fatal(err)
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	if err := ccm.NodeDown(targetNode); err != nil {
+		t.Fatal(err)
+	}
+
+	status, err := ccm.Status()
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Logf("status=%+v\n", status)
+	t.Logf("marking node %q down: %v\n", targetNode, status[targetNode])
+
+	time.Sleep(5 * time.Second)
+
+	session.pool.mu.RLock()
+
+	poolHosts := session.pool.hostConnPools
+	node := status[targetNode]
+	t.Logf("poolhosts=%+v\n", poolHosts)
+
+	if _, ok := poolHosts[node.Addr]; ok {
+		session.pool.mu.RUnlock()
+		t.Fatal("node not removed after remove event")
+	}
+	session.pool.mu.RUnlock()
+}
+
+func TestEventNodeDown(t *testing.T) {
+	const targetNode = "node3"
+	if err := ccm.AllUp(); err != nil {
+		t.Fatal(err)
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	if err := ccm.NodeDown(targetNode); err != nil {
+		t.Fatal(err)
+	}
+
+	status, err := ccm.Status()
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Logf("status=%+v\n", status)
+	t.Logf("marking node %q down: %v\n", targetNode, status[targetNode])
+
+	time.Sleep(5 * time.Second)
+
+	session.pool.mu.RLock()
+	defer session.pool.mu.RUnlock()
+
+	poolHosts := session.pool.hostConnPools
+	node := status[targetNode]
+	t.Logf("poolhosts=%+v\n", poolHosts)
+
+	if _, ok := poolHosts[node.Addr]; ok {
+		t.Fatal("node not removed after remove event")
+	}
+}
+
+func TestEventNodeUp(t *testing.T) {
+	if err := ccm.AllUp(); err != nil {
+		t.Fatal(err)
+	}
+
+	status, err := ccm.Status()
+	if err != nil {
+		t.Fatal(err)
+	}
+	log.Printf("status=%+v\n", status)
+
+	session := createSession(t)
+	defer session.Close()
+	poolHosts := session.pool.hostConnPools
+
+	const targetNode = "node2"
+
+	session.pool.mu.RLock()
+	_, ok := poolHosts[status[targetNode].Addr]
+	session.pool.mu.RUnlock()
+	if !ok {
+		session.pool.mu.RLock()
+		t.Errorf("target pool not in connection pool: addr=%q pools=%v", status[targetNode].Addr, poolHosts)
+		session.pool.mu.RUnlock()
+		t.FailNow()
+	}
+
+	if err := ccm.NodeDown(targetNode); err != nil {
+		t.Fatal(err)
+	}
+
+	time.Sleep(5 * time.Second)
+
+	session.pool.mu.RLock()
+	log.Printf("poolhosts=%+v\n", poolHosts)
+	node := status[targetNode]
+
+	if _, ok := poolHosts[node.Addr]; ok {
+		session.pool.mu.RUnlock()
+		t.Fatal("node not removed after remove event")
+	}
+	session.pool.mu.RUnlock()
+
+	if err := ccm.NodeUp(targetNode); err != nil {
+		t.Fatal(err)
+	}
+
+	// cassandra < 2.2 needs 10 seconds to start up the binary service
+	time.Sleep(10 * time.Second)
+
+	session.pool.mu.RLock()
+	log.Printf("poolhosts=%+v\n", poolHosts)
+	if _, ok := poolHosts[node.Addr]; !ok {
+		session.pool.mu.RUnlock()
+		t.Fatal("node not added after node added event")
+	}
+	session.pool.mu.RUnlock()
+}

+ 32 - 0
events_test.go

@@ -0,0 +1,32 @@
+package gocql
+
+import (
+	"net"
+	"sync"
+	"testing"
+)
+
+func TestEventDebounce(t *testing.T) {
+	const eventCount = 150
+	wg := &sync.WaitGroup{}
+	wg.Add(1)
+
+	eventsSeen := 0
+	debouncer := newEventDeouncer("testDebouncer", func(events []frame) {
+		defer wg.Done()
+		eventsSeen += len(events)
+	})
+
+	for i := 0; i < eventCount; i++ {
+		debouncer.debounce(&statusChangeEventFrame{
+			change: "UP",
+			host:   net.IPv4(127, 0, 0, 1),
+			port:   9042,
+		})
+	}
+
+	wg.Wait()
+	if eventCount != eventsSeen {
+		t.Fatalf("expected to see %d events but got %d", eventCount, eventsSeen)
+	}
+}

+ 43 - 0
filters.go

@@ -0,0 +1,43 @@
+package gocql
+
+// HostFilter interface is used when a host is discovered via server sent events.
+type HostFilter interface {
+	// Called when a new host is discovered, returning true will cause the host
+	// to be added to the pools.
+	Accept(host *HostInfo) bool
+}
+
+// HostFilterFunc converts a func(host HostInfo) bool into a HostFilter
+type HostFilterFunc func(host *HostInfo) bool
+
+func (fn HostFilterFunc) Accept(host *HostInfo) bool {
+	return fn(host)
+}
+
+// AcceptAllFilter will accept all hosts
+func AcceptAllFilterfunc() HostFilter {
+	return HostFilterFunc(func(host *HostInfo) bool {
+		return true
+	})
+}
+
+// DataCentreHostFilter filters all hosts such that they are in the same data centre
+// as the supplied data centre.
+func DataCentreHostFilter(dataCentre string) HostFilter {
+	return HostFilterFunc(func(host *HostInfo) bool {
+		return host.DataCenter() == dataCentre
+	})
+}
+
+// WhiteListHostFilter filters incoming hosts by checking that their address is
+// in the initial hosts whitelist.
+func WhiteListHostFilter(hosts ...string) HostFilter {
+	m := make(map[string]bool, len(hosts))
+	for _, host := range hosts {
+		m[host] = true
+	}
+
+	return HostFilterFunc(func(host *HostInfo) bool {
+		return m[host.Peer()]
+	})
+}

+ 68 - 1
frame.go

@@ -346,7 +346,7 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, err error) {
 	version := p[0] & protoVersionMask
 
 	if version < protoVersion1 || version > protoVersion4 {
-		err = fmt.Errorf("gocql: invalid version: %x", version)
+		err = fmt.Errorf("gocql: invalid version: %d", version)
 		return
 	}
 
@@ -462,6 +462,8 @@ func (f *framer) parseFrame() (frame frame, err error) {
 		frame = f.parseAuthChallengeFrame()
 	case opAuthSuccess:
 		frame = f.parseAuthSuccessFrame()
+	case opEvent:
+		frame = f.parseEventFrame()
 	default:
 		return nil, NewErrProtocol("unknown op in frame header: %s", f.header.op)
 	}
@@ -1154,6 +1156,56 @@ func (f *framer) parseAuthChallengeFrame() frame {
 	}
 }
 
+type statusChangeEventFrame struct {
+	frameHeader
+
+	change string
+	host   net.IP
+	port   int
+}
+
+func (t statusChangeEventFrame) String() string {
+	return fmt.Sprintf("[status_change change=%s host=%v port=%v]", t.change, t.host, t.port)
+}
+
+// essentially the same as statusChange
+type topologyChangeEventFrame struct {
+	frameHeader
+
+	change string
+	host   net.IP
+	port   int
+}
+
+func (t topologyChangeEventFrame) String() string {
+	return fmt.Sprintf("[topology_change change=%s host=%v port=%v]", t.change, t.host, t.port)
+}
+
+func (f *framer) parseEventFrame() frame {
+	eventType := f.readString()
+
+	switch eventType {
+	case "TOPOLOGY_CHANGE":
+		frame := &topologyChangeEventFrame{frameHeader: *f.header}
+		frame.change = f.readString()
+		frame.host, frame.port = f.readInet()
+
+		return frame
+	case "STATUS_CHANGE":
+		frame := &statusChangeEventFrame{frameHeader: *f.header}
+		frame.change = f.readString()
+		frame.host, frame.port = f.readInet()
+
+		return frame
+	case "SCHEMA_CHANGE":
+		// this should work for all versions
+		return f.parseResultSchemaChange()
+	default:
+		panic(fmt.Errorf("gocql: unknown event type: %q", eventType))
+	}
+
+}
+
 type writeAuthResponseFrame struct {
 	data []byte
 }
@@ -1408,6 +1460,21 @@ func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error {
 	return f.finishWrite()
 }
 
+type writeRegisterFrame struct {
+	events []string
+}
+
+func (w *writeRegisterFrame) writeFrame(framer *framer, streamID int) error {
+	return framer.writeRegisterFrame(streamID, w)
+}
+
+func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error {
+	f.writeHeader(f.flags, opRegister, streamID)
+	f.writeStringList(w.events)
+
+	return f.finishWrite()
+}
+
 func (f *framer) readByte() byte {
 	if len(f.rbuf) < 1 {
 		panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.rbuf)))

+ 242 - 51
host_source.go

@@ -2,36 +2,231 @@ package gocql
 
 import (
 	"fmt"
-	"log"
 	"net"
+	"strconv"
+	"strings"
 	"sync"
 	"time"
 )
 
+type nodeState int32
+
+func (n nodeState) String() string {
+	if n == NodeUp {
+		return "UP"
+	} else if n == NodeDown {
+		return "DOWN"
+	}
+	return fmt.Sprintf("UNKNOWN_%d", n)
+}
+
+const (
+	NodeUp nodeState = iota
+	NodeDown
+)
+
+type cassVersion struct {
+	Major, Minor, Patch int
+}
+
+func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error {
+	version := strings.TrimSuffix(string(data), "-SNAPSHOT")
+	v := strings.Split(version, ".")
+	if len(v) != 3 {
+		return fmt.Errorf("invalid schema_version: %v", string(data))
+	}
+
+	var err error
+	c.Major, err = strconv.Atoi(v[0])
+	if err != nil {
+		return fmt.Errorf("invalid major version %v: %v", v[0], err)
+	}
+
+	c.Minor, err = strconv.Atoi(v[1])
+	if err != nil {
+		return fmt.Errorf("invalid minor version %v: %v", v[1], err)
+	}
+
+	c.Patch, err = strconv.Atoi(v[2])
+	if err != nil {
+		return fmt.Errorf("invalid patch version %v: %v", v[2], err)
+	}
+
+	return nil
+}
+
+func (c cassVersion) String() string {
+	return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
+}
+
+func (c cassVersion) nodeUpDelay() time.Duration {
+	if c.Major >= 2 && c.Minor >= 2 {
+		// CASSANDRA-8236
+		return 0
+	}
+
+	return 10 * time.Second
+}
+
 type HostInfo struct {
-	Peer       string
-	DataCenter string
-	Rack       string
-	HostId     string
-	Tokens     []string
+	// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
+	// that we are thread safe use a mutex to access all fields.
+	mu         sync.RWMutex
+	peer       string
+	port       int
+	dataCenter string
+	rack       string
+	hostId     string
+	version    cassVersion
+	state      nodeState
+	tokens     []string
+}
+
+func (h *HostInfo) Equal(host *HostInfo) bool {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	host.mu.RLock()
+	defer host.mu.RUnlock()
+
+	return h.peer == host.peer && h.hostId == host.hostId
+}
+
+func (h *HostInfo) Peer() string {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return h.peer
+}
+
+func (h *HostInfo) setPeer(peer string) *HostInfo {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	h.peer = peer
+	return h
+}
+
+func (h *HostInfo) DataCenter() string {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return h.dataCenter
+}
+
+func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	h.dataCenter = dataCenter
+	return h
+}
+
+func (h *HostInfo) Rack() string {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return h.rack
 }
 
-func (h HostInfo) String() string {
-	return fmt.Sprintf("[hostinfo peer=%q data_centre=%q rack=%q host_id=%q num_tokens=%d]", h.Peer, h.DataCenter, h.Rack, h.HostId, len(h.Tokens))
+func (h *HostInfo) setRack(rack string) *HostInfo {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	h.rack = rack
+	return h
+}
+
+func (h *HostInfo) HostID() string {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return h.hostId
+}
+
+func (h *HostInfo) setHostID(hostID string) *HostInfo {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	h.hostId = hostID
+	return h
+}
+
+func (h *HostInfo) Version() cassVersion {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return h.version
+}
+
+func (h *HostInfo) setVersion(major, minor, patch int) *HostInfo {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	h.version = cassVersion{major, minor, patch}
+	return h
+}
+
+func (h *HostInfo) State() nodeState {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return h.state
+}
+
+func (h *HostInfo) setState(state nodeState) *HostInfo {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	h.state = state
+	return h
+}
+
+func (h *HostInfo) Tokens() []string {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return h.tokens
+}
+
+func (h *HostInfo) setTokens(tokens []string) *HostInfo {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	h.tokens = tokens
+	return h
+}
+
+func (h *HostInfo) Port() int {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return h.port
+}
+
+func (h *HostInfo) setPort(port int) *HostInfo {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	h.port = port
+	return h
+}
+
+func (h *HostInfo) update(from *HostInfo) {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+
+	h.tokens = from.tokens
+	h.version = from.version
+	h.hostId = from.hostId
+	h.dataCenter = from.dataCenter
+}
+
+func (h *HostInfo) IsUp() bool {
+	return h.State() == NodeUp
+}
+
+func (h *HostInfo) String() string {
+	h.mu.RLock()
+	defer h.mu.RUnlock()
+	return fmt.Sprintf("[hostinfo peer=%q port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]", h.peer, h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens))
 }
 
 // Polls system.peers at a specific interval to find new hosts
 type ringDescriber struct {
-	dcFilter        string
-	rackFilter      string
-	prevHosts       []HostInfo
-	prevPartitioner string
-	session         *Session
-	closeChan       chan bool
+	dcFilter   string
+	rackFilter string
+	session    *Session
+	closeChan  chan bool
 	// indicates that we can use system.local to get the connections remote address
 	localHasRpcAddr bool
 
-	mu sync.Mutex
+	mu              sync.Mutex
+	prevHosts       []*HostInfo
+	prevPartitioner string
 }
 
 func checkSystemLocal(control *controlConn) (bool, error) {
@@ -49,27 +244,27 @@ func checkSystemLocal(control *controlConn) (bool, error) {
 	return true, nil
 }
 
-func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err error) {
+func (r *ringDescriber) GetHosts() (hosts []*HostInfo, partitioner string, err error) {
 	r.mu.Lock()
 	defer r.mu.Unlock()
 	// we need conn to be the same because we need to query system.peers and system.local
 	// on the same node to get the whole cluster
 
 	const (
-		legacyLocalQuery = "SELECT data_center, rack, host_id, tokens, partitioner FROM system.local"
+		legacyLocalQuery = "SELECT data_center, rack, host_id, tokens, partitioner, release_version FROM system.local"
 		// only supported in 2.2.0, 2.1.6, 2.0.16
-		localQuery = "SELECT broadcast_address, data_center, rack, host_id, tokens, partitioner FROM system.local"
+		localQuery = "SELECT broadcast_address, data_center, rack, host_id, tokens, partitioner, release_version FROM system.local"
 	)
 
-	var localHost HostInfo
+	localHost := &HostInfo{}
 	if r.localHasRpcAddr {
 		iter := r.session.control.query(localQuery)
 		if iter == nil {
 			return r.prevHosts, r.prevPartitioner, nil
 		}
 
-		iter.Scan(&localHost.Peer, &localHost.DataCenter, &localHost.Rack,
-			&localHost.HostId, &localHost.Tokens, &partitioner)
+		iter.Scan(&localHost.peer, &localHost.dataCenter, &localHost.rack,
+			&localHost.hostId, &localHost.tokens, &partitioner, &localHost.version)
 
 		if err = iter.Close(); err != nil {
 			return nil, "", err
@@ -80,7 +275,7 @@ func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err er
 			return r.prevHosts, r.prevPartitioner, nil
 		}
 
-		iter.Scan(&localHost.DataCenter, &localHost.Rack, &localHost.HostId, &localHost.Tokens, &partitioner)
+		iter.Scan(&localHost.dataCenter, &localHost.rack, &localHost.hostId, &localHost.tokens, &partitioner, &localHost.version)
 
 		if err = iter.Close(); err != nil {
 			return nil, "", err
@@ -93,22 +288,26 @@ func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err er
 			panic(err)
 		}
 
-		localHost.Peer = addr
+		localHost.peer = addr
 	}
 
-	hosts = []HostInfo{localHost}
+	localHost.port = r.session.cfg.Port
+
+	hosts = []*HostInfo{localHost}
 
-	iter := r.session.control.query("SELECT rpc_address, data_center, rack, host_id, tokens FROM system.peers")
+	iter := r.session.control.query("SELECT rpc_address, data_center, rack, host_id, tokens, release_version FROM system.peers")
 	if iter == nil {
 		return r.prevHosts, r.prevPartitioner, nil
 	}
 
-	host := HostInfo{}
-	for iter.Scan(&host.Peer, &host.DataCenter, &host.Rack, &host.HostId, &host.Tokens) {
-		if r.matchFilter(&host) {
+	host := &HostInfo{port: r.session.cfg.Port}
+	for iter.Scan(&host.peer, &host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version) {
+		if r.matchFilter(host) {
 			hosts = append(hosts, host)
 		}
-		host = HostInfo{}
+		host = &HostInfo{
+			port: r.session.cfg.Port,
+		}
 	}
 
 	if err = iter.Close(); err != nil {
@@ -122,45 +321,37 @@ func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err er
 }
 
 func (r *ringDescriber) matchFilter(host *HostInfo) bool {
-
-	if r.dcFilter != "" && r.dcFilter != host.DataCenter {
+	if r.dcFilter != "" && r.dcFilter != host.DataCenter() {
 		return false
 	}
 
-	if r.rackFilter != "" && r.rackFilter != host.Rack {
+	if r.rackFilter != "" && r.rackFilter != host.Rack() {
 		return false
 	}
 
 	return true
 }
 
-func (r *ringDescriber) refreshRing() {
+func (r *ringDescriber) refreshRing() error {
 	// if we have 0 hosts this will return the previous list of hosts to
 	// attempt to reconnect to the cluster otherwise we would never find
 	// downed hosts again, could possibly have an optimisation to only
 	// try to add new hosts if GetHosts didnt error and the hosts didnt change.
 	hosts, partitioner, err := r.GetHosts()
 	if err != nil {
-		log.Println("RingDescriber: unable to get ring topology:", err)
-		return
-	}
-
-	r.session.pool.SetHosts(hosts)
-	r.session.pool.SetPartitioner(partitioner)
-}
-
-func (r *ringDescriber) run(sleep time.Duration) {
-	if sleep == 0 {
-		sleep = 30 * time.Second
+		return err
 	}
 
-	for {
-		r.refreshRing()
-
-		select {
-		case <-time.After(sleep):
-		case <-r.closeChan:
-			return
+	// TODO: move this to session
+	// TODO: handle removing hosts here
+	for _, h := range hosts {
+		if host, ok := r.session.ring.addHostIfMissing(h); !ok {
+			r.session.pool.addHost(h)
+		} else {
+			host.update(h)
 		}
 	}
+
+	r.session.pool.SetPartitioner(partitioner)
+	return nil
 }

+ 6 - 1
integration.sh

@@ -52,6 +52,9 @@ function run_tests() {
 		ccm updateconf 'enable_user_defined_functions: true'
 	fi
 
+	sleep 1s
+
+	ccm list
 	ccm start -v
 	ccm status
 	ccm node1 nodetool status
@@ -62,7 +65,7 @@ function run_tests() {
     	go test -v . -timeout 15s -run=TestAuthentication -tags integration -runssl -runauth -proto=$proto -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=1000ms
 	else
 
-		go test -timeout 5m -tags integration -v -gocql.timeout=10s -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy ./...
+		go test -timeout 10m -tags integration -v -gocql.timeout=10s -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy ./...
 
 		if [ ${PIPESTATUS[0]} -ne 0 ]; then
 			echo "--- FAIL: ccm status follows:"
@@ -73,6 +76,8 @@ function run_tests() {
 			echo "--- FAIL: Received a non-zero exit code from the go test execution, please investigate this"
 			exit 1
 		fi
+
+		go test -timeout 10m -tags ccm -v -gocql.timeout=10s -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy ./...
 	fi
 
 	ccm remove

+ 200 - 24
policies.go

@@ -5,6 +5,7 @@
 package gocql
 
 import (
+	"fmt"
 	"log"
 	"sync"
 	"sync/atomic"
@@ -12,6 +13,114 @@ import (
 	"github.com/hailocab/go-hostpool"
 )
 
+// cowHostList implements a copy on write host list, its equivilent type is []*HostInfo
+type cowHostList struct {
+	list atomic.Value
+	mu   sync.Mutex
+}
+
+func (c *cowHostList) String() string {
+	return fmt.Sprintf("%+v", c.get())
+}
+
+func (c *cowHostList) get() []*HostInfo {
+	// TODO(zariel): should we replace this with []*HostInfo?
+	l, ok := c.list.Load().(*[]*HostInfo)
+	if !ok {
+		return nil
+	}
+	return *l
+}
+
+func (c *cowHostList) set(list []*HostInfo) {
+	c.mu.Lock()
+	c.list.Store(&list)
+	c.mu.Unlock()
+}
+
+// add will add a host if it not already in the list
+func (c *cowHostList) add(host *HostInfo) bool {
+	c.mu.Lock()
+	l := c.get()
+
+	if n := len(l); n == 0 {
+		l = []*HostInfo{host}
+	} else {
+		newL := make([]*HostInfo, n+1)
+		for i := 0; i < n; i++ {
+			if host.Equal(l[i]) {
+				c.mu.Unlock()
+				return false
+			}
+			newL[i] = l[i]
+		}
+		newL[n] = host
+		l = newL
+	}
+
+	c.list.Store(&l)
+	c.mu.Unlock()
+	return true
+}
+
+func (c *cowHostList) update(host *HostInfo) {
+	c.mu.Lock()
+	l := c.get()
+
+	if len(l) == 0 {
+		c.mu.Unlock()
+		return
+	}
+
+	found := false
+	newL := make([]*HostInfo, len(l))
+	for i := range l {
+		if host.Equal(l[i]) {
+			newL[i] = host
+			found = true
+		} else {
+			newL[i] = l[i]
+		}
+	}
+
+	if found {
+		c.list.Store(&newL)
+	}
+
+	c.mu.Unlock()
+}
+
+func (c *cowHostList) remove(addr string) bool {
+	c.mu.Lock()
+	l := c.get()
+	size := len(l)
+	if size == 0 {
+		c.mu.Unlock()
+		return false
+	}
+
+	found := false
+	newL := make([]*HostInfo, 0, size)
+	for i := 0; i < len(l); i++ {
+		if l[i].Peer() != addr {
+			newL = append(newL, l[i])
+		} else {
+			found = true
+		}
+	}
+
+	if !found {
+		c.mu.Unlock()
+		return false
+	}
+
+	newL = newL[:size-1 : size-1]
+	c.list.Store(&newL)
+	c.mu.Unlock()
+
+	return true
+}
+
 // RetryableQuery is an interface that represents a query or batch statement that
 // exposes the correct functions for the retry policy logic to evaluate correctly.
 type RetryableQuery interface {
@@ -50,9 +159,16 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 	return q.Attempts() <= s.NumRetries
 }
 
+type HostStateNotifier interface {
+	AddHost(host *HostInfo)
+	RemoveHost(addr string)
+	// TODO(zariel): add host up/down
+}
+
 // HostSelectionPolicy is an interface for selecting
 // the most appropriate host to execute a given query.
 type HostSelectionPolicy interface {
+	HostStateNotifier
 	SetHosts
 	SetPartitioner
 	//Pick returns an iteration function over selected hosts
@@ -72,19 +188,17 @@ type NextHost func() SelectedHost
 // RoundRobinHostPolicy is a round-robin load balancing policy, where each host
 // is tried sequentially for each query.
 func RoundRobinHostPolicy() HostSelectionPolicy {
-	return &roundRobinHostPolicy{hosts: []HostInfo{}}
+	return &roundRobinHostPolicy{}
 }
 
 type roundRobinHostPolicy struct {
-	hosts []HostInfo
+	hosts cowHostList
 	pos   uint32
 	mu    sync.RWMutex
 }
 
-func (r *roundRobinHostPolicy) SetHosts(hosts []HostInfo) {
-	r.mu.Lock()
-	r.hosts = hosts
-	r.mu.Unlock()
+func (r *roundRobinHostPolicy) SetHosts(hosts []*HostInfo) {
+	r.hosts.set(hosts)
 }
 
 func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {
@@ -96,24 +210,31 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 	// to the number of hosts known to this policy
 	var i int
 	return func() SelectedHost {
-		r.mu.RLock()
-		defer r.mu.RUnlock()
-		if len(r.hosts) == 0 {
+		hosts := r.hosts.get()
+		if len(hosts) == 0 {
 			return nil
 		}
 
 		// always increment pos to evenly distribute traffic in case of
 		// failures
-		pos := atomic.AddUint32(&r.pos, 1)
-		if i >= len(r.hosts) {
+		pos := atomic.AddUint32(&r.pos, 1) - 1
+		if i >= len(hosts) {
 			return nil
 		}
-		host := &r.hosts[(pos)%uint32(len(r.hosts))]
+		host := hosts[(pos)%uint32(len(hosts))]
 		i++
 		return selectedRoundRobinHost{host}
 	}
 }
 
+func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
+	r.hosts.add(host)
+}
+
+func (r *roundRobinHostPolicy) RemoveHost(addr string) {
+	r.hosts.remove(addr)
+}
+
 // selectedRoundRobinHost is a host returned by the roundRobinHostPolicy and
 // implements the SelectedHost interface
 type selectedRoundRobinHost struct {
@@ -132,24 +253,25 @@ func (host selectedRoundRobinHost) Mark(err error) {
 // selected based on the partition key, so queries are sent to the host which
 // owns the partition. Fallback is used when routing information is not available.
 func TokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
-	return &tokenAwareHostPolicy{fallback: fallback, hosts: []HostInfo{}}
+	return &tokenAwareHostPolicy{fallback: fallback}
 }
 
 type tokenAwareHostPolicy struct {
+	hosts       cowHostList
 	mu          sync.RWMutex
-	hosts       []HostInfo
 	partitioner string
 	tokenRing   *tokenRing
 	fallback    HostSelectionPolicy
 }
 
-func (t *tokenAwareHostPolicy) SetHosts(hosts []HostInfo) {
+func (t *tokenAwareHostPolicy) SetHosts(hosts []*HostInfo) {
+	t.hosts.set(hosts)
+
 	t.mu.Lock()
 	defer t.mu.Unlock()
 
 	// always update the fallback
 	t.fallback.SetHosts(hosts)
-	t.hosts = hosts
 
 	t.resetTokenRing()
 }
@@ -166,6 +288,23 @@ func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
 	}
 }
 
+func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
+	t.hosts.add(host)
+	t.fallback.AddHost(host)
+
+	t.mu.Lock()
+	t.resetTokenRing()
+	t.mu.Unlock()
+}
+
+func (t *tokenAwareHostPolicy) RemoveHost(addr string) {
+	t.hosts.remove(addr)
+
+	t.mu.Lock()
+	t.resetTokenRing()
+	t.mu.Unlock()
+}
+
 func (t *tokenAwareHostPolicy) resetTokenRing() {
 	if t.partitioner == "" {
 		// partitioner not yet set
@@ -173,7 +312,8 @@ func (t *tokenAwareHostPolicy) resetTokenRing() {
 	}
 
 	// create a new token ring
-	tokenRing, err := newTokenRing(t.partitioner, t.hosts)
+	hosts := t.hosts.get()
+	tokenRing, err := newTokenRing(t.partitioner, hosts)
 	if err != nil {
 		log.Printf("Unable to update the token ring due to error: %s", err)
 		return
@@ -215,6 +355,7 @@ func (t *tokenAwareHostPolicy) Pick(qry *Query) NextHost {
 		hostReturned bool
 		fallbackIter NextHost
 	)
+
 	return func() SelectedHost {
 		if !hostReturned {
 			hostReturned = true
@@ -266,22 +407,22 @@ func (host selectedTokenAwareHost) Mark(err error) {
 //     )
 //
 func HostPoolHostPolicy(hp hostpool.HostPool) HostSelectionPolicy {
-	return &hostPoolHostPolicy{hostMap: map[string]HostInfo{}, hp: hp}
+	return &hostPoolHostPolicy{hostMap: map[string]*HostInfo{}, hp: hp}
 }
 
 type hostPoolHostPolicy struct {
 	hp      hostpool.HostPool
-	hostMap map[string]HostInfo
 	mu      sync.RWMutex
+	hostMap map[string]*HostInfo
 }
 
-func (r *hostPoolHostPolicy) SetHosts(hosts []HostInfo) {
+func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) {
 	peers := make([]string, len(hosts))
-	hostMap := make(map[string]HostInfo, len(hosts))
+	hostMap := make(map[string]*HostInfo, len(hosts))
 
 	for i, host := range hosts {
-		peers[i] = host.Peer
-		hostMap[host.Peer] = host
+		peers[i] = host.Peer()
+		hostMap[host.Peer()] = host
 	}
 
 	r.mu.Lock()
@@ -290,6 +431,41 @@ func (r *hostPoolHostPolicy) SetHosts(hosts []HostInfo) {
 	r.mu.Unlock()
 }
 
+func (r *hostPoolHostPolicy) AddHost(host *HostInfo) {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+
+	if _, ok := r.hostMap[host.Peer()]; ok {
+		return
+	}
+
+	hosts := make([]string, 0, len(r.hostMap)+1)
+	for addr := range r.hostMap {
+		hosts = append(hosts, addr)
+	}
+	hosts = append(hosts, host.Peer())
+
+	r.hp.SetHosts(hosts)
+	r.hostMap[host.Peer()] = host
+}
+
+func (r *hostPoolHostPolicy) RemoveHost(addr string) {
+	r.mu.Unlock()
+	defer r.mu.Unlock()
+
+	if _, ok := r.hostMap[addr]; !ok {
+		return
+	}
+
+	delete(r.hostMap, addr)
+	hosts := make([]string, 0, len(r.hostMap))
+	for addr := range r.hostMap {
+		hosts = append(hosts, addr)
+	}
+
+	r.hp.SetHosts(hosts)
+}
+
 func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) {
 	// noop
 }
@@ -309,7 +485,7 @@ func (r *hostPoolHostPolicy) Pick(qry *Query) NextHost {
 			return nil
 		}
 
-		return selectedHostPoolHost{&host, hostR}
+		return selectedHostPoolHost{host, hostR}
 	}
 }
 

+ 74 - 47
policies_test.go

@@ -16,36 +16,35 @@ import (
 func TestRoundRobinHostPolicy(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
-	hosts := []HostInfo{
-		HostInfo{HostId: "0"},
-		HostInfo{HostId: "1"},
+	hosts := []*HostInfo{
+		{hostId: "0"},
+		{hostId: "1"},
 	}
 
 	policy.SetHosts(hosts)
 
-	// the first host selected is actually at [1], but this is ok for RR
 	// interleaved iteration should always increment the host
 	iterA := policy.Pick(nil)
-	if actual := iterA(); actual.Info() != &hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
+	if actual := iterA(); actual.Info() != hosts[0] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
 	}
 	iterB := policy.Pick(nil)
-	if actual := iterB(); actual.Info() != &hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
+	if actual := iterB(); actual.Info() != hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
 	}
-	if actual := iterB(); actual.Info() != &hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
+	if actual := iterB(); actual.Info() != hosts[0] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
 	}
-	if actual := iterA(); actual.Info() != &hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
+	if actual := iterA(); actual.Info() != hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
 	}
 
 	iterC := policy.Pick(nil)
-	if actual := iterC(); actual.Info() != &hosts[1] {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostId)
+	if actual := iterC(); actual.Info() != hosts[0] {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID())
 	}
-	if actual := iterC(); actual.Info() != &hosts[0] {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostId)
+	if actual := iterC(); actual.Info() != hosts[1] {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID())
 	}
 }
 
@@ -66,23 +65,23 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 	}
 
 	// set the hosts
-	hosts := []HostInfo{
-		HostInfo{Peer: "0", Tokens: []string{"00"}},
-		HostInfo{Peer: "1", Tokens: []string{"25"}},
-		HostInfo{Peer: "2", Tokens: []string{"50"}},
-		HostInfo{Peer: "3", Tokens: []string{"75"}},
+	hosts := []*HostInfo{
+		{peer: "0", tokens: []string{"00"}},
+		{peer: "1", tokens: []string{"25"}},
+		{peer: "2", tokens: []string{"50"}},
+		{peer: "3", tokens: []string{"75"}},
 	}
 	policy.SetHosts(hosts)
 
 	// the token ring is not setup without the partitioner, but the fallback
 	// should work
-	if actual := policy.Pick(nil)(); actual.Info().Peer != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer)
+	if actual := policy.Pick(nil)(); actual.Info().Peer() != "0" {
+		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 	}
 
 	query.RoutingKey([]byte("30"))
-	if actual := policy.Pick(query)(); actual.Info().Peer != "2" {
-		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer)
+	if actual := policy.Pick(query)(); actual.Info().Peer() != "1" {
+		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 	}
 
 	policy.SetPartitioner("OrderedPartitioner")
@@ -90,18 +89,18 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 	// now the token ring is configured
 	query.RoutingKey([]byte("20"))
 	iter = policy.Pick(query)
-	if actual := iter(); actual.Info().Peer != "1" {
-		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer)
+	if actual := iter(); actual.Info().Peer() != "1" {
+		t.Errorf("Expected peer 1 but was %s", actual.Info().Peer())
 	}
 	// rest are round robin
-	if actual := iter(); actual.Info().Peer != "3" {
-		t.Errorf("Expected peer 3 but was %s", actual.Info().Peer)
+	if actual := iter(); actual.Info().Peer() != "2" {
+		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer())
 	}
-	if actual := iter(); actual.Info().Peer != "0" {
-		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer)
+	if actual := iter(); actual.Info().Peer() != "3" {
+		t.Errorf("Expected peer 3 but was %s", actual.Info().Peer())
 	}
-	if actual := iter(); actual.Info().Peer != "2" {
-		t.Errorf("Expected peer 2 but was %s", actual.Info().Peer)
+	if actual := iter(); actual.Info().Peer() != "0" {
+		t.Errorf("Expected peer 0 but was %s", actual.Info().Peer())
 	}
 }
 
@@ -109,9 +108,9 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 func TestHostPoolHostPolicy(t *testing.T) {
 	policy := HostPoolHostPolicy(hostpool.New(nil))
 
-	hosts := []HostInfo{
-		HostInfo{HostId: "0", Peer: "0"},
-		HostInfo{HostId: "1", Peer: "1"},
+	hosts := []*HostInfo{
+		{hostId: "0", peer: "0"},
+		{hostId: "1", peer: "1"},
 	}
 
 	policy.SetHosts(hosts)
@@ -120,26 +119,26 @@ func TestHostPoolHostPolicy(t *testing.T) {
 	// interleaved iteration should always increment the host
 	iter := policy.Pick(nil)
 	actualA := iter()
-	if actualA.Info().HostId != "0" {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actualA.Info().HostId)
+	if actualA.Info().HostID() != "0" {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actualA.Info().HostID())
 	}
 	actualA.Mark(nil)
 
 	actualB := iter()
-	if actualB.Info().HostId != "1" {
-		t.Errorf("Expected hosts[1] but was hosts[%s]", actualB.Info().HostId)
+	if actualB.Info().HostID() != "1" {
+		t.Errorf("Expected hosts[1] but was hosts[%s]", actualB.Info().HostID())
 	}
 	actualB.Mark(fmt.Errorf("error"))
 
 	actualC := iter()
-	if actualC.Info().HostId != "0" {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actualC.Info().HostId)
+	if actualC.Info().HostID() != "0" {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actualC.Info().HostID())
 	}
 	actualC.Mark(nil)
 
 	actualD := iter()
-	if actualD.Info().HostId != "0" {
-		t.Errorf("Expected hosts[0] but was hosts[%s]", actualD.Info().HostId)
+	if actualD.Info().HostID() != "0" {
+		t.Errorf("Expected hosts[0] but was hosts[%s]", actualD.Info().HostID())
 	}
 	actualD.Mark(nil)
 }
@@ -171,8 +170,8 @@ func TestRoundRobinConnPolicy(t *testing.T) {
 func TestRoundRobinNilHostInfo(t *testing.T) {
 	policy := RoundRobinHostPolicy()
 
-	host := HostInfo{HostId: "host-1"}
-	policy.SetHosts([]HostInfo{host})
+	host := &HostInfo{hostId: "host-1"}
+	policy.SetHosts([]*HostInfo{host})
 
 	iter := policy.Pick(nil)
 	next := iter()
@@ -180,7 +179,7 @@ func TestRoundRobinNilHostInfo(t *testing.T) {
 		t.Fatal("got nil host")
 	} else if v := next.Info(); v == nil {
 		t.Fatal("got nil HostInfo")
-	} else if v.HostId != host.HostId {
+	} else if v.HostID() != host.HostID() {
 		t.Fatalf("expected host %v got %v", host, *v)
 	}
 
@@ -192,3 +191,31 @@ func TestRoundRobinNilHostInfo(t *testing.T) {
 		}
 	}
 }
+
+func TestCOWList_Add(t *testing.T) {
+	var cow cowHostList
+
+	toAdd := [...]string{"peer1", "peer2", "peer3"}
+
+	for _, addr := range toAdd {
+		if !cow.add(&HostInfo{peer: addr}) {
+			t.Fatal("did not add peer which was not in the set")
+		}
+	}
+
+	hosts := cow.get()
+	if len(hosts) != len(toAdd) {
+		t.Fatalf("expected to have %d hosts got %d", len(toAdd), len(hosts))
+	}
+
+	set := make(map[string]bool)
+	for _, host := range hosts {
+		set[host.Peer()] = true
+	}
+
+	for _, addr := range toAdd {
+		if !set[addr] {
+			t.Errorf("addr was not in the host list: %q", addr)
+		}
+	}
+}

+ 75 - 0
ring.go

@@ -0,0 +1,75 @@
+package gocql
+
+import (
+	"sync"
+)
+
+type ring struct {
+	// endpoints are the set of endpoints which the driver will attempt to connect
+	// to in the case it can not reach any of its hosts. They are also used to boot
+	// strap the initial connection.
+	endpoints []string
+	// hosts are the set of all hosts in the cassandra ring that we know of
+	mu    sync.RWMutex
+	hosts map[string]*HostInfo
+
+	// TODO: we should store the ring metadata here also.
+}
+
+func (r *ring) getHost(addr string) *HostInfo {
+	r.mu.RLock()
+	host := r.hosts[addr]
+	r.mu.RUnlock()
+	return host
+}
+
+func (r *ring) allHosts() []*HostInfo {
+	r.mu.RLock()
+	hosts := make([]*HostInfo, 0, len(r.hosts))
+	for _, host := range r.hosts {
+		hosts = append(hosts, host)
+	}
+	r.mu.RUnlock()
+	return hosts
+}
+
+func (r *ring) addHost(host *HostInfo) bool {
+	r.mu.Lock()
+	if r.hosts == nil {
+		r.hosts = make(map[string]*HostInfo)
+	}
+
+	addr := host.Peer()
+	_, ok := r.hosts[addr]
+	r.hosts[addr] = host
+	r.mu.Unlock()
+	return ok
+}
+
+func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
+	r.mu.Lock()
+	if r.hosts == nil {
+		r.hosts = make(map[string]*HostInfo)
+	}
+
+	addr := host.Peer()
+	existing, ok := r.hosts[addr]
+	if !ok {
+		r.hosts[addr] = host
+		existing = host
+	}
+	r.mu.Unlock()
+	return existing, ok
+}
+
+func (r *ring) removeHost(addr string) bool {
+	r.mu.Lock()
+	if r.hosts == nil {
+		r.hosts = make(map[string]*HostInfo)
+	}
+
+	_, ok := r.hosts[addr]
+	delete(r.hosts, addr)
+	r.mu.Unlock()
+	return ok
+}

+ 82 - 29
session.go

@@ -10,7 +10,8 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"log"
+	"net"
+	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -37,10 +38,22 @@ type Session struct {
 	schemaDescriber     *schemaDescriber
 	trace               Tracer
 	hostSource          *ringDescriber
-	mu                  sync.RWMutex
+	ring                ring
+
+	connCfg *ConnConfig
+
+	mu sync.RWMutex
+
+	hostFilter HostFilter
 
 	control *controlConn
 
+	// event handlers
+	nodeEvents *eventDeouncer
+
+	// ring metadata
+	hosts []HostInfo
+
 	cfg ClusterConfig
 
 	closeMu  sync.RWMutex
@@ -66,49 +79,75 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		pageSize: cfg.PageSize,
 	}
 
-	pool, err := cfg.PoolConfig.buildPool(s)
+	connCfg, err := connConfig(s)
 	if err != nil {
-		return nil, err
-	}
-	s.pool = pool
-
-	// See if there are any connections in the pool
-	if pool.Size() == 0 {
 		s.Close()
-		return nil, ErrNoConnectionsStarted
+		return nil, fmt.Errorf("gocql: unable to create session: %v", err)
 	}
+	s.connCfg = connCfg
+
+	s.nodeEvents = newEventDeouncer("NodeEvents", s.handleNodeEvent)
 
 	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
 	// I think it might be a good idea to simplify this and make it always discover
 	// hosts, maybe with more filters.
-	if cfg.DiscoverHosts {
-		s.hostSource = &ringDescriber{
-			session:    s,
-			dcFilter:   cfg.Discovery.DcFilter,
-			rackFilter: cfg.Discovery.RackFilter,
-			closeChan:  make(chan bool),
-		}
+	s.hostSource = &ringDescriber{
+		session:   s,
+		closeChan: make(chan bool),
 	}
 
+	s.pool = cfg.PoolConfig.buildPool(s)
+
+	var hosts []*HostInfo
+
 	if !cfg.disableControlConn {
 		s.control = createControlConn(s)
-		s.control.reconnect(false)
+		if err := s.control.connect(cfg.Hosts); err != nil {
+			s.Close()
+			return nil, err
+		}
 
-		// need to setup host source to check for rpc_address in system.local
-		localHasRPCAddr, err := checkSystemLocal(s.control)
+		// need to setup host source to check for broadcast_address in system.local
+		localHasRPCAddr, _ := checkSystemLocal(s.control)
+		s.hostSource.localHasRpcAddr = localHasRPCAddr
+		hosts, _, err = s.hostSource.GetHosts()
 		if err != nil {
-			log.Printf("gocql: unable to verify if system.local table contains rpc_address, falling back to connection address: %v", err)
+			s.Close()
+			return nil, err
 		}
 
-		if cfg.DiscoverHosts {
-			s.hostSource.localHasRpcAddr = localHasRPCAddr
+	} else {
+		// we dont get host info
+		hosts = make([]*HostInfo, len(cfg.Hosts))
+		for i, hostport := range cfg.Hosts {
+			// TODO: remove duplication
+			addr, portStr, err := net.SplitHostPort(JoinHostPort(hostport, cfg.Port))
+			if err != nil {
+				s.Close()
+				return nil, fmt.Errorf("NewSession: unable to parse hostport of addr %q: %v", hostport, err)
+			}
+
+			port, err := strconv.Atoi(portStr)
+			if err != nil {
+				s.Close()
+				return nil, fmt.Errorf("NewSession: invalid port for hostport of addr %q: %v", hostport, err)
+			}
+
+			hosts[i] = &HostInfo{peer: addr, port: port, state: NodeUp}
 		}
 	}
 
-	if cfg.DiscoverHosts {
-		s.hostSource.refreshRing()
-		go s.hostSource.run(cfg.Discovery.Sleep)
+	for _, host := range hosts {
+		s.handleNodeUp(net.ParseIP(host.Peer()), host.Port(), false)
+	}
+
+	// TODO(zariel): we probably dont need this any more as we verify that we
+	// can connect to one of the endpoints supplied by using the control conn.
+	// See if there are any connections in the pool
+	if s.pool.Size() == 0 {
+		s.Close()
+		return nil, ErrNoConnectionsStarted
 	}
 
 	return s, nil
@@ -206,6 +245,10 @@ func (s *Session) Close() {
 	if s.control != nil {
 		s.control.close()
 	}
+
+	if s.nodeEvents != nil {
+		s.nodeEvents.stop()
+	}
 }
 
 func (s *Session) Closed() bool {
@@ -228,6 +271,7 @@ func (s *Session) executeQuery(qry *Query) *Iter {
 	for {
 		host, conn := s.pool.Pick(qry)
 
+		qry.attempts++
 		//Assign the error unavailable to the iterator
 		if conn == nil {
 			if qry.rt == nil || !qry.rt.Attempt(qry) {
@@ -241,11 +285,10 @@ func (s *Session) executeQuery(qry *Query) *Iter {
 		t := time.Now()
 		iter = conn.executeQuery(qry)
 		qry.totalLatency += time.Now().Sub(t).Nanoseconds()
-		qry.attempts++
 
 		//Exit for loop if the query was successful
 		if iter.err == nil {
-			host.Mark(iter.err)
+			host.Mark(nil)
 			break
 		}
 
@@ -495,6 +538,10 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
 	}
 }
 
+func (s *Session) connect(addr string, errorHandler ConnErrorHandler) (*Conn, error) {
+	return Connect(addr, s.connCfg, errorHandler, s)
+}
+
 // Query represents a CQL statement that can be executed.
 type Query struct {
 	stmt             string
@@ -707,7 +754,7 @@ func (q *Query) PageState(state []byte) *Query {
 // Exec executes the query without returning any rows.
 func (q *Query) Exec() error {
 	iter := q.Iter()
-	return iter.err
+	return iter.Close()
 }
 
 func isUseStatement(stmt string) bool {
@@ -798,11 +845,17 @@ type Iter struct {
 	rows [][][]byte
 	meta resultMetadata
 	next *nextIter
+	host *HostInfo
 
 	framer *framer
 	once   sync.Once
 }
 
+// Host returns the host which the query was sent to.
+func (iter *Iter) Host() *HostInfo {
+	return iter.host
+}
+
 // Columns returns the name and type of the selected columns.
 func (iter *Iter) Columns() []ColumnInfo {
 	return iter.meta.columns

+ 3 - 11
session_test.go

@@ -16,11 +16,7 @@ func TestSessionAPI(t *testing.T) {
 		cons: Quorum,
 	}
 
-	var err error
-	s.pool, err = cfg.PoolConfig.buildPool(s)
-	if err != nil {
-		t.Fatal(err)
-	}
+	s.pool = cfg.PoolConfig.buildPool(s)
 	defer s.Close()
 
 	s.SetConsistency(All)
@@ -70,7 +66,7 @@ func TestSessionAPI(t *testing.T) {
 
 	testBatch := s.NewBatch(LoggedBatch)
 	testBatch.Query("test")
-	err = s.ExecuteBatch(testBatch)
+	err := s.ExecuteBatch(testBatch)
 
 	if err != ErrNoConnections {
 		t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err)
@@ -167,11 +163,7 @@ func TestBatchBasicAPI(t *testing.T) {
 	}
 	defer s.Close()
 
-	var err error
-	s.pool, err = cfg.PoolConfig.buildPool(s)
-	if err != nil {
-		t.Fatal(err)
-	}
+	s.pool = cfg.PoolConfig.buildPool(s)
 
 	b := s.NewBatch(UnloggedBatch)
 	if b.Type != UnloggedBatch {

+ 4 - 5
token.go

@@ -121,7 +121,7 @@ type tokenRing struct {
 	hosts       []*HostInfo
 }
 
-func newTokenRing(partitioner string, hosts []HostInfo) (*tokenRing, error) {
+func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
 	tokenRing := &tokenRing{
 		tokens: []token{},
 		hosts:  []*HostInfo{},
@@ -137,9 +137,8 @@ func newTokenRing(partitioner string, hosts []HostInfo) (*tokenRing, error) {
 		return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
 	}
 
-	for i := range hosts {
-		host := &hosts[i]
-		for _, strToken := range host.Tokens {
+	for _, host := range hosts {
+		for _, strToken := range host.Tokens() {
 			token := tokenRing.partitioner.ParseString(strToken)
 			tokenRing.tokens = append(tokenRing.tokens, token)
 			tokenRing.hosts = append(tokenRing.hosts, host)
@@ -181,7 +180,7 @@ func (t *tokenRing) String() string {
 		buf.WriteString("]")
 		buf.WriteString(t.tokens[i].String())
 		buf.WriteString(":")
-		buf.WriteString(t.hosts[i].Peer)
+		buf.WriteString(t.hosts[i].Peer())
 	}
 	buf.WriteString("\n}")
 	return string(buf.Bytes())

+ 75 - 75
token_test.go

@@ -215,22 +215,22 @@ func TestUnknownTokenRing(t *testing.T) {
 // Test of the tokenRing with the Murmur3Partitioner
 func TestMurmur3TokenRing(t *testing.T) {
 	// Note, strings are parsed directly to int64, they are not murmur3 hashed
-	var hosts []HostInfo = []HostInfo{
-		HostInfo{
-			Peer:   "0",
-			Tokens: []string{"0"},
+	hosts := []*HostInfo{
+		{
+			peer:   "0",
+			tokens: []string{"0"},
 		},
-		HostInfo{
-			Peer:   "1",
-			Tokens: []string{"25"},
+		{
+			peer:   "1",
+			tokens: []string{"25"},
 		},
-		HostInfo{
-			Peer:   "2",
-			Tokens: []string{"50"},
+		{
+			peer:   "2",
+			tokens: []string{"50"},
 		},
-		HostInfo{
-			Peer:   "3",
-			Tokens: []string{"75"},
+		{
+			peer:   "3",
+			tokens: []string{"75"},
 		},
 	}
 	ring, err := newTokenRing("Murmur3Partitioner", hosts)
@@ -242,33 +242,33 @@ func TestMurmur3TokenRing(t *testing.T) {
 
 	var actual *HostInfo
 	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer)
+	if actual.Peer() != "0" {
+		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer)
+	if actual.Peer() != "1" {
+		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer)
+	if actual.Peer() != "2" {
+		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer)
+	if actual.Peer() != "3" {
+		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer != "1" {
-		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer)
+	if actual.Peer() != "1" {
+		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer != "0" {
-		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer)
+	if actual.Peer() != "0" {
+		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }
 
@@ -276,28 +276,28 @@ func TestMurmur3TokenRing(t *testing.T) {
 func TestOrderedTokenRing(t *testing.T) {
 	// Tokens here more or less are similar layout to the int tokens above due
 	// to each numeric character translating to a consistently offset byte.
-	var hosts []HostInfo = []HostInfo{
-		HostInfo{
-			Peer: "0",
-			Tokens: []string{
+	hosts := []*HostInfo{
+		{
+			peer: "0",
+			tokens: []string{
 				"00",
 			},
 		},
-		HostInfo{
-			Peer: "1",
-			Tokens: []string{
+		{
+			peer: "1",
+			tokens: []string{
 				"25",
 			},
 		},
-		HostInfo{
-			Peer: "2",
-			Tokens: []string{
+		{
+			peer: "2",
+			tokens: []string{
 				"50",
 			},
 		},
-		HostInfo{
-			Peer: "3",
-			Tokens: []string{
+		{
+			peer: "3",
+			tokens: []string{
 				"75",
 			},
 		},
@@ -311,61 +311,61 @@ func TestOrderedTokenRing(t *testing.T) {
 
 	var actual *HostInfo
 	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer)
+	if actual.Peer() != "0" {
+		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer)
+	if actual.Peer() != "1" {
+		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer)
+	if actual.Peer() != "2" {
+		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer)
+	if actual.Peer() != "3" {
+		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer != "1" {
-		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer)
+	if actual.Peer() != "1" {
+		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer != "1" {
-		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer)
+	if actual.Peer() != "1" {
+		t.Errorf("Expected peer 1 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }
 
 // Test of the tokenRing with the RandomPartitioner
 func TestRandomTokenRing(t *testing.T) {
 	// String tokens are parsed into big.Int in base 10
-	var hosts []HostInfo = []HostInfo{
-		HostInfo{
-			Peer: "0",
-			Tokens: []string{
+	hosts := []*HostInfo{
+		{
+			peer: "0",
+			tokens: []string{
 				"00",
 			},
 		},
-		HostInfo{
-			Peer: "1",
-			Tokens: []string{
+		{
+			peer: "1",
+			tokens: []string{
 				"25",
 			},
 		},
-		HostInfo{
-			Peer: "2",
-			Tokens: []string{
+		{
+			peer: "2",
+			tokens: []string{
 				"50",
 			},
 		},
-		HostInfo{
-			Peer: "3",
-			Tokens: []string{
+		{
+			peer: "3",
+			tokens: []string{
 				"75",
 			},
 		},
@@ -379,32 +379,32 @@ func TestRandomTokenRing(t *testing.T) {
 
 	var actual *HostInfo
 	actual = ring.GetHostForToken(p.ParseString("0"))
-	if actual.Peer != "0" {
-		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer)
+	if actual.Peer() != "0" {
+		t.Errorf("Expected peer 0 for token \"0\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("25"))
-	if actual.Peer != "1" {
-		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer)
+	if actual.Peer() != "1" {
+		t.Errorf("Expected peer 1 for token \"25\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("50"))
-	if actual.Peer != "2" {
-		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer)
+	if actual.Peer() != "2" {
+		t.Errorf("Expected peer 2 for token \"50\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("75"))
-	if actual.Peer != "3" {
-		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer)
+	if actual.Peer() != "3" {
+		t.Errorf("Expected peer 3 for token \"01\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("12"))
-	if actual.Peer != "1" {
-		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer)
+	if actual.Peer() != "1" {
+		t.Errorf("Expected peer 1 for token \"12\", but was %s", actual.Peer())
 	}
 
 	actual = ring.GetHostForToken(p.ParseString("24324545443332"))
-	if actual.Peer != "0" {
-		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer)
+	if actual.Peer() != "0" {
+		t.Errorf("Expected peer 0 for token \"24324545443332\", but was %s", actual.Peer())
 	}
 }