Browse Source

Merge pull request #318 from Zariel/support-protocol-v3

Support protocol v3 wire format
Ben Hood 10 years ago
parent
commit
66e4489a53
7 changed files with 715 additions and 300 deletions
  1. 11 2
      cluster.go
  2. 108 53
      conn.go
  3. 287 85
      conn_test.go
  4. 149 33
      frame.go
  5. 2 0
      integration.sh
  6. 61 30
      marshal.go
  7. 97 97
      marshal_test.go

+ 11 - 2
cluster.go

@@ -62,7 +62,7 @@ type ClusterConfig struct {
 	Port             int           // port (default: 9042)
 	Port             int           // port (default: 9042)
 	Keyspace         string        // initial keyspace (optional)
 	Keyspace         string        // initial keyspace (optional)
 	NumConns         int           // number of connections per host (default: 2)
 	NumConns         int           // number of connections per host (default: 2)
-	NumStreams       int           // number of streams per connection (default: 128)
+	NumStreams       int           // number of streams per connection (default: max per protocol, either 128 or 32768)
 	Consistency      Consistency   // default consistency level (default: Quorum)
 	Consistency      Consistency   // default consistency level (default: Quorum)
 	Compressor       Compressor    // compression algorithm (default: nil)
 	Compressor       Compressor    // compression algorithm (default: nil)
 	Authenticator    Authenticator // authenticator (default: nil)
 	Authenticator    Authenticator // authenticator (default: nil)
@@ -85,7 +85,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Timeout:          600 * time.Millisecond,
 		Timeout:          600 * time.Millisecond,
 		Port:             9042,
 		Port:             9042,
 		NumConns:         2,
 		NumConns:         2,
-		NumStreams:       128,
 		Consistency:      Quorum,
 		Consistency:      Quorum,
 		ConnPoolType:     NewSimplePool,
 		ConnPoolType:     NewSimplePool,
 		DiscoverHosts:    false,
 		DiscoverHosts:    false,
@@ -102,6 +101,16 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	if len(cfg.Hosts) < 1 {
 	if len(cfg.Hosts) < 1 {
 		return nil, ErrNoHosts
 		return nil, ErrNoHosts
 	}
 	}
+
+	maxStreams := 128
+	if cfg.ProtoVersion > protoVersion2 {
+		maxStreams = 32768
+	}
+
+	if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
+		cfg.NumStreams = maxStreams
+	}
+
 	pool := cfg.ConnPoolType(cfg)
 	pool := cfg.ConnPoolType(cfg)
 
 
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	//Adjust the size of the prepared statements cache to match the latest configuration

+ 108 - 53
conn.go

@@ -10,7 +10,9 @@ import (
 	"crypto/x509"
 	"crypto/x509"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"io/ioutil"
+	"log"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -19,9 +21,11 @@ import (
 	"time"
 	"time"
 )
 )
 
 
-const defaultFrameSize = 4096
-const flagResponse = 0x80
-const maskVersion = 0x7F
+const (
+	defaultFrameSize = 4096
+	flagResponse     = 0x80
+	maskVersion      = 0x7F
+)
 
 
 //JoinHostPort is a utility to return a address string that can be used
 //JoinHostPort is a utility to return a address string that can be used
 //gocql.Conn to form a connection with a host.
 //gocql.Conn to form a connection with a host.
@@ -88,7 +92,7 @@ type Conn struct {
 	r       *bufio.Reader
 	r       *bufio.Reader
 	timeout time.Duration
 	timeout time.Duration
 
 
-	uniq  chan uint8
+	uniq  chan int
 	calls []callReq
 	calls []callReq
 	nwait int32
 	nwait int32
 
 
@@ -123,14 +127,17 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 				return nil, errors.New("Failed parsing or appending certs")
 				return nil, errors.New("Failed parsing or appending certs")
 			}
 			}
 		}
 		}
+
 		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
 		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
+
 		config := tls.Config{
 		config := tls.Config{
 			Certificates: []tls.Certificate{mycert},
 			Certificates: []tls.Certificate{mycert},
 			RootCAs:      certPool,
 			RootCAs:      certPool,
 		}
 		}
+
 		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
 		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
 		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
 		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
 			return nil, err
 			return nil, err
@@ -139,16 +146,25 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
-		cfg.NumStreams = 128
-	}
-	if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
+	// going to default to proto 2
+	if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
+		log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
 		cfg.ProtoVersion = 2
 		cfg.ProtoVersion = 2
 	}
 	}
