Browse Source

Merge pull request #480 from Zariel/remove-simple-pool

Remove SimpleConnPool, remove custom pools
Chris Bannister 10 years ago
parent
commit
139a10f05e
11 changed files with 121 additions and 517 deletions
  1. 1 1
      cass1batch_test.go
  2. 13 11
      cassandra_test.go
  3. 30 2
      cluster.go
  4. 7 6
      conn_test.go
  5. 14 437
      connectionpool.go
  6. 6 11
      host_source.go
  7. 16 11
      policies.go
  8. 3 3
      policies_test.go
  9. 26 29
      session.go
  10. 4 4
      session_test.go
  11. 1 2
      stress_test.go

+ 1 - 1
cass1batch_test.go

@@ -10,7 +10,7 @@ import (
 func TestProto1BatchInsert(t *testing.T) {
 func TestProto1BatchInsert(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
 	if err := createTable(session, "CREATE TABLE large (id int primary key)"); err != nil {
 	if err := createTable(session, "CREATE TABLE large (id int primary key)"); err != nil {
-		t.Fatal("create table:", err)
+		t.Fatal(err)
 	}
 	}
 	defer session.Close()
 	defer session.Close()
 
 

+ 13 - 11
cassandra_test.go

@@ -96,7 +96,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
 	}
 	}
 
 
 	// should reuse the same conn apparently
 	// should reuse the same conn apparently
-	conn := session.Pool.Pick(nil)
+	conn := session.pool.Pick(nil)
 	if conn == nil {
 	if conn == nil {
 		tb.Fatal("no connections available in the pool")
 		tb.Fatal("no connections available in the pool")
 	}
 	}
@@ -189,7 +189,9 @@ func TestRingDiscovery(t *testing.T) {
 		time.Sleep(*flagAutoWait)
 		time.Sleep(*flagAutoWait)
 	}
 	}
 
 
