Pārlūkot izejas kodu

improved frame parsing, support protocol v1

Christoph Hack 12 gadi atpakaļ
vecāks
revīzija
ec8aa7ca9c
7 mainītis faili ar 325 papildinājumiem un 254 dzēšanām
  1. 2 1
      cluster.go
  2. 124 111
      conn.go
  3. 98 85
      conn_test.go
  4. 83 16
      frame.go
  5. 3 1
      gocql_test/main.go
  6. 13 38
      session.go
  7. 2 2
      topology.go

+ 2 - 1
cluster.go

@@ -56,6 +56,7 @@ func (cfg *ClusterConfig) CreateSession() *Session {
 		hostPool: NewRoundRobin(),
 		connPool: make(map[string]*RoundRobin),
 		conns:    make(map[*Conn]struct{}),
+		quitWait: make(chan bool),
 	}
 	impl.wgStart.Add(1)
 	for i := 0; i < len(impl.cfg.Hosts); i++ {
@@ -89,7 +90,7 @@ type clusterImpl struct {
 
 func (c *clusterImpl) connect(addr string) {
 	cfg := ConnConfig{
-		ProtoVersion: 2,
+		ProtoVersion: c.cfg.ProtoVersion,
 		CQLVersion:   c.cfg.CQLVersion,
 		Timeout:      c.cfg.Timeout,
 		NumStreams:   c.cfg.NumStreams,

+ 124 - 111
conn.go

@@ -12,11 +12,14 @@ import (
 )
 
 const defaultFrameSize = 4096
+const flagResponse = 0x80
+const maskVersion = 0x7F
 
 type Cluster interface {
 	//HandleAuth(addr, method string) ([]byte, Challenger, error)
 	HandleError(conn *Conn, err error, closed bool)
 	HandleKeyspace(conn *Conn, keyspace string)
+	// Authenticate(addr string)
 }
 
 /* type Challenger interface {
@@ -46,6 +49,7 @@ type Conn struct {
 
 	cluster Cluster
 	addr    string
+	version uint8
 }
 
 // Connect establishes a connection to a Cassandra node.
@@ -58,12 +62,16 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
 	if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
 		cfg.NumStreams = 128
 	}
+	if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
+		cfg.ProtoVersion = 2
+	}
 	c := &Conn{
 		conn:    conn,
 		uniq:    make(chan uint8, cfg.NumStreams),
 		calls:   make([]callReq, cfg.NumStreams),
 		prep:    make(map[string]*queryInfo),
 		timeout: cfg.Timeout,
+		version: uint8(cfg.ProtoVersion),
 		addr:    conn.RemoteAddr().String(),
 		cluster: cluster,
 	}
@@ -82,19 +90,21 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
 
 func (c *Conn) startup(cfg *ConnConfig) error {
 	req := make(frame, headerSize, defaultFrameSize)
-	req.setHeader(protoRequest, 0, 0, opStartup)
+	req.setHeader(c.version, 0, 0, opStartup)
 	req.writeStringMap(map[string]string{
 		"CQL_VERSION": cfg.CQLVersion,
 	})
 	resp, err := c.callSimple(req)
 	if err != nil {
 		return err
-	} else if resp[3] == opError {
-		return resp.readErrorFrame()
-	} else if resp[3] != opReady {
+	}
+	switch x := resp.(type) {
+	case readyFrame:
+	case error:
+		return x
+	default:
 		return ErrProtocol
 	}
-
 	return nil
 }
 
@@ -102,24 +112,22 @@ func (c *Conn) startup(cfg *ConnConfig) error {
 // to execute any queries. This method runs as long as the connection is
 // open and is therefore usually called in a separate goroutine.
 func (c *Conn) serve() {
-	var err error
 	for {
-		var frame frame
-		frame, err = c.recv()
+		resp, err := c.recv()
 		if err != nil {
 			break
 		}
-		c.dispatch(frame)
+		c.dispatch(resp)
 	}
 
 	c.conn.Close()
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
 		if atomic.LoadInt32(&req.active) == 1 {
-			req.resp <- callResp{nil, err}
+			req.resp <- callResp{nil, ErrProtocol}
 		}
 	}
-	c.cluster.HandleError(c, err, true)
+	c.cluster.HandleError(c, ErrProtocol, true)
 }
 
 func (c *Conn) recv() (frame, error) {
@@ -130,7 +138,7 @@ func (c *Conn) recv() (frame, error) {
 		nn, err := c.conn.Read(resp[n:])
 		n += nn
 		if err != nil {
-			if err, ok := err.(net.Error); ok && err.Timeout() {
+			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
 				if n > last {
 					// we hit the deadline but we made progress.
 					// simply extend the deadline
@@ -150,7 +158,7 @@ func (c *Conn) recv() (frame, error) {
 			}
 		}
 		if n == headerSize && len(resp) == headerSize {
-			if resp[0] != protoResponse {
+			if resp[0] != c.version|flagResponse {
 				return nil, ErrProtocol
 			}
 			resp.grow(resp.Length())
@@ -159,16 +167,20 @@ func (c *Conn) recv() (frame, error) {
 	return resp, nil
 }
 
-func (c *Conn) callSimple(req frame) (frame, error) {
+func (c *Conn) callSimple(req frame) (interface{}, error) {
 	req.setLength(len(req) - headerSize)
 	if _, err := c.conn.Write(req); err != nil {
 		c.conn.Close()
 		return nil, err
 	}
-	return c.recv()
+	buf, err := c.recv()
+	if err != nil {
+		return nil, err
+	}
+	return decodeFrame(buf)
 }
 
-func (c *Conn) call(req frame) (frame, error) {
+func (c *Conn) call(req frame) (interface{}, error) {
 	id := <-c.uniq
 	req[2] = id
 
@@ -178,16 +190,22 @@ func (c *Conn) call(req frame) (frame, error) {
 	atomic.StoreInt32(&call.active, 1)
 
 	req.setLength(len(req) - headerSize)
-	if _, err := c.conn.Write(req); err != nil {
+	if n, err := c.conn.Write(req); err != nil {
 		c.conn.Close()
-		return nil, err
+		if n > 0 {
+			return nil, ErrProtocol
+		}
+		return nil, ErrUnavailable
 	}
 
 	reply := <-call.resp
 	call.resp = nil
-
 	c.uniq <- id
-	return reply.buf, reply.err
+
+	if reply.err != nil {
+		return nil, reply.err
+	}
+	return decodeFrame(reply.buf)
 }
 
 func (c *Conn) dispatch(resp frame) {
@@ -205,7 +223,7 @@ func (c *Conn) dispatch(resp frame) {
 
 func (c *Conn) ping() error {
 	req := make(frame, headerSize)
-	req.setHeader(protoRequest, 0, 0, opOptions)
+	req.setHeader(c.version, 0, 0, opOptions)
 	_, err := c.call(req)
 	return err
 }
@@ -224,44 +242,95 @@ func (c *Conn) prepareStatement(stmt string) (*queryInfo, error) {
 	c.prepMu.Unlock()
 
 	frame := make(frame, headerSize, defaultFrameSize)
-	frame.setHeader(protoRequest, 0, 0, opPrepare)
+	frame.setHeader(c.version, 0, 0, opPrepare)
 	frame.writeLongString(stmt)
 	frame.setLength(len(frame) - headerSize)
 
-	frame, err := c.call(frame)
+	resp, err := c.call(frame)
 	if err != nil {
 		return nil, err
 	}
-	if frame[3] == opError {
-		return nil, frame.readErrorFrame()
+	switch x := resp.(type) {
+	case resultPreparedFrame:
+		info.id = x.PreparedId
+		info.args = x.Values
+		info.wg.Done()
+	case error:
+		return nil, x
+	default:
+		return nil, ErrProtocol
 	}
-	frame.skipHeader()
-	frame.readInt() // kind
-	info.id = frame.readShortBytes()
-	info.args = frame.readMetaData()
-	info.rval = frame.readMetaData()
-	info.wg.Done()
 	return info, nil
 }
 
 func (c *Conn) ExecuteQuery(qry *Query) (*Iter, error) {
-	frame, err := c.executeQuery(qry)
+	var info *queryInfo
+	if len(qry.Args) > 0 {
+		var err error
+		info, err = c.prepareStatement(qry.Stmt)
+		if err != nil {
+			return nil, err
+		}
+	}
+	req := make(frame, headerSize, defaultFrameSize)
+	if info == nil {
+		req.setHeader(c.version, 0, 0, opQuery)
+		req.writeLongString(qry.Stmt)
+		req.writeConsistency(qry.Cons)
+		if c.version > 1 {
+			req.writeByte(0)
+		}
+	} else {
+		req.setHeader(c.version, 0, 0, opExecute)
+		req.writeShortBytes(info.id)
+		if c.version == 1 {
+			req.writeShort(uint16(len(qry.Args)))
+		} else {
+			req.writeConsistency(qry.Cons)
+			flags := uint8(0)
+			if len(qry.Args) > 0 {
+				flags |= flagQueryValues
+			}
+			req.writeByte(flags)
+			if flags&flagQueryValues != 0 {
+				req.writeShort(uint16(len(qry.Args)))
+			}
+		}
+		for i := 0; i < len(qry.Args); i++ {
+			val, err := Marshal(info.args[i].TypeInfo, qry.Args[i])
+			if err != nil {
+				return nil, err
+			}
+			req.writeBytes(val)
+		}
+		if c.version == 1 {
+			req.writeConsistency(qry.Cons)
+		}
+	}
+	resp, err := c.call(req)
 	if err != nil {
 		return nil, err
 	}
-	if frame[3] == opError {
-		return nil, frame.readErrorFrame()
-	} else if frame[3] == opResult {
-		iter := new(Iter)
-		iter.readFrame(frame)
-		return iter, nil
+	switch x := resp.(type) {
+	case resultVoidFrame:
+		return &Iter{}, nil
+	case resultRowsFrame:
+		return &Iter{columns: x.Columns, rows: x.Rows}, nil
+	case resultKeyspaceFrame:
+		c.cluster.HandleKeyspace(c, x.Keyspace)
+		return &Iter{}, nil
+	case error:
+		return &Iter{err: x}, nil
 	}
-	return nil, nil
+	return nil, ErrProtocol
 }
 
 func (c *Conn) ExecuteBatch(batch *Batch) error {
+	if c.version == 1 {
+		return ErrProtocol
+	}
 	frame := make(frame, headerSize, defaultFrameSize)
-	frame.setHeader(protoRequest, 0, 0, opBatch)
+	frame.setHeader(c.version, 0, 0, opBatch)
 	frame.writeByte(byte(batch.Type))
 	frame.writeShort(uint16(len(batch.Entries)))
 	for i := 0; i < len(batch.Entries); i++ {
@@ -290,15 +359,17 @@ func (c *Conn) ExecuteBatch(batch *Batch) error {
 	}
 	frame.writeConsistency(batch.Cons)
 
-	frame, err := c.call(frame)
+	resp, err := c.call(frame)
 	if err != nil {
 		return err
 	}
-
-	if frame[3] == opError {
-		return frame.readErrorFrame()
+	switch x := resp.(type) {
+	case resultVoidFrame:
+	case error:
+		return x
+	default:
+		return ErrProtocol
 	}
-
 	return nil
 }
 
@@ -310,81 +381,23 @@ func (c *Conn) Address() string {
 	return c.addr
 }
 
-func (c *Conn) executeQuery(query *Query) (frame, error) {
-	var info *queryInfo
-	if len(query.Args) > 0 {
-		var err error
-		info, err = c.prepareStatement(query.Stmt)
-		if err != nil {
-			return nil, err
-		}
-	}
-
-	frame := make(frame, headerSize, defaultFrameSize)
-	if info == nil {
-		frame.setHeader(protoRequest, 0, 0, opQuery)
-		frame.writeLongString(query.Stmt)
-	} else {
-		frame.setHeader(protoRequest, 0, 0, opExecute)
-		frame.writeShortBytes(info.id)
-	}
-	frame.writeConsistency(query.Cons)
-	flags := uint8(0)
-	if len(query.Args) > 0 {
-		flags |= flagQueryValues
-	}
-	frame.writeByte(flags)
-	if len(query.Args) > 0 {
-		frame.writeShort(uint16(len(query.Args)))
-		for i := 0; i < len(query.Args); i++ {
-			val, err := Marshal(info.args[i].TypeInfo, query.Args[i])
-			if err != nil {
-				return nil, err
-			}
-			frame.writeBytes(val)
-		}
-	}
-
-	frame, err := c.call(frame)
-	if err != nil {
-		return nil, err
-	}
-
-	if frame[3] == opResult {
-		f := frame
-		f.skipHeader()
-		if f.readInt() == resultKindKeyspace {
-			keyspace := f.readString()
-			c.cluster.HandleKeyspace(c, keyspace)
-		}
-	}
-
-	if frame[3] == opError {
-		frame.skipHeader()
-		code := frame.readInt()
-		desc := frame.readString()
-		return nil, Error{code, desc}
-	}
-	return frame, nil
-}
-
 func (c *Conn) UseKeyspace(keyspace string) error {
 	frame := make(frame, headerSize, defaultFrameSize)
-	frame.setHeader(protoRequest, 0, 0, opQuery)
+	frame.setHeader(c.version, 0, 0, opQuery)
 	frame.writeLongString("USE " + keyspace)
 	frame.writeConsistency(1)
 	frame.writeByte(0)
 
-	frame, err := c.call(frame)
+	resp, err := c.call(frame)
 	if err != nil {
 		return err
 	}
-
-	if frame[3] == opError {
-		frame.skipHeader()
-		code := frame.readInt()
-		desc := frame.readString()
-		return Error{code, desc}
+	switch x := resp.(type) {
+	case resultKeyspaceFrame:
+	case error:
+		return x
+	default:
+		return ErrProtocol
 	}
 	return nil
 }

+ 98 - 85
gocql_test.go → conn_test.go

@@ -21,6 +21,101 @@ type TestServer struct {
 	listen  net.Listener
 }
 
+func TestSimple(t *testing.T) {
+	srv := NewTestServer(t)
+	defer srv.Stop()
+
+	db := NewCluster(srv.Address).CreateSession()
+
+	if err := db.Query("void").Exec(); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestClosed(t *testing.T) {
+	srv := NewTestServer(t)
+	defer srv.Stop()
+
+	session := NewCluster(srv.Address).CreateSession()
+
+	session.Close()
+
+	if err := session.Query("void").Exec(); err != ErrUnavailable {
+		t.Errorf("expected %#v, got %#v", ErrUnavailable, err)
+	}
+}
+
+func TestTimeout(t *testing.T) {
+	srv := NewTestServer(t)
+	defer srv.Stop()
+
+	db := NewCluster(srv.Address).CreateSession()
+
+	go func() {
+		<-time.After(1 * time.Second)
+		t.Fatal("no timeout")
+	}()
+
+	if err := db.Query("kill").Exec(); err == nil {
+		t.Fatal("expected error")
+	}
+}
+
+func TestSlowQuery(t *testing.T) {
+	srv := NewTestServer(t)
+	defer srv.Stop()
+
+	db := NewCluster(srv.Address).CreateSession()
+
+	if err := db.Query("slow").Exec(); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestRoundRobin(t *testing.T) {
+	servers := make([]*TestServer, 5)
+	addrs := make([]string, len(servers))
+	for i := 0; i < len(servers); i++ {
+		servers[i] = NewTestServer(t)
+		addrs[i] = servers[i].Address
+		defer servers[i].Stop()
+	}
+	cluster := NewCluster(addrs...)
+	cluster.StartupMin = len(addrs)
+	db := cluster.CreateSession()
+
+	var wg sync.WaitGroup
+	wg.Add(5)
+	for i := 0; i < 5; i++ {
+		go func() {
+			for j := 0; j < 5; j++ {
+				if err := db.Query("void").Exec(); err != nil {
+					t.Fatal(err)
+				}
+			}
+			wg.Done()
+		}()
+	}
+	wg.Wait()
+
+	diff := 0
+	for i := 1; i < len(servers); i++ {
+		d := 0
+		if servers[i].nreq > servers[i-1].nreq {
+			d = int(servers[i].nreq - servers[i-1].nreq)
+		} else {
+			d = int(servers[i-1].nreq - servers[i].nreq)
+		}
+		if d > diff {
+			diff = d
+		}
+	}
+
+	if diff > 0 {
+		t.Fatal("diff:", diff)
+	}
+}
+
 func NewTestServer(t *testing.T) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
@@ -78,7 +173,7 @@ func (srv *TestServer) process(frame frame, conn net.Conn) {
 		case "slow":
 			go func() {
 				<-time.After(1 * time.Second)
-				frame.writeInt(0)
+				frame.writeInt(resultKindVoid)
 				frame.setLength(len(frame) - headerSize)
 				if _, err := conn.Write(frame); err != nil {
 					return
@@ -89,9 +184,9 @@ func (srv *TestServer) process(frame frame, conn net.Conn) {
 			frame.writeInt(3)
 			frame.writeString(strings.TrimSpace(query[3:]))
 		case "void":
-			frame.writeInt(0)
+			frame.writeInt(resultKindVoid)
 		default:
-			frame.writeInt(0)
+			frame.writeInt(resultKindVoid)
 		}
 	default:
 		frame = frame[:headerSize]
@@ -118,85 +213,3 @@ func (srv *TestServer) readFrame(conn net.Conn) frame {
 	}
 	return frame
 }
-
-func TestSimple(t *testing.T) {
-	srv := NewTestServer(t)
-	defer srv.Stop()
-
-	db := NewCluster(srv.Address).CreateSession()
-
-	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
-	}
-}
-
-func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t)
-	defer srv.Stop()
-
-	db := NewCluster(srv.Address).CreateSession()
-
-	go func() {
-		<-time.After(1 * time.Second)
-		t.Fatal("no timeout")
-	}()
-
-	if err := db.Query("kill").Exec(); err == nil {
-		t.Fatal("expected error")
-	}
-}
-
-func TestSlowQuery(t *testing.T) {
-	srv := NewTestServer(t)
-	defer srv.Stop()
-
-	db := NewCluster(srv.Address).CreateSession()
-
-	if err := db.Query("slow").Exec(); err != nil {
-		t.Fatal(err)
-	}
-}
-
-func TestRoundRobin(t *testing.T) {
-	servers := make([]*TestServer, 5)
-	addrs := make([]string, len(servers))
-	for i := 0; i < len(servers); i++ {
-		servers[i] = NewTestServer(t)
-		addrs[i] = servers[i].Address
-		defer servers[i].Stop()
-	}
-	cluster := NewCluster(addrs...)
-	cluster.StartupMin = len(addrs)
-	db := cluster.CreateSession()
-
-	var wg sync.WaitGroup
-	wg.Add(5)
-	for i := 0; i < 5; i++ {
-		go func() {
-			for j := 0; j < 5; j++ {
-				if err := db.Query("void").Exec(); err != nil {
-					t.Fatal(err)
-				}
-			}
-			wg.Done()
-		}()
-	}
-	wg.Wait()
-
-	diff := 0
-	for i := 1; i < len(servers); i++ {
-		d := 0
-		if servers[i].nreq > servers[i-1].nreq {
-			d = int(servers[i].nreq - servers[i-1].nreq)
-		} else {
-			d = int(servers[i-1].nreq - servers[i].nreq)
-		}
-		if d > diff {
-			diff = d
-		}
-	}
-
-	if diff > 0 {
-		t.Fatal("diff:", diff)
-	}
-}

+ 83 - 16
frame.go

@@ -254,22 +254,6 @@ func (f *frame) readMetaData() []ColumnInfo {
 	return info
 }
 
-func (f *frame) readErrorFrame() (err error) {
-	defer func() {
-		if r := recover(); r != nil {
-			if e, ok := r.(error); ok && e == ErrProtocol {
-				err = e
-				return
-			}
-			panic(r)
-		}
-	}()
-	f.skipHeader()
-	code := f.readInt()
-	desc := f.readString()
-	return Error{code, desc}
-}
-
 func (f *frame) writeConsistency(c Consistency) {
 	f.writeShort(consistencyCodes[c])
 }
@@ -286,3 +270,86 @@ var consistencyCodes = []uint16{
 	Serial:      0x0008,
 	LocalSerial: 0x0009,
 }
+
+func decodeFrame(f frame) (rval interface{}, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			if e, ok := r.(error); ok && e == ErrProtocol {
+				err = e
+				return
+			}
+			panic(r)
+		}
+	}()
+	if len(f) < headerSize || (f[0] != 1|flagResponse && f[0] != 2|flagResponse) {
+		return nil, ErrProtocol
+	}
+	switch f[3] {
+	case opReady:
+		return readyFrame{}, nil
+	case opResult:
+		f.skipHeader()
+		switch kind := f.readInt(); kind {
+		case resultKindVoid:
+			return resultVoidFrame{}, nil
+		case resultKindRows:
+			columns := f.readMetaData()
+			numRows := f.readInt()
+			values := make([][]byte, numRows*len(columns))
+			for i := 0; i < len(values); i++ {
+				values[i] = f.readBytes()
+			}
+			rows := make([][][]byte, numRows)
+			for i := 0; i < len(values); i += len(columns) {
+				rows[i] = values[i : i+len(columns)]
+			}
+			return resultRowsFrame{columns, rows, nil}, nil
+		case resultKindKeyspace:
+			keyspace := f.readString()
+			return resultKeyspaceFrame{keyspace}, nil
+		case resultKindPrepared:
+			id := f.readShortBytes()
+			values := f.readMetaData()
+			return resultPreparedFrame{id, values}, nil
+		case resultKindSchemaChanged:
+			return resultVoidFrame{}, nil
+		default:
+			return nil, ErrProtocol
+		}
+	case opError:
+		f.skipHeader()
+		code := f.readInt()
+		msg := f.readString()
+		return errorFrame{code, msg}, nil
+	default:
+		return nil, ErrProtocol
+	}
+}
+
+type readyFrame struct{}
+
+type resultVoidFrame struct{}
+
+type resultRowsFrame struct {
+	Columns     []ColumnInfo
+	Rows        [][][]byte
+	PagingState []byte
+}
+
+type resultKeyspaceFrame struct {
+	Keyspace string
+}
+
+type resultPreparedFrame struct {
+	PreparedId []byte
+	Values     []ColumnInfo
+}
+
+type errorFrame struct {
+	Code    int
+	Message string
+}
+
+func (e errorFrame) Error() string {
+	return e.Message
+}

+ 3 - 1
gocql_test/main.go

@@ -35,7 +35,9 @@ type Page struct {
 type Attachment []byte
 
 func initSchema() error {
-	session.Query("DROP KEYSPACE gocql_test").Exec()
+	if err := session.Query("DROP KEYSPACE gocql_test").Exec(); err != nil {
+		log.Println("drop keyspace", err)
+	}
 
 	if err := session.Query(`CREATE KEYSPACE gocql_test
 		WITH replication = {

+ 13 - 38
session.go

@@ -127,57 +127,32 @@ func (b QueryBuilder) Scan(values ...interface{}) error {
 }
 
 type Iter struct {
-	err    error
-	pos    int
-	values [][]byte
-	info   []ColumnInfo
-}
-
-func (iter *Iter) readFrame(frame frame) {
-	defer func() {
-		if r := recover(); r != nil {
-			if e, ok := r.(error); ok && e == ErrProtocol {
-				iter.err = e
-				return
-			}
-			panic(r)
-		}
-	}()
-	frame.skipHeader()
-	iter.pos = 0
-	iter.err = nil
-	iter.values = nil
-	if frame.readInt() != resultKindRows {
-		return
-	}
-	iter.info = frame.readMetaData()
-	numRows := frame.readInt()
-	iter.values = make([][]byte, numRows*len(iter.info))
-	for i := 0; i < len(iter.values); i++ {
-		iter.values[i] = frame.readBytes()
-	}
+	err     error
+	pos     int
+	rows    [][][]byte
+	columns []ColumnInfo
 }
 
 func (iter *Iter) Columns() []ColumnInfo {
-	return iter.info
+	return iter.columns
 }
 
 func (iter *Iter) Scan(values ...interface{}) bool {
-	if iter.err != nil || iter.pos >= len(iter.values) {
+	if iter.err != nil || iter.pos >= len(iter.rows) {
 		return false
 	}
-	if len(values) != len(iter.info) {
+	if len(values) != len(iter.columns) {
 		iter.err = errors.New("count mismatch")
 		return false
 	}
-	for i := 0; i < len(values); i++ {
-		err := Unmarshal(iter.info[i].TypeInfo, iter.values[i+iter.pos], values[i])
+	for i := 0; i < len(iter.columns); i++ {
+		err := Unmarshal(iter.columns[i].TypeInfo, iter.rows[iter.pos][i], values[i])
 		if err != nil {
 			iter.err = err
 			return false
 		}
 	}
-	iter.pos += len(values)
+	iter.pos++
 	return true
 }
 
@@ -262,7 +237,7 @@ func (e Error) Error() string {
 }
 
 var (
-	ErrNotFound        = errors.New("not found")
-	ErrNoHostAvailable = errors.New("no host available")
-	ErrProtocol        = errors.New("protocol error")
+	ErrNotFound    = errors.New("not found")
+	ErrUnavailable = errors.New("unavailable")
+	ErrProtocol    = errors.New("protocol error")
 )

+ 2 - 2
topology.go

@@ -54,7 +54,7 @@ func (r *RoundRobin) Size() int {
 func (r *RoundRobin) ExecuteQuery(qry *Query) (*Iter, error) {
 	node := r.pick()
 	if node == nil {
-		return nil, ErrNoHostAvailable
+		return nil, ErrUnavailable
 	}
 	return node.ExecuteQuery(qry)
 }
@@ -62,7 +62,7 @@ func (r *RoundRobin) ExecuteQuery(qry *Query) (*Iter, error) {
 func (r *RoundRobin) ExecuteBatch(batch *Batch) error {
 	node := r.pick()
 	if node == nil {
-		return ErrNoHostAvailable
+		return ErrUnavailable
 	}
 	return node.ExecuteBatch(batch)
 }