瀏覽代碼

Remove SimpleConnPool, remove custom pools

Remove custom pools so that the driver can be simplified and
make more assumptions about the underlying conneciton pool whilst
still allowing configuration of host select and conn selection policies.
Chris Bannister 10 年之前
父節點
當前提交
50b9680b45
共有 11 個文件被更改,包括 121 次插入517 次删除
  1. 1 1
      cass1batch_test.go
  2. 13 11
      cassandra_test.go
  3. 30 2
      cluster.go
  4. 7 6
      conn_test.go
  5. 14 437
      connectionpool.go
  6. 6 11
      host_source.go
  7. 16 11
      policies.go
  8. 3 3
      policies_test.go
  9. 26 29
      session.go
  10. 4 4
      session_test.go
  11. 1 2
      stress_test.go

+ 1 - 1
cass1batch_test.go

@@ -10,7 +10,7 @@ import (
 func TestProto1BatchInsert(t *testing.T) {
 	session := createSession(t)
 	if err := createTable(session, "CREATE TABLE large (id int primary key)"); err != nil {
-		t.Fatal("create table:", err)
+		t.Fatal(err)
 	}
 	defer session.Close()
 

+ 13 - 11
cassandra_test.go

@@ -96,7 +96,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
 	}
 
 	// should reuse the same conn apparently
-	conn := session.Pool.Pick(nil)
+	conn := session.pool.Pick(nil)
 	if conn == nil {
 		tb.Fatal("no connections available in the pool")
 	}
@@ -189,7 +189,9 @@ func TestRingDiscovery(t *testing.T) {
 		time.Sleep(*flagAutoWait)
 	}
 
-	size := len(session.Pool.(*SimplePool).connPool)
+	session.pool.mu.RLock()
+	size := len(session.pool.hostConnPools)
+	session.pool.mu.RUnlock()
 
 	if *clusterSize != size {
 		t.Logf("WARN: Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
@@ -1139,7 +1141,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 		t.Fatal("create:", err)
 	}
 	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
-	conn := session.Pool.Pick(nil)
+	conn := session.pool.Pick(nil)
 	flight := new(inflightPrepare)
 	stmtsLRU.Lock()
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
@@ -1165,7 +1167,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 
 func TestMissingSchemaPrepare(t *testing.T) {
 	s := createSession(t)
-	conn := s.Pool.Pick(nil)
+	conn := s.pool.Pick(nil)
 	defer s.Close()
 
 	insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons,
@@ -1214,7 +1216,7 @@ func TestQueryInfo(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 
-	conn := session.Pool.Pick(nil)
+	conn := session.pool.Pick(nil)
 	info, err := conn.prepareStatement("SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
 
 	if err != nil {
@@ -2016,7 +2018,7 @@ func TestRoutingKey(t *testing.T) {
 // Integration test of the token-aware policy-based connection pool
 func TestTokenAwareConnPool(t *testing.T) {
 	cluster := createCluster()
-	cluster.ConnPoolType = NewTokenAwareConnPool
+	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
 	cluster.DiscoverHosts = true
 
 	// Drop and re-create the keyspace once. Different tests should use their own
@@ -2037,8 +2039,8 @@ func TestTokenAwareConnPool(t *testing.T) {
 		time.Sleep(*flagAutoWait)
 	}
 
-	if session.Pool.Size() != cluster.NumConns*len(cluster.Hosts) {
-		t.Errorf("Expected pool size %d but was %d", cluster.NumConns*len(cluster.Hosts), session.Pool.Size())
+	if session.pool.Size() != cluster.NumConns*len(cluster.Hosts) {
+		t.Errorf("Expected pool size %d but was %d", cluster.NumConns*len(cluster.Hosts), session.pool.Size())
 	}
 
 	if err := createTable(session, "CREATE TABLE test_token_aware (id int, data text, PRIMARY KEY (id))"); err != nil {
@@ -2077,7 +2079,7 @@ func TestStream0(t *testing.T) {
 			break
 		}
 
-		conn = session.Pool.Pick(nil)
+		conn = session.pool.Pick(nil)
 	}
 
 	if conn == nil {
@@ -2116,7 +2118,7 @@ func TestNegativeStream(t *testing.T) {
 			break
 		}
 
-		conn = session.Pool.Pick(nil)
+		conn = session.pool.Pick(nil)
 	}
 
 	if conn == nil {
@@ -2222,7 +2224,7 @@ func TestLexicalUUIDType(t *testing.T) {
 // Issue 475
 func TestSessionBindRoutingKey(t *testing.T) {
 	cluster := createCluster()
-	cluster.ConnPoolType = NewTokenAwareConnPool
+	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 	session := createSessionFromCluster(cluster, t)
 	defer session.Close()

+ 30 - 2
cluster.go

@@ -50,6 +50,33 @@ type DiscoveryConfig struct {
 	Sleep time.Duration
 }
 
+// PoolConfig configures the connection pool used by the driver, it defaults to
+// using a round robbin host selection policy and a round robbin connection selection
+// policy for each host.
+type PoolConfig struct {
+	// HostSelectionPolicy sets the policy for selecting which host to use for a
+	// given query (default: RoundRobinHostPolicy())
+	HostSelectionPolicy HostSelectionPolicy
+
+	// ConnSelectionPolicy sets the policy factory for selecting a connection to use for
+	// each host for a query (default: RoundRobinConnPolicy())
+	ConnSelectionPolicy func() ConnSelectionPolicy
+}
+
+func (p PoolConfig) buildPool(cfg *ClusterConfig) (*policyConnPool, error) {
+	hostSelection := p.HostSelectionPolicy
+	if hostSelection == nil {
+		hostSelection = RoundRobinHostPolicy()
+	}
+
+	connSelection := p.ConnSelectionPolicy
+	if connSelection == nil {
+		connSelection = RoundRobinConnPolicy()
+	}
+
+	return newPolicyConnPool(cfg, hostSelection, connSelection)
+}
+
 // ClusterConfig is a struct to configure the default cluster implementation
 // of gocoql. It has a varity of attributes that can be used to modify the
 // behavior to fit the most common use cases. Applications that requre a
@@ -68,7 +95,6 @@ type ClusterConfig struct {
 	Authenticator     Authenticator     // authenticator (default: nil)
 	RetryPolicy       RetryPolicy       // Default retry policy to use for queries (default: 0)
 	SocketKeepalive   time.Duration     // The keepalive period to use, enabled if > 0 (default: 0)
-	ConnPoolType      NewPoolFunc       // The function used to create the connection pool for the session (default: NewSimplePool)
 	DiscoverHosts     bool              // If set, gocql will attempt to automatically discover other members of the Cassandra cluster (default: false)
 	MaxPreparedStmts  int               // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
 	MaxRoutingKeyInfo int               // Sets the maximum cache size for query info about statements for each session (default: 1000)
@@ -77,6 +103,9 @@ type ClusterConfig struct {
 	Discovery         DiscoveryConfig
 	SslOpts           *SslOptions
 	DefaultTimestamp  bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
+	// PoolConfig configures the underlying connection pool, allowing the
+	// configuration of host selection and connection selection policies.
+	PoolConfig PoolConfig
 }
 
 // NewCluster generates a new config for the default cluster implementation.
@@ -89,7 +118,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Port:              9042,
 		NumConns:          2,
 		Consistency:       Quorum,
-		ConnPoolType:      NewSimplePool,
 		DiscoverHosts:     false,
 		MaxPreparedStmts:  defaultMaxPreparedStmts,
 		MaxRoutingKeyInfo: 1000,

+ 7 - 6
conn_test.go

@@ -259,8 +259,7 @@ func TestConnClosing(t *testing.T) {
 	wg.Wait()
 
 	time.Sleep(1 * time.Second) //Sleep so the fillPool can complete.
-	pool := db.Pool.(ConnectionPool)
-	conns := pool.Size()
+	conns := db.pool.Size()
 
 	if conns != numConns {
 		t.Errorf("Expected to have %d connections but have %d", numConns, conns)
@@ -390,7 +389,8 @@ func TestRoundRobinConnPoolRoundRobin(t *testing.T) {
 
 	// create a new cluster using the policy-based round robin conn pool
 	cluster := NewCluster(addrs...)
-	cluster.ConnPoolType = NewRoundRobinConnPool
+	cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+	cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy()
 
 	db, err := cluster.CreateSession()
 	if err != nil {
@@ -420,7 +420,7 @@ func TestRoundRobinConnPoolRoundRobin(t *testing.T) {
 
 	// wait for the pool to drain
 	time.Sleep(100 * time.Millisecond)
-	size := db.Pool.Size()
+	size := db.pool.Size()
 	if size != 0 {
 		t.Errorf("connection pool did not drain, still contains %d connections", size)
 	}
@@ -450,7 +450,8 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 	defer srv.Stop()
 
 	cluster := createTestSslCluster(srv.Address, defaultProto, true)
-	cluster.ConnPoolType = NewRoundRobinConnPool
+	cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+	cluster.PoolConfig.ConnSelectionPolicy = RoundRobinConnPolicy()
 
 	db, err := cluster.CreateSession()
 	if err != nil {
@@ -465,7 +466,7 @@ func TestPolicyConnPoolSSL(t *testing.T) {
 
 	// wait for the pool to drain
 	time.Sleep(100 * time.Millisecond)
-	size := db.Pool.Size()
+	size := db.pool.Size()
 	if size != 0 {
 		t.Errorf("connection pool did not drain, still contains %d connections", size)
 	}

+ 14 - 437
connectionpool.go

@@ -17,89 +17,6 @@ import (
 	"time"
 )
 
-/*ConnectionPool represents the interface gocql will use to work with a collection of connections.
-
-Purpose
-
-The connection pool in gocql opens and closes connections as well as selects an available connection
-for gocql to execute a query against. The pool is also responsible for handling connection errors that
-are caught by the connection experiencing the error.
-
-A connection pool should make a copy of the variables used from the ClusterConfig provided to the pool
-upon creation. ClusterConfig is a pointer and can be modified after the creation of the pool. This can
-lead to issues with variables being modified outside the expectations of the ConnectionPool type.
-
-Example of Single Connection Pool:
-
-	type SingleConnection struct {
-		conn *Conn
-		cfg *ClusterConfig
-	}
-
-	func NewSingleConnection(cfg *ClusterConfig) ConnectionPool {
-		addr := JoinHostPort(cfg.Hosts[0], cfg.Port)
-
-		connCfg := ConnConfig{
-			ProtoVersion:  cfg.ProtoVersion,
-			CQLVersion:    cfg.CQLVersion,
-			Timeout:       cfg.Timeout,
-			NumStreams:    cfg.NumStreams,
-			Compressor:    cfg.Compressor,
-			Authenticator: cfg.Authenticator,
-			Keepalive:     cfg.SocketKeepalive,
-		}
-		pool := SingleConnection{cfg:cfg}
-		pool.conn = Connect(addr,connCfg,pool)
-		return &pool
-	}
-
-	func (s *SingleConnection) HandleError(conn *Conn, err error, closed bool) {
-		if closed {
-			connCfg := ConnConfig{
-				ProtoVersion:  cfg.ProtoVersion,
-				CQLVersion:    cfg.CQLVersion,
-				Timeout:       cfg.Timeout,
-				NumStreams:    cfg.NumStreams,
-				Compressor:    cfg.Compressor,
-				Authenticator: cfg.Authenticator,
-				Keepalive:     cfg.SocketKeepalive,
-			}
-			s.conn = Connect(conn.Address(),connCfg,s)
-		}
-	}
-
-	func (s *SingleConnection) Pick(qry *Query) *Conn {
-		if s.conn.isClosed {
-			return nil
-		}
-		return s.conn
-	}
-
-	func (s *SingleConnection) Size() int {
-		return 1
-	}
-
-	func (s *SingleConnection) Close() {
-		s.conn.Close()
-	}
-
-This is a very simple example of a type that exposes the connection pool interface. To assign
-this type as the connection pool to use you would assign it to the ClusterConfig like so:
-
-		cluster := NewCluster("127.0.0.1")
-		cluster.ConnPoolType = NewSingleConnection
-		...
-		session, err := cluster.CreateSession()
-
-To see a more complete example of a ConnectionPool implementation please see the SimplePool type.
-*/
-type ConnectionPool interface {
-	SetHosts
-	Pick(*Query) *Conn
-	Size() int
-	Close()
-}
-
 // interface to implement to receive the host information
 type SetHosts interface {
 	SetHosts(hosts []HostInfo)
@@ -110,35 +27,6 @@ type SetPartitioner interface {
 	SetPartitioner(partitioner string)
 }
 
-//NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
-type NewPoolFunc func(*ClusterConfig) (ConnectionPool, error)
-
-//SimplePool is the current implementation of the connection pool inside gocql. This
-//pool is meant to be a simple default used by gocql so users can get up and running
-//quickly.
-type SimplePool struct {
-	cfg      *ClusterConfig
-	hostPool *RoundRobin
-	connPool map[string]*RoundRobin
-	conns    map[*Conn]struct{}
-	keyspace string
-
-	hostMu sync.RWMutex
-	// this is the set of current hosts which the pool will attempt to connect to
-	hosts map[string]*HostInfo
-
-	// protects hostpool, connPoll, conns, quit
-	mu sync.Mutex
-
-	cFillingPool chan int
-
-	quit     bool
-	quitWait chan bool
-	quitOnce sync.Once
-
-	tlsConfig *tls.Config
-}
-
 func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
 	// ca cert is optional
 	if sslOpts.CaPath != "" {
@@ -169,309 +57,6 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
 	return &sslOpts.Config, nil
 }
 
-//NewSimplePool is the function used by gocql to create the simple connection pool.
-//This is the default if no other pool type is specified.
-func NewSimplePool(cfg *ClusterConfig) (ConnectionPool, error) {
-	pool := &SimplePool{
-		cfg:          cfg,
-		hostPool:     NewRoundRobin(),
-		connPool:     make(map[string]*RoundRobin),
-		conns:        make(map[*Conn]struct{}),
-		quitWait:     make(chan bool),
-		cFillingPool: make(chan int, 1),
-		keyspace:     cfg.Keyspace,
-		hosts:        make(map[string]*HostInfo),
-	}
-
-	for _, host := range cfg.Hosts {
-		// seed hosts have unknown topology
-		// TODO: Handle populating this during SetHosts
-		pool.hosts[host] = &HostInfo{Peer: host}
-	}
-
-	if cfg.SslOpts != nil {
-		config, err := setupTLSConfig(cfg.SslOpts)
-		if err != nil {
-			return nil, err
-		}
-		pool.tlsConfig = config
-	}
-
-	//Walk through connecting to hosts. As soon as one host connects
-	//defer the remaining connections to cluster.fillPool()
-	for i := 0; i < len(cfg.Hosts); i++ {
-		addr := JoinHostPort(cfg.Hosts[i], cfg.Port)
-
-		if pool.connect(addr) == nil {
-			pool.cFillingPool <- 1
-			go pool.fillPool()
-			break
-		}
-	}
-
-	return pool, nil
-}
-
-func (c *SimplePool) connect(addr string) error {
-
-	cfg := ConnConfig{
-		ProtoVersion:  c.cfg.ProtoVersion,
-		CQLVersion:    c.cfg.CQLVersion,
-		Timeout:       c.cfg.Timeout,
-		NumStreams:    c.cfg.NumStreams,
-		Compressor:    c.cfg.Compressor,
-		Authenticator: c.cfg.Authenticator,
-		Keepalive:     c.cfg.SocketKeepalive,
-		tlsConfig:     c.tlsConfig,
-	}
-
-	conn, err := Connect(addr, cfg, c)
-	if err != nil {
-		log.Printf("connect: failed to connect to %q: %v", addr, err)
-		return err
-	}
-
-	return c.addConn(conn)
-}
-
-func (c *SimplePool) addConn(conn *Conn) error {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	if c.quit {
-		conn.Close()
-		return nil
-	}
-
-	//Set the connection's keyspace if any before adding it to the pool
-	if c.keyspace != "" {
-		if err := conn.UseKeyspace(c.keyspace); err != nil {
-			log.Printf("error setting connection keyspace. %v", err)
-			conn.Close()
-			return err
-		}
-	}
-
-	connPool := c.connPool[conn.Address()]
-	if connPool == nil {
-		connPool = NewRoundRobin()
-		c.connPool[conn.Address()] = connPool
-		c.hostPool.AddNode(connPool)
-	}
-
-	connPool.AddNode(conn)
-	c.conns[conn] = struct{}{}
-
-	return nil
-}
-
-//fillPool manages the pool of connections making sure that each host has the correct
-//amount of connections defined. Also the method will test a host with one connection
-//instead of flooding the host with number of connections defined in the cluster config
-func (c *SimplePool) fillPool() {
-	//Debounce large amounts of requests to fill pool
-	select {
-	case <-time.After(1 * time.Millisecond):
-		return
-	case <-c.cFillingPool:
-		defer func() { c.cFillingPool <- 1 }()
-	}
-
-	c.mu.Lock()
-	isClosed := c.quit
-	c.mu.Unlock()
-	//Exit if cluster(session) is closed
-	if isClosed {
-		return
-	}
-
-	c.hostMu.RLock()
-
-	//Walk through list of defined hosts
-	var wg sync.WaitGroup
-	for host := range c.hosts {
-		addr := JoinHostPort(host, c.cfg.Port)
-
-		numConns := 1
-		//See if the host already has connections in the pool
-		c.mu.Lock()
-		conns, ok := c.connPool[addr]
-		c.mu.Unlock()
-
-		if ok {
-			//if the host has enough connections just exit
-			numConns = conns.Size()
-			if numConns >= c.cfg.NumConns {
-				continue
-			}
-		} else {
-			//See if the host is reachable
-			if err := c.connect(addr); err != nil {
-				continue
-			}
-		}
-
-		//This is reached if the host is responsive and needs more connections
-		//Create connections for host synchronously to mitigate flooding the host.
-		wg.Add(1)
-		go func(a string, conns int) {
-			defer wg.Done()
-			for ; conns < c.cfg.NumConns; conns++ {
-				c.connect(a)
-			}
-		}(addr, numConns)
-	}
-
-	c.hostMu.RUnlock()
-
-	//Wait until we're finished connecting to each host before returning
-	wg.Wait()
-}
-
-// Should only be called if c.mu is locked
-func (c *SimplePool) removeConnLocked(conn *Conn) {
-	conn.Close()
-	connPool := c.connPool[conn.addr]
-	if connPool == nil {
-		return
-	}
-	connPool.RemoveNode(conn)
-	if connPool.Size() == 0 {
-		c.hostPool.RemoveNode(connPool)
-		delete(c.connPool, conn.addr)
-	}
-	delete(c.conns, conn)
-}
-
-func (c *SimplePool) removeConn(conn *Conn) {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	c.removeConnLocked(conn)
-}
-
-//HandleError is called by a Connection object to report to the pool an error has occured.
-//Logic is then executed within the pool to clean up the erroroneous connection and try to
-//top off the pool.
-func (c *SimplePool) HandleError(conn *Conn, err error, closed bool) {
-	if !closed {
-		// ignore all non-fatal errors
-		return
-	}
-	c.removeConn(conn)
-	c.mu.Lock()
-	poolClosed := c.quit
-	c.mu.Unlock()
-	if !poolClosed {
-		go c.fillPool() // top off pool.
-	}
-}
-
-//Pick selects a connection to be used by the query.
-func (c *SimplePool) Pick(qry *Query) *Conn {
-	//Check if connections are available
-	c.mu.Lock()
-	conns := len(c.conns)
-	c.mu.Unlock()
-
-	if conns == 0 {
-		//try to populate the pool before returning.
-		c.fillPool()
-	}
-
-	return c.hostPool.Pick(qry)
-}
-
-//Size returns the number of connections currently active in the pool
-func (p *SimplePool) Size() int {
-	p.mu.Lock()
-	conns := len(p.conns)
-	p.mu.Unlock()
-	return conns
-}
-
-//Close kills the pool and all associated connections.
-func (c *SimplePool) Close() {
-	c.quitOnce.Do(func() {
-		c.mu.Lock()
-		defer c.mu.Unlock()
-		c.quit = true
-		close(c.quitWait)
-		for conn := range c.conns {
-			c.removeConnLocked(conn)
-		}
-	})
-}
-
-func (c *SimplePool) SetHosts(hosts []HostInfo) {
-
-	c.hostMu.Lock()
-	toRemove := make(map[string]struct{})
-	for k := range c.hosts {
-		toRemove[k] = struct{}{}
-	}
-
-	for _, host := range hosts {
-		host := host
-		delete(toRemove, host.Peer)
-		// we already have it
-		if _, ok := c.hosts[host.Peer]; ok {
-			// TODO: Check rack, dc, token range is consistent, trigger topology change
-			// update stored host
-			continue
-		}
-
-		c.hosts[host.Peer] = &host
-	}
-
-	// can we hold c.mu whilst iterating this loop?
-	for addr := range toRemove {
-		c.removeHostLocked(addr)
-	}
-	c.hostMu.Unlock()
-
-	c.fillPool()
-}
-
-func (c *SimplePool) removeHostLocked(addr string) {
-	if _, ok := c.hosts[addr]; !ok {
-		return
-	}
-	delete(c.hosts, addr)
-
-	c.mu.Lock()
-	defer c.mu.Unlock()
-
-	if _, ok := c.connPool[addr]; !ok {
-		return
-	}
-
-	for conn := range c.conns {
-		if conn.Address() == addr {
-			c.removeConnLocked(conn)
-		}
-	}
-}
-
-//NewRoundRobinConnPool creates a connection pool which selects hosts by
-//round-robin, and then selects a connection for that host by round-robin.
-func NewRoundRobinConnPool(cfg *ClusterConfig) (ConnectionPool, error) {
-	return NewPolicyConnPool(
-		cfg,
-		NewRoundRobinHostPolicy(),
-		NewRoundRobinConnPolicy,
-	)
-}
-
-//NewTokenAwareConnPool creates a connection pool which selects hosts by
-//a token aware policy, and then selects a connection for that host by
-//round-robin.
-func NewTokenAwareConnPool(cfg *ClusterConfig) (ConnectionPool, error) {
-	return NewPolicyConnPool(
-		cfg,
-		NewTokenAwareHostPolicy(NewRoundRobinHostPolicy()),
-		NewRoundRobinConnPolicy,
-	)
-}
-
 type policyConnPool struct {
 	port     int
 	numConns int
@@ -484,18 +69,13 @@ type policyConnPool struct {
 	hostConnPools map[string]*hostConnPool
 }
 
-//Creates a policy based connection pool. This func isn't meant to be directly
-//used as a NewPoolFunc in ClusterConfig, instead a func should be created
-//which satisfies the NewPoolFunc type, which calls this func with the desired
-//hostPolicy and connPolicy; see NewRoundRobinConnPool or NewTokenAwareConnPool
-//for examples.
-func NewPolicyConnPool(
-	cfg *ClusterConfig,
-	hostPolicy HostSelectionPolicy,
-	connPolicy func() ConnSelectionPolicy,
-) (ConnectionPool, error) {
-	var err error
-	var tlsConfig *tls.Config
+func newPolicyConnPool(cfg *ClusterConfig, hostPolicy HostSelectionPolicy,
+	connPolicy func() ConnSelectionPolicy) (*policyConnPool, error) {
+
+	var (
+		err       error
+		tlsConfig *tls.Config
+	)
 
 	if cfg.SslOpts != nil {
 		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
@@ -594,9 +174,12 @@ func (p *policyConnPool) Size() int {
 func (p *policyConnPool) Pick(qry *Query) *Conn {
 	nextHost := p.hostPolicy.Pick(qry)
 
+	var (
+		host *HostInfo
+		conn *Conn
+	)
+
 	p.mu.RLock()
-	var host *HostInfo
-	var conn *Conn
 	for conn == nil {
 		host = nextHost()
 		if host == nil {
@@ -639,14 +222,8 @@ type hostConnPool struct {
 	filling bool
 }
 
-func newHostConnPool(
-	host string,
-	port int,
-	size int,
-	connCfg ConnConfig,
-	keyspace string,
-	policy ConnSelectionPolicy,
-) *hostConnPool {
+func newHostConnPool(host string, port int, size int, connCfg ConnConfig,
+	keyspace string, policy ConnSelectionPolicy) *hostConnPool {
 
 	pool := &hostConnPool{
 		host:     host,

+ 6 - 11
host_source.go

@@ -24,14 +24,10 @@ type ringDescriber struct {
 	closeChan       chan bool
 }
 
-func (r *ringDescriber) GetHosts() (
-	hosts []HostInfo,
-	partitioner string,
-	err error,
-) {
+func (r *ringDescriber) GetHosts() (hosts []HostInfo, partitioner string, err error) {
 	// we need conn to be the same because we need to query system.peers and system.local
 	// on the same node to get the whole cluster
-	conn := r.session.Pool.Pick(nil)
+	conn := r.session.pool.Pick(nil)
 	if conn == nil {
 		return r.prevHosts, r.prevPartitioner, nil
 	}
@@ -106,12 +102,11 @@ func (h *ringDescriber) run(sleep time.Duration) {
 			hosts, partitioner, err := h.GetHosts()
 			if err != nil {
 				log.Println("RingDescriber: unable to get ring topology:", err)
-			} else {
-				h.session.Pool.SetHosts(hosts)
-				if v, ok := h.session.Pool.(SetPartitioner); ok {
-					v.SetPartitioner(partitioner)
-				}
+				continue
 			}
+
+			h.session.pool.SetHosts(hosts)
+			h.session.pool.SetPartitioner(partitioner)
 		case <-h.closeChan:
 			return
 		}

+ 16 - 11
policies.go

@@ -10,8 +10,8 @@ import (
 	"sync/atomic"
 )
 
-//RetryableQuery is an interface that represents a query or batch statement that
-//exposes the correct functions for the retry policy logic to evaluate correctly.
+// RetryableQuery is an interface that represents a query or batch statement that
+// exposes the correct functions for the retry policy logic to evaluate correctly.
 type RetryableQuery interface {
 	Attempts() int
 	GetConsistency() Consistency
@@ -48,8 +48,8 @@ func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool {
 	return q.Attempts() <= s.NumRetries
 }
 
-//HostSelectionPolicy is an interface for selecting
-//the most appropriate host to execute a given query.
+// HostSelectionPolicy is an interface for selecting
+// the most appropriate host to execute a given query.
 type HostSelectionPolicy interface {
 	SetHosts
 	SetPartitioner
@@ -57,11 +57,12 @@ type HostSelectionPolicy interface {
 	Pick(*Query) NextHost
 }
 
-//NextHost is an iteration function over picked hosts
+// NextHost is an iteration function over picked hosts
 type NextHost func() *HostInfo
 
-//NewRoundRobinHostPolicy is a round-robin load balancing policy
-func NewRoundRobinHostPolicy() HostSelectionPolicy {
+// RoundRobinHostPolicy is a round-robin load balancing policy, where each host
+// is tried sequentially for each query.
+func RoundRobinHostPolicy() HostSelectionPolicy {
 	return &roundRobinHostPolicy{hosts: []HostInfo{}}
 }
 
@@ -105,8 +106,10 @@ func (r *roundRobinHostPolicy) Pick(qry *Query) NextHost {
 	}
 }
 
-//NewTokenAwareHostPolicy is a token aware host selection policy
-func NewTokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
+// TokenAwareHostPolicy is a token aware host selection policy, where hosts are
+// selected based on the partition key, so queries are sent to the host which
+// owns the partition. Fallback is used when routing information is not available.
+func TokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy {
 	return &tokenAwareHostPolicy{fallback: fallback, hosts: []HostInfo{}}
 }
 
@@ -227,8 +230,10 @@ type roundRobinConnPolicy struct {
 	mu    sync.RWMutex
 }
 
-func NewRoundRobinConnPolicy() ConnSelectionPolicy {
-	return &roundRobinConnPolicy{}
+func RoundRobinConnPolicy() func() ConnSelectionPolicy {
+	return func() ConnSelectionPolicy {
+		return &roundRobinConnPolicy{}
+	}
 }
 
 func (r *roundRobinConnPolicy) SetConns(conns []*Conn) {

+ 3 - 3
policies_test.go

@@ -8,7 +8,7 @@ import "testing"
 
 // Tests of the round-robin host selection policy implementation
 func TestRoundRobinHostPolicy(t *testing.T) {
-	policy := NewRoundRobinHostPolicy()
+	policy := RoundRobinHostPolicy()
 
 	hosts := []HostInfo{
 		HostInfo{HostId: "0"},
@@ -46,7 +46,7 @@ func TestRoundRobinHostPolicy(t *testing.T) {
 // Tests of the token-aware host selection policy implementation with a
 // round-robin host selection policy fallback.
 func TestTokenAwareHostPolicy(t *testing.T) {
-	policy := NewTokenAwareHostPolicy(NewRoundRobinHostPolicy())
+	policy := TokenAwareHostPolicy(RoundRobinHostPolicy())
 
 	query := &Query{}
 
@@ -101,7 +101,7 @@ func TestTokenAwareHostPolicy(t *testing.T) {
 
 // Tests of the round-robin connection selection policy implementation
 func TestRoundRobinConnPolicy(t *testing.T) {
-	policy := NewRoundRobinConnPolicy()
+	policy := RoundRobinConnPolicy()()
 
 	conn0 := &Conn{}
 	conn1 := &Conn{}

+ 26 - 29
session.go

@@ -28,7 +28,7 @@ import (
 // and automatically sets a default consinstency level on all operations
 // that do not have a consistency level set.
 type Session struct {
-	Pool                ConnectionPool
+	pool                *policyConnPool
 	cons                Consistency
 	pageSize            int
 	prefetch            float64
@@ -60,47 +60,44 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		cfg.NumStreams = maxStreams
 	}
 
-	pool, err := cfg.ConnPoolType(&cfg)
-	if err != nil {
-		return nil, err
-	}
-
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	stmtsLRU.Lock()
 	initStmtsLRU(cfg.MaxPreparedStmts)
 	stmtsLRU.Unlock()
 
 	s := &Session{
-		Pool:     pool,
 		cons:     cfg.Consistency,
 		prefetch: 0.25,
 		cfg:      cfg,
+		pageSize: cfg.PageSize,
+	}
+
+	pool, err := cfg.PoolConfig.buildPool(&s.cfg)
+	if err != nil {
+		return nil, err
 	}
+	s.pool = pool
 
 	//See if there are any connections in the pool
-	if pool.Size() > 0 {
-		s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
-
-		s.SetConsistency(cfg.Consistency)
-		s.SetPageSize(cfg.PageSize)
-
-		if cfg.DiscoverHosts {
-			s.hostSource = &ringDescriber{
-				session:    s,
-				dcFilter:   cfg.Discovery.DcFilter,
-				rackFilter: cfg.Discovery.RackFilter,
-				closeChan:  make(chan bool),
-			}
+	if pool.Size() == 0 {
+		s.Close()
+		return nil, ErrNoConnectionsStarted
+	}
+
+	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
-			go s.hostSource.run(cfg.Discovery.Sleep)
+	if cfg.DiscoverHosts {
+		s.hostSource = &ringDescriber{
+			session:    s,
+			dcFilter:   cfg.Discovery.DcFilter,
+			rackFilter: cfg.Discovery.RackFilter,
+			closeChan:  make(chan bool),
 		}
 
-		return s, nil
+		go s.hostSource.run(cfg.Discovery.Sleep)
 	}
 
-	s.Close()
-
-	return nil, ErrNoConnectionsStarted
+	return s, nil
 }
 
 // SetConsistency sets the default consistency level for this session. This
@@ -185,7 +182,7 @@ func (s *Session) Close() {
 	}
 	s.isClosed = true
 
-	s.Pool.Close()
+	s.pool.Close()
 
 	if s.hostSource != nil {
 		close(s.hostSource.closeChan)
@@ -210,7 +207,7 @@ func (s *Session) executeQuery(qry *Query) *Iter {
 	qry.attempts = 0
 	qry.totalLatency = 0
 	for {
-		conn := s.Pool.Pick(qry)
+		conn := s.pool.Pick(qry)
 
 		//Assign the error unavailable to the iterator
 		if conn == nil {
@@ -294,7 +291,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	)
 
 	// get the query info for the statement
-	conn := s.Pool.Pick(nil)
+	conn := s.pool.Pick(nil)
 	if conn == nil {
 		// no connections
 		inflight.err = ErrNoConnections
@@ -389,7 +386,7 @@ func (s *Session) executeBatch(batch *Batch) (*Iter, error) {
 	batch.attempts = 0
 	batch.totalLatency = 0
 	for {
-		conn := s.Pool.Pick(nil)
+		conn := s.pool.Pick(nil)
 
 		//Assign the error unavailable and break loop
 		if conn == nil {

+ 4 - 4
session_test.go

@@ -10,13 +10,13 @@ import (
 func TestSessionAPI(t *testing.T) {
 
 	cfg := &ClusterConfig{}
-	pool, err := NewSimplePool(cfg)
+	pool, err := cfg.PoolConfig.buildPool(cfg)
 	if err != nil {
 		t.Fatal(err)
 	}
 
 	s := &Session{
-		Pool: pool,
+		pool: pool,
 		cfg:  *cfg,
 		cons: Quorum,
 	}
@@ -160,13 +160,13 @@ func TestQueryShouldPrepare(t *testing.T) {
 func TestBatchBasicAPI(t *testing.T) {
 
 	cfg := &ClusterConfig{RetryPolicy: &SimpleRetryPolicy{NumRetries: 2}}
-	pool, err := NewSimplePool(cfg)
+	pool, err := cfg.PoolConfig.buildPool(cfg)
 	if err != nil {
 		t.Fatal(err)
 	}
 
 	s := &Session{
-		Pool: pool,
+		pool: pool,
 		cfg:  *cfg,
 		cons: Quorum,
 	}

+ 1 - 2
stress_test.go

@@ -36,7 +36,6 @@ func BenchmarkConnStress(b *testing.B) {
 
 	b.SetParallelism(workers)
 	b.RunParallel(writer)
-
 }
 
 func BenchmarkConnRoutingKey(b *testing.B) {
@@ -45,7 +44,7 @@ func BenchmarkConnRoutingKey(b *testing.B) {
 	cluster := createCluster()
 	cluster.NumConns = 1
 	cluster.NumStreams = workers
-	cluster.ConnPoolType = NewTokenAwareConnPool
+	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
 	session := createSessionFromCluster(cluster, b)
 	defer session.Close()