Browse Source

Created the Connection Pool interface and converted the existing connection pool logic to a SimplePool interface.

Phillip Couto 11 years ago
parent
commit
3783dd1c40
5 changed files with 260 additions and 224 deletions
  1. 7 207
      cluster.go
  2. 4 8
      conn.go
  3. 2 5
      conn_test.go
  4. 238 0
      connectionpool.go
  5. 9 4
      session.go

+ 7 - 207
cluster.go

@@ -6,10 +6,6 @@ package gocql
 
 import (
 	"errors"
-	"fmt"
-	"log"
-	"strings"
-	"sync"
 	"time"
 )
 
@@ -31,6 +27,7 @@ type ClusterConfig struct {
 	Authenticator   Authenticator // authenticator (default: nil)
 	RetryPolicy     RetryPolicy   // Default retry policy to use for queries (default: 0)
 	SocketKeepalive time.Duration // The keepalive period to use, enabled if > 0 (default: 0)
+	ConnPoolType    NewPoolFunc   // The function used to create the connection pool for the session (default: NewSimplePool)
 }
 
 // NewCluster generates a new config for the default cluster implementation.
@@ -44,6 +41,7 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		NumConns:     2,
 		NumStreams:   128,
 		Consistency:  Quorum,
+		ConnPoolType: NewSimplePool,
 	}
 	return cfg
 }
@@ -56,219 +54,21 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	if len(cfg.Hosts) < 1 {
 		return nil, ErrNoHosts
 	}
+	pool := cfg.ConnPoolType(cfg)
 
-	impl := &clusterImpl{
-		cfg:          *cfg,
-		hostPool:     NewRoundRobin(),
-		connPool:     make(map[string]*RoundRobin),
-		conns:        make(map[*Conn]struct{}),
-		quitWait:     make(chan bool),
-		cFillingPool: make(chan int, 1),
-		keyspace:     cfg.Keyspace,
-	}
-	//Walk through connecting to hosts. As soon as one host connects
-	//defer the remaining connections to cluster.fillPool()
-	for i := 0; i < len(impl.cfg.Hosts); i++ {
-		addr := strings.TrimSpace(impl.cfg.Hosts[i])
-		if strings.Index(addr, ":") < 0 {
-			addr = fmt.Sprintf("%s:%d", addr, impl.cfg.DefaultPort)
-		}
-		err := impl.connect(addr)
-		if err == nil {
-			impl.cFillingPool <- 1
-			go impl.fillPool()
-			break
-		}
-
-	}
 	//See if there are any connections in the pool
-	impl.mu.Lock()
-	conns := len(impl.conns)
-	impl.mu.Unlock()
-	if conns > 0 {
-		s := NewSession(impl)
+	if pool.Size() > 0 {
+		s := NewSession(pool, cfg)
 		s.SetConsistency(cfg.Consistency)
 		return s, nil
 	}
+
 	impl.Close()
 	return nil, ErrNoConnectionsStarted
 
 }
 