-	size := len(session.Pool.(*SimplePool).connPool)
+	session.pool.mu.RLock()
+	size := len(session.pool.hostConnPools)
+	session.pool.mu.RUnlock()
 
 
 	if *clusterSize != size {
 	if *clusterSize != size {
 		t.Logf("WARN: Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
 		t.Logf("WARN: Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
@@ -1139,7 +1141,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 		t.Fatal("create:", err)
 		t.Fatal("create:", err)
 	}
 	}
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
-	conn := session.Pool.Pick(nil)
+	conn := session.pool.Pick(nil)
 	flight := new(inflightPrepare)
 	flight := new(inflightPrepare)
 	stmtsLRU.Lock()
 	stmtsLRU.Lock()
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
@@ -1165,7 +1167,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 
 
 func TestMissingSchemaPrepare(t *testing.T) {
 func TestMissingSchemaPrepare(t *testing.T) {
 	s := createSession(t)
 	s := createSession(t)
-	conn := s.Pool.Pick(nil)
+	conn := s.pool.Pick(nil)
 	defer s.Close()
 	defer s.Close()
 
 
 	insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons,
 	insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons,
@@ -1214,7 +1216,7 @@ func TestQueryInfo(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
 	defer session.Close()
 	defer session.Close()
 
 
-	conn := session.Pool.Pick(nil)
+	conn := session.pool.Pick(nil)
 	info, err := conn.prepareStatement("SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
 	info, err := conn.prepareStatement("SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
 
 
 	if err != nil {
 	if err != nil {
@@ -2016,7 +2018,7 @@ func TestRoutingKey(t *testing.T) {
 // Integration test of the token-aware policy-based connection pool
 // Integration test of the token-aware policy-based connection pool
 func TestTokenAwareConnPool(t *testing.T) {
 func TestTokenAwareConnPool(t *testing.T) {
 	cluster := createCluster()
 	cluster := createCluster()
-	cluster.ConnPoolType = NewTokenAwareConnPool
+	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
 	cluster.DiscoverHosts = true
 	cluster.DiscoverHosts = true
 
 
 	// Drop and re-create the keyspace once. Different tests should use their own
 	// Drop and re-create the keyspace once. Different tests should use their own
@@ -2037,8 +2039,8 @@ func TestTokenAwareConnPool(t *testing.T) {
 		time.Sleep(*flagAutoWait)
 		time.Sleep(*flagAutoWait)
 	}
 	}
 
 
-	if session.Pool.Size() != cluster.NumConns*len(cluster.Hosts) {
-		t.Errorf("Expected pool size %d but was %d", cluster.NumConns*len(cluster.Hosts), session.Pool.Size())
+	if session.pool.Size() != cluster.NumConns*len(cluster.Hosts) {
+		t.Errorf("Expected pool size %d but was %d", cluster.NumConns*len(cluster.Hosts), session.pool.Size())
 	}
 	}
 
 
 	if err := createTable(session, "CREATE TABLE test_token_aware (id int, data text, PRIMARY KEY (id))"); err != nil {
 	if err := createTable(session, "CREATE TABLE test_token_aware (id int, data text, PRIMARY KEY (id))"); err != nil {
@@ -2077,7 +2079,7 @@ func TestStream0(t *testing.T) {
 			break
 			break
 		}
 		}
 
 
-		conn = session.Pool.Pick(nil)
+		conn = session.pool.Pick(nil)
 	}
 	}
 
 
 	if conn == nil {
 	if conn == nil {
@@ -2116,7 +2118,7 @@ func TestNegativeStream(t *testing.T) {
 			break
 			break
 		}
 		}
 
 
-		conn = session.Pool.Pick(nil)
+		conn = session.pool.Pick(nil)
 	}
 	}
 
 
 	if conn == nil {
 	if conn == nil {
@@ -2222,7 +2224,7 @@ func TestLexicalUUIDType(t *testing.T) {
 // Issue 475
 // Issue 475
 func TestSessionBindRoutingKey(t *testing.T) {
 func TestSessionBindRoutingKey(t *testing.T) {
 	cluster := createCluster()
 	cluster := createCluster()
-	cluster.ConnPoolType = NewTokenAwareConnPool
+	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 
 	session := createSessionFromCluster(cluster, t)
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()
 	defer session.Close()

+ 30 - 2
cluster.go

@@ -50,6 +50,33 @@ type DiscoveryConfig struct {
 	Sleep time.Duration
 	Sleep time.Duration
 }
 }
 
 
+// PoolConfig configures the connection pool used by the driver, it defaults to
+// using a round robbin host selection policy and a round robbin connection selection
+// policy for each host.
+type PoolConfig struct {
+	// HostSelectionPolicy sets the policy for selecting which host to use for a
+	// given query (default: RoundRobinHostPolicy())
+	HostSelectionPolicy HostSelectionPolicy
+
+	// ConnSelectionPolicy sets the policy factory for selecting a connection to use for
+	// each host for a query (default: RoundRobinConnPolicy())
+	ConnSelectionPolicy func() ConnSelectionPolicy
+}
+
+func (p PoolConfig) buildPool(cfg *ClusterConfig) (*policyConnPool, error) {
+	hostSelection := p.HostSelectionPolicy
+	if hostSelection == nil {
+		hostSelection = RoundRobinHostPolicy()
+	}
+
+	connSelection := p.ConnSelectionPolicy
+	if connSelection == nil {
+		connSelection = RoundRobinConnPolicy()
+	}
+
+	return newPolicyConnPool(cfg, hostSelection, connSelection)
+}
+
 // ClusterConfig is a struct to configure the default cluster implementation
 // ClusterConfig is a struct to configure the default cluster implementation
 // of gocoql. It has a varity of attributes that can be used to modify the
 // of gocoql. It has a varity of attributes that can be used to modify the
 // behavior to fit the most common use cases. Applications that requre a
 // behavior to fit the most common use cases. Applications that requre a
@@ -68,7 +95,6 @@ type ClusterConfig struct {
 	Authenticator     Authenticator     // authenticator (default: nil)
 	Authenticator     Authenticator     // authenticator (default: nil)
 	RetryPolicy       RetryPolicy       // Default retry policy to use for queries (default: 0)
 	RetryPolicy       RetryPolicy       // Default retry policy to use for queries (default: 0)
 	SocketKeepalive   time.Duration     // The keepalive period to use, enabled if > 0 (default: 0)
 	SocketKeepalive   time.Duration     // The keepalive period to use, enabled if > 0 (default: 0)
-	ConnPoolType      NewPoolFunc       // The function used to create the connection pool for the session (default: NewSimplePool)
 	DiscoverHosts     bool              // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
 	DiscoverHosts     bool              // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
 	MaxPreparedStmts  int               // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
 	MaxPreparedStmts  int               // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
 	MaxRoutingKeyInfo int               // Sets the maximum cache size for query info about statements for each session (default: 1000)
 	MaxRoutingKeyInfo int               // Sets the maximum cache size for query info about statements for each session (default: 1000)
@@ -77,6 +103,9 @@ type ClusterConfig struct {
 	Discovery         DiscoveryConfig
 	Discovery         DiscoveryConfig
 	SslOpts           *SslOptions
 	SslOpts           *SslOptions
 	DefaultTimestamp  bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
 	DefaultTimestamp  bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
+	// PoolConfig configures the underlying connection pool, allowing the
+	// configuration of host selection and connection selection policies.
+	PoolConfig PoolConfig
 }
 }
 
 
 // NewCluster generates a new config for the default cluster implementation.
 // NewCluster generates a new config for the default cluster implementation.
@@ -89,7 +118,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Port:              9042,
 		Port:              9042,
 		NumConns:          2,
 		NumConns:          2,
 		Consistency:       Quorum,
 		Consistency:       Quorum,
-		ConnPoolType:      NewSimplePool,
 		DiscoverHosts:     false,
 		DiscoverHosts:     false,
 		MaxPreparedStmts:  defaultMaxPreparedStmts,
 		MaxPreparedStmts:  defaultMaxPreparedStmts,
 		MaxRoutingKeyInfo: 1000,
 		MaxRoutingKeyInfo: 1000,

+ 7 - 6
conn_test.go

@@ -259,8 +259,7 @@ func TestConnClosing(t *testing.T) {
 	wg.Wait()
 	wg.Wait()
 
 
 	time.Sleep(1 * time.Second) //Sleep so the fillPool can complete.
 	time.Sleep(1 * time.Second) //Sleep so the fillPool can complete.
-	pool := db.Pool.(ConnectionPool)
-	conns := pool.Size()
+	conns := db.pool.Size()
 
 
 	if conns != numConns {
 	if conns != numConns {
 		t.Errorf("Expected to have %d connections but have %d", numConns, conns)
 		t.Errorf("Expected to have %d connections but have %d", numConns, conns)
@@ -390,7 +389,8 @@ func TestRoundRobinConnPoolRoundRobin(t *testing.T) {
 
 
 	// create a new cluster using the policy-based round robin conn pool
 	// create a new cluster using the policy-based round robin conn pool
 	cluster := NewCluster(addrs...)
 	cluster := NewCluster(addrs...)
-	cluster.ConnPoolType = NewRoundRobinConnPool
+	cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+	cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy()
 
 
 	db, err := cluster.CreateSession()
 	db, err := cluster.CreateSession()
 	if err != nil {
 	if err != nil {
@@ -420,7 +420,7 @@ func TestRoundRobinConnPoolRoundRobin(t *testing.T) {
 
 
 	// wait for the pool to drain
 	// wait for the pool to drain
 	time.Sleep(100 * time.Millisecond)
 	time.Sleep(100 * time.Millisecond)
-	size := db.Pool.Size()
+	size := db.pool.Size()
 	if size != 0 {
 	if size != 0 {
 		t.Errorf("connection pool did not drain, still contains %d connections", size)
 		t.Errorf("connection pool did not drain, still contains %d connections", size)
 	}
 	}
@@ -450,7 +450,8 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
-	cluster.ConnPoolType = NewRoundRobinConnPool
+	cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+	cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy()
 
 
 	db, err := cluster.CreateSession()
 	db, err := cluster.CreateSession()
 	if err != nil {
 	if err != nil {
@@ -465,7 +466,7 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 
 
 	// wait for the pool to drain
 	// wait for the pool to drain
 	time.Sleep(100 * time.Millisecond)
 	time.Sleep(100 * time.Millisecond)
-	size := db.Pool.Size()
+	size := db.pool.Size()
 	if size != 0 {
 	if size != 0 {
 		t.Errorf("connection pool did not drain, still contains %d connections", size)
 		t.Errorf("connection pool did not drain, still contains %d connections", size)
 	}
 	}

+ 14 - 437
connectionpool.go

@@ -17,89 +17,6 @@ import (
 	"time"
 	"time"
 )
 )
 
 
-/*ConnectionPool represents the interface gocql will use to work with a collection of connections.
-
-Purpose
-
-The connection pool in gocql opens and closes connections as well as selects an available connection
-for gocql to execute a query against. The pool is also responsible for handling connection errors that
-are caught by the connection experiencing the error.
-
-A connection pool should make a copy of the variables used from the ClusterConfig provided to the pool
-upon creation. ClusterConfig is a pointer and can be modified after the creation of the pool. This can
-lead to issues with variables being modified outside the expectations of the ConnectionPool type.
-
-Example of Single Connection Pool:
-
-	type SingleConnection struct {
-		conn *Conn
-		cfg *ClusterConfig
-	}
-
-	func NewSingleConnection(cfg *ClusterConfig) ConnectionPool {
-		addr := JoinHostPort(cfg.Hosts[0], cfg.Port)
-
-		connCfg := ConnConfig{
-			ProtoVersion:  cfg.ProtoVersion,
-			CQLVersion:    cfg.CQLVersion,
-			Timeout:       cfg.Timeout,
-			NumStreams:    cfg.NumStreams,
-			Compressor:    cfg.Compressor,
-			Authenticator: cfg.Authenticator,
-			Keepalive:     cfg.SocketKeepalive,
-		}
-		pool := SingleConnection{cfg:cfg}
-		pool.conn = Connect(addr,connCfg,pool)
-		return &pool
-	}
-
-	func (s *SingleConnection) HandleError(conn *Conn, err error, closed bool) {
-		if closed {
-			connCfg := ConnConfig{
-				ProtoVersion:  cfg.ProtoVersion,
-				CQLVersion:    cfg.CQLVersion,
-				Timeout:       cfg.Timeout,
-				NumStreams:    cfg.NumStreams,
-				Compressor:    cfg.Compressor,
-				Authenticator: cfg.Authenticator,
-				Keepalive:     cfg.SocketKeepalive,
-			}
-			s.conn = Connect(conn.Address(),connCfg,s)
-		}
-	}
-
-	func (s *SingleConnection) Pick(qry *Query) *Conn {
-		if s.conn.isClosed {
-			return nil
-		}
-		return s.conn
-	}
-
-	func (s *SingleConnection) Size() int {
-		return 1
-	}
-
-	func (s *SingleConnection) Close() {
-		s.conn.Close()
-	}
-
-This is a very simple example of a type that exposes the connection pool interface. To assign
-this type as the connection pool to use you would assign it to the ClusterConfig like so:
-
-		cluster := NewCluster("127.0.0.1")
-		cluster.ConnPoolType = NewSingleConnection
-		...
-		session, err := cluster.CreateSession()
-
-To see a more complete example of a ConnectionPool implementation please see the SimplePool type.
-*/
-type ConnectionPool interface {
-	SetHosts
-	Pick(*Query) *Conn
-	Size() int
-	Close()
-}
-
 // interface to implement to receive the host information
 // interface to implement to receive the host information
 type SetHosts interface {
 type SetHosts interface {
 	SetHosts(hosts []HostInfo)
 	SetHosts(hosts []HostInfo)
@@ -110,35 +27,6 @@ type SetPartitioner interface {
 	SetPartitioner(partitioner string)
 	SetPartitioner(partitioner string)
 }
 }
 
 
-//NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
-type NewPoolFunc func(*ClusterConfig) (ConnectionPool, error)
-
-//SimplePool is the current implementation of the connection pool inside gocql. This
-//pool is meant to be a simple default used by gocql so users can get up and running
-//quickly.
-type SimplePool struct {
-	cfg      *ClusterConfig
-	hostPool *RoundRobin
-	connPool map[string]*RoundRobin
-	conns    map[*Conn]struct{}
-	keyspace string
-
-	hostMu sync.RWMutex
-	// this is the set of current hosts which the pool will attempt to connect to
-	hosts map[string]*HostInfo
-
-	// protects hostpool, connPoll, conns, quit
-	mu sync.Mutex
-
-	cFillingPool chan int
-
-	quit     bool
-	quitWait chan bool
-	quitOnce sync.Once
-
-	tlsConfig *tls.Config
-}
-
 func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
 func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
 	// ca cert is optional
 	// ca cert is optional
 	if sslOpts.CaPath != "" {
 	if sslOpts.CaPath != "" {
@@ -169,309 +57,6 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
 	return &sslOpts.Config, nil
 	return &sslOpts.Config, nil
 }
 }
 
 
-//NewSimplePool is the function used by gocql to create the simple connection pool.
-//This is the default if no other pool type is specified.
-func NewSimplePool(cfg *ClusterConfig) (ConnectionPool, error) {
-	pool := &SimplePool{
-		cfg:          cfg,
-		hostPool:     NewRoundRobin(),
-		connPool:     make(map[string]*RoundRobin),
-		conns:        make(map[*Conn]struct{}),
-		quitWait:     make(chan bool),
-		cFillingPool: make(chan int, 1),
-		keyspace:     cfg.Keyspace,
-		hosts:        make(map[string]*HostInfo),
-	}
-
-	for _, host := range cfg.Hosts {
-		// seed hosts have unknown topology
-		// TODO: Handle populating this during SetHosts
-		pool.hosts[host] = &HostInfo{Peer: host}
-	}
-
-	if cfg.SslOpts != nil {
-		config, err := setupTLSConfig(cfg.SslOpts)
-		if err != nil {
-			return nil, err
-		}
-		pool.tlsConfig = config
-	}
-
-	//Walk through connecting to hosts. As soon as one host connects
-	//defer the remaining connections to cluster.fillPool()
-	for i := 0; i < len(cfg.Hosts); i++ {
-		addr := JoinHostPort(cfg.Hosts[i], cfg.Port)
-
-		if pool.connect(addr) == nil {
-			pool.cFillingPool <- 1
-			go pool.fillPool()
-			break
-		}
-	}
-
-	return pool, nil
-}
-
-func (c *SimplePool) connect(addr string) error {
-
-	cfg := ConnConfig{
-		ProtoVersion:  c.cfg.ProtoVersion,
-		CQLVersion:    c.cfg.CQLVersion,
-		Timeout:       c.cfg.Timeout,
-		NumStreams:    c.cfg.NumStreams,
-		Compressor:    c.cfg.Compressor,
-		Authenticator: c.cfg.Authenticator,
-		Keepalive:     c.cfg.SocketKeepalive,
-		tlsConfig:     c.tlsConfig,
-	}
-
-	conn, err := Connect(addr, cfg, c)
-	if err != nil {
-		log.Printf("connect: failed to connect to %q: %v", addr, err)
-		return err
-	}
-
-	return c.addConn(conn)
-}
-
-func (c *SimplePool) addConn(conn *Conn) error {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	if c.quit {
-		conn.Close()
-		return nil
-	}
-
-	//Set the connection's keyspace if any before adding it to the pool
-	if c.keyspace != "" {
-		if err := conn.UseKeyspace(c.keyspace); err != nil {
-			log.Printf("error setting connection keyspace. %v", err)
-			conn.Close()
-			return err
-		}
-	}
-
-	connPool := c.connPool[conn.Address()]
-	if connPool == nil {
-		connPool = NewRoundRobin()
-		c.connPool[conn.Address()] = connPool
-		c.hostPool.AddNode(connPool)
-	}
-
-	connPool.AddNode(conn)
-	c.conns[conn] = struct{}{}
-
-	return nil
-}
-
-//fillPool manages the pool of connections making sure that each host has the correct
-//amount of connections defined. Also the method will test a host with one connection
-//instead of flooding the host with number of connections defined in the cluster config
-func (c *SimplePool) fillPool() {
-	//Debounce large amounts of requests to fill pool
-	select {
-	case <-time.After(1 * time.Millisecond):
-		return
-	case <-c.cFillingPool:
-		defer func() { c.cFillingPool <- 1 }()
-	}
-
-	c.mu.Lock()
-	isClosed := c.quit
-	c.mu.Unlock()
-	//Exit if cluster(session) is closed
-	if isClosed {
-		return
-	}
-
-	c.hostMu.RLock()
-
-	//Walk through list of defined hosts
-	var wg sync.WaitGroup
-	for host := range c.hosts {
-		addr := JoinHostPort(host, c.cfg.Port)
-
-		numConns := 1
-		//See if the host already has connections in the pool
-		c.mu.Lock()
-		conns, ok := c.connPool[addr]
-		c.mu.Unlock()
-
-		if ok {
-			//if the host has enough connections just exit
-			numConns = conns.Size()
-			if numConns >= c.cfg.NumConns {
-				continue
-			}
-		} else {
-			//See if the host is reachable
-			if err := c.connect(addr); err != nil {
-				continue
-			}
-		}
-
-		//This is reached if the host is responsive and needs more connections
-		//Create connections for host synchronously to mitigate flooding the host.
-		wg.Add(1)
-		go func(a string, conns int) {
-			defer wg.Done()
-			for ; conns < c.cfg.NumConns; conns++ {
-				c.connect(a)
-			}
-		}(addr, numConns)
-	}
-
-	c.hostMu.RUnlock()
-
-	//Wait until we're finished connecting to each host before returning
-	wg.Wait()
-}
-
-// Should only be called if c.mu is locked
-func (c *SimplePool) removeConnLocked(conn *Conn) {
-	conn.Close()
-	connPool := c.connPool[conn.addr]
-	if connPool == nil {
-		return
-	}
-	connPool.RemoveNode(conn)
-	if connPool.Size() == 0 {
-		c.hostPool.RemoveNode(connPool)
-		delete(c.connPool, conn.addr)
-	}
-	delete(c.conns, conn)
-}
-
-func (c *SimplePool) removeConn(conn *Conn) {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	c.removeConnLocked(conn)
-}
-
-//HandleError is called by a Connection object to report to the pool an error has occured.
-//Logic is then executed within the pool to clean up the erroroneous connection and try to
-//top off the pool.
-func (c *SimplePool) HandleError(conn *Conn, err error, closed bool) {
-	if !closed {
-		// ignore all non-fatal errors
-		return
-	}
-	c.removeConn(conn)
-	c.mu.Lock()
-	poolClosed := c.quit
-	c.mu.Unlock()
-	if !poolClosed {
-		go c.fillPool() // top off pool.
-	}
-}
-
-//Pick selects a connection to be used by the query.
-func (c *SimplePool) Pick(qry *Query) *Conn {
-	//Check if connections are available
-	c.mu.Lock()
-	conns := len(c.conns)
-	c.mu.Unlock()
-
-	if conns == 0 {
-		//try to populate the pool before returning.
-		c.fillPool()
-	}
-
-	return c.hostPool.Pick(qry)
-}
-
-//Size returns the number of connections currently active in the pool
-func (p *SimplePool) Size() int {
-	p.mu.Lock()
-	conns := len(p.conns)
-	p.mu.Unlock()
-	return conns
-}
-
-//Close kills the pool and all associated connections.
-func (c *SimplePool) Close() {
-	c.quitOnce.Do(func() {
-		c.mu.Lock()
-		defer c.mu.Unlock()
-		c.quit = true
-		close(c.quitWait)
-		for conn := range c.conns {
-			c.removeConnLocked(conn)
-		}
-	})
-}
-
-func (c *SimplePool) SetHosts(hosts []HostInfo) {
-
-	c.hostMu.Lock()
-	toRemove := make(map[string]struct{})
-	for k := range c.hosts {
-		toRemove[k] = struct{}{}
-	}
-
-	for _, host := range hosts {
-		host := host
-		delete(toRemove, host.Peer)
-		// we already have it
-		if _, ok := c.hosts[host.Peer]; ok {
-			// TODO: Check rack, dc, token range is consistent, trigger topology change
-			// update stored host
-			continue
-		}
-
-		c.hosts[host.Peer] = &host
-	}
-
-	// can we hold c.mu whilst iterating this loop?
-	for addr := range toRemove {
-		c.removeHostLocked(addr)
-	}
-	c.hostMu.Unlock()
-
-	c.fillPool()
-}
-
-func (c *SimplePool) removeHostLocked(addr string) {
-	if _, ok := c.hosts[addr]; !ok {
-		return
-	}
-	delete(c.hosts, addr)
-
-	c.mu.Lock()
-	defer c.mu.Unlock()
-
-	if _, ok := c.connPool[addr]; !ok {
-		return
-	}
-
-	for conn := range c.conns {
-		if conn.Address() == addr {
-			c.removeConnLocked(conn)
-		}
-	}
-}
-
-//NewRoundRobinConnPool creates a connection pool which selects hosts by
-//round-robin, and then selects a connection for that host by round-robin.
-func NewRoundRobinConnPool(cfg *ClusterConfig) (ConnectionPool, error) {
-	return NewPolicyConnPool(
-		cfg,
-		NewRoundRobinHostPolicy(),
-		NewRoundRobinConnPolicy,
-	)
-}
-
-//NewTokenAwareConnPool creates a connection pool which selects hosts by
-//a token aware policy, and then selects a connection for that host by
-//round-robin.
-func NewTokenAwareConnPool(cfg *ClusterConfig) (ConnectionPool, error) {
-	return NewPolicyConnPool(
-		cfg,
-		NewTokenAwareHostPolicy(NewRoundRobinHostPolicy()),
-		NewRoundRobinConnPolicy,
-	)
-}
-
 type policyConnPool struct {
 type policyConnPool struct {
 	port     int
 	port     int
 	numConns int
 	numConns int
@@ -484,18 +69,13 @@ type policyConnPool struct {
 	hostConnPools map[string]*hostConnPool
 	hostConnPools map[string]*hostConnPool
 }
 }
 
 
-//Creates a policy based connection pool. This func isn't meant to be directly
-//used as a NewPoolFunc in ClusterConfig, instead a func should be created
-//which satisfies the NewPoolFunc type, which calls this func with the desired
-//hostPolicy and connPolicy; see NewRoundRobinConnPool or NewTokenAwareConnPool
-//for examples.
-func NewPolicyConnPool(
-	cfg *ClusterConfig,
-	hostPolicy HostSelectionPolicy,
-	connPolicy func() ConnSelectionPolicy,
-) (ConnectionPool, error) {
-	var err error
-	var tlsConfig *tls.Config
+func newPolicyConnPool(cfg *ClusterConfig, hostPolicy HostSelectionPolicy,
+	connPolicy func() ConnSelectionPolicy) (*policyConnPool, error) {
+
+	var (
+		err       error
+		tlsConfig *tls.Config
+	)
 
 
 	if cfg.SslOpts != nil {
 	if cfg.SslOpts != nil {
 		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
 		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
@@ -594,9 +174,12 @@ func (p *policyConnPool) Size() int {
 func (p *policyConnPool) Pick(qry *Query) *Conn {
 func (p *policyConnPool) Pick(qry *Query) *Conn {
 	nextHost := p.hostPolicy.Pick(qry)
 	nextHost := p.hostPolicy.Pick(qry)
 
 
+	var (
+		host *HostInfo
+		conn *Conn
+	)
+
 	p.mu.RLock()
 	p.mu.RLock()
-	var host *HostInfo
-	var conn *Conn
 	for conn == nil {
 	for conn == nil {
 		host = nextHost()
 		host = nextHost()
 		if host == nil {
 		if host == nil {
@@ -639,14 +222,8 @@ type hostConnPool struct {
 	filling bool
 	filling bool
 }
 }
 
 
-func newHostConnPool(
-	host string,
-	port int,
-	size int,
-	connCfg ConnConfig,
-	keyspace string,
-	policy ConnSelectionPolicy,
-) *hostConnPool {
+func newHostConnPool(host string, port int, size int, connCfg ConnConfig,
+	keyspace string, policy ConnSelectionPolicy) *hostConnPool {
 
 
 	pool := &hostConnPool{
 	pool := &hostConnPool{
 		host:     host,
 		host:     host,

+ 6 - 11
host_source.go

@@ -24,14 +24,10 @@ type ringDescriber struct {
 	closeChan       chan bool
 	closeChan       chan bool
 }
 }
 
 
-func (r *ringDescriber) GetHosts() (
-	hosts []HostInfo,
-	partitioner string,
-	err error,
-) {
+func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err error) {
 	// we need conn to be the same because we need to query system.peers and system.local
 	// we need conn to be the same because we need to query system.peers and system.local
 	// on the same node to get the whole cluster
 	// on the same node to get the whole cluster
-	conn := r.session.Pool.Pick(nil)
+	conn := r.session.pool.Pick(nil)
 	if conn == nil {
 	if conn == nil {
 		return r.prevHosts, r.prevPartitioner, nil
 		return r.prevHosts, r.prevPartitioner, nil
 	}
 	}
@@ -106,12 +102,11 @@ func (h *ringDescriber) run(sleep time.Duration) {
 			hosts, partitioner, err := h.GetHosts()
 			hosts, partitioner, err := h.GetHosts()
 			if err != nil {
 			if err != nil {
 				log.Println("RingDescriber: unable to get ring topology:", err)
 				log.Println("RingDescriber: unable to get ring topology:", err)
-			} else {
-				h.session.Pool.SetHosts(hosts)
-				if v, ok := h.session.Pool.(SetPartitioner); ok {
-					v.SetPartitioner(partitioner)
-				}
+				continue
 			}
 			}
+
+			h.session.pool.SetHosts(hosts)
+			h.session.pool.SetPartitioner(partitioner)
 		case <-h.closeChan:
 		case <-h.closeChan:
 			return
 			return
 		}
 		}

+ 16 - 11
policies.go

@@ -10,8 +10,8 @@ import (
 	"sync/atomic"
 	"sync/atomic"
 )
 )
 
 
-//RetryableQuery is an interface that represents a query or batch statement that
-//exposes the correct functions for the retry policy logic to evaluate correctly.
+// RetryableQuery is an interface that represents a query or batch statement that
+// exposes the correct functions for the retry policy logic to evaluate correctly.
 type RetryableQuery interface {
 type RetryableQuery interface {
 	Attempts() int
 	Attempts() int
 	GetConsistency() Consistency
 	GetConsistency() Consistency
@@ -48,8 +48,8 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 	return q.Attempts() <= s.NumRetries
 	return q.Attempts() <= s.NumRetries
 }
 }
 
 
-//HostSelectionPolicy is an interface for selecting
-//the most appropriate host to execute a given query.
+// HostSelectionPolicy is an interface for selecting
+// the most appropriate host to execute a given query.
 type HostSelectionPolicy interface {
 type HostSelectionPolicy interface {
 	SetHosts
 	SetHosts
 	SetPartitioner
 	SetPartitioner
@@ -57,11 +57,12 @@ type HostSelectionPolicy interface {
 	Pick(*Query) NextHost
 	Pick(*Query) NextHost
 }
 }
 
 
-//NextHost is an iteration function over picked hosts
+// NextHost is an iteration function over picked hosts
 type NextHost func() *HostInfo
 type NextHost func() *HostInfo
 
 
-//NewRoundRobinHostPolicy is a round-robin load balancing policy
-func NewRoundRobinHostPolicy() HostSelectionPolicy {
+// RoundRobinHostPolicy is a round-robin load balancing policy, where each host
+// is tried sequentially for each query.
+func RoundRobinHostPolicy() HostSelectionPolicy {
 	return &roundRobinHostPolicy{hosts: []HostInfo{}}
 	return &roundRobinHostPolicy{hosts: []HostInfo{}}
 }
 }
 
 
@@ -105,8 +106,10 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 	}
 	}
 }
 }
 
 
-//NewTokenAwareHostPolicy is a token aware host selection policy
-func NewTokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
+// TokenAwareHostPolicy is a token aware host selection policy, where hosts are
+// selected based on the partition key, so queries are sent to the host which
+// owns the partition. Fallback is used when routing information is not available.
+func TokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
 	return &tokenAwareHostPolicy{fallback: fallback, hosts: []HostInfo{}}
 	return &tokenAwareHostPolicy{fallback: fallback, hosts: []HostInfo{}}
 }
 }
 
 
@@ -227,8 +230,10 @@ type roundRobinConnPolicy struct {
 	mu    sync.RWMutex
 	mu    sync.RWMutex
 }
 }
 
 
-func NewRoundRobinConnPolicy() ConnSelectionPolicy {
-	return &roundRobinConnPolicy{}
+func RoundRobinConnPolicy() func() ConnSelectionPolicy {
+	return func() ConnSelectionPolicy {
+		return &roundRobinConnPolicy{}
+	}
 }
 }
 
 
 func (r *roundRobinConnPolicy) SetConns(conns []*Conn) {
 func (r *roundRobinConnPolicy) SetConns(conns []*Conn) {

+ 3 - 3
policies_test.go

@@ -8,7 +8,7 @@ import "testing"
 
 
 // Tests of the round-robin host selection policy implementation
 // Tests of the round-robin host selection policy implementation
 func TestRoundRobinHostPolicy(t *testing.T) {
 func TestRoundRobinHostPolicy(t *testing.T) {
-	policy := NewRoundRobinHostPolicy()
+	policy := RoundRobinHostPolicy()
 
 
 	hosts := []HostInfo{
 	hosts := []HostInfo{
 		HostInfo{HostId: "0"},
 		HostInfo{HostId: "0"},
@@ -46,7 +46,7 @@ func TestRoundRobinHostPolicy(t *testing.T) {
 // Tests of the token-aware host selection policy implementation with a
 // Tests of the token-aware host selection policy implementation with a
 // round-robin host selection policy fallback.
 // round-robin host selection policy fallback.
 func TestTokenAwareHostPolicy(t *testing.T) {
 func TestTokenAwareHostPolicy(t *testing.T) {
-	policy := NewTokenAwareHostPolicy(NewRoundRobinHostPolicy())
+	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 
 	query := &Query{}
 	query := &Query{}
 
 
@@ -101,7 +101,7 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 
 // Tests of the round-robin connection selection policy implementation
 // Tests of the round-robin connection selection policy implementation
 func TestRoundRobinConnPolicy(t *testing.T) {
 func TestRoundRobinConnPolicy(t *testing.T) {
-	policy := NewRoundRobinConnPolicy()
+	policy := RoundRobinConnPolicy()()
 
 
 	conn0 := &Conn{}
 	conn0 := &Conn{}
 	conn1 := &Conn{}
 	conn1 := &Conn{}

+ 26 - 29
session.go

@@ -28,7 +28,7 @@ import (
 // and automatically sets a default consinstency level on all operations
 // and automatically sets a default consinstency level on all operations
 // that do not have a consistency level set.
 // that do not have a consistency level set.
 type Session struct {
 type Session struct {
-	Pool                ConnectionPool
+	pool                *policyConnPool
 	cons                Consistency
 	cons                Consistency
 	pageSize            int
 	pageSize            int
 	prefetch            float64
 	prefetch            float64
@@ -60,47 +60,44 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		cfg.NumStreams = maxStreams
 		cfg.NumStreams = maxStreams
 	}
 	}
 
 
-	pool, err := cfg.ConnPoolType(&cfg)
-	if err != nil {
-		return nil, err
-	}
-
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	stmtsLRU.Lock()
 	stmtsLRU.Lock()
 	initStmtsLRU(cfg.MaxPreparedStmts)
 	initStmtsLRU(cfg.MaxPreparedStmts)
 	stmtsLRU.Unlock()
 	stmtsLRU.Unlock()
 
 
 	s := &Session{
 	s := &Session{
-		Pool:     pool,
 		cons:     cfg.Consistency,
 		cons:     cfg.Consistency,
 		prefetch: 0.25,
 		prefetch: 0.25,
 		cfg:      cfg,
 		cfg:      cfg,
+		pageSize: cfg.PageSize,
+	}
+
+	pool, err := cfg.PoolConfig.buildPool(&s.cfg)
+	if err != nil {
+		return nil, err
 	}
 	}
+	s.pool = pool
 
 
 	//See if there are any connections in the pool
 	//See if there are any connections in the pool
-	if pool.Size() > 0 {
-		s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
-
-		s.SetConsistency(cfg.Consistency)
-		s.SetPageSize(cfg.PageSize)
-
-		if cfg.DiscoverHosts {
-			s.hostSource = &ringDescriber{
-				session:    s,
-				dcFilter:   cfg.Discovery.DcFilter,
-				rackFilter: cfg.Discovery.RackFilter,
-				closeChan:  make(chan bool),
-			}
+	if pool.Size() == 0 {
+		s.Close()
+		return nil, ErrNoConnectionsStarted
+	}
+
+	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
 
-			go s.hostSource.run(cfg.Discovery.Sleep)
+	if cfg.DiscoverHosts {
+		s.hostSource = &ringDescriber{
+			session:    s,
+			dcFilter:   cfg.Discovery.DcFilter,
+			rackFilter: cfg.Discovery.RackFilter,
+			closeChan:  make(chan bool),
 		}
 		}
 
 
-		return s, nil
+		go s.hostSource.run(cfg.Discovery.Sleep)
 	}
 	}
 
 
-	s.Close()
-
-	return nil, ErrNoConnectionsStarted
+	return s, nil
 }
 }
 
 
 // SetConsistency sets the default consistency level for this session. This
 // SetConsistency sets the default consistency level for this session. This
@@ -185,7 +182,7 @@ func (s *Session) Close() {
 	}
 	}
 	s.isClosed = true
 	s.isClosed = true
 
 
-	s.Pool.Close()
+	s.pool.Close()
 
 
 	if s.hostSource != nil {
 	if s.hostSource != nil {
 		close(s.hostSource.closeChan)
 		close(s.hostSource.closeChan)
@@ -210,7 +207,7 @@ func (s *Session) executeQuery(qry *Query) *Iter {
 	qry.attempts = 0
 	qry.attempts = 0
 	qry.totalLatency = 0
 	qry.totalLatency = 0
 	for {
 	for {
-		conn := s.Pool.Pick(qry)
+		conn := s.pool.Pick(qry)
 
 
 		//Assign the error unavailable to the iterator
 		//Assign the error unavailable to the iterator
 		if conn == nil {
 		if conn == nil {
@@ -294,7 +291,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	)
 	)
 
 
 	// get the query info for the statement
 	// get the query info for the statement
-	conn := s.Pool.Pick(nil)
+	conn := s.pool.Pick(nil)
 	if conn == nil {
 	if conn == nil {
 		// no connections
 		// no connections
 		inflight.err = ErrNoConnections
 		inflight.err = ErrNoConnections
@@ -389,7 +386,7 @@ func (s *Session) executeBatch(batch *Batch) (*Iter, error) {
 	batch.attempts = 0
 	batch.attempts = 0
 	batch.totalLatency = 0
 	batch.totalLatency = 0
 	for {
 	for {
-		conn := s.Pool.Pick(nil)
+		conn := s.pool.Pick(nil)
 
 
 		//Assign the error unavailable and break loop
 		//Assign the error unavailable and break loop
 		if conn == nil {
 		if conn == nil {

+ 4 - 4
session_test.go

@@ -10,13 +10,13 @@ import (
 func TestSessionAPI(t *testing.T) {
 func TestSessionAPI(t *testing.T) {
 
 
 	cfg := &ClusterConfig{}
 	cfg := &ClusterConfig{}
-	pool, err := NewSimplePool(cfg)
+	pool, err := cfg.PoolConfig.buildPool(cfg)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
 	s := &Session{
 	s := &Session{
-		Pool: pool,
+		pool: pool,
 		cfg:  *cfg,
 		cfg:  *cfg,
 		cons: Quorum,
 		cons: Quorum,
 	}
 	}
@@ -160,13 +160,13 @@ func TestQueryShouldPrepare(t *testing.T) {
 func TestBatchBasicAPI(t *testing.T) {
 func TestBatchBasicAPI(t *testing.T) {
 
 
 	cfg := &ClusterConfig{RetryPolicy: &SimpleRetryPolicy{NumRetries: 2}}
 	cfg := &ClusterConfig{RetryPolicy: &SimpleRetryPolicy{NumRetries: 2}}
-	pool, err := NewSimplePool(cfg)
+	pool, err := cfg.PoolConfig.buildPool(cfg)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
 	s := &Session{
 	s := &Session{
-		Pool: pool,
+		pool: pool,
 		cfg:  *cfg,
 		cfg:  *cfg,
 		cons: Quorum,
 		cons: Quorum,
 	}
 	}

+ 1 - 2
stress_test.go

@@ -36,7 +36,6 @@ func BenchmarkConnStress(b *testing.B) {
 
 
 	b.SetParallelism(workers)
 	b.SetParallelism(workers)
 	b.RunParallel(writer)
 	b.RunParallel(writer)
-
 }
 }
 
 
 func BenchmarkConnRoutingKey(b *testing.B) {
 func BenchmarkConnRoutingKey(b *testing.B) {
@@ -45,7 +44,7 @@ func BenchmarkConnRoutingKey(b *testing.B) {
 	cluster := createCluster()
 	cluster := createCluster()
 	cluster.NumConns = 1
 	cluster.NumConns = 1
 	cluster.NumStreams = workers
 	cluster.NumStreams = workers
-	cluster.ConnPoolType = NewTokenAwareConnPool
+	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
 	session := createSessionFromCluster(cluster, b)
 	session := createSessionFromCluster(cluster, b)
 	defer session.Close()
 	defer session.Close()