+
+	maxStreams := 128
+	if cfg.ProtoVersion > protoVersion2 {
+		maxStreams = 32768
+	}
+
+	if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
+		cfg.NumStreams = maxStreams
+	}
+
 	c := &Conn{
 	c := &Conn{
 		conn:       conn,
 		conn:       conn,
 		r:          bufio.NewReader(conn),
 		r:          bufio.NewReader(conn),
-		uniq:       make(chan uint8, cfg.NumStreams),
+		uniq:       make(chan int, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
 		timeout:    cfg.Timeout,
 		timeout:    cfg.Timeout,
 		version:    uint8(cfg.ProtoVersion),
 		version:    uint8(cfg.ProtoVersion),
@@ -162,8 +178,8 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		c.setKeepalive(cfg.Keepalive)
 		c.setKeepalive(cfg.Keepalive)
 	}
 	}
 
 
-	for i := 0; i < cap(c.uniq); i++ {
-		c.uniq <- uint8(i)
+	for i := 0; i < cfg.NumStreams; i++ {
+		c.uniq <- i
 	}
 	}
 
 
 	if err := c.startup(&cfg); err != nil {
 	if err := c.startup(&cfg); err != nil {
@@ -254,53 +270,80 @@ func (c *Conn) serve() {
 	c.pool.HandleError(c, err, true)
 	c.pool.HandleError(c, err, true)
 }
 }
 
 
+func (c *Conn) Write(p []byte) (int, error) {
+	c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
+	return c.conn.Write(p)
+}
+
+func (c *Conn) Read(p []byte) (int, error) {
+	return c.r.Read(p)
+}
+
 func (c *Conn) recv() (frame, error) {
 func (c *Conn) recv() (frame, error) {
-	resp := make(frame, headerSize, headerSize+512)
-	c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-	n, last, pinged := 0, 0, false
-	for n < len(resp) {
-		nn, err := c.r.Read(resp[n:])
-		n += nn
-		if err != nil {
-			if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
-				if n > last {
-					// we hit the deadline but we made progress.
-					// simply extend the deadline
-					c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-					last = n
-				} else if n == 0 && !pinged {
-					c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-					if atomic.LoadInt32(&c.nwait) > 0 {
-						go c.ping()
-						pinged = true
-					}
-				} else {
-					return nil, err
-				}
-			} else {
-				return nil, err
-			}
+	size := headerProtoSize[c.version]
+	resp := make(frame, size, size+512)
+
+	// read a full header, ignore timeouts, as this is being ran in a loop
+	c.conn.SetReadDeadline(time.Time{})
+	_, err := io.ReadFull(c.r, resp[:size])
+	if err != nil {
+		return nil, err
+	}
+
+	if v := c.version | flagResponse; resp[0] != v {
+		return nil, NewErrProtocol("recv: response protocol version does not match connection protocol version (%d != %d)", resp[0], v)
+	}
+
+	bodySize := resp.Length(c.version)
+	if bodySize == 0 {
+		return resp, nil
+	}
+	resp.grow(bodySize)
+
+	const maxAttempts = 5
+
+	n := size
+	for i := 0; i < maxAttempts; i++ {
+		var nn int
+		c.conn.SetReadDeadline(time.Now().Add(c.timeout))
+		nn, err = io.ReadFull(c.r, resp[n:size+bodySize])
+		if err == nil {
+			break
 		}
 		}
-		if n == headerSize && len(resp) == headerSize {
-			if resp[0] != c.version|flagResponse {
-				return nil, NewErrProtocol("recv: Response protocol version does not match connection protocol version (%d != %d)", resp[0], c.version|flagResponse)
-			}
-			resp.grow(resp.Length())
+		n += nn
+
+		if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
+			break
 		}
 		}
 	}
 	}
+
+	if err != nil {
+		return nil, err
+	}
+
 	return resp, nil
 	return resp, nil
 }
 }
 
 
 func (c *Conn) execSimple(op operation) (interface{}, error) {
 func (c *Conn) execSimple(op operation) (interface{}, error) {
 	f, err := op.encodeFrame(c.version, nil)
 	f, err := op.encodeFrame(c.version, nil)
-	f.setLength(len(f) - headerSize)
-	if _, err := c.conn.Write([]byte(f)); err != nil {
+	if err != nil {
+		// this should be a noop err
+		return nil, err
+	}
+
+	bodyLen := len(f) - headerProtoSize[c.version]
+	f.setLength(bodyLen, c.version)
+
+	if _, err := c.Write([]byte(f)); err != nil {
 		c.Close()
 		c.Close()
 		return nil, err
 		return nil, err
 	}
 	}
+
+	// here recv wont timeout waiting for a header, should it?
 	if f, err = c.recv(); err != nil {
 	if f, err = c.recv(); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+
 	return c.decodeFrame(f, nil)
 	return c.decodeFrame(f, nil)
 }
 }
 
 
@@ -309,9 +352,12 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+
 	if trace != nil {
 	if trace != nil {
 		req[1] |= flagTrace
 		req[1] |= flagTrace
 	}
 	}
+
+	headerSize := headerProtoSize[c.version]
 	if len(req) > headerSize && c.compressor != nil {
 	if len(req) > headerSize && c.compressor != nil {
 		body, err := c.compressor.Encode([]byte(req[headerSize:]))
 		body, err := c.compressor.Encode([]byte(req[headerSize:]))
 		if err != nil {
 		if err != nil {
@@ -320,16 +366,17 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 		req = append(req[:headerSize], frame(body)...)
 		req = append(req[:headerSize], frame(body)...)
 		req[1] |= flagCompress
 		req[1] |= flagCompress
 	}
 	}
-	req.setLength(len(req) - headerSize)
+	bodyLen := len(req) - headerSize
+	req.setLength(bodyLen, c.version)
 
 
 	id := <-c.uniq
 	id := <-c.uniq
-	req[2] = id
+	req.setStream(id, c.version)
 	call := &c.calls[id]
 	call := &c.calls[id]
 	call.resp = make(chan callResp, 1)
 	call.resp = make(chan callResp, 1)
 	atomic.AddInt32(&c.nwait, 1)
 	atomic.AddInt32(&c.nwait, 1)
 	atomic.StoreInt32(&call.active, 1)
 	atomic.StoreInt32(&call.active, 1)
 
 
-	if _, err := c.conn.Write(req); err != nil {
+	if _, err := c.Write(req); err != nil {
 		c.uniq <- id
 		c.uniq <- id
 		c.Close()
 		c.Close()
 		return nil, err
 		return nil, err
@@ -342,11 +389,12 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 	if reply.err != nil {
 	if reply.err != nil {
 		return nil, reply.err
 		return nil, reply.err
 	}
 	}
+
 	return c.decodeFrame(reply.buf, trace)
 	return c.decodeFrame(reply.buf, trace)
 }
 }
 
 
 func (c *Conn) dispatch(resp frame) {
 func (c *Conn) dispatch(resp frame) {
-	id := int(resp[2])
+	id := resp.Stream(c.version)
 	if id >= len(c.calls) {
 	if id >= len(c.calls) {
 		return
 		return
 	}
 	}
@@ -543,10 +591,10 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 }
 }
 
 
 func (c *Conn) executeBatch(batch *Batch) error {
 func (c *Conn) executeBatch(batch *Batch) error {
-	if c.version == 1 {
+	if c.version == protoVersion1 {
 		return ErrUnsupported
 		return ErrUnsupported
 	}
 	}
-	f := make(frame, headerSize, defaultFrameSize)
+	f := newFrame(c.version)
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.writeByte(byte(batch.Type))
 	f.writeByte(byte(batch.Type))
 	f.writeShort(uint16(len(batch.Entries)))
 	f.writeShort(uint16(len(batch.Entries)))
@@ -594,6 +642,10 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		}
 		}
 	}
 	}
 	f.writeConsistency(batch.Cons)
 	f.writeConsistency(batch.Cons)
+	if c.version >= protoVersion3 {
+		// TODO: add support for flags here
+		f.writeByte(0)
+	}
 
 
 	resp, err := c.exec(f, nil)
 	resp, err := c.exec(f, nil)
 	if err != nil {
 	if err != nil {
@@ -631,12 +683,15 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 			panic(r)
 			panic(r)
 		}
 		}
 	}()
 	}()
+
+	headerSize := headerProtoSize[c.version]
 	if len(f) < headerSize {
 	if len(f) < headerSize {
 		return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
 		return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
 	} else if f[0] != c.version|flagResponse {
 	} else if f[0] != c.version|flagResponse {
 		return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
 		return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
 	}
 	}
-	flags, op, f := f[1], f[3], f[headerSize:]
+
+	flags, op, f := f[1], f.Op(c.version), f[headerSize:]
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
 		if buf, err := c.compressor.Decode([]byte(f)); err != nil {
 		if buf, err := c.compressor.Decode([]byte(f)); err != nil {
 			return nil, err
 			return nil, err
@@ -661,7 +716,7 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 		case resultKindVoid:
 		case resultKindVoid:
 			return resultVoidFrame{}, nil
 			return resultVoidFrame{}, nil
 		case resultKindRows:
 		case resultKindRows:
-			columns, pageState := f.readMetaData()
+			columns, pageState := f.readMetaData(c.version)
 			numRows := f.readInt()
 			numRows := f.readInt()
 			values := make([][]byte, numRows*len(columns))
 			values := make([][]byte, numRows*len(columns))
 			for i := 0; i < len(values); i++ {
 			for i := 0; i < len(values); i++ {
@@ -677,11 +732,11 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 			return resultKeyspaceFrame{keyspace}, nil
 			return resultKeyspaceFrame{keyspace}, nil
 		case resultKindPrepared:
 		case resultKindPrepared:
 			id := f.readShortBytes()
 			id := f.readShortBytes()
-			args, _ := f.readMetaData()
+			args, _ := f.readMetaData(c.version)
 			if c.version < 2 {
 			if c.version < 2 {
 				return resultPreparedFrame{PreparedId: id, Arguments: args}, nil
 				return resultPreparedFrame{PreparedId: id, Arguments: args}, nil
 			}
 			}
-			rvals, _ := f.readMetaData()
+			rvals, _ := f.readMetaData(c.version)
 			return resultPreparedFrame{PreparedId: id, Arguments: args, ReturnValues: rvals}, nil
 			return resultPreparedFrame{PreparedId: id, Arguments: args, ReturnValues: rvals}, nil
 		case resultKindSchemaChanged:
 		case resultKindSchemaChanged:
 			return resultVoidFrame{}, nil
 			return resultVoidFrame{}, nil

+ 287 - 85
conn_test.go

@@ -5,6 +5,7 @@ package gocql
 import (
 import (
 	"crypto/tls"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509"
+	"fmt"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
@@ -15,6 +16,10 @@ import (
 	"time"
 	"time"
 )
 )
 
 
+const (
+	defaultProto = protoVersion2
+)
+
 func TestJoinHostPort(t *testing.T) {
 func TestJoinHostPort(t *testing.T) {
 	tests := map[string]string{
 	tests := map[string]string{
 		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
 		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
@@ -29,43 +34,38 @@ func TestJoinHostPort(t *testing.T) {
 	}
 	}
 }
 }
 
 
-type TestServer struct {
-	Address  string
-	t        *testing.T
-	nreq     uint64
-	listen   net.Listener
-	nKillReq uint64
-}
-
 func TestSimple(t *testing.T) {
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+	db, err := cluster.CreateSession()
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
+		return
 	}
 	}
 
 
 	if err := db.Query("void").Exec(); err != nil {
 	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
+		t.Errorf("0x%x: %v", defaultProto, err)
 	}
 	}
 }
 }
 
 
 func TestSSLSimple(t *testing.T) {
 func TestSSLSimple(t *testing.T) {
-	srv := NewSSLTestServer(t)
+	srv := NewSSLTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := createTestSslCluster(srv.Address).CreateSession()
+	db, err := createTestSslCluster(srv.Address, defaultProto).CreateSession()
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
 	}
 	}
 
 
 	if err := db.Query("void").Exec(); err != nil {
 	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
+		t.Fatalf("0x%x: %v", defaultProto, err)
 	}
 	}
 }
 }
 
 