-type clusterImpl struct {
-	cfg      ClusterConfig
-	hostPool *RoundRobin
-	connPool map[string]*RoundRobin
-	conns    map[*Conn]struct{}
-	keyspace string
-	mu       sync.Mutex
-
-	cFillingPool chan int
-
-	quit     bool
-	quitWait chan bool
-	quitOnce sync.Once
-}
-
-func (c *clusterImpl) connect(addr string) error {
-	cfg := ConnConfig{
-		ProtoVersion:  c.cfg.ProtoVersion,
-		CQLVersion:    c.cfg.CQLVersion,
-		Timeout:       c.cfg.Timeout,
-		NumStreams:    c.cfg.NumStreams,
-		Compressor:    c.cfg.Compressor,
-		Authenticator: c.cfg.Authenticator,
-		Keepalive:     c.cfg.SocketKeepalive,
-	}
-
-	for {
-		conn, err := Connect(addr, cfg, c)
-		if err != nil {
-			log.Printf("failed to connect to %q: %v", addr, err)
-			return err
-		}
-		return c.addConn(conn)
-	}
-}
-
-func (c *clusterImpl) addConn(conn *Conn) error {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	if c.quit {
-		conn.Close()
-		return nil
-	}
-	//Set the connection's keyspace if any before adding it to the pool
-	if c.keyspace != "" {
-		if err := conn.UseKeyspace(c.keyspace); err != nil {
-			log.Printf("error setting connection keyspace. %v", err)
-			conn.Close()
-			return err
-		}
-	}
-	connPool := c.connPool[conn.Address()]
-	if connPool == nil {
-		connPool = NewRoundRobin()
-		c.connPool[conn.Address()] = connPool
-		c.hostPool.AddNode(connPool)
-	}
-	connPool.AddNode(conn)
-	c.conns[conn] = struct{}{}
-	return nil
-}
-
-//fillPool manages the pool of connections making sure that each host has the correct
-//amount of connections defined. Also the method will test a host with one connection
-//instead of flooding the host with number of connections defined in the cluster config
-func (c *clusterImpl) fillPool() {
-	//Debounce large amounts of requests to fill pool
-	select {
-	case <-time.After(1 * time.Millisecond):
-		return
-	case <-c.cFillingPool:
-		defer func() { c.cFillingPool <- 1 }()
-	}
-
-	c.mu.Lock()
-	isClosed := c.quit
-	c.mu.Unlock()
-	//Exit if cluster(session) is closed
-	if isClosed {
-		return
-	}
-	//Walk through list of defined hosts
-	for i := 0; i < len(c.cfg.Hosts); i++ {
-		addr := strings.TrimSpace(c.cfg.Hosts[i])
-		if strings.Index(addr, ":") < 0 {
-			addr = fmt.Sprintf("%s:%d", addr, c.cfg.DefaultPort)
-		}
-		var numConns int = 1
-		//See if the host already has connections in the pool
-		c.mu.Lock()
-		conns, ok := c.connPool[addr]
-		c.mu.Unlock()
-
-		if ok {
-			//if the host has enough connections just exit
-			numConns = conns.Size()
-			if numConns >= c.cfg.NumConns {
-				continue
-			}
-		} else {
-			//See if the host is reachable
-			if err := c.connect(addr); err != nil {
-				continue
-			}
-		}
-		//This is reached if the host is responsive and needs more connections
-		//Create connections for host synchronously to mitigate flooding the host.
-		go func(a string, conns int) {
-			for ; conns < c.cfg.NumConns; conns++ {
-				c.connect(addr)
-			}
-		}(addr, numConns)
-	}
-}
-
-// Should only be called if c.mu is locked
-func (c *clusterImpl) removeConnLocked(conn *Conn) {
-	conn.Close()
-	connPool := c.connPool[conn.addr]
-	if connPool == nil {
-		return
-	}
-	connPool.RemoveNode(conn)
-	if connPool.Size() == 0 {
-		c.hostPool.RemoveNode(connPool)
-		delete(c.connPool, conn.addr)
-	}
-	delete(c.conns, conn)
-}
-
-func (c *clusterImpl) removeConn(conn *Conn) {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	c.removeConnLocked(conn)
-}
-
-func (c *clusterImpl) HandleError(conn *Conn, err error, closed bool) {
-	if !closed {
-		// ignore all non-fatal errors
-		return
-	}
-	c.removeConn(conn)
-	if !c.quit {
-		go c.fillPool() // top off pool.
-	}
-}
-
-func (c *clusterImpl) Pick(qry *Query) *Conn {
-	//Check if connections are available
-	c.mu.Lock()
-	conns := len(c.conns)
-	c.mu.Unlock()
-
-	if conns == 0 {
-		//try to populate the pool before returning.
-		c.fillPool()
-	}
-
-	return c.hostPool.Pick(qry)
-}
-
-func (c *clusterImpl) Close() {
-	c.quitOnce.Do(func() {
-		c.mu.Lock()
-		defer c.mu.Unlock()
-		c.quit = true
-		close(c.quitWait)
-		for conn := range c.conns {
-			c.removeConnLocked(conn)
-		}
-	})
-}
-
 var (
-	ErrNoHosts       = errors.New("no hosts provided")
+	ErrNoHosts              = errors.New("no hosts provided")
 	ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
 )

