|
|
@@ -36,9 +36,9 @@ import (
|
|
|
"encoding/binary"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
- "math/rand"
|
|
|
"net"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
@@ -62,13 +62,12 @@ const (
|
|
|
|
|
|
keyVersion string = "CQL_VERSION"
|
|
|
keyCompression string = "COMPRESSION"
|
|
|
+ keyspaceQuery string = "USE "
|
|
|
)
|
|
|
|
|
|
var consistencyLevels = map[string]byte{"any": 0x00, "one": 0x01, "two": 0x02,
|
|
|
"three": 0x03, "quorum": 0x04, "all": 0x05, "local_quorum": 0x06, "each_quorum": 0x07}
|
|
|
|
|
|
-var rnd = rand.New(rand.NewSource(0))
|
|
|
-
|
|
|
type drv struct{}
|
|
|
|
|
|
func (d drv) Open(name string) (driver.Conn, error) {
|
|
|
@@ -76,23 +75,28 @@ func (d drv) Open(name string) (driver.Conn, error) {
|
|
|
}
|
|
|
|
|
|
type connection struct {
|
|
|
- c net.Conn
|
|
|
+ c net.Conn
|
|
|
+ address string
|
|
|
+ alive bool
|
|
|
+ pool *pool
|
|
|
+}
|
|
|
+
|
|
|
+type pool struct {
|
|
|
+ connections []*connection
|
|
|
+ i int
|
|
|
+ keyspace string
|
|
|
+ version string
|
|
|
compression string
|
|
|
consistency byte
|
|
|
+ dead bool
|
|
|
+ stop chan struct{}
|
|
|
}
|
|
|
|
|
|
-func Open(name string) (*connection, error) {
|
|
|
+func Open(name string) (*pool, error) {
|
|
|
parts := strings.Split(name, " ")
|
|
|
- address := ""
|
|
|
+ var addresses []string
|
|
|
if len(parts) >= 1 {
|
|
|
- addresses := strings.Split(parts[0], ",")
|
|
|
- if len(addresses) > 0 {
|
|
|
- address = addresses[rnd.Intn(len(addresses))]
|
|
|
- }
|
|
|
- }
|
|
|
- c, err := net.Dial("tcp", address)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+ addresses = strings.Split(parts[0], ",")
|
|
|
}
|
|
|
|
|
|
version := "3.0.0"
|
|
|
@@ -126,7 +130,37 @@ func Open(name string) (*connection, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- cn := &connection{c: c, compression: compression, consistency: consistency}
|
|
|
+ pool := &pool{
|
|
|
+ keyspace: keyspace,
|
|
|
+ version: version,
|
|
|
+ compression: compression,
|
|
|
+ consistency: consistency,
|
|
|
+ stop: make(chan struct{}),
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, address := range addresses {
|
|
|
+ pool.connections = append(pool.connections, &connection{address: address, pool: pool})
|
|
|
+ }
|
|
|
+
|
|
|
+ pool.join()
|
|
|
+
|
|
|
+ return pool, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (cn *connection) open() {
|
|
|
+ cn.alive = false
|
|
|
+
|
|
|
+ var err error
|
|
|
+ cn.c, err = net.Dial("tcp", cn.address)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var (
|
|
|
+ version = cn.pool.version
|
|
|
+ compression = cn.pool.compression
|
|
|
+ keyspace = cn.pool.keyspace
|
|
|
+ )
|
|
|
|
|
|
b := &bytes.Buffer{}
|
|
|
|
|
|
@@ -149,28 +183,22 @@ func Open(name string) (*connection, error) {
|
|
|
}
|
|
|
|
|
|
if err := cn.sendUncompressed(opStartup, b.Bytes()); err != nil {
|
|
|
- return nil, err
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
opcode, _, err := cn.recv()
|
|
|
if err != nil {
|
|
|
- return nil, err
|
|
|
+ return
|
|
|
}
|
|
|
if opcode != opReady {
|
|
|
- return nil, fmt.Errorf("connection not ready")
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
if keyspace != "" {
|
|
|
- st, err := cn.Prepare(fmt.Sprintf("USE %s", keyspace))
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- if _, err = st.Exec([]driver.Value{}); err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ cn.UseKeyspace(keyspace)
|
|
|
}
|
|
|
|
|
|
- return cn, nil
|
|
|
+ cn.alive = true
|
|
|
}
|
|
|
|
|
|
// close a connection actively, typically used when there's an error and we want to ensure
|
|
|
@@ -178,6 +206,15 @@ func Open(name string) (*connection, error) {
|
|
|
func (cn *connection) close() {
|
|
|
cn.c.Close()
|
|
|
cn.c = nil // ensure we generate ErrBadConn when cn gets reused
|
|
|
+ cn.alive = false
|
|
|
+
|
|
|
+ // Check if the entire pool is dead
|
|
|
+ for _, cn := range cn.pool.connections {
|
|
|
+ if cn.alive {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ cn.pool.dead = false
|
|
|
}
|
|
|
|
|
|
// explicitly send a request as uncompressed
|
|
|
@@ -187,7 +224,7 @@ func (cn *connection) sendUncompressed(opcode byte, body []byte) error {
|
|
|
}
|
|
|
|
|
|
func (cn *connection) send(opcode byte, body []byte) error {
|
|
|
- return cn._send(opcode, body, cn.compression == "snappy" && len(body) > 0)
|
|
|
+ return cn._send(opcode, body, cn.pool.compression == "snappy" && len(body) > 0)
|
|
|
}
|
|
|
|
|
|
func (cn *connection) _send(opcode byte, body []byte, compression bool) error {
|
|
|
@@ -256,7 +293,7 @@ func (cn *connection) recv() (byte, []byte, error) {
|
|
|
return 0, nil, err
|
|
|
}
|
|
|
}
|
|
|
- if header[1]&flagCompressed != 0 && cn.compression == "snappy" {
|
|
|
+ if header[1]&flagCompressed != 0 && cn.pool.compression == "snappy" {
|
|
|
var err error
|
|
|
body, err = snappy.Decode(nil, body)
|
|
|
if err != nil {
|
|
|
@@ -273,35 +310,132 @@ func (cn *connection) recv() (byte, []byte, error) {
|
|
|
return opcode, body, nil
|
|
|
}
|
|
|
|
|
|
-func (cn *connection) Begin() (driver.Tx, error) {
|
|
|
- if cn.c == nil {
|
|
|
+func (p *pool) conn() (*connection, error) {
|
|
|
+ if p.dead {
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
- return cn, nil
|
|
|
+
|
|
|
+ totalConnections := len(p.connections)
|
|
|
+ start := p.i + 1 // make sure that we start from the next position in the ring
|
|
|
+
|
|
|
+ for i := 0; i < totalConnections; i++ {
|
|
|
+ idx := (i + start) % totalConnections
|
|
|
+ cn := p.connections[idx]
|
|
|
+ if cn.alive {
|
|
|
+ p.i = idx // set the new 'i' so the ring will start again in the right place
|
|
|
+ return cn, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // we've exhausted the pool, gonna have a bad time
|
|
|
+ p.dead = true
|
|
|
+ return nil, driver.ErrBadConn
|
|
|
}
|
|
|
|
|
|
-func (cn *connection) Commit() error {
|
|
|
- if cn.c == nil {
|
|
|
+func (p *pool) join() {
|
|
|
+ p.reconnect()
|
|
|
+
|
|
|
+ // Every 1 second, we want to try reconnecting to disconnected nodes
|
|
|
+ go func() {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-p.stop:
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ p.reconnect()
|
|
|
+ time.Sleep(time.Second)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+}
|
|
|
+
|
|
|
+func (p *pool) reconnect() {
|
|
|
+ for _, cn := range p.connections {
|
|
|
+ if !cn.alive {
|
|
|
+ cn.open()
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (p *pool) Begin() (driver.Tx, error) {
|
|
|
+ if p.dead {
|
|
|
+ return nil, driver.ErrBadConn
|
|
|
+ }
|
|
|
+ return p, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (p *pool) Commit() error {
|
|
|
+ if p.dead {
|
|
|
return driver.ErrBadConn
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (cn *connection) Close() error {
|
|
|
- if cn.c == nil {
|
|
|
+func (p *pool) Close() error {
|
|
|
+ if p.dead {
|
|
|
return driver.ErrBadConn
|
|
|
}
|
|
|
- cn.close()
|
|
|
+ for _, cn := range p.connections {
|
|
|
+ cn.close()
|
|
|
+ }
|
|
|
+ p.stop <- struct{}{}
|
|
|
+ p.dead = true
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (cn *connection) Rollback() error {
|
|
|
- if cn.c == nil {
|
|
|
+func (p *pool) Rollback() error {
|
|
|
+ if p.dead {
|
|
|
return driver.ErrBadConn
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+func (p *pool) Prepare(query string) (driver.Stmt, error) {
|
|
|
+ // Explicitly check if the query is a "USE <keyspace>"
|
|
|
+ // Since it needs to be special cased and run on each server
|
|
|
+ if strings.HasPrefix(query, keyspaceQuery) {
|
|
|
+ keyspace := query[len(keyspaceQuery):]
|
|
|
+ p.UseKeyspace(keyspace)
|
|
|
+ return &statement{}, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ for {
|
|
|
+ cn, err := p.conn()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ st, err := cn.Prepare(query)
|
|
|
+ if err != nil {
|
|
|
+ // the cn has gotten marked as dead already
|
|
|
+ if p.dead {
|
|
|
+ // The entire pool is dead, so we bubble up the ErrBadConn
|
|
|
+ return nil, driver.ErrBadConn
|
|
|
+ } else {
|
|
|
+ continue // Retry request on another cn
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return st, nil
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (p *pool) UseKeyspace(keyspace string) {
|
|
|
+ p.keyspace = keyspace
|
|
|
+ for _, cn := range p.connections {
|
|
|
+ cn.UseKeyspace(keyspace)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (cn *connection) UseKeyspace(keyspace string) error {
|
|
|
+ st, err := cn.Prepare(keyspaceQuery + keyspace)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if _, err = st.Exec([]driver.Value{}); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func (cn *connection) Prepare(query string) (driver.Stmt, error) {
|
|
|
body := make([]byte, len(query)+4)
|
|
|
binary.BigEndian.PutUint32(body[0:4], uint32(len(query)))
|
|
|
@@ -389,7 +523,7 @@ func (st *statement) exec(v []driver.Value) error {
|
|
|
copy(body[p+4:], b)
|
|
|
p += 4 + len(b)
|
|
|
}
|
|
|
- binary.BigEndian.PutUint16(body[p:], uint16(st.cn.consistency))
|
|
|
+ binary.BigEndian.PutUint16(body[p:], uint16(st.cn.pool.consistency))
|
|
|
if err := st.cn.send(opExecute, body); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -397,6 +531,9 @@ func (st *statement) exec(v []driver.Value) error {
|
|
|
}
|
|
|
|
|
|
func (st *statement) Exec(v []driver.Value) (driver.Result, error) {
|
|
|
+ if st.cn == nil {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
if err := st.exec(v); err != nil {
|
|
|
return nil, err
|
|
|
}
|