Browse Source

Merge pull request #318 from Zariel/support-protocol-v3

Support protocol v3 wire format
Ben Hood 10 years ago
parent
commit
66e4489a53
7 changed files with 715 additions and 300 deletions
  1. 11 2
      cluster.go
  2. 108 53
      conn.go
  3. 287 85
      conn_test.go
  4. 149 33
      frame.go
  5. 2 0
      integration.sh
  6. 61 30
      marshal.go
  7. 97 97
      marshal_test.go

+ 11 - 2
cluster.go

@@ -62,7 +62,7 @@ type ClusterConfig struct {
 	Port             int           // port (default: 9042)
 	Keyspace         string        // initial keyspace (optional)
 	NumConns         int           // number of connections per host (default: 2)
-	NumStreams       int           // number of streams per connection (default: 128)
+	NumStreams       int           // number of streams per connection (default: max per protocol, either 128 or 32768)
 	Consistency      Consistency   // default consistency level (default: Quorum)
 	Compressor       Compressor    // compression algorithm (default: nil)
 	Authenticator    Authenticator // authenticator (default: nil)
@@ -85,7 +85,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Timeout:          600 * time.Millisecond,
 		Port:             9042,
 		NumConns:         2,
-		NumStreams:       128,
 		Consistency:      Quorum,
 		ConnPoolType:     NewSimplePool,
 		DiscoverHosts:    false,
@@ -102,6 +101,16 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	if len(cfg.Hosts) < 1 {
 		return nil, ErrNoHosts
 	}
+
+	maxStreams := 128
+	if cfg.ProtoVersion > protoVersion2 {
+		maxStreams = 32768
+	}
+
+	if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
+		cfg.NumStreams = maxStreams
+	}
+
 	pool := cfg.ConnPoolType(cfg)
 
 	//Adjust the size of the prepared statements cache to match the latest configuration

+ 108 - 53
conn.go

@@ -10,7 +10,9 @@ import (
 	"crypto/x509"
 	"errors"
 	"fmt"
+	"io"
 	"io/ioutil"
+	"log"
 	"net"
 	"strconv"
 	"strings"
@@ -19,9 +21,11 @@ import (
 	"time"
 )
 
-const defaultFrameSize = 4096
-const flagResponse = 0x80
-const maskVersion = 0x7F
+const (
+	defaultFrameSize = 4096
+	flagResponse     = 0x80
+	maskVersion      = 0x7F
+)
 
 //JoinHostPort is a utility to return a address string that can be used
 //gocql.Conn to form a connection with a host.
@@ -88,7 +92,7 @@ type Conn struct {
 	r       *bufio.Reader
 	timeout time.Duration
 
-	uniq  chan uint8
+	uniq  chan int
 	calls []callReq
 	nwait int32
 
@@ -123,14 +127,17 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 				return nil, errors.New("Failed parsing or appending certs")
 			}
 		}
+
 		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
 		if err != nil {
 			return nil, err
 		}
+
 		config := tls.Config{
 			Certificates: []tls.Certificate{mycert},
 			RootCAs:      certPool,
 		}
+
 		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
 		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
 			return nil, err
@@ -139,16 +146,25 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		return nil, err
 	}
 
-	if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
-		cfg.NumStreams = 128
-	}
-	if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
+	// going to default to proto 2
+	if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
+		log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
 		cfg.ProtoVersion = 2
 	}