+ 4 - 8
conn.go

@@ -17,10 +17,6 @@ const defaultFrameSize = 4096
 const flagResponse = 0x80
 const maskVersion = 0x7F
 
-type Cluster interface {
-	HandleError(conn *Conn, err error, closed bool)
-}
-
 type Authenticator interface {
 	Challenge(req []byte) (resp []byte, auth Authenticator, err error)
 	Success(data []byte) error
@@ -72,7 +68,7 @@ type Conn struct {
 	prepMu sync.Mutex
 	prep   map[string]*inflightPrepare
 
-	cluster    Cluster
+	pool       ConnectionPool
 	compressor Compressor
 	auth       Authenticator
 	addr       string
@@ -84,7 +80,7 @@ type Conn struct {
 
 // Connect establishes a connection to a Cassandra node.
 // You must also call the Serve method before you can execute any queries.
-func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
+func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 	conn, err := net.DialTimeout("tcp", addr, cfg.Timeout)
 	if err != nil {
 		return nil, err
@@ -105,7 +101,7 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
 		timeout:    cfg.Timeout,
 		version:    uint8(cfg.ProtoVersion),
 		addr:       conn.RemoteAddr().String(),
-		cluster:    cluster,
+		pool:       pool,
 		compressor: cfg.Compressor,
 		auth:       cfg.Authenticator,
 	}
@@ -203,7 +199,7 @@ func (c *Conn) serve() {
 			req.resp <- callResp{nil, err}
 		}
 	}
-	c.cluster.HandleError(c, err, true)
+	c.pool.HandleError(c, err, true)
 }
 
 func (c *Conn) recv() (frame, error) {

+ 2 - 5
conn_test.go

@@ -183,11 +183,8 @@ func TestConnClosing(t *testing.T) {
 	wg.Wait()
 
 	time.Sleep(1 * time.Second) //Sleep so the fillPool can complete.
-	cluster := db.Node.(*clusterImpl)
-
-	cluster.mu.Lock()
-	conns := len(cluster.conns)
-	cluster.mu.Unlock()
+	pool := db.Pool.(ConnectionPool)
+	conns := pool.Size()
 
 	if conns != numConns {
 		t.Fatalf("Expected to have %d connections but have %d", numConns, conns)

+ 238 - 0
connectionpool.go

@@ -0,0 +1,238 @@
+package gocql
+
+import (
+	"fmt"
+	"log"
+	"strings"
+	"sync"
+	"time"
+)
+
+//ConnectionPool is the interface gocql expects to be exposed for a connection pool.
+type ConnectionPool interface {
+	Pick(*Query) *Conn
+	Size() int
+	HandleError(*Conn, error, bool)
+	Close()
+}
+
+//NewPoolFunc is the type used by ClusterConfig to create a pool of a specific type.
+type NewPoolFunc func(*ClusterConfig) ConnectionPool
+
+//SimplePool is the current implementation of the connection pool inside gocql. This
+//pool is meant to be a simple default used by gocql so users can get up and running
+//quickly.
+type SimplePool struct {
+	cfg      *ClusterConfig
+	hostPool *RoundRobin
+	connPool map[string]*RoundRobin
+	conns    map[*Conn]struct{}
+	keyspace string
+	mu       sync.Mutex
+
+	cFillingPool chan int
+
+	quit     bool
+	quitWait chan bool
+	quitOnce sync.Once
+}
+
+//NewSimplePool is the function used by gocql to create the simple connection pool.
+//This is the default if no other pool type is specified.
+func NewSimplePool(cfg *ClusterConfig) ConnectionPool {
+	pool := SimplePool{
+		cfg:          cfg,
+		hostPool:     NewRoundRobin(),
+		connPool:     make(map[string]*RoundRobin),
+		conns:        make(map[*Conn]struct{}),
+		quitWait:     make(chan bool),
+		cFillingPool: make(chan int, 1),
+		keyspace:     cfg.Keyspace,
+	}
+	//Walk through connecting to hosts. As soon as one host connects
+	//defer the remaining connections to cluster.fillPool()
+	for i := 0; i < len(pool.cfg.Hosts); i++ {
+		addr := strings.TrimSpace(pool.cfg.Hosts[i])
+		if strings.Index(addr, ":") < 0 {
+			addr = fmt.Sprintf("%s:%d", addr, pool.cfg.DefaultPort)
+		}
+		if pool.connect(addr) == nil {
+			pool.cFillingPool <- 1
+			go pool.fillPool()
+			break
+		}
+
+	}
+	return &pool
+}
+
+func (c *SimplePool) connect(addr string) error {
+	cfg := ConnConfig{
+		ProtoVersion:  c.cfg.ProtoVersion,
+		CQLVersion:    c.cfg.CQLVersion,
+		Timeout:       c.cfg.Timeout,
+		NumStreams:    c.cfg.NumStreams,
+		Compressor:    c.cfg.Compressor,
+		Authenticator: c.cfg.Authenticator,
+		Keepalive:     c.cfg.SocketKeepalive,
+	}
+
+	for {
+		conn, err := Connect(addr, cfg, c)
+		if err != nil {
+			log.Printf("failed to connect to %q: %v", addr, err)
+			return err
+		}
+		return c.addConn(conn)
+	}
+}
+
+func (c *SimplePool) addConn(conn *Conn) error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if c.quit {
+		conn.Close()
+		return nil
+	}
+	//Set the connection's keyspace if any before adding it to the pool
+	if c.keyspace != "" {
+		if err := conn.UseKeyspace(c.keyspace); err != nil {
+			log.Printf("error setting connection keyspace. %v", err)
+			conn.Close()
+			return err
+		}
+	}
+	connPool := c.connPool[conn.Address()]
+	if connPool == nil {
+		connPool = NewRoundRobin()
+		c.connPool[conn.Address()] = connPool
+		c.hostPool.AddNode(connPool)
+	}
+	connPool.AddNode(conn)
+	c.conns[conn] = struct{}{}
+	return nil
+}
+
+//fillPool manages the pool of connections making sure that each host has the correct
+//amount of connections defined. Also the method will test a host with one connection
+//instead of flooding the host with number of connections defined in the cluster config
+func (c *SimplePool) fillPool() {
+	//Debounce large amounts of requests to fill pool
+	select {
+	case <-time.After(1 * time.Millisecond):
+		return
+	case <-c.cFillingPool:
+		defer func() { c.cFillingPool <- 1 }()
+	}
+
+	c.mu.Lock()
+	isClosed := c.quit
+	c.mu.Unlock()
+	//Exit if cluster(session) is closed
+	if isClosed {
+		return
+	}
+	//Walk through list of defined hosts
+	for i := 0; i < len(c.cfg.Hosts); i++ {
+		addr := strings.TrimSpace(c.cfg.Hosts[i])
+		if strings.Index(addr, ":") < 0 {
+			addr = fmt.Sprintf("%s:%d", addr, c.cfg.DefaultPort)
+		}
+		var numConns int = 1
+		//See if the host already has connections in the pool
+		c.mu.Lock()
+		conns, ok := c.connPool[addr]
+		c.mu.Unlock()
+
+		if ok {
+			//if the host has enough connections just exit
+			numConns = conns.Size()
+			if numConns >= c.cfg.NumConns {
+				continue
+			}
+		} else {
+			//See if the host is reachable
+			if err := c.connect(addr); err != nil {
+				continue
+			}
+		}
+		//This is reached if the host is responsive and needs more connections
+		//Create connections for host synchronously to mitigate flooding the host.
+		go func(a string, conns int) {
+			for ; conns < c.cfg.NumConns; conns++ {
+				c.connect(addr)
+			}
+		}(addr, numConns)
+	}
+}
+
+// Should only be called if c.mu is locked
+func (c *SimplePool) removeConnLocked(conn *Conn) {
+	conn.Close()
+	connPool := c.connPool[conn.addr]
+	if connPool == nil {
+		return
+	}
+	connPool.RemoveNode(conn)
+	if connPool.Size() == 0 {
+		c.hostPool.RemoveNode(connPool)
+		delete(c.connPool, conn.addr)
+	}
+	delete(c.conns, conn)
+}
+
+func (c *SimplePool) removeConn(conn *Conn) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.removeConnLocked(conn)
+}
+
+//HandleError is called by a Connection object to report to the pool an error has occured.
+//Logic is then executed within the pool to clean up the erroroneous connection and try to
+//top off the pool.
+func (c *SimplePool) HandleError(conn *Conn, err error, closed bool) {
+	if !closed {
+		// ignore all non-fatal errors
+		return
+	}
+	c.removeConn(conn)
+	if !c.quit {
+		go c.fillPool() // top off pool.
+	}
+}
+
+//Pick selects a connection to be used by the query.
+func (c *SimplePool) Pick(qry *Query) *Conn {
+	//Check if connections are available
+	c.mu.Lock()
+	conns := len(c.conns)
+	c.mu.Unlock()
+
+	if conns == 0 {
+		//try to populate the pool before returning.
+		c.fillPool()
+	}
+
+	return c.hostPool.Pick(qry)
+}
+
+//Size returns the number of connections currently active in the pool
+func (p *SimplePool) Size() int {
+	p.mu.Lock()
+	conns := len(p.conns)
+	p.mu.Unlock()
+	return conns
+}
+
+//Close kills the pool and all associated connections.
+func (c *SimplePool) Close() {
+	c.quitOnce.Do(func() {
+		c.mu.Lock()
+		defer c.mu.Unlock()
+		c.quit = true
+		close(c.quitWait)
+		for conn := range c.conns {
+			c.removeConnLocked(conn)
+		}
+	})
+}

+ 9 - 4
session.go

@@ -24,21 +24,22 @@ import (
 // and automatically sets a default consinstency level on all operations
 // that do not have a consistency level set.
 type Session struct {
-	Node     Node
+	Pool     ConnectionPool
 	cons     Consistency
 	pageSize int
 	prefetch float64
 	trace    Tracer
 	mu       sync.RWMutex
-	cfg      ClusterConfig
+
+	cfg ClusterConfig
 
 	closeMu  sync.RWMutex
 	isClosed bool
 }
 
 // NewSession wraps an existing Node.
-func NewSession(c *clusterImpl) *Session {
-	return &Session{Node: c, cons: Quorum, prefetch: 0.25, cfg: c.cfg}
+func NewSession(p ConnectionPool, c *ClusterConfig) *Session {
+	return &Session{Pool: p, cons: Quorum, prefetch: 0.25, cfg: c}
 }
 
 // SetConsistency sets the default consistency level for this session. This
@@ -91,6 +92,7 @@ func (s *Session) Query(stmt string, values ...interface{}) *Query {
 // Close closes all connections. The session is unusable after this
 // operation.
 func (s *Session) Close() {
+
 	s.closeMu.Lock()
 	defer s.closeMu.Unlock()
 	if s.isClosed {
@@ -109,6 +111,7 @@ func (s *Session) Closed() bool {
 }
 
 func (s *Session) executeQuery(qry *Query) *Iter {
+
 	// fail fast
 	if s.Closed() {
 		return &Iter{err: ErrSessionClosed}
@@ -117,6 +120,7 @@ func (s *Session) executeQuery(qry *Query) *Iter {
 	var iter *Iter
 	for count := 0; count <= qry.rt.NumRetries; count++ {
 		conn := s.Node.Pick(qry)
+
 		//Assign the error unavailable to the iterator
 		if conn == nil {
 			iter = &Iter{err: ErrNoConnections}
@@ -151,6 +155,7 @@ func (s *Session) ExecuteBatch(batch *Batch) error {
 	var err error
 	for count := 0; count <= batch.rt.NumRetries; count++ {
 		conn := s.Node.Pick(nil)
+
 		//Assign the error unavailable and break loop
 		if conn == nil {
 			err = ErrNoConnections