|
|
@@ -12,11 +12,14 @@ import (
|
|
|
)
|
|
|
|
|
|
const defaultFrameSize = 4096
|
|
|
+const flagResponse = 0x80
|
|
|
+const maskVersion = 0x7F
|
|
|
|
|
|
type Cluster interface {
|
|
|
//HandleAuth(addr, method string) ([]byte, Challenger, error)
|
|
|
HandleError(conn *Conn, err error, closed bool)
|
|
|
HandleKeyspace(conn *Conn, keyspace string)
|
|
|
+ // Authenticate(addr string)
|
|
|
}
|
|
|
|
|
|
/* type Challenger interface {
|
|
|
@@ -46,6 +49,7 @@ type Conn struct {
|
|
|
|
|
|
cluster Cluster
|
|
|
addr string
|
|
|
+ version uint8
|
|
|
}
|
|
|
|
|
|
// Connect establishes a connection to a Cassandra node.
|
|
|
@@ -58,12 +62,16 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
|
|
|
if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
|
|
|
cfg.NumStreams = 128
|
|
|
}
|
|
|
+ if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
|
|
|
+ cfg.ProtoVersion = 2
|
|
|
+ }
|
|
|
c := &Conn{
|
|
|
conn: conn,
|
|
|
uniq: make(chan uint8, cfg.NumStreams),
|
|
|
calls: make([]callReq, cfg.NumStreams),
|
|
|
prep: make(map[string]*queryInfo),
|
|
|
timeout: cfg.Timeout,
|
|
|
+ version: uint8(cfg.ProtoVersion),
|
|
|
addr: conn.RemoteAddr().String(),
|
|
|
cluster: cluster,
|
|
|
}
|
|
|
@@ -82,19 +90,21 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
|
|
|
|
|
|
func (c *Conn) startup(cfg *ConnConfig) error {
|
|
|
req := make(frame, headerSize, defaultFrameSize)
|
|
|
- req.setHeader(protoRequest, 0, 0, opStartup)
|
|
|
+ req.setHeader(c.version, 0, 0, opStartup)
|
|
|
req.writeStringMap(map[string]string{
|
|
|
"CQL_VERSION": cfg.CQLVersion,
|
|
|
})
|
|
|
resp, err := c.callSimple(req)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
- } else if resp[3] == opError {
|
|
|
- return resp.readErrorFrame()
|
|
|
- } else if resp[3] != opReady {
|
|
|
+ }
|
|
|
+ switch x := resp.(type) {
|
|
|
+ case readyFrame:
|
|
|
+ case error:
|
|
|
+ return x
|
|
|
+ default:
|
|
|
return ErrProtocol
|
|
|
}
|
|
|
-
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -102,24 +112,22 @@ func (c *Conn) startup(cfg *ConnConfig) error {
|
|
|
// to execute any queries. This method runs as long as the connection is
|
|
|
// open and is therefore usually called in a separate goroutine.
|
|
|
func (c *Conn) serve() {
|
|
|
- var err error
|
|
|
for {
|
|
|
- var frame frame
|
|
|
- frame, err = c.recv()
|
|
|
+ resp, err := c.recv()
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
|
- c.dispatch(frame)
|
|
|
+ c.dispatch(resp)
|
|
|
}
|
|
|
|
|
|
c.conn.Close()
|
|
|
for id := 0; id < len(c.calls); id++ {
|
|
|
req := &c.calls[id]
|
|
|
if atomic.LoadInt32(&req.active) == 1 {
|
|
|
- req.resp <- callResp{nil, err}
|
|
|
+ req.resp <- callResp{nil, ErrProtocol}
|
|
|
}
|
|
|
}
|
|
|
- c.cluster.HandleError(c, err, true)
|
|
|
+ c.cluster.HandleError(c, ErrProtocol, true)
|
|
|
}
|
|
|
|
|
|
func (c *Conn) recv() (frame, error) {
|
|
|
@@ -130,7 +138,7 @@ func (c *Conn) recv() (frame, error) {
|
|
|
nn, err := c.conn.Read(resp[n:])
|
|
|
n += nn
|
|
|
if err != nil {
|
|
|
- if err, ok := err.(net.Error); ok && err.Timeout() {
|
|
|
+ if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
|
|
if n > last {
|
|
|
// we hit the deadline but we made progress.
|
|
|
// simply extend the deadline
|
|
|
@@ -150,7 +158,7 @@ func (c *Conn) recv() (frame, error) {
|
|
|
}
|
|
|
}
|
|
|
if n == headerSize && len(resp) == headerSize {
|
|
|
- if resp[0] != protoResponse {
|
|
|
+ if resp[0] != c.version|flagResponse {
|
|
|
return nil, ErrProtocol
|
|
|
}
|
|
|
resp.grow(resp.Length())
|
|
|
@@ -159,16 +167,20 @@ func (c *Conn) recv() (frame, error) {
|
|
|
return resp, nil
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) callSimple(req frame) (frame, error) {
|
|
|
+func (c *Conn) callSimple(req frame) (interface{}, error) {
|
|
|
req.setLength(len(req) - headerSize)
|
|
|
if _, err := c.conn.Write(req); err != nil {
|
|
|
c.conn.Close()
|
|
|
return nil, err
|
|
|
}
|
|
|
- return c.recv()
|
|
|
+ buf, err := c.recv()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return decodeFrame(buf)
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) call(req frame) (frame, error) {
|
|
|
+func (c *Conn) call(req frame) (interface{}, error) {
|
|
|
id := <-c.uniq
|
|
|
req[2] = id
|
|
|
|
|
|
@@ -178,16 +190,22 @@ func (c *Conn) call(req frame) (frame, error) {
|
|
|
atomic.StoreInt32(&call.active, 1)
|
|
|
|
|
|
req.setLength(len(req) - headerSize)
|
|
|
- if _, err := c.conn.Write(req); err != nil {
|
|
|
+ if n, err := c.conn.Write(req); err != nil {
|
|
|
c.conn.Close()
|
|
|
- return nil, err
|
|
|
+ if n > 0 {
|
|
|
+ return nil, ErrProtocol
|
|
|
+ }
|
|
|
+ return nil, ErrUnavailable
|
|
|
}
|
|
|
|
|
|
reply := <-call.resp
|
|
|
call.resp = nil
|
|
|
-
|
|
|
c.uniq <- id
|
|
|
- return reply.buf, reply.err
|
|
|
+
|
|
|
+ if reply.err != nil {
|
|
|
+ return nil, reply.err
|
|
|
+ }
|
|
|
+ return decodeFrame(reply.buf)
|
|
|
}
|
|
|
|
|
|
func (c *Conn) dispatch(resp frame) {
|
|
|
@@ -205,7 +223,7 @@ func (c *Conn) dispatch(resp frame) {
|
|
|
|
|
|
func (c *Conn) ping() error {
|
|
|
req := make(frame, headerSize)
|
|
|
- req.setHeader(protoRequest, 0, 0, opOptions)
|
|
|
+ req.setHeader(c.version, 0, 0, opOptions)
|
|
|
_, err := c.call(req)
|
|
|
return err
|
|
|
}
|
|
|
@@ -224,44 +242,95 @@ func (c *Conn) prepareStatement(stmt string) (*queryInfo, error) {
|
|
|
c.prepMu.Unlock()
|
|
|
|
|
|
frame := make(frame, headerSize, defaultFrameSize)
|
|
|
- frame.setHeader(protoRequest, 0, 0, opPrepare)
|
|
|
+ frame.setHeader(c.version, 0, 0, opPrepare)
|
|
|
frame.writeLongString(stmt)
|
|
|
frame.setLength(len(frame) - headerSize)
|
|
|
|
|
|
- frame, err := c.call(frame)
|
|
|
+ resp, err := c.call(frame)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- if frame[3] == opError {
|
|
|
- return nil, frame.readErrorFrame()
|
|
|
+ switch x := resp.(type) {
|
|
|
+ case resultPreparedFrame:
|
|
|
+ info.id = x.PreparedId
|
|
|
+ info.args = x.Values
|
|
|
+ info.wg.Done()
|
|
|
+ case error:
|
|
|
+ return nil, x
|
|
|
+ default:
|
|
|
+ return nil, ErrProtocol
|
|
|
}
|
|
|
- frame.skipHeader()
|
|
|
- frame.readInt() // kind
|
|
|
- info.id = frame.readShortBytes()
|
|
|
- info.args = frame.readMetaData()
|
|
|
- info.rval = frame.readMetaData()
|
|
|
- info.wg.Done()
|
|
|
return info, nil
|
|
|
}
|
|
|
|
|
|
func (c *Conn) ExecuteQuery(qry *Query) (*Iter, error) {
|
|
|
- frame, err := c.executeQuery(qry)
|
|
|
+ var info *queryInfo
|
|
|
+ if len(qry.Args) > 0 {
|
|
|
+ var err error
|
|
|
+ info, err = c.prepareStatement(qry.Stmt)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ req := make(frame, headerSize, defaultFrameSize)
|
|
|
+ if info == nil {
|
|
|
+ req.setHeader(c.version, 0, 0, opQuery)
|
|
|
+ req.writeLongString(qry.Stmt)
|
|
|
+ req.writeConsistency(qry.Cons)
|
|
|
+ if c.version > 1 {
|
|
|
+ req.writeByte(0)
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ req.setHeader(c.version, 0, 0, opExecute)
|
|
|
+ req.writeShortBytes(info.id)
|
|
|
+ if c.version == 1 {
|
|
|
+ req.writeShort(uint16(len(qry.Args)))
|
|
|
+ } else {
|
|
|
+ req.writeConsistency(qry.Cons)
|
|
|
+ flags := uint8(0)
|
|
|
+ if len(qry.Args) > 0 {
|
|
|
+ flags |= flagQueryValues
|
|
|
+ }
|
|
|
+ req.writeByte(flags)
|
|
|
+ if flags&flagQueryValues != 0 {
|
|
|
+ req.writeShort(uint16(len(qry.Args)))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for i := 0; i < len(qry.Args); i++ {
|
|
|
+ val, err := Marshal(info.args[i].TypeInfo, qry.Args[i])
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ req.writeBytes(val)
|
|
|
+ }
|
|
|
+ if c.version == 1 {
|
|
|
+ req.writeConsistency(qry.Cons)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ resp, err := c.call(req)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- if frame[3] == opError {
|
|
|
- return nil, frame.readErrorFrame()
|
|
|
- } else if frame[3] == opResult {
|
|
|
- iter := new(Iter)
|
|
|
- iter.readFrame(frame)
|
|
|
- return iter, nil
|
|
|
+ switch x := resp.(type) {
|
|
|
+ case resultVoidFrame:
|
|
|
+ return &Iter{}, nil
|
|
|
+ case resultRowsFrame:
|
|
|
+ return &Iter{columns: x.Columns, rows: x.Rows}, nil
|
|
|
+ case resultKeyspaceFrame:
|
|
|
+ c.cluster.HandleKeyspace(c, x.Keyspace)
|
|
|
+ return &Iter{}, nil
|
|
|
+ case error:
|
|
|
+ return &Iter{err: x}, nil
|
|
|
}
|
|
|
- return nil, nil
|
|
|
+ return nil, ErrProtocol
|
|
|
}
|
|
|
|
|
|
func (c *Conn) ExecuteBatch(batch *Batch) error {
|
|
|
+ if c.version == 1 {
|
|
|
+ return ErrProtocol
|
|
|
+ }
|
|
|
frame := make(frame, headerSize, defaultFrameSize)
|
|
|
- frame.setHeader(protoRequest, 0, 0, opBatch)
|
|
|
+ frame.setHeader(c.version, 0, 0, opBatch)
|
|
|
frame.writeByte(byte(batch.Type))
|
|
|
frame.writeShort(uint16(len(batch.Entries)))
|
|
|
for i := 0; i < len(batch.Entries); i++ {
|
|
|
@@ -290,15 +359,17 @@ func (c *Conn) ExecuteBatch(batch *Batch) error {
|
|
|
}
|
|
|
frame.writeConsistency(batch.Cons)
|
|
|
|
|
|
- frame, err := c.call(frame)
|
|
|
+ resp, err := c.call(frame)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
-
|
|
|
- if frame[3] == opError {
|
|
|
- return frame.readErrorFrame()
|
|
|
+ switch x := resp.(type) {
|
|
|
+ case resultVoidFrame:
|
|
|
+ case error:
|
|
|
+ return x
|
|
|
+ default:
|
|
|
+ return ErrProtocol
|
|
|
}
|
|
|
-
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -310,81 +381,23 @@ func (c *Conn) Address() string {
|
|
|
return c.addr
|
|
|
}
|
|
|
|
|
|
-func (c *Conn) executeQuery(query *Query) (frame, error) {
|
|
|
- var info *queryInfo
|
|
|
- if len(query.Args) > 0 {
|
|
|
- var err error
|
|
|
- info, err = c.prepareStatement(query.Stmt)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- frame := make(frame, headerSize, defaultFrameSize)
|
|
|
- if info == nil {
|
|
|
- frame.setHeader(protoRequest, 0, 0, opQuery)
|
|
|
- frame.writeLongString(query.Stmt)
|
|
|
- } else {
|
|
|
- frame.setHeader(protoRequest, 0, 0, opExecute)
|
|
|
- frame.writeShortBytes(info.id)
|
|
|
- }
|
|
|
- frame.writeConsistency(query.Cons)
|
|
|
- flags := uint8(0)
|
|
|
- if len(query.Args) > 0 {
|
|
|
- flags |= flagQueryValues
|
|
|
- }
|
|
|
- frame.writeByte(flags)
|
|
|
- if len(query.Args) > 0 {
|
|
|
- frame.writeShort(uint16(len(query.Args)))
|
|
|
- for i := 0; i < len(query.Args); i++ {
|
|
|
- val, err := Marshal(info.args[i].TypeInfo, query.Args[i])
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- frame.writeBytes(val)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- frame, err := c.call(frame)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- if frame[3] == opResult {
|
|
|
- f := frame
|
|
|
- f.skipHeader()
|
|
|
- if f.readInt() == resultKindKeyspace {
|
|
|
- keyspace := f.readString()
|
|
|
- c.cluster.HandleKeyspace(c, keyspace)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if frame[3] == opError {
|
|
|
- frame.skipHeader()
|
|
|
- code := frame.readInt()
|
|
|
- desc := frame.readString()
|
|
|
- return nil, Error{code, desc}
|
|
|
- }
|
|
|
- return frame, nil
|
|
|
-}
|
|
|
-
|
|
|
func (c *Conn) UseKeyspace(keyspace string) error {
|
|
|
frame := make(frame, headerSize, defaultFrameSize)
|
|
|
- frame.setHeader(protoRequest, 0, 0, opQuery)
|
|
|
+ frame.setHeader(c.version, 0, 0, opQuery)
|
|
|
frame.writeLongString("USE " + keyspace)
|
|
|
frame.writeConsistency(1)
|
|
|
frame.writeByte(0)
|
|
|
|
|
|
- frame, err := c.call(frame)
|
|
|
+ resp, err := c.call(frame)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
-
|
|
|
- if frame[3] == opError {
|
|
|
- frame.skipHeader()
|
|
|
- code := frame.readInt()
|
|
|
- desc := frame.readString()
|
|
|
- return Error{code, desc}
|
|
|
+ switch x := resp.(type) {
|
|
|
+ case resultKeyspaceFrame:
|
|
|
+ case error:
|
|
|
+ return x
|
|
|
+ default:
|
|
|
+ return ErrProtocol
|
|
|
}
|
|
|
return nil
|
|
|
}
|