+
+	maxStreams := 128
+	if cfg.ProtoVersion > protoVersion2 {
+		maxStreams = 32768
+	}
+
+	if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
+		cfg.NumStreams = maxStreams
+	}
+
 	c := &Conn{
 		conn:       conn,
 		r:          bufio.NewReader(conn),
-		uniq:       make(chan uint8, cfg.NumStreams),
+		uniq:       make(chan int, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
 		timeout:    cfg.Timeout,
 		version:    uint8(cfg.ProtoVersion),
@@ -162,8 +178,8 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		c.setKeepalive(cfg.Keepalive)
 	}
 
-	for i := 0; i < cap(c.uniq); i++ {
-		c.uniq <- uint8(i)
+	for i := 0; i < cfg.NumStreams; i++ {
+		c.uniq <- i
 	}
 
 	if err := c.startup(&cfg); err != nil {
@@ -254,53 +270,80 @@ func (c *Conn) serve() {
 	c.pool.HandleError(c, err, true)
 }
 
+func (c *Conn) Write(p []byte) (int, error) {
+	c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
+	return c.conn.Write(p)
+}
+
+func (c *Conn) Read(p []byte) (int, error) {
+	return c.r.Read(p)
+}
+
 func (c *Conn) recv() (frame, error) {
-	resp := make(frame, headerSize, headerSize+512)
-	c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-	n, last, pinged := 0, 0, false
-	for n < len(resp) {
-		nn, err := c.r.Read(resp[n:])
-		n += nn
-		if err != nil {
-			if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
-				if n > last {
-					// we hit the deadline but we made progress.
-					// simply extend the deadline
-					c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-					last = n
-				} else if n == 0 && !pinged {
-					c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-					if atomic.LoadInt32(&c.nwait) > 0 {
-						go c.ping()
-						pinged = true
-					}
-				} else {
-					return nil, err
-				}
-			} else {
-				return nil, err
-			}
+	size := headerProtoSize[c.version]
+	resp := make(frame, size, size+512)
+
+	// read a full header, ignore timeouts, as this is being ran in a loop
+	c.conn.SetReadDeadline(time.Time{})
+	_, err := io.ReadFull(c.r, resp[:size])
+	if err != nil {
+		return nil, err
+	}
+
+	if v := c.version | flagResponse; resp[0] != v {
+		return nil, NewErrProtocol("recv: response protocol version does not match connection protocol version (%d != %d)", resp[0], v)
+	}
+
+	bodySize := resp.Length(c.version)
+	if bodySize == 0 {
+		return resp, nil
+	}
+	resp.grow(bodySize)
+
+	const maxAttempts = 5
+
+	n := size
+	for i := 0; i < maxAttempts; i++ {
+		var nn int
+		c.conn.SetReadDeadline(time.Now().Add(c.timeout))
+		nn, err = io.ReadFull(c.r, resp[n:size+bodySize])
+		if err == nil {
+			break
 		}
-		if n == headerSize && len(resp) == headerSize {
-			if resp[0] != c.version|flagResponse {
-				return nil, NewErrProtocol("recv: Response protocol version does not match connection protocol version (%d != %d)", resp[0], c.version|flagResponse)
-			}
-			resp.grow(resp.Length())
+		n += nn
+
+		if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
+			break
 		}
 	}
+
+	if err != nil {
+		return nil, err
+	}
+
 	return resp, nil
 }
 
 func (c *Conn) execSimple(op operation) (interface{}, error) {
 	f, err := op.encodeFrame(c.version, nil)
-	f.setLength(len(f) - headerSize)
-	if _, err := c.conn.Write([]byte(f)); err != nil {
+	if err != nil {
+		// this should be a noop err
+		return nil, err
+	}
+
+	bodyLen := len(f) - headerProtoSize[c.version]
+	f.setLength(bodyLen, c.version)
+
+	if _, err := c.Write([]byte(f)); err != nil {
 		c.Close()
 		return nil, err
 	}
+
+	// here recv wont timeout waiting for a header, should it?
 	if f, err = c.recv(); err != nil {
 		return nil, err
 	}
+
 	return c.decodeFrame(f, nil)
 }
 
@@ -309,9 +352,12 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 	if err != nil {
 		return nil, err
 	}
+
 	if trace != nil {
 		req[1] |= flagTrace
 	}
+
+	headerSize := headerProtoSize[c.version]
 	if len(req) > headerSize && c.compressor != nil {
 		body, err := c.compressor.Encode([]byte(req[headerSize:]))
 		if err != nil {
@@ -320,16 +366,17 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 		req = append(req[:headerSize], frame(body)...)
 		req[1] |= flagCompress
 	}
-	req.setLength(len(req) - headerSize)
+	bodyLen := len(req) - headerSize
+	req.setLength(bodyLen, c.version)
 
 	id := <-c.uniq
-	req[2] = id
+	req.setStream(id, c.version)
 	call := &c.calls[id]
 	call.resp = make(chan callResp, 1)
 	atomic.AddInt32(&c.nwait, 1)
 	atomic.StoreInt32(&call.active, 1)
 
-	if _, err := c.conn.Write(req); err != nil {
+	if _, err := c.Write(req); err != nil {
 		c.uniq <- id
 		c.Close()
 		return nil, err
@@ -342,11 +389,12 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 	if reply.err != nil {
 		return nil, reply.err
 	}
+
 	return c.decodeFrame(reply.buf, trace)
 }
 
 func (c *Conn) dispatch(resp frame) {
-	id := int(resp[2])
+	id := resp.Stream(c.version)
 	if id >= len(c.calls) {
 		return
 	}
@@ -543,10 +591,10 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 }
 
 func (c *Conn) executeBatch(batch *Batch) error {
-	if c.version == 1 {
+	if c.version == protoVersion1 {
 		return ErrUnsupported
 	}
-	f := make(frame, headerSize, defaultFrameSize)
+	f := newFrame(c.version)
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.writeByte(byte(batch.Type))
 	f.writeShort(uint16(len(batch.Entries)))
@@ -594,6 +642,10 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		}
 	}
 	f.writeConsistency(batch.Cons)
+	if c.version >= protoVersion3 {
+		// TODO: add support for flags here
+		f.writeByte(0)
+	}
 
 	resp, err := c.exec(f, nil)
 	if err != nil {
@@ -631,12 +683,15 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 			panic(r)
 		}
 	}()
+
+	headerSize := headerProtoSize[c.version]
 	if len(f) < headerSize {
 		return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
 	} else if f[0] != c.version|flagResponse {
 		return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
 	}
-	flags, op, f := f[1], f[3], f[headerSize:]
+
+	flags, op, f := f[1], f.Op(c.version), f[headerSize:]
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
 		if buf, err := c.compressor.Decode([]byte(f)); err != nil {
 			return nil, err
@@ -661,7 +716,7 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 		case resultKindVoid:
 			return resultVoidFrame{}, nil
 		case resultKindRows:
-			columns, pageState := f.readMetaData()
+			columns, pageState := f.readMetaData(c.version)
 			numRows := f.readInt()
 			values := make([][]byte, numRows*len(columns))
 			for i := 0; i < len(values); i++ {
@@ -677,11 +732,11 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 			return resultKeyspaceFrame{keyspace}, nil
 		case resultKindPrepared:
 			id := f.readShortBytes()
-			args, _ := f.readMetaData()
+			args, _ := f.readMetaData(c.version)
 			if c.version < 2 {
 				return resultPreparedFrame{PreparedId: id, Arguments: args}, nil
 			}
-			rvals, _ := f.readMetaData()
+			rvals, _ := f.readMetaData(c.version)
 			return resultPreparedFrame{PreparedId: id, Arguments: args, ReturnValues: rvals}, nil
 		case resultKindSchemaChanged:
 			return resultVoidFrame{}, nil

+ 287 - 85
conn_test.go

@@ -5,6 +5,7 @@ package gocql
 import (
 	"crypto/tls"
 	"crypto/x509"
+	"fmt"
 	"io"
 	"io/ioutil"
 	"net"
@@ -15,6 +16,10 @@ import (
 	"time"
 )
 
+const (
+	defaultProto = protoVersion2
+)
+
 func TestJoinHostPort(t *testing.T) {
 	tests := map[string]string{
 		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
@@ -29,43 +34,38 @@ func TestJoinHostPort(t *testing.T) {
 	}
 }
 
-type TestServer struct {
-	Address  string
-	t        *testing.T
-	nreq     uint64
-	listen   net.Listener
-	nKillReq uint64
-}
-
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+	db, err := cluster.CreateSession()
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
+		return
 	}
 
 	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
+		t.Errorf("0x%x: %v", defaultProto, err)
 	}
 }
 
 func TestSSLSimple(t *testing.T) {
-	srv := NewSSLTestServer(t)
+	srv := NewSSLTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := createTestSslCluster(srv.Address).CreateSession()
+	db, err := createTestSslCluster(srv.Address, defaultProto).CreateSession()
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
 	}
 
 	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
+		t.Fatalf("0x%x: %v", defaultProto, err)
 	}
 }
 
-func createTestSslCluster(hosts string) *ClusterConfig {
+func createTestSslCluster(hosts string, proto uint8) *ClusterConfig {
 	cluster := NewCluster(hosts)
 	cluster.SslOpts = &SslOptions{
 		CertPath:               "testdata/pki/gocql.crt",
@@ -73,82 +73,103 @@ func createTestSslCluster(hosts string) *ClusterConfig {
 		CaPath:                 "testdata/pki/ca.crt",
 		EnableHostVerification: false,
 	}
+	cluster.ProtoVersion = int(proto)
 	return cluster
 }
 
 func TestClosed(t *testing.T) {
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	session, err := NewCluster(srv.Address).CreateSession()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+
+	session, err := cluster.CreateSession()
+	defer session.Close()
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
+		return
 	}
-	session.Close()
 
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
-		t.Errorf("expected %#v, got %#v", ErrSessionClosed, err)
+		t.Errorf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
+		return
 	}
 }
 
+func newTestSession(addr string, proto uint8) (*Session, error) {
+	cluster := NewCluster(addr)
+	cluster.ProtoVersion = int(proto)
+	return cluster.CreateSession()
+}
+
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 		t.Errorf("NewCluster: %v", err)
+		return
 	}
+	defer db.Close()
 
 	go func() {
 		<-time.After(2 * time.Second)
-		t.Fatal("no timeout")
+		t.Errorf("no timeout")
 	}()
 
 	if err := db.Query("kill").Exec(); err == nil {
-		t.Fatal("expected error")
+		t.Errorf("expected error")
 	}
 }
 
 // TestQueryRetry will test to make sure that gocql will execute
 // the exact amount of retry queries designated by the user.
 func TestQueryRetry(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("NewCluster: %v", err)
 	}
+	defer db.Close()
 
 	go func() {
 		<-time.After(5 * time.Second)
-		t.Fatal("no timeout")
+		t.Fatalf("no timeout")
 	}()
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 
 	qry := db.Query("kill").RetryPolicy(rt)
 	if err := qry.Exec(); err == nil {
-		t.Fatal("expected error")
+		t.Fatalf("expected error")
 	}
-	requests := srv.nKillReq
-	if requests != uint64(qry.Attempts()) {
-		t.Fatalf("expected requests %v to match query attemps %v", requests, qry.Attempts())
+
+	requests := atomic.LoadInt64(&srv.nKillReq)
+	attempts := qry.Attempts()
+	if requests != int64(attempts) {
+		t.Fatalf("expected requests %v to match query attemps %v", requests, attempts)
 	}
+
 	//Minus 1 from the requests variable since there is the initial query attempt
-	if requests-1 != uint64(rt.NumRetries) {
+	if requests-1 != int64(rt.NumRetries) {
 		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
 	}
 }
 
 func TestSlowQuery(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 		t.Errorf("NewCluster: %v", err)
+		return
 	}
 
 	if err := db.Query("slow").Exec(); err != nil {
@@ -159,22 +180,24 @@ func TestSlowQuery(t *testing.T) {
 func TestRoundRobin(t *testing.T) {
 	servers := make([]*TestServer, 5)
 	addrs := make([]string, len(servers))
-	for i := 0; i < len(servers); i++ {
-		servers[i] = NewTestServer(t)
-		addrs[i] = servers[i].Address
-		defer servers[i].Stop()
+	for n := 0; n < len(servers); n++ {
+		servers[n] = NewTestServer(t, defaultProto)
+		addrs[n] = servers[n].Address
+		defer servers[n].Stop()
 	}
 	cluster := NewCluster(addrs...)
+	cluster.ProtoVersion = defaultProto
+
 	db, err := cluster.CreateSession()
-	time.Sleep(1 * time.Second) //Sleep to allow the Cluster.fillPool to complete
+	time.Sleep(1 * time.Second) // Sleep to allow the Cluster.fillPool to complete
 
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("NewCluster: %v", err)
 	}
 
 	var wg sync.WaitGroup
 	wg.Add(5)
-	for i := 0; i < 5; i++ {
+	for n := 0; n < 5; n++ {
 		go func() {
 			for j := 0; j < 5; j++ {
 				if err := db.Query("void").Exec(); err != nil {
@@ -187,12 +210,12 @@ func TestRoundRobin(t *testing.T) {
 	wg.Wait()
 
 	diff := 0
-	for i := 1; i < len(servers); i++ {
+	for n := 1; n < len(servers); n++ {
 		d := 0
-		if servers[i].nreq > servers[i-1].nreq {
-			d = int(servers[i].nreq - servers[i-1].nreq)
+		if servers[n].nreq > servers[n-1].nreq {
+			d = int(servers[n].nreq - servers[n-1].nreq)
 		} else {
-			d = int(servers[i-1].nreq - servers[i].nreq)
+			d = int(servers[n-1].nreq - servers[n].nreq)
 		}
 		if d > diff {
 			diff = d
@@ -206,7 +229,8 @@ func TestRoundRobin(t *testing.T) {
 
 func TestConnClosing(t *testing.T) {
 	t.Skip("Skipping until test can be ran reliably")
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, protoVersion2)
 	defer srv.Stop()
 
 	db, err := NewCluster(srv.Address).CreateSession()
@@ -238,21 +262,147 @@ func TestConnClosing(t *testing.T) {
 	}
 }
 
-func NewTestServer(t *testing.T) *TestServer {
+func TestStreams_Protocol1(t *testing.T) {
+	srv := NewTestServer(t, protoVersion1)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 1
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	var wg sync.WaitGroup
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			if err := db.Query("void").Exec(); err != nil {
+				t.Error(err)
+			}
+		}()
+	}
+	wg.Wait()
+}
+
+func TestStreams_Protocol2(t *testing.T) {
+	srv := NewTestServer(t, protoVersion2)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 2
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// the test server processes each conn synchronously
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		if err = db.Query("void").Exec(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func TestStreams_Protocol3(t *testing.T) {
+	srv := NewTestServer(t, protoVersion3)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 3
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// the test server processes each conn synchronously
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		if err = db.Query("void").Exec(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func BenchmarkProtocolV3(b *testing.B) {
+	srv := NewTestServer(b, protoVersion3)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 3
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		b.Fatal(err)
+	}
+	defer db.Close()
+
+	b.ResetTimer()
+	b.ReportAllocs()
+	for i := 0; i < b.N; i++ {
+		if err = db.Query("void").Exec(); err != nil {
+			b.Fatal(err)
+		}
+	}
+}
+
+func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 		t.Fatal(err)
 	}
+
 	listen, err := net.ListenTCP("tcp", laddr)
 	if err != nil {
 		t.Fatal(err)
 	}
-	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
+
+	headerSize := 8
+	if protocol > protoVersion2 {
+		headerSize = 9
+	}
+
+	srv := &TestServer{
+		Address:    listen.Addr().String(),
+		listen:     listen,
+		t:          t,
+		protocol:   protocol,
+		headerSize: headerSize,
+	}
+
 	go srv.serve()
+
 	return srv
 }
 
-func NewSSLTestServer(t *testing.T) *TestServer {
+func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	certPool := x509.NewCertPool()
 	if !certPool.AppendCertsFromPEM(pem) {
@@ -270,11 +420,34 @@ func NewSSLTestServer(t *testing.T) *TestServer {
 	if err != nil {
 		t.Fatal(err)
 	}
-	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
+
+	headerSize := 8
+	if protocol > protoVersion2 {
+		headerSize = 9
+	}
+
+	srv := &TestServer{
+		Address:    listen.Addr().String(),
+		listen:     listen,
+		t:          t,
+		protocol:   protocol,
+		headerSize: headerSize,
+	}
 	go srv.serve()
 	return srv
 }
 
+type TestServer struct {
+	Address  string
+	t        testing.TB
+	nreq     uint64
+	listen   net.Listener
+	nKillReq int64
+
+	protocol   uint8
+	headerSize int
+}
+
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
 	for {
@@ -285,9 +458,16 @@ func (srv *TestServer) serve() {
 		go func(conn net.Conn) {
 			defer conn.Close()
 			for {
-				frame := srv.readFrame(conn)
+				frame, err := srv.readFrame(conn)
+				if err == io.EOF {
+					return
+				} else if err != nil {
+					srv.t.Error(err)
+					continue
+				}
+
 				atomic.AddUint64(&srv.nreq, 1)
-				srv.process(frame, conn)
+				go srv.process(frame, conn)
 			}
 		}(conn)
 	}
@@ -297,65 +477,87 @@ func (srv *TestServer) Stop() {
 	srv.listen.Close()
 }
 
-func (srv *TestServer) process(frame frame, conn net.Conn) {
-	switch frame[3] {
+func (srv *TestServer) process(f frame, conn net.Conn) {
+	headerSize := headerProtoSize[srv.protocol]
+	stream := f.Stream(srv.protocol)
+
+	switch f.Op(srv.protocol) {
 	case opStartup:
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opReady)
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opReady)
+	case opOptions:
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opSupported)
+		f.writeShort(0)
 	case opQuery:
-		input := frame
-		input.skipHeader()
+		input := f
+		input.skipHeader(srv.protocol)
 		query := strings.TrimSpace(input.readLongString())
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opResult)
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opResult)
 		first := query
 		if n := strings.Index(query, " "); n > 0 {
 			first = first[:n]
 		}
 		switch strings.ToLower(first) {
 		case "kill":
-			atomic.AddUint64(&srv.nKillReq, 1)
-			select {}
+			atomic.AddInt64(&srv.nKillReq, 1)
+			f = f[:headerSize]
+			f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+			f.writeInt(0x1001)
+			f.writeString("query killed")
 		case "slow":
 			go func() {
 				<-time.After(1 * time.Second)
-				frame.writeInt(resultKindVoid)
-				frame.setLength(len(frame) - headerSize)
-				if _, err := conn.Write(frame); err != nil {
+				f.writeInt(resultKindVoid)
+				f.setLength(len(f)-headerSize, srv.protocol)
+				if _, err := conn.Write(f); err != nil {
 					return
 				}
 			}()
 			return
 		case "use":
-			frame.writeInt(3)
-			frame.writeString(strings.TrimSpace(query[3:]))
+			f.writeInt(3)
+			f.writeString(strings.TrimSpace(query[3:]))
 		case "void":
-			frame.writeInt(resultKindVoid)
+			f.writeInt(resultKindVoid)
 		default:
-			frame.writeInt(resultKindVoid)
+			f.writeInt(resultKindVoid)
 		}
 	default:
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opError)
-		frame.writeInt(0)
-		frame.writeString("not supported")
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+		f.writeInt(0)
+		f.writeString("not supported")
 	}
-	frame.setLength(len(frame) - headerSize)
-	if _, err := conn.Write(frame); err != nil {
+
+	f.setLength(len(f)-headerSize, srv.protocol)
+	if _, err := conn.Write(f); err != nil {
+		srv.t.Log(err)
 		return
 	}
 }
 
-func (srv *TestServer) readFrame(conn net.Conn) frame {
-	frame := make(frame, headerSize, headerSize+512)
+func (srv *TestServer) readFrame(conn net.Conn) (frame, error) {
+	frame := make(frame, srv.headerSize, srv.headerSize+512)
 	if _, err := io.ReadFull(conn, frame); err != nil {
-		srv.t.Fatal(err)
+		return nil, err
+	}
+
+	// should be a request frame
+	if frame[0]&protoDirectionMask != 0 {
+		return nil, fmt.Errorf("expected to read a request frame got version: 0x%x", frame[0])
 	}
-	if n := frame.Length(); n > 0 {
+	if v := frame[0] & protoVersionMask; v != srv.protocol {
+		return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, v)
+	}
+
+	if n := frame.Length(srv.protocol); n > 0 {
 		frame.grow(n)
-		if _, err := io.ReadFull(conn, frame[headerSize:]); err != nil {
-			srv.t.Fatal(err)
+		if _, err := io.ReadFull(conn, frame[srv.headerSize:]); err != nil {
+			return nil, err
 		}
 	}
-	return frame
+
+	return frame, nil
 }

+ 149 - 33
frame.go

@@ -5,12 +5,16 @@
 package gocql
 
 import (
+	"fmt"
 	"net"
 )
 
 const (
-	protoRequest  byte = 0x02
-	protoResponse byte = 0x82
+	protoDirectionMask = 0x80
+	protoVersionMask   = 0x7F
+	protoVersion1      = 0x01
+	protoVersion2      = 0x02
+	protoVersion3      = 0x03
 
 	opError         byte = 0x00
 	opStartup       byte = 0x01
@@ -42,13 +46,26 @@ const (
 	flagPageState   uint8 = 8
 	flagHasMore     uint8 = 2
 
-	headerSize = 8
-
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 )
 
+var headerProtoSize = [...]int{
+	protoVersion1: 8,
+	protoVersion2: 8,
+	protoVersion3: 9,
+}
+
+// TODO: replace with a struct which has a header and a body buffer,
+// header just has methods like, set/get the options in its backing array
+// then in a writeTo we write the header then the body.
 type frame []byte
 
+func newFrame(version uint8) frame {
+	// TODO: pool these at the session level incase anyone is using different
+	// clusters with different versions in the same application.
+	return make(frame, headerProtoSize[version], defaultFrameSize)
+}
+
 func (f *frame) writeInt(v int32) {
 	p := f.grow(4)
 	(*f)[p] = byte(v >> 24)
@@ -129,22 +146,67 @@ func (f *frame) writeStringMultimap(v map[string][]string) {
 	}
 }
 
-func (f *frame) setHeader(version, flags, stream, opcode uint8) {
+func (f *frame) setHeader(version, flags uint8, stream int, opcode uint8) {
 	(*f)[0] = version
 	(*f)[1] = flags
-	(*f)[2] = stream
-	(*f)[3] = opcode
+	p := 2
+	if version&maskVersion > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+		p += 2
+	} else {
+		(*f)[2] = byte(stream & 0xFF)
+		p++
+	}
+
+	(*f)[p] = opcode
 }
 
-func (f *frame) setLength(length int) {
-	(*f)[4] = byte(length >> 24)
-	(*f)[5] = byte(length >> 16)
-	(*f)[6] = byte(length >> 8)
-	(*f)[7] = byte(length)
+func (f *frame) setStream(stream int, version uint8) {
+	if version > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+	} else {
+		(*f)[2] = byte(stream)
+	}
 }
 
-func (f *frame) Length() int {
-	return int((*f)[4])<<24 | int((*f)[5])<<16 | int((*f)[6])<<8 | int((*f)[7])
+func (f *frame) Stream(version uint8) (n int) {
+	if version > protoVersion2 {
+		n = int((*f)[2])<<8 | int((*f)[3])
+	} else {
+		n = int((*f)[2])
+	}
+	return
+}
+
+func (f *frame) setLength(length int, version uint8) {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	(*f)[p] = byte(length >> 24)
+	(*f)[p+1] = byte(length >> 16)
+	(*f)[p+2] = byte(length >> 8)
+	(*f)[p+3] = byte(length)
+}
+
+func (f *frame) Op(version uint8) byte {
+	if version > protoVersion2 {
+		return (*f)[4]
+	} else {
+		return (*f)[3]
+	}
+}
+
+func (f *frame) Length(version uint8) int {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	return int((*f)[p])<<24 | int((*f)[p+1])<<16 | int((*f)[p+2])<<8 | int((*f)[p+3])
 }
 
 func (f *frame) grow(n int) int {
@@ -158,13 +220,13 @@ func (f *frame) grow(n int) int {
 	return p
 }
 
-func (f *frame) skipHeader() {
-	*f = (*f)[headerSize:]
+func (f *frame) skipHeader(version uint8) {
+	*f = (*f)[headerProtoSize[version]:]
 }
 
 func (f *frame) readInt() int {
 	if len(*f) < 4 {
-		panic(NewErrProtocol("Trying to read an int while >4 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read an int while <4 bytes in the buffer"))
 	}
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	*f = (*f)[4:]
@@ -173,7 +235,7 @@ func (f *frame) readInt() int {
 
 func (f *frame) readShort() uint16 {
 	if len(*f) < 2 {
-		panic(NewErrProtocol("Trying to read a short while >2 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read a short while <2 bytes in the buffer"))
 	}
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	*f = (*f)[2:]
@@ -223,9 +285,12 @@ func (f *frame) readShortBytes() []byte {
 	return v
 }
 
-func (f *frame) readTypeInfo() *TypeInfo {
+func (f *frame) readTypeInfo(version uint8) *TypeInfo {
 	x := f.readShort()
-	typ := &TypeInfo{Type: Type(x)}
+	typ := &TypeInfo{
+		Proto: version,
+		Type:  Type(x),
+	}
 	switch typ.Type {
 	case TypeCustom:
 		typ.Custom = f.readString()
@@ -233,34 +298,37 @@ func (f *frame) readTypeInfo() *TypeInfo {
 			typ = &TypeInfo{Type: cassType}
 			switch typ.Type {
 			case TypeMap:
-				typ.Key = f.readTypeInfo()
+				typ.Key = f.readTypeInfo(version)
 				fallthrough
 			case TypeList, TypeSet:
-				typ.Elem = f.readTypeInfo()
+				typ.Elem = f.readTypeInfo(version)
 			}
 		}
 	case TypeMap:
-		typ.Key = f.readTypeInfo()
+		typ.Key = f.readTypeInfo(version)
 		fallthrough
 	case TypeList, TypeSet:
-		typ.Elem = f.readTypeInfo()
+		typ.Elem = f.readTypeInfo(version)
 	}
 	return typ
 }
 
-func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
+func (f *frame) readMetaData(version uint8) ([]ColumnInfo, []byte) {
 	flags := f.readInt()
 	numColumns := f.readInt()
+
 	var pageState []byte
 	if flags&2 != 0 {
 		pageState = f.readBytes()
 	}
+
 	globalKeyspace := ""
 	globalTable := ""
 	if flags&1 != 0 {
 		globalKeyspace = f.readString()
 		globalTable = f.readString()
 	}
+
 	columns := make([]ColumnInfo, numColumns)
 	for i := 0; i < numColumns; i++ {
 		columns[i].Keyspace = globalKeyspace
@@ -270,7 +338,7 @@ func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
 			columns[i].Table = f.readString()
 		}
 		columns[i].Name = f.readString()
-		columns[i].TypeInfo = f.readTypeInfo()
+		columns[i].TypeInfo = f.readTypeInfo(version)
 	}
 	return columns, pageState
 }
@@ -381,19 +449,32 @@ type startupFrame struct {
 	Compression string
 }
 
+func (op *startupFrame) String() string {
+	return fmt.Sprintf("[startup cqlversion=%q compression=%q]", op.CQLVersion, op.Compression)
+}
+
 func (op *startupFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
+
 	f.setHeader(version, 0, 0, opStartup)
-	f.writeShort(1)
+
+	// TODO: fix this, this is actually a StringMap
+	var size uint16 = 1
+	if op.Compression != "" {
+		size++
+	}
+
+	f.writeShort(size)
 	f.writeString("CQL_VERSION")
 	f.writeString(op.CQLVersion)
+
 	if op.Compression != "" {
-		f[headerSize+1] += 1
 		f.writeString("COMPRESSION")
 		f.writeString(op.Compression)
 	}
+
 	return f, nil
 }
 
@@ -406,14 +487,20 @@ type queryFrame struct {
 	PageState []byte
 }
 
+func (op *queryFrame) String() string {
+	return fmt.Sprintf("[query statement=%q prepared=%x cons=%v ...]", op.Stmt, op.Prepared, op.Cons)
+}
+
 func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if version == 1 && (op.PageSize != 0 || len(op.PageState) > 0 ||
 		(len(op.Values) > 0 && len(op.Prepared) == 0)) {
 		return nil, ErrUnsupported
 	}
+
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
+
 	if len(op.Prepared) > 0 {
 		f.setHeader(version, 0, 0, opExecute)
 		f.writeShortBytes(op.Prepared)
@@ -421,10 +508,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		f.setHeader(version, 0, 0, opQuery)
 		f.writeLongString(op.Stmt)
 	}
+
 	if version >= 2 {
 		f.writeConsistency(op.Cons)
 		flagPos := len(f)
 		f.writeByte(0)
+
 		if len(op.Values) > 0 {
 			f[flagPos] |= flagQueryValues
 			f.writeShort(uint16(len(op.Values)))
@@ -432,10 +521,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 				f.writeBytes(value)
 			}
 		}
+
 		if op.PageSize > 0 {
 			f[flagPos] |= flagPageSize
 			f.writeInt(int32(op.PageSize))
 		}
+
 		if len(op.PageState) > 0 {
 			f[flagPos] |= flagPageState
 			f.writeBytes(op.PageState)
@@ -449,6 +540,7 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		}
 		f.writeConsistency(op.Cons)
 	}
+
 	return f, nil
 }
 
@@ -456,9 +548,13 @@ type prepareFrame struct {
 	Stmt string
 }
 
+func (op *prepareFrame) String() string {
+	return fmt.Sprintf("[prepare statement=%q]", op.Stmt)
+}
+
 func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	f.setHeader(version, 0, 0, opPrepare)
 	f.writeLongString(op.Stmt)
@@ -467,9 +563,13 @@ func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 
 type optionsFrame struct{}
 
+func (op *optionsFrame) String() string {
+	return "[options]"
+}
+
 func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	f.setHeader(version, 0, 0, opOptions)
 	return f, nil
@@ -479,13 +579,21 @@ type authenticateFrame struct {
 	Authenticator string
 }
 
+func (op *authenticateFrame) String() string {
+	return fmt.Sprintf("[authenticate authenticator=%q]", op.Authenticator)
+}
+
 type authResponseFrame struct {
 	Data []byte
 }
 
+func (op *authResponseFrame) String() string {
+	return fmt.Sprintf("[auth_response data=%q]", op.Data)
+}
+
 func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	f.setHeader(version, 0, 0, opAuthResponse)
 	f.writeBytes(op.Data)
@@ -496,6 +604,14 @@ type authSuccessFrame struct {
 	Data []byte
 }
 
+func (op *authSuccessFrame) String() string {
+	return fmt.Sprintf("[auth_success data=%q]", op.Data)
+}
+
 type authChallengeFrame struct {
 	Data []byte
 }
+
+func (op *authChallengeFrame) String() string {
+	return fmt.Sprintf("[auth_challenge data=%q]", op.Data)
+}

+ 2 - 0
integration.sh

@@ -19,6 +19,8 @@ function run_tests() {
 	local proto=2
 	if [[ $version == 1.2.* ]]; then
 		proto=1
+	elif [[ $version == 2.1.* ]]; then
+		proto=3
 	fi
 
 	go test -timeout 5m -tags integration -cover -v -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms ./... | tee results.txt

+ 61 - 30
marshal.go

@@ -40,6 +40,9 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 	if value == nil {
 		return nil, nil
 	}
+	if info.Proto < protoVersion1 {
+		panic("protocol version not set")
+	}
 
 	if v, ok := value.(Marshaler); ok {
 		return v.MarshalCQL(info)
@@ -814,6 +817,28 @@ func unmarshalTimestamp(info *TypeInfo, data []byte, value interface{}) error {
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }
 
+func writeCollectionSize(info *TypeInfo, n int, buf *bytes.Buffer) error {
+	if info.Proto > protoVersion2 {
+		if n > math.MaxInt32 {
+			return marshalErrorf("marshal: collection too large")
+		}
+
+		buf.WriteByte(byte(n >> 24))
+		buf.WriteByte(byte(n >> 16))
+		buf.WriteByte(byte(n >> 8))
+		buf.WriteByte(byte(n))
+	} else {
+		if n > math.MaxUint16 {
+			return marshalErrorf("marshal: collection too large")
+		}
+
+		buf.WriteByte(byte(n >> 8))
+		buf.WriteByte(byte(n))
+	}
+
+	return nil
+}
+
 func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 	rv := reflect.ValueOf(value)
 	t := rv.Type()
@@ -825,21 +850,19 @@ func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 		}
 		buf := &bytes.Buffer{}
 		n := rv.Len()
-		if n > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array too large")
+
+		if err := writeCollectionSize(info, n, buf); err != nil {
+			return nil, err
 		}
-		buf.WriteByte(byte(n >> 8))
-		buf.WriteByte(byte(n))
+
 		for i := 0; i < n; i++ {
 			item, err := Marshal(info.Elem, rv.Index(i).Interface())
 			if err != nil {
 				return nil, err
 			}
-			if len(item) > math.MaxUint16 {
-				return nil, marshalErrorf("marshal: slice / array item too large")
+			if err := writeCollectionSize(info, len(item), buf); err != nil {
+				return nil, err
 			}
-			buf.WriteByte(byte(len(item) >> 8))
-			buf.WriteByte(byte(len(item)))
 			buf.Write(item)
 		}
 		return buf.Bytes(), nil
@@ -858,6 +881,17 @@ func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 
+func readCollectionSize(info *TypeInfo, data []byte) (size, read int) {
+	if info.Proto > protoVersion2 {
+		size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
+		read = 4
+	} else {
+		size = int(data[0])<<8 | int(data[1])
+		read = 2
+	}
+	return
+}
+
 func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 	rv := reflect.ValueOf(value)
 	if rv.Kind() != reflect.Ptr {
@@ -879,8 +913,8 @@ func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 		if len(data) < 2 {
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 		}
-		n := int(data[0])<<8 | int(data[1])
-		data = data[2:]
+		n, p := readCollectionSize(info, data)
+		data = data[p:]
 		if k == reflect.Array {
 			if rv.Len() != n {
 				return unmarshalErrorf("unmarshal list: array with wrong size")
@@ -894,8 +928,8 @@ func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 			if len(data) < 2 {
 				return unmarshalErrorf("unmarshal list: unexpected eof")
 			}
-			m := int(data[0])<<8 | int(data[1])
-			data = data[2:]
+			m, p := readCollectionSize(info, data)
+			data = data[p:]
 			if err := Unmarshal(info.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 				return err
 			}
@@ -917,33 +951,29 @@ func marshalMap(info *TypeInfo, value interface{}) ([]byte, error) {
 	}
 	buf := &bytes.Buffer{}
 	n := rv.Len()
-	if n > math.MaxUint16 {
-		return nil, marshalErrorf("marshal: map too large")
+
+	if err := writeCollectionSize(info, n, buf); err != nil {
+		return nil, err
 	}
-	buf.WriteByte(byte(n >> 8))
-	buf.WriteByte(byte(n))
+
 	keys := rv.MapKeys()
 	for _, key := range keys {
 		item, err := Marshal(info.Key, key.Interface())
 		if err != nil {
 			return nil, err
 		}
-		if len(item) > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array item too large")
+		if err := writeCollectionSize(info, len(item), buf); err != nil {
+			return nil, err
 		}
-		buf.WriteByte(byte(len(item) >> 8))
-		buf.WriteByte(byte(len(item)))
 		buf.Write(item)
 
 		item, err = Marshal(info.Elem, rv.MapIndex(key).Interface())
 		if err != nil {
 			return nil, err
 		}
-		if len(item) > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array item too large")
+		if err := writeCollectionSize(info, len(item), buf); err != nil {
+			return nil, err
 		}
-		buf.WriteByte(byte(len(item) >> 8))
-		buf.WriteByte(byte(len(item)))
 		buf.Write(item)
 	}
 	return buf.Bytes(), nil
@@ -967,22 +997,22 @@ func unmarshalMap(info *TypeInfo, data []byte, value interface{}) error {
 	if len(data) < 2 {
 		return unmarshalErrorf("unmarshal map: unexpected eof")
 	}
-	n := int(data[1]) | int(data[0])<<8
-	data = data[2:]
+	n, p := readCollectionSize(info, data)
+	data = data[p:]
 	for i := 0; i < n; i++ {
 		if len(data) < 2 {
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 		}
-		m := int(data[1]) | int(data[0])<<8
-		data = data[2:]
+		m, p := readCollectionSize(info, data)
+		data = data[p:]
 		key := reflect.New(t.Key())
 		if err := Unmarshal(info.Key, data[:m], key.Interface()); err != nil {
 			return err
 		}
 		data = data[m:]
 
-		m = int(data[1]) | int(data[0])<<8
-		data = data[2:]
+		m, p = readCollectionSize(info, data)
+		data = data[p:]
 		val := reflect.New(t.Elem())
 		if err := Unmarshal(info.Elem, data[:m], val.Interface()); err != nil {
 			return err
@@ -1120,6 +1150,7 @@ func unmarshalInet(info *TypeInfo, data []byte, value interface{}) error {
 
 // TypeInfo describes a Cassandra specific data type.
 type TypeInfo struct {
+	Proto  byte // version of the protocol
 	Type   Type
 	Key    *TypeInfo // only used for TypeMap
 	Elem   *TypeInfo // only used for TypeMap, TypeList and TypeSet

+ 97 - 97
marshal_test.go

@@ -21,42 +21,42 @@ var marshalTests = []struct {
 	Value interface{}
 }{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		[]byte("hello world"),
 	},
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		"hello world",
 	},
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte(nil),
 		[]byte(nil),
 	},
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		MyString("hello world"),
 	},
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("HELLO WORLD"),
 		CustomString("hello world"),
 	},
 	{
-		&TypeInfo{Type: TypeBlob},
+		&TypeInfo{Proto: 2, Type: TypeBlob},
 		[]byte("hello\x00"),
 		[]byte("hello\x00"),
 	},
 	{
-		&TypeInfo{Type: TypeBlob},
+		&TypeInfo{Proto: 2, Type: TypeBlob},
 		[]byte(nil),
 		[]byte(nil),
 	},
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		func() UUID {
 			x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0})
@@ -64,217 +64,217 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x00\x00\x00\x00"),
 		0,
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x01\x02\x03\x04"),
 		int(16909060),
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x80\x00\x00\x00"),
 		int32(math.MinInt32),
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		int32(math.MaxInt32),
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x00\x00\x00\x00"),
 		"0",
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x01\x02\x03\x04"),
 		"16909060",
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x80\x00\x00\x00"),
 		"-2147483648", // math.MinInt32
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		"2147483647", // math.MaxInt32
 	},
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		0,
 	},
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		72623859790382856,
 	},
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		int64(math.MinInt64),
 	},
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		int64(math.MaxInt64),
 	},
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		"0",
 	},
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		"72623859790382856",
 	},
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		"-9223372036854775808", // math.MinInt64
 	},
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		"9223372036854775807", // math.MaxInt64
 	},
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x00"),
 		false,
 	},
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x01"),
 		true,
 	},
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte("\x40\x49\x0f\xdb"),
 		float32(3.14159265),
 	},
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		float64(3.14159265),
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x00\x00"),
 		inf.NewDec(0, 0),
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x00\x64"),
 		inf.NewDec(100, 0),
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x02\x19"),
 		decimalize("0.25"),
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x13\xD5\a;\x20\x14\xA2\x91"),
 		decimalize("-0.0012095473475870063"), // From the iconara/cql-rb test suite
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x13*\xF8\xC4\xDF\xEB]o"),
 		decimalize("0.0012095473475870063"), // From the iconara/cql-rb test suite
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x12\xF2\xD8\x02\xB6R\x7F\x99\xEE\x98#\x99\xA9V"),
 		decimalize("-1042342234234.123423435647768234"), // From the iconara/cql-rb test suite
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\r\nJ\x04\"^\x91\x04\x8a\xb1\x18\xfe"),
 		decimalize("1243878957943.1234124191998"), // From the datastax/python-driver test suite
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x06\xe5\xde]\x98Y"),
 		decimalize("-112233.441191"), // From the datastax/python-driver test suite
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x14\x00\xfa\xce"),
 		decimalize("0.00000000000000064206"), // From the datastax/python-driver test suite
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x14\xff\x052"),
 		decimalize("-0.00000000000000064206"), // From the datastax/python-driver test suite
 	},
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\xff\xff\xff\x9c\x00\xfa\xce"),
 		inf.NewDec(64206, -100), // From the datastax/python-driver test suite
 	},
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC),
 	},
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		int64(1376387523000),
 	},
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]int{1, 2},
 	},
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[2]int{1, 2},
 	},
 	{
-		&TypeInfo{Type: TypeSet, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeSet, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]int{1, 2},
 	},
 	{
-		&TypeInfo{Type: TypeSet, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeSet, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte(nil),
 		[]int(nil),
 	},
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		map[string]int{"foo": 1},
 	},
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		[]byte(nil),
 		map[string]int(nil),
 	},
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeVarchar}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeVarchar}},
 		bytes.Join([][]byte{
 			[]byte("\x00\x01\xFF\xFF"),
 			bytes.Repeat([]byte("X"), 65535)}, []byte("")),
 		[]string{strings.Repeat("X", 65535)},
 	},
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeVarchar},
 		},
 		bytes.Join([][]byte{
 			[]byte("\x00\x01\xFF\xFF"),
@@ -286,82 +286,82 @@ var marshalTests = []struct {
 		},
 	},
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x00"),
 		0,
 	},
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x37\xE2\x3C\xEC"),
 		int32(937573612),
 	},
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x37\xE2\x3C\xEC"),
 		big.NewInt(937573612),
 	},
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x03\x9EV \x15\f\x03\x9DK\x18\xCDI\\$?\a["),
 		bigintize("1231312312331283012830129382342342412123"), // From the iconara/cql-rb test suite
 	},
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\xC9v\x8D:\x86"),
 		big.NewInt(-234234234234), // From the iconara/cql-rb test suite
 	},
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
 		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		net.ParseIP("127.0.0.1").To4(),
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xFF\xFF\xFF\xFF"),
 		net.ParseIP("255.255.255.255").To4(),
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		"127.0.0.1",
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xFF\xFF\xFF\xFF"),
 		"255.255.255.255",
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		"21da:d3:0:2f3b:2aa:ff:fe28:9c5a",
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		"fe80::202:b3ff:fe1e:8329",
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"),
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		net.ParseIP("fe80::202:b3ff:fe1e:8329"),
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte(nil),
 		nil,
 	},
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("nullable string"),
 		func() *string {
 			value := "nullable string"
@@ -369,12 +369,12 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte{},
 		(*string)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		func() *int {
 			var value int = math.MaxInt32
@@ -382,22 +382,22 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte(nil),
 		(*int)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		&UUID{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 	},
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{},
 		(*UUID)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		func() *time.Time {
 			t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC)
@@ -405,12 +405,12 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte(nil),
 		(*time.Time)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x00"),
 		func() *bool {
 			b := false
@@ -418,7 +418,7 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x01"),
 		func() *bool {
 			b := true
@@ -426,12 +426,12 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte(nil),
 		(*bool)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte("\x40\x49\x0f\xdb"),
 		func() *float32 {
 			f := float32(3.14159265)
@@ -439,12 +439,12 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte(nil),
 		(*float32)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		func() *float64 {
 			d := float64(3.14159265)
@@ -452,12 +452,12 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte(nil),
 		(*float64)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		func() *net.IP {
 			ip := net.ParseIP("127.0.0.1").To4()
@@ -465,12 +465,12 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte(nil),
 		(*net.IP)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		func() *[]int {
 			l := []int{1, 2}
@@ -478,14 +478,14 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte(nil),
 		(*[]int)(nil),
 	},
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		func() *map[string]int {
@@ -494,9 +494,9 @@ var marshalTests = []struct {
 		}(),
 	},
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		[]byte(nil),
 		(*map[string]int)(nil),
@@ -610,7 +610,7 @@ func TestMarshalVarint(t *testing.T) {
 	}
 
 	for i, test := range varintTests {
-		data, err := Marshal(&TypeInfo{Type: TypeVarint}, test.Value)
+		data, err := Marshal(&TypeInfo{Proto: 2, Type: TypeVarint}, test.Value)
 		if err != nil {
 			t.Errorf("error marshaling varint: %v (test #%d)", err, i)
 		}
@@ -620,7 +620,7 @@ func TestMarshalVarint(t *testing.T) {
 		}
 
 		binder := new(big.Int)
-		err = Unmarshal(&TypeInfo{Type: TypeVarint}, test.Marshaled, binder)
+		err = Unmarshal(&TypeInfo{Proto: 2, Type: TypeVarint}, test.Marshaled, binder)
 		if err != nil {
 			t.Errorf("error unmarshaling varint: %v (test #%d)", err, i)
 		}