-func createTestSslCluster(hosts string) *ClusterConfig {
+func createTestSslCluster(hosts string, proto uint8) *ClusterConfig {
 	cluster := NewCluster(hosts)
 	cluster := NewCluster(hosts)
 	cluster.SslOpts = &SslOptions{
 	cluster.SslOpts = &SslOptions{
 		CertPath:               "testdata/pki/gocql.crt",
 		CertPath:               "testdata/pki/gocql.crt",
@@ -73,82 +73,103 @@ func createTestSslCluster(hosts string) *ClusterConfig {
 		CaPath:                 "testdata/pki/ca.crt",
 		CaPath:                 "testdata/pki/ca.crt",
 		EnableHostVerification: false,
 		EnableHostVerification: false,
 	}
 	}
+	cluster.ProtoVersion = int(proto)
 	return cluster
 	return cluster
 }
 }
 
 
 func TestClosed(t *testing.T) {
 func TestClosed(t *testing.T) {
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	session, err := NewCluster(srv.Address).CreateSession()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+
+	session, err := cluster.CreateSession()
+	defer session.Close()
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
+		return
 	}
 	}
-	session.Close()
 
 
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
-		t.Errorf("expected %#v, got %#v", ErrSessionClosed, err)
+		t.Errorf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
+		return
 	}
 	}
 }
 }
 
 
+func newTestSession(addr string, proto uint8) (*Session, error) {
+	cluster := NewCluster(addr)
+	cluster.ProtoVersion = int(proto)
+	return cluster.CreateSession()
+}
+
 func TestTimeout(t *testing.T) {
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 	if err != nil {
 		t.Errorf("NewCluster: %v", err)
 		t.Errorf("NewCluster: %v", err)
+		return
 	}
 	}
+	defer db.Close()
 
 
 	go func() {
 	go func() {
 		<-time.After(2 * time.Second)
 		<-time.After(2 * time.Second)
-		t.Fatal("no timeout")
+		t.Errorf("no timeout")
 	}()
 	}()
 
 
 	if err := db.Query("kill").Exec(); err == nil {
 	if err := db.Query("kill").Exec(); err == nil {
-		t.Fatal("expected error")
+		t.Errorf("expected error")
 	}
 	}
 }
 }
 
 
 // TestQueryRetry will test to make sure that gocql will execute
 // TestQueryRetry will test to make sure that gocql will execute
 // the exact amount of retry queries designated by the user.
 // the exact amount of retry queries designated by the user.
 func TestQueryRetry(t *testing.T) {
 func TestQueryRetry(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("NewCluster: %v", err)
 	}
 	}
+	defer db.Close()
 
 
 	go func() {
 	go func() {
 		<-time.After(5 * time.Second)
 		<-time.After(5 * time.Second)
-		t.Fatal("no timeout")
+		t.Fatalf("no timeout")
 	}()
 	}()
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 
 
 	qry := db.Query("kill").RetryPolicy(rt)
 	qry := db.Query("kill").RetryPolicy(rt)
 	if err := qry.Exec(); err == nil {
 	if err := qry.Exec(); err == nil {
-		t.Fatal("expected error")
+		t.Fatalf("expected error")
 	}
 	}
-	requests := srv.nKillReq
-	if requests != uint64(qry.Attempts()) {
-		t.Fatalf("expected requests %v to match query attemps %v", requests, qry.Attempts())
+
+	requests := atomic.LoadInt64(&srv.nKillReq)
+	attempts := qry.Attempts()
+	if requests != int64(attempts) {
+		t.Fatalf("expected requests %v to match query attemps %v", requests, attempts)
 	}
 	}
+
 	//Minus 1 from the requests variable since there is the initial query attempt
 	//Minus 1 from the requests variable since there is the initial query attempt
