|
@@ -10,7 +10,9 @@ import (
|
|
|
"crypto/x509"
|
|
"crypto/x509"
|
|
|
"errors"
|
|
"errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
|
|
+ "io"
|
|
|
"io/ioutil"
|
|
"io/ioutil"
|
|
|
|
|
+ "log"
|
|
|
"net"
|
|
"net"
|
|
|
"strconv"
|
|
"strconv"
|
|
|
"strings"
|
|
"strings"
|
|
@@ -88,7 +90,7 @@ type Conn struct {
|
|
|
r *bufio.Reader
|
|
r *bufio.Reader
|
|
|
timeout time.Duration
|
|
timeout time.Duration
|
|
|
|
|
|
|
|
- uniq chan uint8
|
|
|
|
|
|
|
+ uniq chan int
|
|
|
calls []callReq
|
|
calls []callReq
|
|
|
nwait int32
|
|
nwait int32
|
|
|
|
|
|
|
@@ -123,14 +125,17 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
|
|
|
return nil, errors.New("Failed parsing or appending certs")
|
|
return nil, errors.New("Failed parsing or appending certs")
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
|
|
mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
config := tls.Config{
|
|
config := tls.Config{
|
|
|
Certificates: []tls.Certificate{mycert},
|
|
Certificates: []tls.Certificate{mycert},
|
|
|
RootCAs: certPool,
|
|
RootCAs: certPool,
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
|
|
config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
|
|
|
if conn, err = tls.Dial("tcp", addr, &config); err != nil {
|
|
if conn, err = tls.Dial("tcp", addr, &config); err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
@@ -139,16 +144,25 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
|
|
|
|
|
- cfg.NumStreams = 128
|
|
|
|
|
- }
|
|
|
|
|
- if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
|
|
|
|
|
|
|
+ // going to default to proto 2
|
|
|
|
|
+ if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
|
|
|
|
|
+ log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
|
|
|
cfg.ProtoVersion = 2
|
|
cfg.ProtoVersion = 2
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ maxStreams := 128
|
|
|
|
|
+ if cfg.ProtoVersion > protoVersion2 {
|
|
|
|
|
+ maxStreams = 32768
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
|
|
|
|
|
+ cfg.NumStreams = maxStreams
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
c := &Conn{
|
|
c := &Conn{
|
|
|
conn: conn,
|
|
conn: conn,
|
|
|
r: bufio.NewReader(conn),
|
|
r: bufio.NewReader(conn),
|
|
|
- uniq: make(chan uint8, cfg.NumStreams),
|
|
|
|
|
|
|
+ uniq: make(chan int, cfg.NumStreams),
|
|
|
calls: make([]callReq, cfg.NumStreams),
|
|
calls: make([]callReq, cfg.NumStreams),
|
|
|
timeout: cfg.Timeout,
|
|
timeout: cfg.Timeout,
|
|
|
version: uint8(cfg.ProtoVersion),
|
|
version: uint8(cfg.ProtoVersion),
|
|
@@ -162,8 +176,8 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
|
|
|
c.setKeepalive(cfg.Keepalive)
|
|
c.setKeepalive(cfg.Keepalive)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- for i := 0; i < cap(c.uniq); i++ {
|
|
|
|
|
- c.uniq <- uint8(i)
|
|
|
|
|
|
|
+ for i := 0; i < cfg.NumStreams; i++ {
|
|
|
|
|
+ c.uniq <- i
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if err := c.startup(&cfg); err != nil {
|
|
if err := c.startup(&cfg); err != nil {
|
|
@@ -254,47 +268,69 @@ func (c *Conn) serve() {
|
|
|
c.pool.HandleError(c, err, true)
|
|
c.pool.HandleError(c, err, true)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func (c *Conn) Write(p []byte) (int, error) {
|
|
|
|
|
+ c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
|
|
|
|
|
+ return c.conn.Write(p)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (c *Conn) Read(p []byte) (int, error) {
|
|
|
|
|
+ return c.r.Read(p)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func (c *Conn) recv() (frame, error) {
|
|
func (c *Conn) recv() (frame, error) {
|
|
|
- resp := make(frame, headerSize, headerSize+512)
|
|
|
|
|
- c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
|
|
|
|
- n, last, pinged := 0, 0, false
|
|
|
|
|
- for n < len(resp) {
|
|
|
|
|
- nn, err := c.r.Read(resp[n:])
|
|
|
|
|
- n += nn
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
|
|
|
|
- if n > last {
|
|
|
|
|
- // we hit the deadline but we made progress.
|
|
|
|
|
- // simply extend the deadline
|
|
|
|
|
- c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
|
|
|
|
- last = n
|
|
|
|
|
- } else if n == 0 && !pinged {
|
|
|
|
|
- c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
|
|
|
|
- if atomic.LoadInt32(&c.nwait) > 0 {
|
|
|
|
|
- go c.ping()
|
|
|
|
|
- pinged = true
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- return nil, err
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- return nil, err
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ size := headerProtoSize[c.version]
|
|
|
|
|
+ resp := make(frame, size, size+512)
|
|
|
|
|
+
|
|
|
|
|
+ // read a full header, ignore timeouts, as this is being ran in a loop
|
|
|
|
|
+ c.conn.SetReadDeadline(time.Time{})
|
|
|
|
|
+ _, err := io.ReadFull(c.r, resp[:size])
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if v := c.version | flagResponse; resp[0] != v {
|
|
|
|
|
+ return nil, NewErrProtocol("recv: response protocol version does not match connection protocol version (%d != %d)", resp[0], v)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ bodySize := resp.Length(c.version)
|
|
|
|
|
+ if bodySize == 0 {
|
|
|
|
|
+ return resp, nil
|
|
|
|
|
+ }
|
|
|
|
|
+ resp.grow(bodySize)
|
|
|
|
|
+
|
|
|
|
|
+ const maxAttempts = 5
|
|
|
|
|
+
|
|
|
|
|
+ n := size
|
|
|
|
|
+ for i := 0; i < maxAttempts; i++ {
|
|
|
|
|
+ var nn int
|
|
|
|
|
+ c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
|
|
|
|
+ nn, err = c.Read(resp[n : size+bodySize])
|
|
|
|
|
+ if err == nil {
|
|
|
|
|
+ break
|
|
|
}
|
|
}
|
|
|
- if n == headerSize && len(resp) == headerSize {
|
|
|
|
|
- if resp[0] != c.version|flagResponse {
|
|
|
|
|
- return nil, NewErrProtocol("recv: Response protocol version does not match connection protocol version (%d != %d)", resp[0], c.version|flagResponse)
|
|
|
|
|
- }
|
|
|
|
|
- resp.grow(resp.Length())
|
|
|
|
|
|
|
+ n += nn
|
|
|
|
|
+
|
|
|
|
|
+ if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
|
|
|
|
|
+ break
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
return resp, nil
|
|
return resp, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) execSimple(op operation) (interface{}, error) {
|
|
func (c *Conn) execSimple(op operation) (interface{}, error) {
|
|
|
f, err := op.encodeFrame(c.version, nil)
|
|
f, err := op.encodeFrame(c.version, nil)
|
|
|
- f.setLength(len(f) - headerSize)
|
|
|
|
|
- if _, err := c.conn.Write([]byte(f)); err != nil {
|
|
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ // this should be a noop err
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ f.setLength(len(f)-headerProtoSize[c.version], c.version)
|
|
|
|
|
+
|
|
|
|
|
+ if _, err := c.Write([]byte(f)); err != nil {
|
|
|
c.Close()
|
|
c.Close()
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
@@ -312,6 +348,8 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
|
|
|
if trace != nil {
|
|
if trace != nil {
|
|
|
req[1] |= flagTrace
|
|
req[1] |= flagTrace
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ headerSize := headerProtoSize[c.version]
|
|
|
if len(req) > headerSize && c.compressor != nil {
|
|
if len(req) > headerSize && c.compressor != nil {
|
|
|
body, err := c.compressor.Encode([]byte(req[headerSize:]))
|
|
body, err := c.compressor.Encode([]byte(req[headerSize:]))
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -320,16 +358,16 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
|
|
|
req = append(req[:headerSize], frame(body)...)
|
|
req = append(req[:headerSize], frame(body)...)
|
|
|
req[1] |= flagCompress
|
|
req[1] |= flagCompress
|
|
|
}
|
|
}
|
|
|
- req.setLength(len(req) - headerSize)
|
|
|
|
|
|
|
+ req.setLength(len(req)-headerSize, c.version)
|
|
|
|
|
|
|
|
id := <-c.uniq
|
|
id := <-c.uniq
|
|
|
- req[2] = id
|
|
|
|
|
|
|
+ req.setStream(id, c.version)
|
|
|
call := &c.calls[id]
|
|
call := &c.calls[id]
|
|
|
call.resp = make(chan callResp, 1)
|
|
call.resp = make(chan callResp, 1)
|
|
|
atomic.AddInt32(&c.nwait, 1)
|
|
atomic.AddInt32(&c.nwait, 1)
|
|
|
atomic.StoreInt32(&call.active, 1)
|
|
atomic.StoreInt32(&call.active, 1)
|
|
|
|
|
|
|
|
- if _, err := c.conn.Write(req); err != nil {
|
|
|
|
|
|
|
+ if _, err := c.Write(req); err != nil {
|
|
|
c.uniq <- id
|
|
c.uniq <- id
|
|
|
c.Close()
|
|
c.Close()
|
|
|
return nil, err
|
|
return nil, err
|
|
@@ -346,7 +384,7 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) dispatch(resp frame) {
|
|
func (c *Conn) dispatch(resp frame) {
|
|
|
- id := int(resp[2])
|
|
|
|
|
|
|
+ id := resp.Stream(c.version)
|
|
|
if id >= len(c.calls) {
|
|
if id >= len(c.calls) {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
@@ -543,10 +581,10 @@ func (c *Conn) UseKeyspace(keyspace string) error {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) executeBatch(batch *Batch) error {
|
|
func (c *Conn) executeBatch(batch *Batch) error {
|
|
|
- if c.version == 1 {
|
|
|
|
|
|
|
+ if c.version == protoVersion1 {
|
|
|
return ErrUnsupported
|
|
return ErrUnsupported
|
|
|
}
|
|
}
|
|
|
- f := make(frame, headerSize, defaultFrameSize)
|
|
|
|
|
|
|
+ f := make(frame, headerProtoSize[c.version], defaultFrameSize)
|
|
|
f.setHeader(c.version, 0, 0, opBatch)
|
|
f.setHeader(c.version, 0, 0, opBatch)
|
|
|
f.writeByte(byte(batch.Type))
|
|
f.writeByte(byte(batch.Type))
|
|
|
f.writeShort(uint16(len(batch.Entries)))
|
|
f.writeShort(uint16(len(batch.Entries)))
|
|
@@ -631,12 +669,14 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
|
|
|
panic(r)
|
|
panic(r)
|
|
|
}
|
|
}
|
|
|
}()
|
|
}()
|
|
|
- if len(f) < headerSize {
|
|
|
|
|
- return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ if len(f) < headerProtoSize[c.version] {
|
|
|
|
|
+ return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerProtoSize[c.version])
|
|
|
} else if f[0] != c.version|flagResponse {
|
|
} else if f[0] != c.version|flagResponse {
|
|
|
return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
|
|
return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
|
|
|
}
|
|
}
|
|
|
- flags, op, f := f[1], f[3], f[headerSize:]
|
|
|
|
|
|
|
+
|
|
|
|
|
+ flags, op, f := f[1], f.Op(c.version), f[headerProtoSize[c.version]:]
|
|
|
if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
|
|
if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
|
|
|
if buf, err := c.compressor.Decode([]byte(f)); err != nil {
|
|
if buf, err := c.compressor.Decode([]byte(f)); err != nil {
|
|
|
return nil, err
|
|
return nil, err
|