Browse Source

merged branch mattrobenolt:pool

Christoph Hack 12 years ago
parent
commit
a1bebffbf3
1 changed files with 176 additions and 39 deletions
  1. 176 39
      gocql.go

+ 176 - 39
gocql.go

@@ -36,9 +36,9 @@ import (
 	"encoding/binary"
 	"fmt"
 	"io"
-	"math/rand"
 	"net"
 	"strings"
+	"time"
 )
 
 const (
@@ -62,13 +62,12 @@ const (
 
 	keyVersion     string = "CQL_VERSION"
 	keyCompression string = "COMPRESSION"
+	keyspaceQuery  string = "USE "
 )
 
 var consistencyLevels = map[string]byte{"any": 0x00, "one": 0x01, "two": 0x02,
 	"three": 0x03, "quorum": 0x04, "all": 0x05, "local_quorum": 0x06, "each_quorum": 0x07}
 
-var rnd = rand.New(rand.NewSource(0))
-
 type drv struct{}
 
 func (d drv) Open(name string) (driver.Conn, error) {
@@ -76,23 +75,28 @@ func (d drv) Open(name string) (driver.Conn, error) {
 }
 
 type connection struct {
-	c           net.Conn
+	c       net.Conn
+	address string
+	alive   bool
+	pool    *pool
+}
+
+type pool struct {
+	connections []*connection
+	i           int
+	keyspace    string
+	version     string
 	compression string
 	consistency byte
+	dead        bool
+	stop        chan struct{}
 }
 
-func Open(name string) (*connection, error) {
+func Open(name string) (*pool, error) {
 	parts := strings.Split(name, " ")
-	address := ""
+	var addresses []string
 	if len(parts) >= 1 {
-		addresses := strings.Split(parts[0], ",")
-		if len(addresses) > 0 {
-			address = addresses[rnd.Intn(len(addresses))]
-		}
-	}
-	c, err := net.Dial("tcp", address)
-	if err != nil {
-		return nil, err
+		addresses = strings.Split(parts[0], ",")
 	}
 
 	version := "3.0.0"
@@ -126,7 +130,37 @@ func Open(name string) (*connection, error) {
 		}
 	}
 
-	cn := &connection{c: c, compression: compression, consistency: consistency}
+	pool := &pool{
+		keyspace:    keyspace,
+		version:     version,
+		compression: compression,
+		consistency: consistency,
+		stop:        make(chan struct{}),
+	}
+
+	for _, address := range addresses {
+		pool.connections = append(pool.connections, &connection{address: address, pool: pool})
+	}
+
+	pool.join()
+
+	return pool, nil
+}
+
+func (cn *connection) open() {
+	cn.alive = false
+
+	var err error
+	cn.c, err = net.Dial("tcp", cn.address)
+	if err != nil {
+		return
+	}
+
+	var (
+		version     = cn.pool.version
+		compression = cn.pool.compression
+		keyspace    = cn.pool.keyspace
+	)
 
 	b := &bytes.Buffer{}
 
@@ -149,28 +183,22 @@ func Open(name string) (*connection, error) {
 	}
 
 	if err := cn.sendUncompressed(opStartup, b.Bytes()); err != nil {
-		return nil, err
+		return
 	}
 
 	opcode, _, err := cn.recv()
 	if err != nil {
-		return nil, err
+		return
 	}
 	if opcode != opReady {
-		return nil, fmt.Errorf("connection not ready")
+		return
 	}
 
 	if keyspace != "" {
-		st, err := cn.Prepare(fmt.Sprintf("USE %s", keyspace))
-		if err != nil {
-			return nil, err
-		}
-		if _, err = st.Exec([]driver.Value{}); err != nil {
-			return nil, err
-		}
+		cn.UseKeyspace(keyspace)
 	}
 
-	return cn, nil
+	cn.alive = true
 }
 
 // close a connection actively, typically used when there's an error and we want to ensure
@@ -178,6 +206,15 @@ func Open(name string) (*connection, error) {
 func (cn *connection) close() {
 	cn.c.Close()
 	cn.c = nil // ensure we generate ErrBadConn when cn gets reused
+	cn.alive = false
+
+	// Check if the entire pool is dead
+	for _, cn := range cn.pool.connections {
+		if cn.alive {
+			return
+		}
+	}
+	cn.pool.dead = false
 }
 
 // explicitly send a request as uncompressed
@@ -187,7 +224,7 @@ func (cn *connection) sendUncompressed(opcode byte, body []byte) error {
 }
 
 func (cn *connection) send(opcode byte, body []byte) error {
-	return cn._send(opcode, body, cn.compression == "snappy" && len(body) > 0)
+	return cn._send(opcode, body, cn.pool.compression == "snappy" && len(body) > 0)
 }
 
 func (cn *connection) _send(opcode byte, body []byte, compression bool) error {
@@ -256,7 +293,7 @@ func (cn *connection) recv() (byte, []byte, error) {
 			return 0, nil, err
 		}
 	}
-	if header[1]&flagCompressed != 0 && cn.compression == "snappy" {
+	if header[1]&flagCompressed != 0 && cn.pool.compression == "snappy" {
 		var err error
 		body, err = snappy.Decode(nil, body)
 		if err != nil {
@@ -273,35 +310,132 @@ func (cn *connection) recv() (byte, []byte, error) {
 	return opcode, body, nil
 }
 
-func (cn *connection) Begin() (driver.Tx, error) {
-	if cn.c == nil {
+func (p *pool) conn() (*connection, error) {
+	if p.dead {
 		return nil, driver.ErrBadConn
 	}
-	return cn, nil
+
+	totalConnections := len(p.connections)
+	start := p.i + 1 // make sure that we start from the next position in the ring
+
+	for i := 0; i < totalConnections; i++ {
+		idx := (i + start) % totalConnections
+		cn := p.connections[idx]
+		if cn.alive {
+			p.i = idx // set the new 'i' so the ring will start again in the right place
+			return cn, nil
+		}
+	}
+
+	// we've exhausted the pool, gonna have a bad time
+	p.dead = true
+	return nil, driver.ErrBadConn
 }
 
-func (cn *connection) Commit() error {
-	if cn.c == nil {
+func (p *pool) join() {
+	p.reconnect()
+
+	// Every 1 second, we want to try reconnecting to disconnected nodes
+	go func() {
+		for {
+			select {
+			case <-p.stop:
+				return
+			default:
+				p.reconnect()
+				time.Sleep(time.Second)
+			}
+		}
+	}()
+}
+
+func (p *pool) reconnect() {
+	for _, cn := range p.connections {
+		if !cn.alive {
+			cn.open()
+		}
+	}
+}
+
+func (p *pool) Begin() (driver.Tx, error) {
+	if p.dead {
+		return nil, driver.ErrBadConn
+	}
+	return p, nil
+}
+
+func (p *pool) Commit() error {
+	if p.dead {
 		return driver.ErrBadConn
 	}
 	return nil
 }
 
-func (cn *connection) Close() error {
-	if cn.c == nil {
+func (p *pool) Close() error {
+	if p.dead {
 		return driver.ErrBadConn
 	}
-	cn.close()
+	for _, cn := range p.connections {
+		cn.close()
+	}
+	p.stop <- struct{}{}
+	p.dead = true
 	return nil
 }
 
-func (cn *connection) Rollback() error {
-	if cn.c == nil {
+func (p *pool) Rollback() error {
+	if p.dead {
 		return driver.ErrBadConn
 	}
 	return nil
 }
 
+func (p *pool) Prepare(query string) (driver.Stmt, error) {
+	// Explicitly check if the query is a "USE <keyspace>"
+	// Since it needs to be special cased and run on each server
+	if strings.HasPrefix(query, keyspaceQuery) {
+		keyspace := query[len(keyspaceQuery):]
+		p.UseKeyspace(keyspace)
+		return &statement{}, nil
+	}
+
+	for {
+		cn, err := p.conn()
+		if err != nil {
+			return nil, err
+		}
+		st, err := cn.Prepare(query)
+		if err != nil {
+			// the cn has gotten marked as dead already
+			if p.dead {
+				// The entire pool is dead, so we bubble up the ErrBadConn
+				return nil, driver.ErrBadConn
+			} else {
+				continue // Retry request on another cn
+			}
+		}
+		return st, nil
+	}
+}
+
+func (p *pool) UseKeyspace(keyspace string) {
+	p.keyspace = keyspace
+	for _, cn := range p.connections {
+		cn.UseKeyspace(keyspace)
+	}
+}
+
+func (cn *connection) UseKeyspace(keyspace string) error {
+	st, err := cn.Prepare(keyspaceQuery + keyspace)
+	if err != nil {
+		return err
+	}
+	if _, err = st.Exec([]driver.Value{}); err != nil {
+		return err
+	}
+	return nil
+}
+
 func (cn *connection) Prepare(query string) (driver.Stmt, error) {
 	body := make([]byte, len(query)+4)
 	binary.BigEndian.PutUint32(body[0:4], uint32(len(query)))
@@ -389,7 +523,7 @@ func (st *statement) exec(v []driver.Value) error {
 		copy(body[p+4:], b)
 		p += 4 + len(b)
 	}
-	binary.BigEndian.PutUint16(body[p:], uint16(st.cn.consistency))
+	binary.BigEndian.PutUint16(body[p:], uint16(st.cn.pool.consistency))
 	if err := st.cn.send(opExecute, body); err != nil {
 		return err
 	}
@@ -397,6 +531,9 @@ func (st *statement) exec(v []driver.Value) error {
 }
 
 func (st *statement) Exec(v []driver.Value) (driver.Result, error) {
+	if st.cn == nil {
+		return nil, nil
+	}
 	if err := st.exec(v); err != nil {
 		return nil, err
 	}