-	if requests-1 != uint64(rt.NumRetries) {
+	if requests-1 != int64(rt.NumRetries) {
 		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
 		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
 	}
 	}
 }
 }
 
 
 func TestSlowQuery(t *testing.T) {
 func TestSlowQuery(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 	if err != nil {
 		t.Errorf("NewCluster: %v", err)
 		t.Errorf("NewCluster: %v", err)
+		return
 	}
 	}
 
 
 	if err := db.Query("slow").Exec(); err != nil {
 	if err := db.Query("slow").Exec(); err != nil {
@@ -159,22 +180,24 @@ func TestSlowQuery(t *testing.T) {
 func TestRoundRobin(t *testing.T) {
 func TestRoundRobin(t *testing.T) {
 	servers := make([]*TestServer, 5)
 	servers := make([]*TestServer, 5)
 	addrs := make([]string, len(servers))
 	addrs := make([]string, len(servers))
-	for i := 0; i < len(servers); i++ {
-		servers[i] = NewTestServer(t)
-		addrs[i] = servers[i].Address
-		defer servers[i].Stop()
+	for n := 0; n < len(servers); n++ {
+		servers[n] = NewTestServer(t, defaultProto)
+		addrs[n] = servers[n].Address
+		defer servers[n].Stop()
 	}
 	}
 	cluster := NewCluster(addrs...)
 	cluster := NewCluster(addrs...)
+	cluster.ProtoVersion = defaultProto
+
 	db, err := cluster.CreateSession()
 	db, err := cluster.CreateSession()
-	time.Sleep(1 * time.Second) //Sleep to allow the Cluster.fillPool to complete
+	time.Sleep(1 * time.Second) // Sleep to allow the Cluster.fillPool to complete
 
 
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("NewCluster: %v", err)
 	}
 	}
 
 
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	wg.Add(5)
 	wg.Add(5)
-	for i := 0; i < 5; i++ {
+	for n := 0; n < 5; n++ {
 		go func() {
 		go func() {
 			for j := 0; j < 5; j++ {
 			for j := 0; j < 5; j++ {
 				if err := db.Query("void").Exec(); err != nil {
 				if err := db.Query("void").Exec(); err != nil {
@@ -187,12 +210,12 @@ func TestRoundRobin(t *testing.T) {
 	wg.Wait()
 	wg.Wait()
 
 
 	diff := 0
 	diff := 0
-	for i := 1; i < len(servers); i++ {
+	for n := 1; n < len(servers); n++ {
 		d := 0
 		d := 0
-		if servers[i].nreq > servers[i-1].nreq {
-			d = int(servers[i].nreq - servers[i-1].nreq)
+		if servers[n].nreq > servers[n-1].nreq {
+			d = int(servers[n].nreq - servers[n-1].nreq)
 		} else {
 		} else {
-			d = int(servers[i-1].nreq - servers[i].nreq)
+			d = int(servers[n-1].nreq - servers[n].nreq)
 		}
 		}
 		if d > diff {
 		if d > diff {
 			diff = d
 			diff = d
@@ -206,7 +229,8 @@ func TestRoundRobin(t *testing.T) {
 
 
 func TestConnClosing(t *testing.T) {
 func TestConnClosing(t *testing.T) {
 	t.Skip("Skipping until test can be ran reliably")
 	t.Skip("Skipping until test can be ran reliably")
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, protoVersion2)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := NewCluster(srv.Address).CreateSession()
 	db, err := NewCluster(srv.Address).CreateSession()
@@ -238,21 +262,147 @@ func TestConnClosing(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func NewTestServer(t *testing.T) *TestServer {
+func TestStreams_Protocol1(t *testing.T) {
+	srv := NewTestServer(t, protoVersion1)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 1
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	var wg sync.WaitGroup
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			if err := db.Query("void").Exec(); err != nil {
+				t.Error(err)
+			}
+		}()
+	}
+	wg.Wait()
+}
+
+func TestStreams_Protocol2(t *testing.T) {
+	srv := NewTestServer(t, protoVersion2)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 2
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// the test server processes each conn synchronously
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		if err = db.Query("void").Exec(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func TestStreams_Protocol3(t *testing.T) {
+	srv := NewTestServer(t, protoVersion3)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 3
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// the test server processes each conn synchronously
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		if err = db.Query("void").Exec(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func BenchmarkProtocolV3(b *testing.B) {
+	srv := NewTestServer(b, protoVersion3)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 3
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		b.Fatal(err)
+	}
+	defer db.Close()
+
+	b.ResetTimer()
+	b.ReportAllocs()
+	for i := 0; i < b.N; i++ {
+		if err = db.Query("void").Exec(); err != nil {
+			b.Fatal(err)
+		}
+	}
+}
+
+func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
+
 	listen, err := net.ListenTCP("tcp", laddr)
 	listen, err := net.ListenTCP("tcp", laddr)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
+
+	headerSize := 8
+	if protocol > protoVersion2 {
+		headerSize = 9
+	}
+
+	srv := &TestServer{
+		Address:    listen.Addr().String(),
+		listen:     listen,
+		t:          t,
+		protocol:   protocol,
+		headerSize: headerSize,
+	}
+
 	go srv.serve()
 	go srv.serve()
+
 	return srv
 	return srv
 }
 }
 
 
-func NewSSLTestServer(t *testing.T) *TestServer {
+func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	certPool := x509.NewCertPool()
 	certPool := x509.NewCertPool()
 	if !certPool.AppendCertsFromPEM(pem) {
 	if !certPool.AppendCertsFromPEM(pem) {
@@ -270,11 +420,34 @@ func NewSSLTestServer(t *testing.T) *TestServer {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
+
+	headerSize := 8
+	if protocol > protoVersion2 {
+		headerSize = 9
+	}
+
+	srv := &TestServer{
+		Address:    listen.Addr().String(),
+		listen:     listen,
+		t:          t,
+		protocol:   protocol,
+		headerSize: headerSize,
+	}
 	go srv.serve()
 	go srv.serve()
 	return srv
 	return srv
 }
 }
 
 
+type TestServer struct {
+	Address  string
+	t        testing.TB
+	nreq     uint64
+	listen   net.Listener
+	nKillReq int64
+
+	protocol   uint8
+	headerSize int
+}
+
 func (srv *TestServer) serve() {
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
 	defer srv.listen.Close()
 	for {
 	for {
@@ -285,9 +458,16 @@ func (srv *TestServer) serve() {
 		go func(conn net.Conn) {
 		go func(conn net.Conn) {
 			defer conn.Close()
 			defer conn.Close()
 			for {
 			for {
-				frame := srv.readFrame(conn)
+				frame, err := srv.readFrame(conn)
+				if err == io.EOF {
+					return
+				} else if err != nil {
+					srv.t.Error(err)
+					continue
+				}
+
 				atomic.AddUint64(&srv.nreq, 1)
 				atomic.AddUint64(&srv.nreq, 1)
-				srv.process(frame, conn)
+				go srv.process(frame, conn)
 			}
 			}
 		}(conn)
 		}(conn)
 	}
 	}
@@ -297,65 +477,87 @@ func (srv *TestServer) Stop() {
 	srv.listen.Close()
 	srv.listen.Close()
 }
 }
 
 
-func (srv *TestServer) process(frame frame, conn net.Conn) {
-	switch frame[3] {
+func (srv *TestServer) process(f frame, conn net.Conn) {
+	headerSize := headerProtoSize[srv.protocol]
+	stream := f.Stream(srv.protocol)
+
+	switch f.Op(srv.protocol) {
 	case opStartup:
 	case opStartup:
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opReady)
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opReady)
+	case opOptions:
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opSupported)
+		f.writeShort(0)
 	case opQuery:
 	case opQuery:
-		input := frame
-		input.skipHeader()
+		input := f
+		input.skipHeader(srv.protocol)
 		query := strings.TrimSpace(input.readLongString())
 		query := strings.TrimSpace(input.readLongString())
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opResult)
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opResult)
 		first := query
 		first := query
 		if n := strings.Index(query, " "); n > 0 {
 		if n := strings.Index(query, " "); n > 0 {
 			first = first[:n]
 			first = first[:n]
 		}
 		}
 		switch strings.ToLower(first) {
 		switch strings.ToLower(first) {
 		case "kill":
 		case "kill":
-			atomic.AddUint64(&srv.nKillReq, 1)
-			select {}
+			atomic.AddInt64(&srv.nKillReq, 1)
+			f = f[:headerSize]
+			f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+			f.writeInt(0x1001)
+			f.writeString("query killed")
 		case "slow":
 		case "slow":
 			go func() {
 			go func() {
 				<-time.After(1 * time.Second)
 				<-time.After(1 * time.Second)
-				frame.writeInt(resultKindVoid)
-				frame.setLength(len(frame) - headerSize)
-				if _, err := conn.Write(frame); err != nil {
+				f.writeInt(resultKindVoid)
+				f.setLength(len(f)-headerSize, srv.protocol)
+				if _, err := conn.Write(f); err != nil {
 					return
 					return
 				}
 				}
 			}()
 			}()
 			return
 			return
 		case "use":
 		case "use":
-			frame.writeInt(3)
-			frame.writeString(strings.TrimSpace(query[3:]))
+			f.writeInt(3)
+			f.writeString(strings.TrimSpace(query[3:]))
 		case "void":
 		case "void":
-			frame.writeInt(resultKindVoid)
+			f.writeInt(resultKindVoid)
 		default:
 		default:
-			frame.writeInt(resultKindVoid)
+			f.writeInt(resultKindVoid)
 		}
 		}
 	default:
 	default:
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opError)
-		frame.writeInt(0)
-		frame.writeString("not supported")
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+		f.writeInt(0)
+		f.writeString("not supported")
 	}
 	}
-	frame.setLength(len(frame) - headerSize)
-	if _, err := conn.Write(frame); err != nil {
+
+	f.setLength(len(f)-headerSize, srv.protocol)
+	if _, err := conn.Write(f); err != nil {
+		srv.t.Log(err)
 		return
 		return
 	}
 	}
 }
 }
 
 
-func (srv *TestServer) readFrame(conn net.Conn) frame {
-	frame := make(frame, headerSize, headerSize+512)
+func (srv *TestServer) readFrame(conn net.Conn) (frame, error) {
+	frame := make(frame, srv.headerSize, srv.headerSize+512)
 	if _, err := io.ReadFull(conn, frame); err != nil {
 	if _, err := io.ReadFull(conn, frame); err != nil {
-		srv.t.Fatal(err)
+		return nil, err
+	}
+
+	// should be a request frame
+	if frame[0]&protoDirectionMask != 0 {
+		return nil, fmt.Errorf("expected to read a request frame got version: 0x%x", frame[0])
 	}
 	}
-	if n := frame.Length(); n > 0 {
+	if v := frame[0] & protoVersionMask; v != srv.protocol {
+		return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, v)
+	}
+
+	if n := frame.Length(srv.protocol); n > 0 {
 		frame.grow(n)
 		frame.grow(n)
-		if _, err := io.ReadFull(conn, frame[headerSize:]); err != nil {
-			srv.t.Fatal(err)
+		if _, err := io.ReadFull(conn, frame[srv.headerSize:]); err != nil {
+			return nil, err
 		}
 		}
 	}
 	}
-	return frame
+
+	return frame, nil
 }
 }

+ 149 - 33
frame.go

@@ -5,12 +5,16 @@
 package gocql
 package gocql
 
 
 import (
 import (
+	"fmt"
 	"net"
 	"net"
 )
 )
 
 
 const (
 const (
-	protoRequest  byte = 0x02
-	protoResponse byte = 0x82
+	protoDirectionMask = 0x80
+	protoVersionMask   = 0x7F
+	protoVersion1      = 0x01
+	protoVersion2      = 0x02
+	protoVersion3      = 0x03
 
 
 	opError         byte = 0x00
 	opError         byte = 0x00
 	opStartup       byte = 0x01
 	opStartup       byte = 0x01
@@ -42,13 +46,26 @@ const (
 	flagPageState   uint8 = 8
 	flagPageState   uint8 = 8
 	flagHasMore     uint8 = 2
 	flagHasMore     uint8 = 2
 
 
-	headerSize = 8
-
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 )
 )
 
 
+var headerProtoSize = [...]int{
+	protoVersion1: 8,
+	protoVersion2: 8,
+	protoVersion3: 9,
+}
+
+// TODO: replace with a struct which has a header and a body buffer,
+// header just has methods like, set/get the options in its backing array
+// then in a writeTo we write the header then the body.
 type frame []byte
 type frame []byte
 
 
+func newFrame(version uint8) frame {
+	// TODO: pool these at the session level incase anyone is using different
+	// clusters with different versions in the same application.
+	return make(frame, headerProtoSize[version], defaultFrameSize)
+}
+
 func (f *frame) writeInt(v int32) {
 func (f *frame) writeInt(v int32) {
 	p := f.grow(4)
 	p := f.grow(4)
 	(*f)[p] = byte(v >> 24)
 	(*f)[p] = byte(v >> 24)
@@ -129,22 +146,67 @@ func (f *frame) writeStringMultimap(v map[string][]string) {
 	}
 	}
 }
 }
 
 
-func (f *frame) setHeader(version, flags, stream, opcode uint8) {
+func (f *frame) setHeader(version, flags uint8, stream int, opcode uint8) {
 	(*f)[0] = version
 	(*f)[0] = version
 	(*f)[1] = flags
 	(*f)[1] = flags
-	(*f)[2] = stream
-	(*f)[3] = opcode
+	p := 2
+	if version&maskVersion > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+		p += 2
+	} else {
+		(*f)[2] = byte(stream & 0xFF)
+		p++
+	}
+
+	(*f)[p] = opcode
 }
 }
 
 
-func (f *frame) setLength(length int) {
-	(*f)[4] = byte(length >> 24)
-	(*f)[5] = byte(length >> 16)
-	(*f)[6] = byte(length >> 8)
-	(*f)[7] = byte(length)
+func (f *frame) setStream(stream int, version uint8) {
+	if version > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+	} else {
+		(*f)[2] = byte(stream)
+	}
 }
 }
 
 
-func (f *frame) Length() int {
-	return int((*f)[4])<<24 | int((*f)[5])<<16 | int((*f)[6])<<8 | int((*f)[7])
+func (f *frame) Stream(version uint8) (n int) {
+	if version > protoVersion2 {
+		n = int((*f)[2])<<8 | int((*f)[3])
+	} else {
+		n = int((*f)[2])
+	}
+	return
+}
+
+func (f *frame) setLength(length int, version uint8) {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	(*f)[p] = byte(length >> 24)
+	(*f)[p+1] = byte(length >> 16)
+	(*f)[p+2] = byte(length >> 8)
+	(*f)[p+3] = byte(length)
+}
+
+func (f *frame) Op(version uint8) byte {
+	if version > protoVersion2 {
+		return (*f)[4]
+	} else {
+		return (*f)[3]
+	}
+}
+
+func (f *frame) Length(version uint8) int {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	return int((*f)[p])<<24 | int((*f)[p+1])<<16 | int((*f)[p+2])<<8 | int((*f)[p+3])
 }
 }
 
 
 func (f *frame) grow(n int) int {
 func (f *frame) grow(n int) int {
@@ -158,13 +220,13 @@ func (f *frame) grow(n int) int {
 	return p
 	return p
 }
 }
 
 
-func (f *frame) skipHeader() {
-	*f = (*f)[headerSize:]
+func (f *frame) skipHeader(version uint8) {
+	*f = (*f)[headerProtoSize[version]:]
 }
 }
 
 
 func (f *frame) readInt() int {
 func (f *frame) readInt() int {
 	if len(*f) < 4 {
 	if len(*f) < 4 {
-		panic(NewErrProtocol("Trying to read an int while >4 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read an int while <4 bytes in the buffer"))
 	}
 	}
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	*f = (*f)[4:]
 	*f = (*f)[4:]
@@ -173,7 +235,7 @@ func (f *frame) readInt() int {
 
 
 func (f *frame) readShort() uint16 {
 func (f *frame) readShort() uint16 {
 	if len(*f) < 2 {
 	if len(*f) < 2 {
-		panic(NewErrProtocol("Trying to read a short while >2 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read a short while <2 bytes in the buffer"))
 	}
 	}
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	*f = (*f)[2:]
 	*f = (*f)[2:]
@@ -223,9 +285,12 @@ func (f *frame) readShortBytes() []byte {
 	return v
 	return v
 }
 }
 
 
-func (f *frame) readTypeInfo() *TypeInfo {
+func (f *frame) readTypeInfo(version uint8) *TypeInfo {
 	x := f.readShort()
 	x := f.readShort()
-	typ := &TypeInfo{Type: Type(x)}
+	typ := &TypeInfo{
+		Proto: version,
+		Type:  Type(x),
+	}
 	switch typ.Type {
 	switch typ.Type {
 	case TypeCustom:
 	case TypeCustom:
 		typ.Custom = f.readString()
 		typ.Custom = f.readString()
@@ -233,34 +298,37 @@ func (f *frame) readTypeInfo() *TypeInfo {
 			typ = &TypeInfo{Type: cassType}
 			typ = &TypeInfo{Type: cassType}
 			switch typ.Type {
 			switch typ.Type {
 			case TypeMap:
 			case TypeMap:
-				typ.Key = f.readTypeInfo()
+				typ.Key = f.readTypeInfo(version)
 				fallthrough
 				fallthrough
 			case TypeList, TypeSet:
 			case TypeList, TypeSet:
-				typ.Elem = f.readTypeInfo()
+				typ.Elem = f.readTypeInfo(version)
 			}
 			}
 		}
 		}
 	case TypeMap:
 	case TypeMap:
-		typ.Key = f.readTypeInfo()
+		typ.Key = f.readTypeInfo(version)
 		fallthrough
 		fallthrough
 	case TypeList, TypeSet:
 	case TypeList, TypeSet:
-		typ.Elem = f.readTypeInfo()
+		typ.Elem = f.readTypeInfo(version)
 	}
 	}
 	return typ
 	return typ
 }
 }
 
 
-func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
+func (f *frame) readMetaData(version uint8) ([]ColumnInfo, []byte) {
 	flags := f.readInt()
 	flags := f.readInt()
 	numColumns := f.readInt()
 	numColumns := f.readInt()
+
 	var pageState []byte
 	var pageState []byte
 	if flags&2 != 0 {
 	if flags&2 != 0 {
 		pageState = f.readBytes()
 		pageState = f.readBytes()
 	}
 	}
+
 	globalKeyspace := ""
 	globalKeyspace := ""
 	globalTable := ""
 	globalTable := ""
 	if flags&1 != 0 {
 	if flags&1 != 0 {
 		globalKeyspace = f.readString()
 		globalKeyspace = f.readString()
 		globalTable = f.readString()
 		globalTable = f.readString()
 	}
 	}
+
 	columns := make([]ColumnInfo, numColumns)
 	columns := make([]ColumnInfo, numColumns)
 	for i := 0; i < numColumns; i++ {
 	for i := 0; i < numColumns; i++ {
 		columns[i].Keyspace = globalKeyspace
 		columns[i].Keyspace = globalKeyspace
@@ -270,7 +338,7 @@ func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
 			columns[i].Table = f.readString()
 			columns[i].Table = f.readString()
 		}
 		}
 		columns[i].Name = f.readString()
 		columns[i].Name = f.readString()
-		columns[i].TypeInfo = f.readTypeInfo()
+		columns[i].TypeInfo = f.readTypeInfo(version)
 	}
 	}
 	return columns, pageState
 	return columns, pageState
 }
 }
@@ -381,19 +449,32 @@ type startupFrame struct {
 	Compression string
 	Compression string
 }
 }
 
 
+func (op *startupFrame) String() string {
+	return fmt.Sprintf("[startup cqlversion=%q compression=%q]", op.CQLVersion, op.Compression)
+}
+
 func (op *startupFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *startupFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
+
 	f.setHeader(version, 0, 0, opStartup)
 	f.setHeader(version, 0, 0, opStartup)
-	f.writeShort(1)
+
+	// TODO: fix this, this is actually a StringMap
+	var size uint16 = 1
+	if op.Compression != "" {
+		size++
+	}
+
+	f.writeShort(size)
 	f.writeString("CQL_VERSION")
 	f.writeString("CQL_VERSION")
 	f.writeString(op.CQLVersion)
 	f.writeString(op.CQLVersion)
+
 	if op.Compression != "" {
 	if op.Compression != "" {
-		f[headerSize+1] += 1
 		f.writeString("COMPRESSION")
 		f.writeString("COMPRESSION")
 		f.writeString(op.Compression)
 		f.writeString(op.Compression)
 	}
 	}
+
 	return f, nil
 	return f, nil
 }
 }
 
 
@@ -406,14 +487,20 @@ type queryFrame struct {
 	PageState []byte
 	PageState []byte
 }
 }
 
 
+func (op *queryFrame) String() string {
+	return fmt.Sprintf("[query statement=%q prepared=%x cons=%v ...]", op.Stmt, op.Prepared, op.Cons)
+}
+
 func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if version == 1 && (op.PageSize != 0 || len(op.PageState) > 0 ||
 	if version == 1 && (op.PageSize != 0 || len(op.PageState) > 0 ||
 		(len(op.Values) > 0 && len(op.Prepared) == 0)) {
 		(len(op.Values) > 0 && len(op.Prepared) == 0)) {
 		return nil, ErrUnsupported
 		return nil, ErrUnsupported
 	}
 	}
+
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
+
 	if len(op.Prepared) > 0 {
 	if len(op.Prepared) > 0 {
 		f.setHeader(version, 0, 0, opExecute)
 		f.setHeader(version, 0, 0, opExecute)
 		f.writeShortBytes(op.Prepared)
 		f.writeShortBytes(op.Prepared)
@@ -421,10 +508,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		f.setHeader(version, 0, 0, opQuery)
 		f.setHeader(version, 0, 0, opQuery)
 		f.writeLongString(op.Stmt)
 		f.writeLongString(op.Stmt)
 	}
 	}
+
 	if version >= 2 {
 	if version >= 2 {
 		f.writeConsistency(op.Cons)
 		f.writeConsistency(op.Cons)
 		flagPos := len(f)
 		flagPos := len(f)
 		f.writeByte(0)
 		f.writeByte(0)
+
 		if len(op.Values) > 0 {
 		if len(op.Values) > 0 {
 			f[flagPos] |= flagQueryValues
 			f[flagPos] |= flagQueryValues
 			f.writeShort(uint16(len(op.Values)))
 			f.writeShort(uint16(len(op.Values)))
@@ -432,10 +521,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 				f.writeBytes(value)
 				f.writeBytes(value)
 			}
 			}
 		}
 		}
+
 		if op.PageSize > 0 {
 		if op.PageSize > 0 {
 			f[flagPos] |= flagPageSize
 			f[flagPos] |= flagPageSize
 			f.writeInt(int32(op.PageSize))
 			f.writeInt(int32(op.PageSize))
 		}
 		}
+
 		if len(op.PageState) > 0 {
 		if len(op.PageState) > 0 {
 			f[flagPos] |= flagPageState
 			f[flagPos] |= flagPageState
 			f.writeBytes(op.PageState)
 			f.writeBytes(op.PageState)
@@ -449,6 +540,7 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		}
 		}
 		f.writeConsistency(op.Cons)
 		f.writeConsistency(op.Cons)
 	}
 	}
+
 	return f, nil
 	return f, nil
 }
 }
 
 
@@ -456,9 +548,13 @@ type prepareFrame struct {
 	Stmt string
 	Stmt string
 }
 }
 
 
+func (op *prepareFrame) String() string {
+	return fmt.Sprintf("[prepare statement=%q]", op.Stmt)
+}
+
 func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opPrepare)
 	f.setHeader(version, 0, 0, opPrepare)
 	f.writeLongString(op.Stmt)
 	f.writeLongString(op.Stmt)
@@ -467,9 +563,13 @@ func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 
 
 type optionsFrame struct{}
 type optionsFrame struct{}
 
 
+func (op *optionsFrame) String() string {
+	return "[options]"
+}
+
 func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opOptions)
 	f.setHeader(version, 0, 0, opOptions)
 	return f, nil
 	return f, nil
@@ -479,13 +579,21 @@ type authenticateFrame struct {
 	Authenticator string
 	Authenticator string
 }
 }
 
 
+func (op *authenticateFrame) String() string {
+	return fmt.Sprintf("[authenticate authenticator=%q]", op.Authenticator)
+}
+
 type authResponseFrame struct {
 type authResponseFrame struct {
 	Data []byte
 	Data []byte
 }
 }
 
 
+func (op *authResponseFrame) String() string {
+	return fmt.Sprintf("[auth_response data=%q]", op.Data)
+}
+
 func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opAuthResponse)
 	f.setHeader(version, 0, 0, opAuthResponse)
 	f.writeBytes(op.Data)
 	f.writeBytes(op.Data)
@@ -496,6 +604,14 @@ type authSuccessFrame struct {
 	Data []byte
 	Data []byte
 }
 }
 
 
+func (op *authSuccessFrame) String() string {
+	return fmt.Sprintf("[auth_success data=%q]", op.Data)
+}
+
 type authChallengeFrame struct {
 type authChallengeFrame struct {
 	Data []byte
 	Data []byte
 }
 }
+
+func (op *authChallengeFrame) String() string {
+	return fmt.Sprintf("[auth_challenge data=%q]", op.Data)
+}

+ 2 - 0
integration.sh

@@ -19,6 +19,8 @@ function run_tests() {
 	local proto=2
 	local proto=2
 	if [[ $version == 1.2.* ]]; then
 	if [[ $version == 1.2.* ]]; then
 		proto=1
 		proto=1
+	elif [[ $version == 2.1.* ]]; then
+		proto=3
 	fi
 	fi
 
 
 	go test -timeout 5m -tags integration -cover -v -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms ./... | tee results.txt
 	go test -timeout 5m -tags integration -cover -v -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms ./... | tee results.txt

+ 61 - 30
marshal.go

@@ -40,6 +40,9 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 	if value == nil {
 	if value == nil {
 		return nil, nil
 		return nil, nil
 	}
 	}
+	if info.Proto < protoVersion1 {
+		panic("protocol version not set")
+	}
 
 
 	if v, ok := value.(Marshaler); ok {
 	if v, ok := value.(Marshaler); ok {
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
@@ -814,6 +817,28 @@ func unmarshalTimestamp(info *TypeInfo, data []byte, value interface{}) error {
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }
 }
 
 
+func writeCollectionSize(info *TypeInfo, n int, buf *bytes.Buffer) error {
+	if info.Proto > protoVersion2 {
+		if n > math.MaxInt32 {
+			return marshalErrorf("marshal: collection too large")
+		}
+
+		buf.WriteByte(byte(n >> 24))
+		buf.WriteByte(byte(n >> 16))
+		buf.WriteByte(byte(n >> 8))
+		buf.WriteByte(byte(n))
+	} else {
+		if n > math.MaxUint16 {
+			return marshalErrorf("marshal: collection too large")
+		}
+
+		buf.WriteByte(byte(n >> 8))
+		buf.WriteByte(byte(n))
+	}
+
+	return nil
+}
+
 func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 	rv := reflect.ValueOf(value)
 	rv := reflect.ValueOf(value)
 	t := rv.Type()
 	t := rv.Type()
@@ -825,21 +850,19 @@ func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 		}
 		}
 		buf := &bytes.Buffer{}
 		buf := &bytes.Buffer{}
 		n := rv.Len()
 		n := rv.Len()
-		if n > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array too large")
+
+		if err := writeCollectionSize(info, n, buf); err != nil {
+			return nil, err
 		}
 		}
-		buf.WriteByte(byte(n >> 8))
-		buf.WriteByte(byte(n))
+
 		for i := 0; i < n; i++ {
 		for i := 0; i < n; i++ {
 			item, err := Marshal(info.Elem, rv.Index(i).Interface())
 			item, err := Marshal(info.Elem, rv.Index(i).Interface())
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
-			if len(item) > math.MaxUint16 {
-				return nil, marshalErrorf("marshal: slice / array item too large")
+			if err := writeCollectionSize(info, len(item), buf); err != nil {
+				return nil, err
 			}
 			}
-			buf.WriteByte(byte(len(item) >> 8))
-			buf.WriteByte(byte(len(item)))
 			buf.Write(item)
 			buf.Write(item)
 		}
 		}
 		return buf.Bytes(), nil
 		return buf.Bytes(), nil
@@ -858,6 +881,17 @@ func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
 
 
+func readCollectionSize(info *TypeInfo, data []byte) (size, read int) {
+	if info.Proto > protoVersion2 {
+		size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
+		read = 4
+	} else {
+		size = int(data[0])<<8 | int(data[1])
+		read = 2
+	}
+	return
+}
+
 func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 	rv := reflect.ValueOf(value)
 	rv := reflect.ValueOf(value)
 	if rv.Kind() != reflect.Ptr {
 	if rv.Kind() != reflect.Ptr {
@@ -879,8 +913,8 @@ func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 		if len(data) < 2 {
 		if len(data) < 2 {
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 		}
 		}
-		n := int(data[0])<<8 | int(data[1])
-		data = data[2:]
+		n, p := readCollectionSize(info, data)
+		data = data[p:]
 		if k == reflect.Array {
 		if k == reflect.Array {
 			if rv.Len() != n {
 			if rv.Len() != n {
 				return unmarshalErrorf("unmarshal list: array with wrong size")
 				return unmarshalErrorf("unmarshal list: array with wrong size")
@@ -894,8 +928,8 @@ func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 			if len(data) < 2 {
 			if len(data) < 2 {
 				return unmarshalErrorf("unmarshal list: unexpected eof")
 				return unmarshalErrorf("unmarshal list: unexpected eof")
 			}
 			}
-			m := int(data[0])<<8 | int(data[1])
-			data = data[2:]
+			m, p := readCollectionSize(info, data)
+			data = data[p:]
 			if err := Unmarshal(info.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 			if err := Unmarshal(info.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 				return err
 				return err
 			}
 			}
@@ -917,33 +951,29 @@ func marshalMap(info *TypeInfo, value interface{}) ([]byte, error) {
 	}
 	}
 	buf := &bytes.Buffer{}
 	buf := &bytes.Buffer{}
 	n := rv.Len()
 	n := rv.Len()
-	if n > math.MaxUint16 {
-		return nil, marshalErrorf("marshal: map too large")
+
+	if err := writeCollectionSize(info, n, buf); err != nil {
+		return nil, err
 	}
 	}
-	buf.WriteByte(byte(n >> 8))
-	buf.WriteByte(byte(n))
+
 	keys := rv.MapKeys()
 	keys := rv.MapKeys()
 	for _, key := range keys {
 	for _, key := range keys {
 		item, err := Marshal(info.Key, key.Interface())
 		item, err := Marshal(info.Key, key.Interface())
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		if len(item) > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array item too large")
+		if err := writeCollectionSize(info, len(item), buf); err != nil {
+			return nil, err
 		}
 		}
-		buf.WriteByte(byte(len(item) >> 8))
-		buf.WriteByte(byte(len(item)))
 		buf.Write(item)
 		buf.Write(item)
 
 
 		item, err = Marshal(info.Elem, rv.MapIndex(key).Interface())
 		item, err = Marshal(info.Elem, rv.MapIndex(key).Interface())
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		if len(item) > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array item too large")
+		if err := writeCollectionSize(info, len(item), buf); err != nil {
+			return nil, err
 		}
 		}
-		buf.WriteByte(byte(len(item) >> 8))
-		buf.WriteByte(byte(len(item)))
 		buf.Write(item)
 		buf.Write(item)
 	}
 	}
 	return buf.Bytes(), nil
 	return buf.Bytes(), nil
@@ -967,22 +997,22 @@ func unmarshalMap(info *TypeInfo, data []byte, value interface{}) error {
 	if len(data) < 2 {
 	if len(data) < 2 {
 		return unmarshalErrorf("unmarshal map: unexpected eof")
 		return unmarshalErrorf("unmarshal map: unexpected eof")
 	}
 	}
-	n := int(data[1]) | int(data[0])<<8
-	data = data[2:]
+	n, p := readCollectionSize(info, data)
+	data = data[p:]
 	for i := 0; i < n; i++ {
 	for i := 0; i < n; i++ {
 		if len(data) < 2 {
 		if len(data) < 2 {
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 		}
 		}
-		m := int(data[1]) | int(data[0])<<8
-		data = data[2:]
+		m, p := readCollectionSize(info, data)
+		data = data[p:]
 		key := reflect.New(t.Key())
 		key := reflect.New(t.Key())
 		if err := Unmarshal(info.Key, data[:m], key.Interface()); err != nil {
 		if err := Unmarshal(info.Key, data[:m], key.Interface()); err != nil {
 			return err
 			return err
 		}
 		}
 		data = data[m:]
 		data = data[m:]
 
 
-		m = int(data[1]) | int(data[0])<<8
-		data = data[2:]
+		m, p = readCollectionSize(info, data)
+		data = data[p:]
 		val := reflect.New(t.Elem())
 		val := reflect.New(t.Elem())
 		if err := Unmarshal(info.Elem, data[:m], val.Interface()); err != nil {
 		if err := Unmarshal(info.Elem, data[:m], val.Interface()); err != nil {
 			return err
 			return err
@@ -1120,6 +1150,7 @@ func unmarshalInet(info *TypeInfo, data []byte, value interface{}) error {
 
 
 // TypeInfo describes a Cassandra specific data type.
 // TypeInfo describes a Cassandra specific data type.
 type TypeInfo struct {
 type TypeInfo struct {
+	Proto  byte // version of the protocol
 	Type   Type
 	Type   Type
 	Key    *TypeInfo // only used for TypeMap
 	Key    *TypeInfo // only used for TypeMap
 	Elem   *TypeInfo // only used for TypeMap, TypeList and TypeSet
 	Elem   *TypeInfo // only used for TypeMap, TypeList and TypeSet

+ 97 - 97
marshal_test.go

@@ -21,42 +21,42 @@ var marshalTests = []struct {
 	Value interface{}
 	Value interface{}
 }{
 }{
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		[]byte("hello world"),
 		[]byte("hello world"),
 		[]byte("hello world"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		[]byte("hello world"),
 		"hello world",
 		"hello world",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte(nil),
 		[]byte(nil),
 		[]byte(nil),
 		[]byte(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		[]byte("hello world"),
 		MyString("hello world"),
 		MyString("hello world"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("HELLO WORLD"),
 		[]byte("HELLO WORLD"),
 		CustomString("hello world"),
 		CustomString("hello world"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBlob},
+		&TypeInfo{Proto: 2, Type: TypeBlob},
 		[]byte("hello\x00"),
 		[]byte("hello\x00"),
 		[]byte("hello\x00"),
 		[]byte("hello\x00"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBlob},
+		&TypeInfo{Proto: 2, Type: TypeBlob},
 		[]byte(nil),
 		[]byte(nil),
 		[]byte(nil),
 		[]byte(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		func() UUID {
 		func() UUID {
 			x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0})
 			x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0})
@@ -64,217 +64,217 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00"),
 		0,
 		0,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x01\x02\x03\x04"),
 		[]byte("\x01\x02\x03\x04"),
 		int(16909060),
 		int(16909060),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x80\x00\x00\x00"),
 		[]byte("\x80\x00\x00\x00"),
 		int32(math.MinInt32),
 		int32(math.MinInt32),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff"),
 		int32(math.MaxInt32),
 		int32(math.MaxInt32),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00"),
 		"0",
 		"0",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x01\x02\x03\x04"),
 		[]byte("\x01\x02\x03\x04"),
 		"16909060",
 		"16909060",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x80\x00\x00\x00"),
 		[]byte("\x80\x00\x00\x00"),
 		"-2147483648", // math.MinInt32
 		"-2147483648", // math.MinInt32
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff"),
 		"2147483647", // math.MaxInt32
 		"2147483647", // math.MaxInt32
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		0,
 		0,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		72623859790382856,
 		72623859790382856,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		int64(math.MinInt64),
 		int64(math.MinInt64),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		int64(math.MaxInt64),
 		int64(math.MaxInt64),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		"0",
 		"0",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		"72623859790382856",
 		"72623859790382856",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		"-9223372036854775808", // math.MinInt64
 		"-9223372036854775808", // math.MinInt64
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		"9223372036854775807", // math.MaxInt64
 		"9223372036854775807", // math.MaxInt64
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x00"),
 		[]byte("\x00"),
 		false,
 		false,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x01"),
 		[]byte("\x01"),
 		true,
 		true,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte("\x40\x49\x0f\xdb"),
 		[]byte("\x40\x49\x0f\xdb"),
 		float32(3.14159265),
 		float32(3.14159265),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		float64(3.14159265),
 		float64(3.14159265),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00\x00"),
 		inf.NewDec(0, 0),
 		inf.NewDec(0, 0),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x00\x64"),
 		[]byte("\x00\x00\x00\x00\x64"),
 		inf.NewDec(100, 0),
 		inf.NewDec(100, 0),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x02\x19"),
 		[]byte("\x00\x00\x00\x02\x19"),
 		decimalize("0.25"),
 		decimalize("0.25"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x13\xD5\a;\x20\x14\xA2\x91"),
 		[]byte("\x00\x00\x00\x13\xD5\a;\x20\x14\xA2\x91"),
 		decimalize("-0.0012095473475870063"), // From the iconara/cql-rb test suite
 		decimalize("-0.0012095473475870063"), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x13*\xF8\xC4\xDF\xEB]o"),
 		[]byte("\x00\x00\x00\x13*\xF8\xC4\xDF\xEB]o"),
 		decimalize("0.0012095473475870063"), // From the iconara/cql-rb test suite
 		decimalize("0.0012095473475870063"), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x12\xF2\xD8\x02\xB6R\x7F\x99\xEE\x98#\x99\xA9V"),
 		[]byte("\x00\x00\x00\x12\xF2\xD8\x02\xB6R\x7F\x99\xEE\x98#\x99\xA9V"),
 		decimalize("-1042342234234.123423435647768234"), // From the iconara/cql-rb test suite
 		decimalize("-1042342234234.123423435647768234"), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\r\nJ\x04\"^\x91\x04\x8a\xb1\x18\xfe"),
 		[]byte("\x00\x00\x00\r\nJ\x04\"^\x91\x04\x8a\xb1\x18\xfe"),
 		decimalize("1243878957943.1234124191998"), // From the datastax/python-driver test suite
 		decimalize("1243878957943.1234124191998"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x06\xe5\xde]\x98Y"),
 		[]byte("\x00\x00\x00\x06\xe5\xde]\x98Y"),
 		decimalize("-112233.441191"), // From the datastax/python-driver test suite
 		decimalize("-112233.441191"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x14\x00\xfa\xce"),
 		[]byte("\x00\x00\x00\x14\x00\xfa\xce"),
 		decimalize("0.00000000000000064206"), // From the datastax/python-driver test suite
 		decimalize("0.00000000000000064206"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x14\xff\x052"),
 		[]byte("\x00\x00\x00\x14\xff\x052"),
 		decimalize("-0.00000000000000064206"), // From the datastax/python-driver test suite
 		decimalize("-0.00000000000000064206"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\xff\xff\xff\x9c\x00\xfa\xce"),
 		[]byte("\xff\xff\xff\x9c\x00\xfa\xce"),
 		inf.NewDec(64206, -100), // From the datastax/python-driver test suite
 		inf.NewDec(64206, -100), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC),
 		time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		int64(1376387523000),
 		int64(1376387523000),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]int{1, 2},
 		[]int{1, 2},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[2]int{1, 2},
 		[2]int{1, 2},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeSet, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeSet, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]int{1, 2},
 		[]int{1, 2},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeSet, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeSet, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte(nil),
 		[]byte(nil),
 		[]int(nil),
 		[]int(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		},
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		map[string]int{"foo": 1},
 		map[string]int{"foo": 1},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		},
 		[]byte(nil),
 		[]byte(nil),
 		map[string]int(nil),
 		map[string]int(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeVarchar}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeVarchar}},
 		bytes.Join([][]byte{
 		bytes.Join([][]byte{
 			[]byte("\x00\x01\xFF\xFF"),
 			[]byte("\x00\x01\xFF\xFF"),
 			bytes.Repeat([]byte("X"), 65535)}, []byte("")),
 			bytes.Repeat([]byte("X"), 65535)}, []byte("")),
 		[]string{strings.Repeat("X", 65535)},
 		[]string{strings.Repeat("X", 65535)},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeVarchar},
 		},
 		},
 		bytes.Join([][]byte{
 		bytes.Join([][]byte{
 			[]byte("\x00\x01\xFF\xFF"),
 			[]byte("\x00\x01\xFF\xFF"),
@@ -286,82 +286,82 @@ var marshalTests = []struct {
 		},
 		},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x00"),
 		[]byte("\x00"),
 		0,
 		0,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x37\xE2\x3C\xEC"),
 		[]byte("\x37\xE2\x3C\xEC"),
 		int32(937573612),
 		int32(937573612),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x37\xE2\x3C\xEC"),
 		[]byte("\x37\xE2\x3C\xEC"),
 		big.NewInt(937573612),
 		big.NewInt(937573612),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x03\x9EV \x15\f\x03\x9DK\x18\xCDI\\$?\a["),
 		[]byte("\x03\x9EV \x15\f\x03\x9DK\x18\xCDI\\$?\a["),
 		bigintize("1231312312331283012830129382342342412123"), // From the iconara/cql-rb test suite
 		bigintize("1231312312331283012830129382342342412123"), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\xC9v\x8D:\x86"),
 		[]byte("\xC9v\x8D:\x86"),
 		big.NewInt(-234234234234), // From the iconara/cql-rb test suite
 		big.NewInt(-234234234234), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
 		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
 		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
 		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		[]byte("\x7F\x00\x00\x01"),
 		net.ParseIP("127.0.0.1").To4(),
 		net.ParseIP("127.0.0.1").To4(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xFF\xFF\xFF\xFF"),
 		[]byte("\xFF\xFF\xFF\xFF"),
 		net.ParseIP("255.255.255.255").To4(),
 		net.ParseIP("255.255.255.255").To4(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		[]byte("\x7F\x00\x00\x01"),
 		"127.0.0.1",
 		"127.0.0.1",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xFF\xFF\xFF\xFF"),
 		[]byte("\xFF\xFF\xFF\xFF"),
 		"255.255.255.255",
 		"255.255.255.255",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		"21da:d3:0:2f3b:2aa:ff:fe28:9c5a",
 		"21da:d3:0:2f3b:2aa:ff:fe28:9c5a",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		"fe80::202:b3ff:fe1e:8329",
 		"fe80::202:b3ff:fe1e:8329",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"),
 		net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		net.ParseIP("fe80::202:b3ff:fe1e:8329"),
 		net.ParseIP("fe80::202:b3ff:fe1e:8329"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte(nil),
 		[]byte(nil),
 		nil,
 		nil,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("nullable string"),
 		[]byte("nullable string"),
 		func() *string {
 		func() *string {
 			value := "nullable string"
 			value := "nullable string"
@@ -369,12 +369,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte{},
 		[]byte{},
 		(*string)(nil),
 		(*string)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff"),
 		func() *int {
 		func() *int {
 			var value int = math.MaxInt32
 			var value int = math.MaxInt32
@@ -382,22 +382,22 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte(nil),
 		[]byte(nil),
 		(*int)(nil),
 		(*int)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		&UUID{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		&UUID{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{},
 		[]byte{},
 		(*UUID)(nil),
 		(*UUID)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		func() *time.Time {
 		func() *time.Time {
 			t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC)
 			t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC)
@@ -405,12 +405,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte(nil),
 		[]byte(nil),
 		(*time.Time)(nil),
 		(*time.Time)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x00"),
 		[]byte("\x00"),
 		func() *bool {
 		func() *bool {
 			b := false
 			b := false
@@ -418,7 +418,7 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x01"),
 		[]byte("\x01"),
 		func() *bool {
 		func() *bool {
 			b := true
 			b := true
@@ -426,12 +426,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte(nil),
 		[]byte(nil),
 		(*bool)(nil),
 		(*bool)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte("\x40\x49\x0f\xdb"),
 		[]byte("\x40\x49\x0f\xdb"),
 		func() *float32 {
 		func() *float32 {
 			f := float32(3.14159265)
 			f := float32(3.14159265)
@@ -439,12 +439,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte(nil),
 		[]byte(nil),
 		(*float32)(nil),
 		(*float32)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		func() *float64 {
 		func() *float64 {
 			d := float64(3.14159265)
 			d := float64(3.14159265)
@@ -452,12 +452,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte(nil),
 		[]byte(nil),
 		(*float64)(nil),
 		(*float64)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		[]byte("\x7F\x00\x00\x01"),
 		func() *net.IP {
 		func() *net.IP {
 			ip := net.ParseIP("127.0.0.1").To4()
 			ip := net.ParseIP("127.0.0.1").To4()
@@ -465,12 +465,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte(nil),
 		[]byte(nil),
 		(*net.IP)(nil),
 		(*net.IP)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		func() *[]int {
 		func() *[]int {
 			l := []int{1, 2}
 			l := []int{1, 2}
@@ -478,14 +478,14 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte(nil),
 		[]byte(nil),
 		(*[]int)(nil),
 		(*[]int)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		},
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		func() *map[string]int {
 		func() *map[string]int {
@@ -494,9 +494,9 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		},
 		[]byte(nil),
 		[]byte(nil),
 		(*map[string]int)(nil),
 		(*map[string]int)(nil),
@@ -610,7 +610,7 @@ func TestMarshalVarint(t *testing.T) {
 	}
 	}
 
 
 	for i, test := range varintTests {
 	for i, test := range varintTests {
-		data, err := Marshal(&TypeInfo{Type: TypeVarint}, test.Value)
+		data, err := Marshal(&TypeInfo{Proto: 2, Type: TypeVarint}, test.Value)
 		if err != nil {
 		if err != nil {
 			t.Errorf("error marshaling varint: %v (test #%d)", err, i)
 			t.Errorf("error marshaling varint: %v (test #%d)", err, i)
 		}
 		}
@@ -620,7 +620,7 @@ func TestMarshalVarint(t *testing.T) {
 		}
 		}
 
 
 		binder := new(big.Int)
 		binder := new(big.Int)
-		err = Unmarshal(&TypeInfo{Type: TypeVarint}, test.Marshaled, binder)
+		err = Unmarshal(&TypeInfo{Proto: 2, Type: TypeVarint}, test.Marshaled, binder)
 		if err != nil {
 		if err != nil {
 			t.Errorf("error unmarshaling varint: %v (test #%d)", err, i)
 			t.Errorf("error unmarshaling varint: %v (test #%d)", err, i)
 		}
 		}