Browse Source

Merge pull request #332

Miles Delahunty 10 years ago
parent
commit
ed80a3721e
16 changed files with 2562 additions and 318 deletions
  1. 1 2
      .travis.yml
  2. 1 0
      AUTHORS
  3. 13 2
      README.md
  4. 277 0
      cassandra_test.go
  5. 11 2
      cluster.go
  6. 108 53
      conn.go
  7. 287 85
      conn_test.go
  8. 4 1
      connectionpool.go
  9. 149 33
      frame.go
  10. 3 3
      helpers.go
  11. 2 0
      integration.sh
  12. 62 31
      marshal.go
  13. 100 99
      marshal_test.go
  14. 844 0
      metadata.go
  15. 670 0
      metadata_test.go
  16. 30 7
      session.go

+ 1 - 2
.travis.yml

@@ -5,11 +5,10 @@ matrix:
 
 
 env:
 env:
   - CASS=1.2.19
   - CASS=1.2.19
-  - CASS=2.0.11
+  - CASS=2.0.12
   - CASS=2.1.2
   - CASS=2.1.2
 
 
 go:
 go:
-  - 1.2
   - 1.3
   - 1.3
   - 1.4
   - 1.4
 
 

+ 1 - 0
AUTHORS

@@ -43,4 +43,5 @@ James Maloney <jamessagan@gmail.com>
 Ashwin Purohit <purohit@gmail.com>
 Ashwin Purohit <purohit@gmail.com>
 Dan Kinder <dkinder.is.me@gmail.com>
 Dan Kinder <dkinder.is.me@gmail.com>
 Oliver Beattie <oliver@obeattie.com>
 Oliver Beattie <oliver@obeattie.com>
+Justin Corpron <justin@retailnext.com>
 Miles Delahunty <miles.delahunty@gmail.com>
 Miles Delahunty <miles.delahunty@gmail.com>

+ 13 - 2
README.md

@@ -18,8 +18,14 @@ The following matrix shows the versions of Go and Cassandra that are tested with
 
 
 Go/Cassandra | 1.2.19 | 2.0.11 | 2.1.2
 Go/Cassandra | 1.2.19 | 2.0.11 | 2.1.2
 -------------| -------| ------| ---------
 -------------| -------| ------| ---------
-1.2  | yes | yes | yes
 1.3  | yes | yes | yes
 1.3  | yes | yes | yes
+1.4  | yes | yes | yes
+
+
+Sunsetting Model
+----------------
+
+In general, the gocql team will focus on supporting the current and previous versions of Golang. gocql may still work with older versions of Golang, but offical support for these versions will have been sunset.
 
 
 Installation
 Installation
 ------------
 ------------
@@ -41,14 +47,17 @@ Features
   * Automatic reconnect on connection failures with exponential falloff
   * Automatic reconnect on connection failures with exponential falloff
   * Round robin distribution of queries to different hosts
   * Round robin distribution of queries to different hosts
   * Round robin distribution of queries to different connections on a host
   * Round robin distribution of queries to different connections on a host
-  * Each connection can execute up to 128 concurrent queries
+  * Each connection can execute up to n concurrent queries (whereby n is the limit set by the protocol version the client chooses to use)
   * Optional automatic discovery of nodes
   * Optional automatic discovery of nodes
   * Optional support for periodic node discovery via system.peers
   * Optional support for periodic node discovery via system.peers
+* Support for password authentication
 * Iteration over paged results with configurable page size
 * Iteration over paged results with configurable page size
 * Support for TLS/SSL
 * Support for TLS/SSL
 * Optional frame compression (using snappy)
 * Optional frame compression (using snappy)
 * Automatic query preparation
 * Automatic query preparation
 * Support for query tracing
 * Support for query tracing
+* Experimental support for CQL protocol version 3
+* An API to access the schema metadata of a given keyspace
 
 
 Please visit the [Roadmap](https://github.com/gocql/gocql/wiki/Roadmap) page to see what is on the horizion.
 Please visit the [Roadmap](https://github.com/gocql/gocql/wiki/Roadmap) page to see what is on the horizion.
 
 
@@ -149,6 +158,7 @@ There are various ways to bind application level data structures to CQL statemen
 * The `Bind()` API provides a client app with a low level mechanism to introspect query meta data and extract appropriate field values from application level data structures.
 * The `Bind()` API provides a client app with a low level mechanism to introspect query meta data and extract appropriate field values from application level data structures.
 * Building on top of the gocql driver, [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement.
 * Building on top of the gocql driver, [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement.
 * Another external project that layers on top of gocql is [cqlc](http://relops.com/cqlc) which generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax.
 * Another external project that layers on top of gocql is [cqlc](http://relops.com/cqlc) which generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax.
+*  [gocassa](https://github.com/hailocab/gocassa) is an external project that layers on top of gocql to provide convenient query building and data binding.
 
 
 Ecosphere
 Ecosphere
 ---------
 ---------
@@ -159,6 +169,7 @@ The following community maintained tools are known to integrate with gocql:
 * [negronicql](https://github.com/mikebthun/negronicql) is gocql middleware for Negroni.
 * [negronicql](https://github.com/mikebthun/negronicql) is gocql middleware for Negroni.
 * [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement.
 * [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement.
 * [cqlc](http://relops.com/cqlc) which generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax.
 * [cqlc](http://relops.com/cqlc) which generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax.
+* [gocassa](https://github.com/hailocab/gocassa) provides query building, adds data binding, and provides easy-to-use "recipe" tables for common query use-cases.
 
 
 Other Projects
 Other Projects
 --------------
 --------------

+ 277 - 0
cassandra_test.go

@@ -1427,3 +1427,280 @@ func TestEmptyTimestamp(t *testing.T) {
 		t.Errorf("time.Time bind variable should still be empty (was %s)", timeVal)
 		t.Errorf("time.Time bind variable should still be empty (was %s)", timeVal)
 	}
 	}
 }
 }
+
+func TestGetKeyspaceMetadata(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	keyspaceMetadata, err := getKeyspaceMetadata(session, "gocql_test")
+	if err != nil {
+		t.Fatalf("failed to query the keyspace metadata with err: %v", err)
+	}
+	if keyspaceMetadata == nil {
+		t.Fatal("failed to query the keyspace metadata, nil returned")
+	}
+	if keyspaceMetadata.Name != "gocql_test" {
+		t.Errorf("Expected keyspace name to be 'gocql' but was '%s'", keyspaceMetadata.Name)
+	}
+	if keyspaceMetadata.StrategyClass != "org.apache.cassandra.locator.SimpleStrategy" {
+		t.Errorf("Expected replication strategy class to be 'org.apache.cassandra.locator.SimpleStrategy' but was '%s'", keyspaceMetadata.StrategyClass)
+	}
+	if keyspaceMetadata.StrategyOptions == nil {
+		t.Error("Expected replication strategy options map but was nil")
+	}
+	rfStr, ok := keyspaceMetadata.StrategyOptions["replication_factor"]
+	if !ok {
+		t.Fatalf("Expected strategy option 'replication_factor' but was not found in %v", keyspaceMetadata.StrategyOptions)
+	}
+	rfInt, err := strconv.Atoi(rfStr.(string))
+	if err != nil {
+		t.Fatalf("Error converting string to int with err: %v", err)
+	}
+	if rfInt != *flagRF {
+		t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt)
+	}
+}
+
+func TestGetTableMetadata(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE test_table_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	tables, err := getTableMetadata(session, "gocql_test")
+	if err != nil {
+		t.Fatalf("failed to query the table metadata with err: %v", err)
+	}
+	if tables == nil {
+		t.Fatal("failed to query the table metadata, nil returned")
+	}
+
+	var testTable *TableMetadata
+
+	// verify all tables have minimum expected data
+	for i := range tables {
+		table := &tables[i]
+
+		if table.Name == "" {
+			t.Errorf("Expected table name to be set, but it was empty: index=%d metadata=%+v", i, table)
+		}
+		if table.Keyspace != "gocql_test" {
+			t.Errorf("Expected keyspace for '%d' table metadata to be 'gocql_test' but was '%s'", table.Name, table.Keyspace)
+		}
+		if table.KeyValidator == "" {
+			t.Errorf("Expected key validator to be set for table %s", table.Name)
+		}
+		if table.Comparator == "" {
+			t.Errorf("Expected comparator to be set for table %s", table.Name)
+		}
+		if table.DefaultValidator == "" {
+			t.Errorf("Expected default validator to be set for table %s", table.Name)
+		}
+
+		// these fields are not set until the metadata is compiled
+		if table.PartitionKey != nil {
+			t.Errorf("Did not expect partition key for table %s", table.Name)
+		}
+		if table.ClusteringColumns != nil {
+			t.Errorf("Did not expect clustering columns for table %s", table.Name)
+		}
+		if table.Columns != nil {
+			t.Errorf("Did not expect columns for table %s", table.Name)
+		}
+
+		// for the next part of the test after this loop, find the metadata for the test table
+		if table.Name == "test_table_metadata" {
+			testTable = table
+		}
+	}
+
+	// verify actual values on the test tables
+	if testTable == nil {
+		t.Fatal("Expected table metadata for name 'test_table_metadata'")
+	}
+	if testTable.KeyValidator != "org.apache.cassandra.db.marshal.Int32Type" {
+		t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.Int32Type' but was '%s'", testTable.KeyValidator)
+	}
+	if testTable.Comparator != "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)" {
+		t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)' but was '%s'", testTable.Comparator)
+	}
+	if testTable.DefaultValidator != "org.apache.cassandra.db.marshal.BytesType" {
+		t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.BytesType' but was '%s'", testTable.DefaultValidator)
+	}
+	expectedKeyAliases := []string{"first_id"}
+	if !reflect.DeepEqual(testTable.KeyAliases, expectedKeyAliases) {
+		t.Errorf("Expected key aliases %v but was %v", expectedKeyAliases, testTable.KeyAliases)
+	}
+	expectedColumnAliases := []string{"second_id"}
+	if !reflect.DeepEqual(testTable.ColumnAliases, expectedColumnAliases) {
+		t.Errorf("Expected key aliases %v but was %v", expectedColumnAliases, testTable.ColumnAliases)
+	}
+	if testTable.ValueAlias != "" {
+		t.Errorf("Expected value alias '' but was '%s'", testTable.ValueAlias)
+	}
+}
+
+func TestGetColumnMetadata(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE test_column_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	if err := session.Query("CREATE INDEX index_column_metadata ON test_column_metadata ( third_id )").Exec(); err != nil {
+		t.Fatalf("failed to create index with err: %v", err)
+	}
+
+	columns, err := getColumnMetadata(session, "gocql_test")
+	if err != nil {
+		t.Fatalf("failed to query column metadata with err: %v", err)
+	}
+	if columns == nil {
+		t.Fatal("failed to query column metadata, nil returned")
+	}
+
+	testColumns := map[string]*ColumnMetadata{}
+
+	// verify actual values on the test columns
+	for i := range columns {
+		column := &columns[i]
+
+		if column.Name == "" {
+			t.Errorf("Expected column name to be set, but it was empty: index=%d metadata=%+v", i, column)
+		}
+		if column.Table == "" {
+			t.Errorf("Expected column %s table name to be set, but it was empty", column.Name)
+		}
+		if column.Keyspace != "gocql_test" {
+			t.Errorf("Expected column %s keyspace name to be 'gocql_test', but it was '%s'", column.Name, column.Keyspace)
+		}
+		if column.Kind == "" {
+			t.Errorf("Expected column %s kind to be set, but it was empty", column.Name)
+		}
+		if session.cfg.ProtoVersion == 1 && column.Kind != "regular" {
+			t.Errorf("Expected column %s kind to be set to 'regular' for proto V1 but it was '%s'", column.Name, column.Kind)
+		}
+		if column.Validator == "" {
+			t.Errorf("Expected column %s validator to be set, but it was empty", column.Name)
+		}
+
+		// find the test table columns for the next step after this loop
+		if column.Table == "test_column_metadata" {
+			testColumns[column.Name] = column
+		}
+	}
+
+	if *flagProto == 1 {
+		// V1 proto only returns "regular columns"
+		if len(testColumns) != 1 {
+			t.Errorf("Expected 1 test columns but there were %d", len(testColumns))
+		}
+		thirdID, found := testColumns["third_id"]
+		if !found {
+			t.Fatalf("Expected to find column 'third_id' metadata but there was only %v", testColumns)
+		}
+
+		if thirdID.Kind != REGULAR {
+			t.Errorf("Expected %s column kind to be '%s' but it was '%s'", thirdID.Name, REGULAR, thirdID.Kind)
+		}
+
+		if thirdID.Index.Name != "index_column_metadata" {
+			t.Errorf("Expected %s column index name to be 'index_column_metadata' but it was '%s'", thirdID.Name, thirdID.Index.Name)
+		}
+	} else {
+		if len(testColumns) != 3 {
+			t.Errorf("Expected 3 test columns but there were %d", len(testColumns))
+		}
+		firstID, found := testColumns["first_id"]
+		if !found {
+			t.Fatalf("Expected to find column 'first_id' metadata but there was only %v", testColumns)
+		}
+		secondID, found := testColumns["second_id"]
+		if !found {
+			t.Fatalf("Expected to find column 'second_id' metadata but there was only %v", testColumns)
+		}
+		thirdID, found := testColumns["third_id"]
+		if !found {
+			t.Fatalf("Expected to find column 'third_id' metadata but there was only %v", testColumns)
+		}
+
+		if firstID.Kind != PARTITION_KEY {
+			t.Errorf("Expected %s column kind to be '%s' but it was '%s'", firstID.Name, PARTITION_KEY, firstID.Kind)
+		}
+		if secondID.Kind != CLUSTERING_KEY {
+			t.Errorf("Expected %s column kind to be '%s' but it was '%s'", secondID.Name, CLUSTERING_KEY, secondID.Kind)
+		}
+		if thirdID.Kind != REGULAR {
+			t.Errorf("Expected %s column kind to be '%s' but it was '%s'", thirdID.Name, REGULAR, thirdID.Kind)
+		}
+
+		if thirdID.Index.Name != "index_column_metadata" {
+			t.Errorf("Expected %s column index name to be 'index_column_metadata' but it was '%s'", thirdID.Name, thirdID.Index.Name)
+		}
+	}
+}
+
+func TestKeyspaceMetadata(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE test_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	if err := session.Query("CREATE INDEX index_metadata ON test_metadata ( third_id )").Exec(); err != nil {
+		t.Fatalf("failed to create index with err: %v", err)
+	}
+
+	keyspaceMetadata, err := session.KeyspaceMetadata("gocql_test")
+	if err != nil {
+		t.Fatalf("failed to query keyspace metadata with err: %v", err)
+	}
+	if keyspaceMetadata == nil {
+		t.Fatal("expected the keyspace metadata to not be nil, but it was nil")
+	}
+	if keyspaceMetadata.Name != session.cfg.Keyspace {
+		t.Fatalf("Expected the keyspace name to be %s but was %s", session.cfg.Keyspace, keyspaceMetadata.Name)
+	}
+	if len(keyspaceMetadata.Tables) == 0 {
+		t.Errorf("Expected tables but there were none")
+	}
+
+	tableMetadata, found := keyspaceMetadata.Tables["test_metadata"]
+	if !found {
+		t.Fatalf("failed to find the test_metadata table metadata")
+	}
+
+	if len(tableMetadata.PartitionKey) != 1 {
+		t.Errorf("expected partition key length of 1, but was %d", len(tableMetadata.PartitionKey))
+	}
+	for i, column := range tableMetadata.PartitionKey {
+		if column == nil {
+			t.Errorf("partition key column metadata at index %d was nil", i)
+		}
+	}
+	if tableMetadata.PartitionKey[0].Name != "first_id" {
+		t.Errorf("Expected the first partition key column to be 'first_id' but was '%s'", tableMetadata.PartitionKey[0].Name)
+	}
+	if len(tableMetadata.ClusteringColumns) != 1 {
+		t.Fatalf("expected clustering columns length of 1, but was %d", len(tableMetadata.ClusteringColumns))
+	}
+	for i, column := range tableMetadata.ClusteringColumns {
+		if column == nil {
+			t.Fatalf("clustering column metadata at index %d was nil", i)
+		}
+	}
+	if tableMetadata.ClusteringColumns[0].Name != "second_id" {
+		t.Errorf("Expected the first clustering column to be 'second_id' but was '%s'", tableMetadata.ClusteringColumns[0].Name)
+	}
+	thirdColumn, found := tableMetadata.Columns["third_id"]
+	if !found {
+		t.Fatalf("Expected a column definition for 'third_id'")
+	}
+	if thirdColumn.Index.Name != "index_metadata" {
+		t.Errorf("Expected column index named 'index_metadata' but was '%s'", thirdColumn.Index.Name)
+	}
+}

+ 11 - 2
cluster.go

@@ -62,7 +62,7 @@ type ClusterConfig struct {
 	Port             int           // port (default: 9042)
 	Port             int           // port (default: 9042)
 	Keyspace         string        // initial keyspace (optional)
 	Keyspace         string        // initial keyspace (optional)
 	NumConns         int           // number of connections per host (default: 2)
 	NumConns         int           // number of connections per host (default: 2)
-	NumStreams       int           // number of streams per connection (default: 128)
+	NumStreams       int           // number of streams per connection (default: max per protocol, either 128 or 32768)
 	Consistency      Consistency   // default consistency level (default: Quorum)
 	Consistency      Consistency   // default consistency level (default: Quorum)
 	Compressor       Compressor    // compression algorithm (default: nil)
 	Compressor       Compressor    // compression algorithm (default: nil)
 	Authenticator    Authenticator // authenticator (default: nil)
 	Authenticator    Authenticator // authenticator (default: nil)
@@ -85,7 +85,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 		Timeout:          600 * time.Millisecond,
 		Timeout:          600 * time.Millisecond,
 		Port:             9042,
 		Port:             9042,
 		NumConns:         2,
 		NumConns:         2,
-		NumStreams:       128,
 		Consistency:      Quorum,
 		Consistency:      Quorum,
 		ConnPoolType:     NewSimplePool,
 		ConnPoolType:     NewSimplePool,
 		DiscoverHosts:    false,
 		DiscoverHosts:    false,
@@ -102,6 +101,16 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
 	if len(cfg.Hosts) < 1 {
 	if len(cfg.Hosts) < 1 {
 		return nil, ErrNoHosts
 		return nil, ErrNoHosts
 	}
 	}
+
+	maxStreams := 128
+	if cfg.ProtoVersion > protoVersion2 {
+		maxStreams = 32768
+	}
+
+	if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
+		cfg.NumStreams = maxStreams
+	}
+
 	pool := cfg.ConnPoolType(cfg)
 	pool := cfg.ConnPoolType(cfg)
 
 
 	//Adjust the size of the prepared statements cache to match the latest configuration
 	//Adjust the size of the prepared statements cache to match the latest configuration

+ 108 - 53
conn.go

@@ -10,7 +10,9 @@ import (
 	"crypto/x509"
 	"crypto/x509"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"io/ioutil"
+	"log"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -19,9 +21,11 @@ import (
 	"time"
 	"time"
 )
 )
 
 
-const defaultFrameSize = 4096
-const flagResponse = 0x80
-const maskVersion = 0x7F
+const (
+	defaultFrameSize = 4096
+	flagResponse     = 0x80
+	maskVersion      = 0x7F
+)
 
 
 //JoinHostPort is a utility to return a address string that can be used
 //JoinHostPort is a utility to return a address string that can be used
 //gocql.Conn to form a connection with a host.
 //gocql.Conn to form a connection with a host.
@@ -88,7 +92,7 @@ type Conn struct {
 	r       *bufio.Reader
 	r       *bufio.Reader
 	timeout time.Duration
 	timeout time.Duration
 
 
-	uniq  chan uint8
+	uniq  chan int
 	calls []callReq
 	calls []callReq
 	nwait int32
 	nwait int32
 
 
@@ -123,14 +127,17 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 				return nil, errors.New("Failed parsing or appending certs")
 				return nil, errors.New("Failed parsing or appending certs")
 			}
 			}
 		}
 		}
+
 		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
 		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
+
 		config := tls.Config{
 		config := tls.Config{
 			Certificates: []tls.Certificate{mycert},
 			Certificates: []tls.Certificate{mycert},
 			RootCAs:      certPool,
 			RootCAs:      certPool,
 		}
 		}
+
 		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
 		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
 		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
 		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
 			return nil, err
 			return nil, err
@@ -139,16 +146,25 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
-		cfg.NumStreams = 128
-	}
-	if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
+	// going to default to proto 2
+	if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
+		log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
 		cfg.ProtoVersion = 2
 		cfg.ProtoVersion = 2
 	}
 	}
+
+	maxStreams := 128
+	if cfg.ProtoVersion > protoVersion2 {
+		maxStreams = 32768
+	}
+
+	if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
+		cfg.NumStreams = maxStreams
+	}
+
 	c := &Conn{
 	c := &Conn{
 		conn:       conn,
 		conn:       conn,
 		r:          bufio.NewReader(conn),
 		r:          bufio.NewReader(conn),
-		uniq:       make(chan uint8, cfg.NumStreams),
+		uniq:       make(chan int, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
 		timeout:    cfg.Timeout,
 		timeout:    cfg.Timeout,
 		version:    uint8(cfg.ProtoVersion),
 		version:    uint8(cfg.ProtoVersion),
@@ -162,8 +178,8 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		c.setKeepalive(cfg.Keepalive)
 		c.setKeepalive(cfg.Keepalive)
 	}
 	}
 
 
-	for i := 0; i < cap(c.uniq); i++ {
-		c.uniq <- uint8(i)
+	for i := 0; i < cfg.NumStreams; i++ {
+		c.uniq <- i
 	}
 	}
 
 
 	if err := c.startup(&cfg); err != nil {
 	if err := c.startup(&cfg); err != nil {
@@ -254,53 +270,80 @@ func (c *Conn) serve() {
 	c.pool.HandleError(c, err, true)
 	c.pool.HandleError(c, err, true)
 }
 }
 
 
+func (c *Conn) Write(p []byte) (int, error) {
+	c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
+	return c.conn.Write(p)
+}
+
+func (c *Conn) Read(p []byte) (int, error) {
+	return c.r.Read(p)
+}
+
 func (c *Conn) recv() (frame, error) {
 func (c *Conn) recv() (frame, error) {
-	resp := make(frame, headerSize, headerSize+512)
-	c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-	n, last, pinged := 0, 0, false
-	for n < len(resp) {
-		nn, err := c.r.Read(resp[n:])
-		n += nn
-		if err != nil {
-			if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
-				if n > last {
-					// we hit the deadline but we made progress.
-					// simply extend the deadline
-					c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-					last = n
-				} else if n == 0 && !pinged {
-					c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-					if atomic.LoadInt32(&c.nwait) > 0 {
-						go c.ping()
-						pinged = true
-					}
-				} else {
-					return nil, err
-				}
-			} else {
-				return nil, err
-			}
+	size := headerProtoSize[c.version]
+	resp := make(frame, size, size+512)
+
+	// read a full header, ignore timeouts, as this is being ran in a loop
+	c.conn.SetReadDeadline(time.Time{})
+	_, err := io.ReadFull(c.r, resp[:size])
+	if err != nil {
+		return nil, err
+	}
+
+	if v := c.version | flagResponse; resp[0] != v {
+		return nil, NewErrProtocol("recv: response protocol version does not match connection protocol version (%d != %d)", resp[0], v)
+	}
+
+	bodySize := resp.Length(c.version)
+	if bodySize == 0 {
+		return resp, nil
+	}
+	resp.grow(bodySize)
+
+	const maxAttempts = 5
+
+	n := size
+	for i := 0; i < maxAttempts; i++ {
+		var nn int
+		c.conn.SetReadDeadline(time.Now().Add(c.timeout))
+		nn, err = io.ReadFull(c.r, resp[n:size+bodySize])
+		if err == nil {
+			break
 		}
 		}
-		if n == headerSize && len(resp) == headerSize {
-			if resp[0] != c.version|flagResponse {
-				return nil, NewErrProtocol("recv: Response protocol version does not match connection protocol version (%d != %d)", resp[0], c.version|flagResponse)
-			}
-			resp.grow(resp.Length())
+		n += nn
+
+		if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
+			break
 		}
 		}
 	}
 	}
+
+	if err != nil {
+		return nil, err
+	}
+
 	return resp, nil
 	return resp, nil
 }
 }
 
 
 func (c *Conn) execSimple(op operation) (interface{}, error) {
 func (c *Conn) execSimple(op operation) (interface{}, error) {
 	f, err := op.encodeFrame(c.version, nil)
 	f, err := op.encodeFrame(c.version, nil)
-	f.setLength(len(f) - headerSize)
-	if _, err := c.conn.Write([]byte(f)); err != nil {
+	if err != nil {
+		// this should be a noop err
+		return nil, err
+	}
+
+	bodyLen := len(f) - headerProtoSize[c.version]
+	f.setLength(bodyLen, c.version)
+
+	if _, err := c.Write([]byte(f)); err != nil {
 		c.Close()
 		c.Close()
 		return nil, err
 		return nil, err
 	}
 	}
+
+	// here recv wont timeout waiting for a header, should it?
 	if f, err = c.recv(); err != nil {
 	if f, err = c.recv(); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+
 	return c.decodeFrame(f, nil)
 	return c.decodeFrame(f, nil)
 }
 }
 
 
@@ -309,9 +352,12 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+
 	if trace != nil {
 	if trace != nil {
 		req[1] |= flagTrace
 		req[1] |= flagTrace
 	}
 	}
+
+	headerSize := headerProtoSize[c.version]
 	if len(req) > headerSize && c.compressor != nil {
 	if len(req) > headerSize && c.compressor != nil {
 		body, err := c.compressor.Encode([]byte(req[headerSize:]))
 		body, err := c.compressor.Encode([]byte(req[headerSize:]))
 		if err != nil {
 		if err != nil {
@@ -320,16 +366,17 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 		req = append(req[:headerSize], frame(body)...)
 		req = append(req[:headerSize], frame(body)...)
 		req[1] |= flagCompress
 		req[1] |= flagCompress
 	}
 	}
-	req.setLength(len(req) - headerSize)
+	bodyLen := len(req) - headerSize
+	req.setLength(bodyLen, c.version)
 
 
 	id := <-c.uniq
 	id := <-c.uniq
-	req[2] = id
+	req.setStream(id, c.version)
 	call := &c.calls[id]
 	call := &c.calls[id]
 	call.resp = make(chan callResp, 1)
 	call.resp = make(chan callResp, 1)
 	atomic.AddInt32(&c.nwait, 1)
 	atomic.AddInt32(&c.nwait, 1)
 	atomic.StoreInt32(&call.active, 1)
 	atomic.StoreInt32(&call.active, 1)
 
 
-	if _, err := c.conn.Write(req); err != nil {
+	if _, err := c.Write(req); err != nil {
 		c.uniq <- id
 		c.uniq <- id
 		c.Close()
 		c.Close()
 		return nil, err
 		return nil, err
@@ -342,11 +389,12 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 	if reply.err != nil {
 	if reply.err != nil {
 		return nil, reply.err
 		return nil, reply.err
 	}
 	}
+
 	return c.decodeFrame(reply.buf, trace)
 	return c.decodeFrame(reply.buf, trace)
 }
 }
 
 
 func (c *Conn) dispatch(resp frame) {
 func (c *Conn) dispatch(resp frame) {
-	id := int(resp[2])
+	id := resp.Stream(c.version)
 	if id >= len(c.calls) {
 	if id >= len(c.calls) {
 		return
 		return
 	}
 	}
@@ -543,10 +591,10 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 }
 }
 
 
 func (c *Conn) executeBatch(batch *Batch) error {
 func (c *Conn) executeBatch(batch *Batch) error {
-	if c.version == 1 {
+	if c.version == protoVersion1 {
 		return ErrUnsupported
 		return ErrUnsupported
 	}
 	}
-	f := make(frame, headerSize, defaultFrameSize)
+	f := newFrame(c.version)
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.writeByte(byte(batch.Type))
 	f.writeByte(byte(batch.Type))
 	f.writeShort(uint16(len(batch.Entries)))
 	f.writeShort(uint16(len(batch.Entries)))
@@ -594,6 +642,10 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		}
 		}
 	}
 	}
 	f.writeConsistency(batch.Cons)
 	f.writeConsistency(batch.Cons)
+	if c.version >= protoVersion3 {
+		// TODO: add support for flags here
+		f.writeByte(0)
+	}
 
 
 	resp, err := c.exec(f, nil)
 	resp, err := c.exec(f, nil)
 	if err != nil {
 	if err != nil {
@@ -631,12 +683,15 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 			panic(r)
 			panic(r)
 		}
 		}
 	}()
 	}()
+
+	headerSize := headerProtoSize[c.version]
 	if len(f) < headerSize {
 	if len(f) < headerSize {
 		return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
 		return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
 	} else if f[0] != c.version|flagResponse {
 	} else if f[0] != c.version|flagResponse {
 		return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
 		return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
 	}
 	}
-	flags, op, f := f[1], f[3], f[headerSize:]
+
+	flags, op, f := f[1], f.Op(c.version), f[headerSize:]
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
 		if buf, err := c.compressor.Decode([]byte(f)); err != nil {
 		if buf, err := c.compressor.Decode([]byte(f)); err != nil {
 			return nil, err
 			return nil, err
@@ -661,7 +716,7 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 		case resultKindVoid:
 		case resultKindVoid:
 			return resultVoidFrame{}, nil
 			return resultVoidFrame{}, nil
 		case resultKindRows:
 		case resultKindRows:
-			columns, pageState := f.readMetaData()
+			columns, pageState := f.readMetaData(c.version)
 			numRows := f.readInt()
 			numRows := f.readInt()
 			values := make([][]byte, numRows*len(columns))
 			values := make([][]byte, numRows*len(columns))
 			for i := 0; i < len(values); i++ {
 			for i := 0; i < len(values); i++ {
@@ -677,11 +732,11 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 			return resultKeyspaceFrame{keyspace}, nil
 			return resultKeyspaceFrame{keyspace}, nil
 		case resultKindPrepared:
 		case resultKindPrepared:
 			id := f.readShortBytes()
 			id := f.readShortBytes()
-			args, _ := f.readMetaData()
+			args, _ := f.readMetaData(c.version)
 			if c.version < 2 {
 			if c.version < 2 {
 				return resultPreparedFrame{PreparedId: id, Arguments: args}, nil
 				return resultPreparedFrame{PreparedId: id, Arguments: args}, nil
 			}
 			}
-			rvals, _ := f.readMetaData()
+			rvals, _ := f.readMetaData(c.version)
 			return resultPreparedFrame{PreparedId: id, Arguments: args, ReturnValues: rvals}, nil
 			return resultPreparedFrame{PreparedId: id, Arguments: args, ReturnValues: rvals}, nil
 		case resultKindSchemaChanged:
 		case resultKindSchemaChanged:
 			return resultVoidFrame{}, nil
 			return resultVoidFrame{}, nil

+ 287 - 85
conn_test.go

@@ -5,6 +5,7 @@ package gocql
 import (
 import (
 	"crypto/tls"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509"
+	"fmt"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
@@ -15,6 +16,10 @@ import (
 	"time"
 	"time"
 )
 )
 
 
+const (
+	defaultProto = protoVersion2
+)
+
 func TestJoinHostPort(t *testing.T) {
 func TestJoinHostPort(t *testing.T) {
 	tests := map[string]string{
 	tests := map[string]string{
 		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
 		"127.0.0.1:0":                                 JoinHostPort("127.0.0.1", 0),
@@ -29,43 +34,38 @@ func TestJoinHostPort(t *testing.T) {
 	}
 	}
 }
 }
 
 
-type TestServer struct {
-	Address  string
-	t        *testing.T
-	nreq     uint64
-	listen   net.Listener
-	nKillReq uint64
-}
-
 func TestSimple(t *testing.T) {
 func TestSimple(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+	db, err := cluster.CreateSession()
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
+		return
 	}
 	}
 
 
 	if err := db.Query("void").Exec(); err != nil {
 	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
+		t.Errorf("0x%x: %v", defaultProto, err)
 	}
 	}
 }
 }
 
 
 func TestSSLSimple(t *testing.T) {
 func TestSSLSimple(t *testing.T) {
-	srv := NewSSLTestServer(t)
+	srv := NewSSLTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := createTestSslCluster(srv.Address).CreateSession()
+	db, err := createTestSslCluster(srv.Address, defaultProto).CreateSession()
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("0x%x: NewCluster: %v", defaultProto, err)
 	}
 	}
 
 
 	if err := db.Query("void").Exec(); err != nil {
 	if err := db.Query("void").Exec(); err != nil {
-		t.Error(err)
+		t.Fatalf("0x%x: %v", defaultProto, err)
 	}
 	}
 }
 }
 
 
-func createTestSslCluster(hosts string) *ClusterConfig {
+func createTestSslCluster(hosts string, proto uint8) *ClusterConfig {
 	cluster := NewCluster(hosts)
 	cluster := NewCluster(hosts)
 	cluster.SslOpts = &SslOptions{
 	cluster.SslOpts = &SslOptions{
 		CertPath:               "testdata/pki/gocql.crt",
 		CertPath:               "testdata/pki/gocql.crt",
@@ -73,82 +73,103 @@ func createTestSslCluster(hosts string) *ClusterConfig {
 		CaPath:                 "testdata/pki/ca.crt",
 		CaPath:                 "testdata/pki/ca.crt",
 		EnableHostVerification: false,
 		EnableHostVerification: false,
 	}
 	}
+	cluster.ProtoVersion = int(proto)
 	return cluster
 	return cluster
 }
 }
 
 
 func TestClosed(t *testing.T) {
 func TestClosed(t *testing.T) {
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
 	t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis")
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	session, err := NewCluster(srv.Address).CreateSession()
+	cluster := NewCluster(srv.Address)
+	cluster.ProtoVersion = int(defaultProto)
+
+	session, err := cluster.CreateSession()
+	defer session.Close()
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Errorf("0x%x: NewCluster: %v", defaultProto, err)
+		return
 	}
 	}
-	session.Close()
 
 
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
 	if err := session.Query("void").Exec(); err != ErrSessionClosed {
-		t.Errorf("expected %#v, got %#v", ErrSessionClosed, err)
+		t.Errorf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err)
+		return
 	}
 	}
 }
 }
 
 
+func newTestSession(addr string, proto uint8) (*Session, error) {
+	cluster := NewCluster(addr)
+	cluster.ProtoVersion = int(proto)
+	return cluster.CreateSession()
+}
+
 func TestTimeout(t *testing.T) {
 func TestTimeout(t *testing.T) {
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 	if err != nil {
 		t.Errorf("NewCluster: %v", err)
 		t.Errorf("NewCluster: %v", err)
+		return
 	}
 	}
+	defer db.Close()
 
 
 	go func() {
 	go func() {
 		<-time.After(2 * time.Second)
 		<-time.After(2 * time.Second)
-		t.Fatal("no timeout")
+		t.Errorf("no timeout")
 	}()
 	}()
 
 
 	if err := db.Query("kill").Exec(); err == nil {
 	if err := db.Query("kill").Exec(); err == nil {
-		t.Fatal("expected error")
+		t.Errorf("expected error")
 	}
 	}
 }
 }
 
 
 // TestQueryRetry will test to make sure that gocql will execute
 // TestQueryRetry will test to make sure that gocql will execute
 // the exact amount of retry queries designated by the user.
 // the exact amount of retry queries designated by the user.
 func TestQueryRetry(t *testing.T) {
 func TestQueryRetry(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("NewCluster: %v", err)
 	}
 	}
+	defer db.Close()
 
 
 	go func() {
 	go func() {
 		<-time.After(5 * time.Second)
 		<-time.After(5 * time.Second)
-		t.Fatal("no timeout")
+		t.Fatalf("no timeout")
 	}()
 	}()
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 	rt := &SimpleRetryPolicy{NumRetries: 1}
 
 
 	qry := db.Query("kill").RetryPolicy(rt)
 	qry := db.Query("kill").RetryPolicy(rt)
 	if err := qry.Exec(); err == nil {
 	if err := qry.Exec(); err == nil {
-		t.Fatal("expected error")
+		t.Fatalf("expected error")
 	}
 	}
-	requests := srv.nKillReq
-	if requests != uint64(qry.Attempts()) {
-		t.Fatalf("expected requests %v to match query attemps %v", requests, qry.Attempts())
+
+	requests := atomic.LoadInt64(&srv.nKillReq)
+	attempts := qry.Attempts()
+	if requests != int64(attempts) {
+		t.Fatalf("expected requests %v to match query attemps %v", requests, attempts)
 	}
 	}
+
 	//Minus 1 from the requests variable since there is the initial query attempt
 	//Minus 1 from the requests variable since there is the initial query attempt
-	if requests-1 != uint64(rt.NumRetries) {
+	if requests-1 != int64(rt.NumRetries) {
 		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
 		t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1)
 	}
 	}
 }
 }
 
 
 func TestSlowQuery(t *testing.T) {
 func TestSlowQuery(t *testing.T) {
-	srv := NewTestServer(t)
+	srv := NewTestServer(t, defaultProto)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
-	db, err := NewCluster(srv.Address).CreateSession()
+	db, err := newTestSession(srv.Address, defaultProto)
 	if err != nil {
 	if err != nil {
 		t.Errorf("NewCluster: %v", err)
 		t.Errorf("NewCluster: %v", err)
+		return
 	}
 	}
 
 
 	if err := db.Query("slow").Exec(); err != nil {
 	if err := db.Query("slow").Exec(); err != nil {
@@ -159,22 +180,24 @@ func TestSlowQuery(t *testing.T) {
 func TestRoundRobin(t *testing.T) {
 func TestRoundRobin(t *testing.T) {
 	servers := make([]*TestServer, 5)
 	servers := make([]*TestServer, 5)
 	addrs := make([]string, len(servers))
 	addrs := make([]string, len(servers))
-	for i := 0; i < len(servers); i++ {
-		servers[i] = NewTestServer(t)
-		addrs[i] = servers[i].Address
-		defer servers[i].Stop()
+	for n := 0; n < len(servers); n++ {
+		servers[n] = NewTestServer(t, defaultProto)
+		addrs[n] = servers[n].Address
+		defer servers[n].Stop()
 	}
 	}
 	cluster := NewCluster(addrs...)
 	cluster := NewCluster(addrs...)
+	cluster.ProtoVersion = defaultProto
+
 	db, err := cluster.CreateSession()
 	db, err := cluster.CreateSession()
-	time.Sleep(1 * time.Second) //Sleep to allow the Cluster.fillPool to complete
+	time.Sleep(1 * time.Second) // Sleep to allow the Cluster.fillPool to complete
 
 
 	if err != nil {
 	if err != nil {
-		t.Errorf("NewCluster: %v", err)
+		t.Fatalf("NewCluster: %v", err)
 	}
 	}
 
 
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	wg.Add(5)
 	wg.Add(5)
-	for i := 0; i < 5; i++ {
+	for n := 0; n < 5; n++ {
 		go func() {
 		go func() {
 			for j := 0; j < 5; j++ {
 			for j := 0; j < 5; j++ {
 				if err := db.Query("void").Exec(); err != nil {
 				if err := db.Query("void").Exec(); err != nil {
@@ -187,12 +210,12 @@ func TestRoundRobin(t *testing.T) {
 	wg.Wait()
 	wg.Wait()
 
 
 	diff := 0
 	diff := 0
-	for i := 1; i < len(servers); i++ {
+	for n := 1; n < len(servers); n++ {
 		d := 0
 		d := 0
-		if servers[i].nreq > servers[i-1].nreq {
-			d = int(servers[i].nreq - servers[i-1].nreq)
+		if servers[n].nreq > servers[n-1].nreq {
+			d = int(servers[n].nreq - servers[n-1].nreq)
 		} else {
 		} else {
-			d = int(servers[i-1].nreq - servers[i].nreq)
+			d = int(servers[n-1].nreq - servers[n].nreq)
 		}
 		}
 		if d > diff {
 		if d > diff {
 			diff = d
 			diff = d
@@ -206,7 +229,8 @@ func TestRoundRobin(t *testing.T) {
 
 
 func TestConnClosing(t *testing.T) {
 func TestConnClosing(t *testing.T) {
 	t.Skip("Skipping until test can be ran reliably")
 	t.Skip("Skipping until test can be ran reliably")
-	srv := NewTestServer(t)
+
+	srv := NewTestServer(t, protoVersion2)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	db, err := NewCluster(srv.Address).CreateSession()
 	db, err := NewCluster(srv.Address).CreateSession()
@@ -238,21 +262,147 @@ func TestConnClosing(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func NewTestServer(t *testing.T) *TestServer {
+func TestStreams_Protocol1(t *testing.T) {
+	srv := NewTestServer(t, protoVersion1)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 1
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	var wg sync.WaitGroup
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			if err := db.Query("void").Exec(); err != nil {
+				t.Error(err)
+			}
+		}()
+	}
+	wg.Wait()
+}
+
+func TestStreams_Protocol2(t *testing.T) {
+	srv := NewTestServer(t, protoVersion2)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 2
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// the test server processes each conn synchronously
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		if err = db.Query("void").Exec(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func TestStreams_Protocol3(t *testing.T) {
+	srv := NewTestServer(t, protoVersion3)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 3
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	for i := 0; i < db.cfg.NumStreams; i++ {
+		// the test server processes each conn synchronously
+		// here were just validating that if we send NumStream request we get
+		// a response for every stream and the lengths for the queries are set
+		// correctly.
+		if err = db.Query("void").Exec(); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func BenchmarkProtocolV3(b *testing.B) {
+	srv := NewTestServer(b, protoVersion3)
+	defer srv.Stop()
+
+	// TODO: these are more like session tests and should instead operate
+	// on a single Conn
+	cluster := NewCluster(srv.Address)
+	cluster.NumConns = 1
+	cluster.ProtoVersion = 3
+
+	db, err := cluster.CreateSession()
+	if err != nil {
+		b.Fatal(err)
+	}
+	defer db.Close()
+
+	b.ResetTimer()
+	b.ReportAllocs()
+	for i := 0; i < b.N; i++ {
+		if err = db.Query("void").Exec(); err != nil {
+			b.Fatal(err)
+		}
+	}
+}
+
+func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
+
 	listen, err := net.ListenTCP("tcp", laddr)
 	listen, err := net.ListenTCP("tcp", laddr)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
+
+	headerSize := 8
+	if protocol > protoVersion2 {
+		headerSize = 9
+	}
+
+	srv := &TestServer{
+		Address:    listen.Addr().String(),
+		listen:     listen,
+		t:          t,
+		protocol:   protocol,
+		headerSize: headerSize,
+	}
+
 	go srv.serve()
 	go srv.serve()
+
 	return srv
 	return srv
 }
 }
 
 
-func NewSSLTestServer(t *testing.T) *TestServer {
+func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	pem, err := ioutil.ReadFile("testdata/pki/ca.crt")
 	certPool := x509.NewCertPool()
 	certPool := x509.NewCertPool()
 	if !certPool.AppendCertsFromPEM(pem) {
 	if !certPool.AppendCertsFromPEM(pem) {
@@ -270,11 +420,34 @@ func NewSSLTestServer(t *testing.T) *TestServer {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	srv := &TestServer{Address: listen.Addr().String(), listen: listen, t: t}
+
+	headerSize := 8
+	if protocol > protoVersion2 {
+		headerSize = 9
+	}
+
+	srv := &TestServer{
+		Address:    listen.Addr().String(),
+		listen:     listen,
+		t:          t,
+		protocol:   protocol,
+		headerSize: headerSize,
+	}
 	go srv.serve()
 	go srv.serve()
 	return srv
 	return srv
 }
 }
 
 
+type TestServer struct {
+	Address  string
+	t        testing.TB
+	nreq     uint64
+	listen   net.Listener
+	nKillReq int64
+
+	protocol   uint8
+	headerSize int
+}
+
 func (srv *TestServer) serve() {
 func (srv *TestServer) serve() {
 	defer srv.listen.Close()
 	defer srv.listen.Close()
 	for {
 	for {
@@ -285,9 +458,16 @@ func (srv *TestServer) serve() {
 		go func(conn net.Conn) {
 		go func(conn net.Conn) {
 			defer conn.Close()
 			defer conn.Close()
 			for {
 			for {
-				frame := srv.readFrame(conn)
+				frame, err := srv.readFrame(conn)
+				if err == io.EOF {
+					return
+				} else if err != nil {
+					srv.t.Error(err)
+					continue
+				}
+
 				atomic.AddUint64(&srv.nreq, 1)
 				atomic.AddUint64(&srv.nreq, 1)
-				srv.process(frame, conn)
+				go srv.process(frame, conn)
 			}
 			}
 		}(conn)
 		}(conn)
 	}
 	}
@@ -297,65 +477,87 @@ func (srv *TestServer) Stop() {
 	srv.listen.Close()
 	srv.listen.Close()
 }
 }
 
 
-func (srv *TestServer) process(frame frame, conn net.Conn) {
-	switch frame[3] {
+func (srv *TestServer) process(f frame, conn net.Conn) {
+	headerSize := headerProtoSize[srv.protocol]
+	stream := f.Stream(srv.protocol)
+
+	switch f.Op(srv.protocol) {
 	case opStartup:
 	case opStartup:
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opReady)
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opReady)
+	case opOptions:
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opSupported)
+		f.writeShort(0)
 	case opQuery:
 	case opQuery:
-		input := frame
-		input.skipHeader()
+		input := f
+		input.skipHeader(srv.protocol)
 		query := strings.TrimSpace(input.readLongString())
 		query := strings.TrimSpace(input.readLongString())
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opResult)
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opResult)
 		first := query
 		first := query
 		if n := strings.Index(query, " "); n > 0 {
 		if n := strings.Index(query, " "); n > 0 {
 			first = first[:n]
 			first = first[:n]
 		}
 		}
 		switch strings.ToLower(first) {
 		switch strings.ToLower(first) {
 		case "kill":
 		case "kill":
-			atomic.AddUint64(&srv.nKillReq, 1)
-			select {}
+			atomic.AddInt64(&srv.nKillReq, 1)
+			f = f[:headerSize]
+			f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+			f.writeInt(0x1001)
+			f.writeString("query killed")
 		case "slow":
 		case "slow":
 			go func() {
 			go func() {
 				<-time.After(1 * time.Second)
 				<-time.After(1 * time.Second)
-				frame.writeInt(resultKindVoid)
-				frame.setLength(len(frame) - headerSize)
-				if _, err := conn.Write(frame); err != nil {
+				f.writeInt(resultKindVoid)
+				f.setLength(len(f)-headerSize, srv.protocol)
+				if _, err := conn.Write(f); err != nil {
 					return
 					return
 				}
 				}
 			}()
 			}()
 			return
 			return
 		case "use":
 		case "use":
-			frame.writeInt(3)
-			frame.writeString(strings.TrimSpace(query[3:]))
+			f.writeInt(3)
+			f.writeString(strings.TrimSpace(query[3:]))
 		case "void":
 		case "void":
-			frame.writeInt(resultKindVoid)
+			f.writeInt(resultKindVoid)
 		default:
 		default:
-			frame.writeInt(resultKindVoid)
+			f.writeInt(resultKindVoid)
 		}
 		}
 	default:
 	default:
-		frame = frame[:headerSize]
-		frame.setHeader(protoResponse, 0, frame[2], opError)
-		frame.writeInt(0)
-		frame.writeString("not supported")
+		f = f[:headerSize]
+		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+		f.writeInt(0)
+		f.writeString("not supported")
 	}
 	}
-	frame.setLength(len(frame) - headerSize)
-	if _, err := conn.Write(frame); err != nil {
+
+	f.setLength(len(f)-headerSize, srv.protocol)
+	if _, err := conn.Write(f); err != nil {
+		srv.t.Log(err)
 		return
 		return
 	}
 	}
 }
 }
 
 
-func (srv *TestServer) readFrame(conn net.Conn) frame {
-	frame := make(frame, headerSize, headerSize+512)
+func (srv *TestServer) readFrame(conn net.Conn) (frame, error) {
+	frame := make(frame, srv.headerSize, srv.headerSize+512)
 	if _, err := io.ReadFull(conn, frame); err != nil {
 	if _, err := io.ReadFull(conn, frame); err != nil {
-		srv.t.Fatal(err)
+		return nil, err
+	}
+
+	// should be a request frame
+	if frame[0]&protoDirectionMask != 0 {
+		return nil, fmt.Errorf("expected to read a request frame got version: 0x%x", frame[0])
 	}
 	}
-	if n := frame.Length(); n > 0 {
+	if v := frame[0] & protoVersionMask; v != srv.protocol {
+		return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, v)
+	}
+
+	if n := frame.Length(srv.protocol); n > 0 {
 		frame.grow(n)
 		frame.grow(n)
-		if _, err := io.ReadFull(conn, frame[headerSize:]); err != nil {
-			srv.t.Fatal(err)
+		if _, err := io.ReadFull(conn, frame[srv.headerSize:]); err != nil {
+			return nil, err
 		}
 		}
 	}
 	}
-	return frame
+
+	return frame, nil
 }
 }

+ 4 - 1
connectionpool.go

@@ -297,7 +297,10 @@ func (c *SimplePool) HandleError(conn *Conn, err error, closed bool) {
 		return
 		return
 	}
 	}
 	c.removeConn(conn)
 	c.removeConn(conn)
-	if !c.quit {
+	c.mu.Lock()
+	poolClosed := c.quit
+	c.mu.Unlock()
+	if !poolClosed {
 		go c.fillPool() // top off pool.
 		go c.fillPool() // top off pool.
 	}
 	}
 }
 }

+ 149 - 33
frame.go

@@ -5,12 +5,16 @@
 package gocql
 package gocql
 
 
 import (
 import (
+	"fmt"
 	"net"
 	"net"
 )
 )
 
 
 const (
 const (
-	protoRequest  byte = 0x02
-	protoResponse byte = 0x82
+	protoDirectionMask = 0x80
+	protoVersionMask   = 0x7F
+	protoVersion1      = 0x01
+	protoVersion2      = 0x02
+	protoVersion3      = 0x03
 
 
 	opError         byte = 0x00
 	opError         byte = 0x00
 	opStartup       byte = 0x01
 	opStartup       byte = 0x01
@@ -42,13 +46,26 @@ const (
 	flagPageState   uint8 = 8
 	flagPageState   uint8 = 8
 	flagHasMore     uint8 = 2
 	flagHasMore     uint8 = 2
 
 
-	headerSize = 8
-
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 )
 )
 
 
+var headerProtoSize = [...]int{
+	protoVersion1: 8,
+	protoVersion2: 8,
+	protoVersion3: 9,
+}
+
+// TODO: replace with a struct which has a header and a body buffer,
+// header just has methods like, set/get the options in its backing array
+// then in a writeTo we write the header then the body.
 type frame []byte
 type frame []byte
 
 
+func newFrame(version uint8) frame {
+	// TODO: pool these at the session level incase anyone is using different
+	// clusters with different versions in the same application.
+	return make(frame, headerProtoSize[version], defaultFrameSize)
+}
+
 func (f *frame) writeInt(v int32) {
 func (f *frame) writeInt(v int32) {
 	p := f.grow(4)
 	p := f.grow(4)
 	(*f)[p] = byte(v >> 24)
 	(*f)[p] = byte(v >> 24)
@@ -129,22 +146,67 @@ func (f *frame) writeStringMultimap(v map[string][]string) {
 	}
 	}
 }
 }
 
 
-func (f *frame) setHeader(version, flags, stream, opcode uint8) {
+func (f *frame) setHeader(version, flags uint8, stream int, opcode uint8) {
 	(*f)[0] = version
 	(*f)[0] = version
 	(*f)[1] = flags
 	(*f)[1] = flags
-	(*f)[2] = stream
-	(*f)[3] = opcode
+	p := 2
+	if version&maskVersion > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+		p += 2
+	} else {
+		(*f)[2] = byte(stream & 0xFF)
+		p++
+	}
+
+	(*f)[p] = opcode
 }
 }
 
 
-func (f *frame) setLength(length int) {
-	(*f)[4] = byte(length >> 24)
-	(*f)[5] = byte(length >> 16)
-	(*f)[6] = byte(length >> 8)
-	(*f)[7] = byte(length)
+func (f *frame) setStream(stream int, version uint8) {
+	if version > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+	} else {
+		(*f)[2] = byte(stream)
+	}
 }
 }
 
 
-func (f *frame) Length() int {
-	return int((*f)[4])<<24 | int((*f)[5])<<16 | int((*f)[6])<<8 | int((*f)[7])
+func (f *frame) Stream(version uint8) (n int) {
+	if version > protoVersion2 {
+		n = int((*f)[2])<<8 | int((*f)[3])
+	} else {
+		n = int((*f)[2])
+	}
+	return
+}
+
+func (f *frame) setLength(length int, version uint8) {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	(*f)[p] = byte(length >> 24)
+	(*f)[p+1] = byte(length >> 16)
+	(*f)[p+2] = byte(length >> 8)
+	(*f)[p+3] = byte(length)
+}
+
+func (f *frame) Op(version uint8) byte {
+	if version > protoVersion2 {
+		return (*f)[4]
+	} else {
+		return (*f)[3]
+	}
+}
+
+func (f *frame) Length(version uint8) int {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	return int((*f)[p])<<24 | int((*f)[p+1])<<16 | int((*f)[p+2])<<8 | int((*f)[p+3])
 }
 }
 
 
 func (f *frame) grow(n int) int {
 func (f *frame) grow(n int) int {
@@ -158,13 +220,13 @@ func (f *frame) grow(n int) int {
 	return p
 	return p
 }
 }
 
 
-func (f *frame) skipHeader() {
-	*f = (*f)[headerSize:]
+func (f *frame) skipHeader(version uint8) {
+	*f = (*f)[headerProtoSize[version]:]
 }
 }
 
 
 func (f *frame) readInt() int {
 func (f *frame) readInt() int {
 	if len(*f) < 4 {
 	if len(*f) < 4 {
-		panic(NewErrProtocol("Trying to read an int while >4 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read an int while <4 bytes in the buffer"))
 	}
 	}
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	*f = (*f)[4:]
 	*f = (*f)[4:]
@@ -173,7 +235,7 @@ func (f *frame) readInt() int {
 
 
 func (f *frame) readShort() uint16 {
 func (f *frame) readShort() uint16 {
 	if len(*f) < 2 {
 	if len(*f) < 2 {
-		panic(NewErrProtocol("Trying to read a short while >2 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read a short while <2 bytes in the buffer"))
 	}
 	}
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	*f = (*f)[2:]
 	*f = (*f)[2:]
@@ -223,9 +285,12 @@ func (f *frame) readShortBytes() []byte {
 	return v
 	return v
 }
 }
 
 
-func (f *frame) readTypeInfo() *TypeInfo {
+func (f *frame) readTypeInfo(version uint8) *TypeInfo {
 	x := f.readShort()
 	x := f.readShort()
-	typ := &TypeInfo{Type: Type(x)}
+	typ := &TypeInfo{
+		Proto: version,
+		Type:  Type(x),
+	}
 	switch typ.Type {
 	switch typ.Type {
 	case TypeCustom:
 	case TypeCustom:
 		typ.Custom = f.readString()
 		typ.Custom = f.readString()
@@ -233,34 +298,37 @@ func (f *frame) readTypeInfo() *TypeInfo {
 			typ = &TypeInfo{Type: cassType}
 			typ = &TypeInfo{Type: cassType}
 			switch typ.Type {
 			switch typ.Type {
 			case TypeMap:
 			case TypeMap:
-				typ.Key = f.readTypeInfo()
+				typ.Key = f.readTypeInfo(version)
 				fallthrough
 				fallthrough
 			case TypeList, TypeSet:
 			case TypeList, TypeSet:
-				typ.Elem = f.readTypeInfo()
+				typ.Elem = f.readTypeInfo(version)
 			}
 			}
 		}
 		}
 	case TypeMap:
 	case TypeMap:
-		typ.Key = f.readTypeInfo()
+		typ.Key = f.readTypeInfo(version)
 		fallthrough
 		fallthrough
 	case TypeList, TypeSet:
 	case TypeList, TypeSet:
-		typ.Elem = f.readTypeInfo()
+		typ.Elem = f.readTypeInfo(version)
 	}
 	}
 	return typ
 	return typ
 }
 }
 
 
-func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
+func (f *frame) readMetaData(version uint8) ([]ColumnInfo, []byte) {
 	flags := f.readInt()
 	flags := f.readInt()
 	numColumns := f.readInt()
 	numColumns := f.readInt()
+
 	var pageState []byte
 	var pageState []byte
 	if flags&2 != 0 {
 	if flags&2 != 0 {
 		pageState = f.readBytes()
 		pageState = f.readBytes()
 	}
 	}
+
 	globalKeyspace := ""
 	globalKeyspace := ""
 	globalTable := ""
 	globalTable := ""
 	if flags&1 != 0 {
 	if flags&1 != 0 {
 		globalKeyspace = f.readString()
 		globalKeyspace = f.readString()
 		globalTable = f.readString()
 		globalTable = f.readString()
 	}
 	}
+
 	columns := make([]ColumnInfo, numColumns)
 	columns := make([]ColumnInfo, numColumns)
 	for i := 0; i < numColumns; i++ {
 	for i := 0; i < numColumns; i++ {
 		columns[i].Keyspace = globalKeyspace
 		columns[i].Keyspace = globalKeyspace
@@ -270,7 +338,7 @@ func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
 			columns[i].Table = f.readString()
 			columns[i].Table = f.readString()
 		}
 		}
 		columns[i].Name = f.readString()
 		columns[i].Name = f.readString()
-		columns[i].TypeInfo = f.readTypeInfo()
+		columns[i].TypeInfo = f.readTypeInfo(version)
 	}
 	}
 	return columns, pageState
 	return columns, pageState
 }
 }
@@ -381,19 +449,32 @@ type startupFrame struct {
 	Compression string
 	Compression string
 }
 }
 
 
+func (op *startupFrame) String() string {
+	return fmt.Sprintf("[startup cqlversion=%q compression=%q]", op.CQLVersion, op.Compression)
+}
+
 func (op *startupFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *startupFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
+
 	f.setHeader(version, 0, 0, opStartup)
 	f.setHeader(version, 0, 0, opStartup)
-	f.writeShort(1)
+
+	// TODO: fix this, this is actually a StringMap
+	var size uint16 = 1
+	if op.Compression != "" {
+		size++
+	}
+
+	f.writeShort(size)
 	f.writeString("CQL_VERSION")
 	f.writeString("CQL_VERSION")
 	f.writeString(op.CQLVersion)
 	f.writeString(op.CQLVersion)
+
 	if op.Compression != "" {
 	if op.Compression != "" {
-		f[headerSize+1] += 1
 		f.writeString("COMPRESSION")
 		f.writeString("COMPRESSION")
 		f.writeString(op.Compression)
 		f.writeString(op.Compression)
 	}
 	}
+
 	return f, nil
 	return f, nil
 }
 }
 
 
@@ -406,14 +487,20 @@ type queryFrame struct {
 	PageState []byte
 	PageState []byte
 }
 }
 
 
+func (op *queryFrame) String() string {
+	return fmt.Sprintf("[query statement=%q prepared=%x cons=%v ...]", op.Stmt, op.Prepared, op.Cons)
+}
+
 func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if version == 1 && (op.PageSize != 0 || len(op.PageState) > 0 ||
 	if version == 1 && (op.PageSize != 0 || len(op.PageState) > 0 ||
 		(len(op.Values) > 0 && len(op.Prepared) == 0)) {
 		(len(op.Values) > 0 && len(op.Prepared) == 0)) {
 		return nil, ErrUnsupported
 		return nil, ErrUnsupported
 	}
 	}
+
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
+
 	if len(op.Prepared) > 0 {
 	if len(op.Prepared) > 0 {
 		f.setHeader(version, 0, 0, opExecute)
 		f.setHeader(version, 0, 0, opExecute)
 		f.writeShortBytes(op.Prepared)
 		f.writeShortBytes(op.Prepared)
@@ -421,10 +508,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		f.setHeader(version, 0, 0, opQuery)
 		f.setHeader(version, 0, 0, opQuery)
 		f.writeLongString(op.Stmt)
 		f.writeLongString(op.Stmt)
 	}
 	}
+
 	if version >= 2 {
 	if version >= 2 {
 		f.writeConsistency(op.Cons)
 		f.writeConsistency(op.Cons)
 		flagPos := len(f)
 		flagPos := len(f)
 		f.writeByte(0)
 		f.writeByte(0)
+
 		if len(op.Values) > 0 {
 		if len(op.Values) > 0 {
 			f[flagPos] |= flagQueryValues
 			f[flagPos] |= flagQueryValues
 			f.writeShort(uint16(len(op.Values)))
 			f.writeShort(uint16(len(op.Values)))
@@ -432,10 +521,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 				f.writeBytes(value)
 				f.writeBytes(value)
 			}
 			}
 		}
 		}
+
 		if op.PageSize > 0 {
 		if op.PageSize > 0 {
 			f[flagPos] |= flagPageSize
 			f[flagPos] |= flagPageSize
 			f.writeInt(int32(op.PageSize))
 			f.writeInt(int32(op.PageSize))
 		}
 		}
+
 		if len(op.PageState) > 0 {
 		if len(op.PageState) > 0 {
 			f[flagPos] |= flagPageState
 			f[flagPos] |= flagPageState
 			f.writeBytes(op.PageState)
 			f.writeBytes(op.PageState)
@@ -449,6 +540,7 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		}
 		}
 		f.writeConsistency(op.Cons)
 		f.writeConsistency(op.Cons)
 	}
 	}
+
 	return f, nil
 	return f, nil
 }
 }
 
 
@@ -456,9 +548,13 @@ type prepareFrame struct {
 	Stmt string
 	Stmt string
 }
 }
 
 
+func (op *prepareFrame) String() string {
+	return fmt.Sprintf("[prepare statement=%q]", op.Stmt)
+}
+
 func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opPrepare)
 	f.setHeader(version, 0, 0, opPrepare)
 	f.writeLongString(op.Stmt)
 	f.writeLongString(op.Stmt)
@@ -467,9 +563,13 @@ func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 
 
 type optionsFrame struct{}
 type optionsFrame struct{}
 
 
+func (op *optionsFrame) String() string {
+	return "[options]"
+}
+
 func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opOptions)
 	f.setHeader(version, 0, 0, opOptions)
 	return f, nil
 	return f, nil
@@ -479,13 +579,21 @@ type authenticateFrame struct {
 	Authenticator string
 	Authenticator string
 }
 }
 
 
+func (op *authenticateFrame) String() string {
+	return fmt.Sprintf("[authenticate authenticator=%q]", op.Authenticator)
+}
+
 type authResponseFrame struct {
 type authResponseFrame struct {
 	Data []byte
 	Data []byte
 }
 }
 
 
+func (op *authResponseFrame) String() string {
+	return fmt.Sprintf("[auth_response data=%q]", op.Data)
+}
+
 func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opAuthResponse)
 	f.setHeader(version, 0, 0, opAuthResponse)
 	f.writeBytes(op.Data)
 	f.writeBytes(op.Data)
@@ -496,6 +604,14 @@ type authSuccessFrame struct {
 	Data []byte
 	Data []byte
 }
 }
 
 
+func (op *authSuccessFrame) String() string {
+	return fmt.Sprintf("[auth_success data=%q]", op.Data)
+}
+
 type authChallengeFrame struct {
 type authChallengeFrame struct {
 	Data []byte
 	Data []byte
 }
 }
+
+func (op *authChallengeFrame) String() string {
+	return fmt.Sprintf("[auth_challenge data=%q]", op.Data)
+}

+ 3 - 3
helpers.go

@@ -82,7 +82,7 @@ func getApacheCassandraType(class string) Type {
 			return TypeFloat
 			return TypeFloat
 		case "Int32Type":
 		case "Int32Type":
 			return TypeInt
 			return TypeInt
-		case "DateType":
+		case "DateType", "TimestampType":
 			return TypeTimestamp
 			return TypeTimestamp
 		case "UUIDType":
 		case "UUIDType":
 			return TypeUUID
 			return TypeUUID
@@ -97,9 +97,9 @@ func getApacheCassandraType(class string) Type {
 		case "MapType":
 		case "MapType":
 			return TypeMap
 			return TypeMap
 		case "ListType":
 		case "ListType":
-			return TypeInet
+			return TypeList
 		case "SetType":
 		case "SetType":
-			return TypeInet
+			return TypeSet
 		}
 		}
 	}
 	}
 	return TypeCustom
 	return TypeCustom

+ 2 - 0
integration.sh

@@ -19,6 +19,8 @@ function run_tests() {
 	local proto=2
 	local proto=2
 	if [[ $version == 1.2.* ]]; then
 	if [[ $version == 1.2.* ]]; then
 		proto=1
 		proto=1
+	elif [[ $version == 2.1.* ]]; then
+		proto=3
 	fi
 	fi
 
 
 	go test -timeout 5m -tags integration -cover -v -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms ./... | tee results.txt
 	go test -timeout 5m -tags integration -cover -v -runssl -proto=$proto -rf=3 -cluster=$(ccm liveset) -clusterSize=$clusterSize -autowait=2000ms ./... | tee results.txt

+ 62 - 31
marshal.go

@@ -40,6 +40,9 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 	if value == nil {
 	if value == nil {
 		return nil, nil
 		return nil, nil
 	}
 	}
+	if info.Proto < protoVersion1 {
+		panic("protocol version not set")
+	}
 
 
 	if v, ok := value.(Marshaler); ok {
 	if v, ok := value.(Marshaler); ok {
 		return v.MarshalCQL(info)
 		return v.MarshalCQL(info)
@@ -814,6 +817,28 @@ func unmarshalTimestamp(info *TypeInfo, data []byte, value interface{}) error {
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }
 }
 
 
+func writeCollectionSize(info *TypeInfo, n int, buf *bytes.Buffer) error {
+	if info.Proto > protoVersion2 {
+		if n > math.MaxInt32 {
+			return marshalErrorf("marshal: collection too large")
+		}
+
+		buf.WriteByte(byte(n >> 24))
+		buf.WriteByte(byte(n >> 16))
+		buf.WriteByte(byte(n >> 8))
+		buf.WriteByte(byte(n))
+	} else {
+		if n > math.MaxUint16 {
+			return marshalErrorf("marshal: collection too large")
+		}
+
+		buf.WriteByte(byte(n >> 8))
+		buf.WriteByte(byte(n))
+	}
+
+	return nil
+}
+
 func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 	rv := reflect.ValueOf(value)
 	rv := reflect.ValueOf(value)
 	t := rv.Type()
 	t := rv.Type()
@@ -825,21 +850,19 @@ func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 		}
 		}
 		buf := &bytes.Buffer{}
 		buf := &bytes.Buffer{}
 		n := rv.Len()
 		n := rv.Len()
-		if n > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array too large")
+
+		if err := writeCollectionSize(info, n, buf); err != nil {
+			return nil, err
 		}
 		}
-		buf.WriteByte(byte(n >> 8))
-		buf.WriteByte(byte(n))
+
 		for i := 0; i < n; i++ {
 		for i := 0; i < n; i++ {
 			item, err := Marshal(info.Elem, rv.Index(i).Interface())
 			item, err := Marshal(info.Elem, rv.Index(i).Interface())
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
-			if len(item) > math.MaxUint16 {
-				return nil, marshalErrorf("marshal: slice / array item too large")
+			if err := writeCollectionSize(info, len(item), buf); err != nil {
+				return nil, err
 			}
 			}
-			buf.WriteByte(byte(len(item) >> 8))
-			buf.WriteByte(byte(len(item)))
 			buf.Write(item)
 			buf.Write(item)
 		}
 		}
 		return buf.Bytes(), nil
 		return buf.Bytes(), nil
@@ -858,6 +881,17 @@ func marshalList(info *TypeInfo, value interface{}) ([]byte, error) {
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
 
 
+func readCollectionSize(info *TypeInfo, data []byte) (size, read int) {
+	if info.Proto > protoVersion2 {
+		size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
+		read = 4
+	} else {
+		size = int(data[0])<<8 | int(data[1])
+		read = 2
+	}
+	return
+}
+
 func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 	rv := reflect.ValueOf(value)
 	rv := reflect.ValueOf(value)
 	if rv.Kind() != reflect.Ptr {
 	if rv.Kind() != reflect.Ptr {
@@ -879,8 +913,8 @@ func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 		if len(data) < 2 {
 		if len(data) < 2 {
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 		}
 		}
-		n := int(data[0])<<8 | int(data[1])
-		data = data[2:]
+		n, p := readCollectionSize(info, data)
+		data = data[p:]
 		if k == reflect.Array {
 		if k == reflect.Array {
 			if rv.Len() != n {
 			if rv.Len() != n {
 				return unmarshalErrorf("unmarshal list: array with wrong size")
 				return unmarshalErrorf("unmarshal list: array with wrong size")
@@ -894,8 +928,8 @@ func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 			if len(data) < 2 {
 			if len(data) < 2 {
 				return unmarshalErrorf("unmarshal list: unexpected eof")
 				return unmarshalErrorf("unmarshal list: unexpected eof")
 			}
 			}
-			m := int(data[0])<<8 | int(data[1])
-			data = data[2:]
+			m, p := readCollectionSize(info, data)
+			data = data[p:]
 			if err := Unmarshal(info.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 			if err := Unmarshal(info.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 				return err
 				return err
 			}
 			}
@@ -917,33 +951,29 @@ func marshalMap(info *TypeInfo, value interface{}) ([]byte, error) {
 	}
 	}
 	buf := &bytes.Buffer{}
 	buf := &bytes.Buffer{}
 	n := rv.Len()
 	n := rv.Len()
-	if n > math.MaxUint16 {
-		return nil, marshalErrorf("marshal: map too large")
+
+	if err := writeCollectionSize(info, n, buf); err != nil {
+		return nil, err
 	}
 	}
-	buf.WriteByte(byte(n >> 8))
-	buf.WriteByte(byte(n))
+
 	keys := rv.MapKeys()
 	keys := rv.MapKeys()
 	for _, key := range keys {
 	for _, key := range keys {
 		item, err := Marshal(info.Key, key.Interface())
 		item, err := Marshal(info.Key, key.Interface())
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		if len(item) > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array item too large")
+		if err := writeCollectionSize(info, len(item), buf); err != nil {
+			return nil, err
 		}
 		}
-		buf.WriteByte(byte(len(item) >> 8))
-		buf.WriteByte(byte(len(item)))
 		buf.Write(item)
 		buf.Write(item)
 
 
 		item, err = Marshal(info.Elem, rv.MapIndex(key).Interface())
 		item, err = Marshal(info.Elem, rv.MapIndex(key).Interface())
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		if len(item) > math.MaxUint16 {
-			return nil, marshalErrorf("marshal: slice / array item too large")
+		if err := writeCollectionSize(info, len(item), buf); err != nil {
+			return nil, err
 		}
 		}
-		buf.WriteByte(byte(len(item) >> 8))
-		buf.WriteByte(byte(len(item)))
 		buf.Write(item)
 		buf.Write(item)
 	}
 	}
 	return buf.Bytes(), nil
 	return buf.Bytes(), nil
@@ -967,22 +997,22 @@ func unmarshalMap(info *TypeInfo, data []byte, value interface{}) error {
 	if len(data) < 2 {
 	if len(data) < 2 {
 		return unmarshalErrorf("unmarshal map: unexpected eof")
 		return unmarshalErrorf("unmarshal map: unexpected eof")
 	}
 	}
-	n := int(data[1]) | int(data[0])<<8
-	data = data[2:]
+	n, p := readCollectionSize(info, data)
+	data = data[p:]
 	for i := 0; i < n; i++ {
 	for i := 0; i < n; i++ {
 		if len(data) < 2 {
 		if len(data) < 2 {
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 			return unmarshalErrorf("unmarshal list: unexpected eof")
 		}
 		}
-		m := int(data[1]) | int(data[0])<<8
-		data = data[2:]
+		m, p := readCollectionSize(info, data)
+		data = data[p:]
 		key := reflect.New(t.Key())
 		key := reflect.New(t.Key())
 		if err := Unmarshal(info.Key, data[:m], key.Interface()); err != nil {
 		if err := Unmarshal(info.Key, data[:m], key.Interface()); err != nil {
 			return err
 			return err
 		}
 		}
 		data = data[m:]
 		data = data[m:]
 
 
-		m = int(data[1]) | int(data[0])<<8
-		data = data[2:]
+		m, p = readCollectionSize(info, data)
+		data = data[p:]
 		val := reflect.New(t.Elem())
 		val := reflect.New(t.Elem())
 		if err := Unmarshal(info.Elem, data[:m], val.Interface()); err != nil {
 		if err := Unmarshal(info.Elem, data[:m], val.Interface()); err != nil {
 			return err
 			return err
@@ -1120,10 +1150,11 @@ func unmarshalInet(info *TypeInfo, data []byte, value interface{}) error {
 
 
 // TypeInfo describes a Cassandra specific data type.
 // TypeInfo describes a Cassandra specific data type.
 type TypeInfo struct {
 type TypeInfo struct {
+	Proto  byte // version of the protocol
 	Type   Type
 	Type   Type
 	Key    *TypeInfo // only used for TypeMap
 	Key    *TypeInfo // only used for TypeMap
 	Elem   *TypeInfo // only used for TypeMap, TypeList and TypeSet
 	Elem   *TypeInfo // only used for TypeMap, TypeList and TypeSet
-	Custom string    // only used for TypeCostum
+	Custom string    // only used for TypeCustom
 }
 }
 
 
 // String returns a human readable name for the Cassandra datatype
 // String returns a human readable name for the Cassandra datatype

+ 100 - 99
marshal_test.go

@@ -21,42 +21,42 @@ var marshalTests = []struct {
 	Value interface{}
 	Value interface{}
 }{
 }{
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		[]byte("hello world"),
 		[]byte("hello world"),
 		[]byte("hello world"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		[]byte("hello world"),
 		"hello world",
 		"hello world",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte(nil),
 		[]byte(nil),
 		[]byte(nil),
 		[]byte(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("hello world"),
 		[]byte("hello world"),
 		MyString("hello world"),
 		MyString("hello world"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("HELLO WORLD"),
 		[]byte("HELLO WORLD"),
 		CustomString("hello world"),
 		CustomString("hello world"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBlob},
+		&TypeInfo{Proto: 2, Type: TypeBlob},
 		[]byte("hello\x00"),
 		[]byte("hello\x00"),
 		[]byte("hello\x00"),
 		[]byte("hello\x00"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBlob},
+		&TypeInfo{Proto: 2, Type: TypeBlob},
 		[]byte(nil),
 		[]byte(nil),
 		[]byte(nil),
 		[]byte(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		func() UUID {
 		func() UUID {
 			x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0})
 			x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0})
@@ -64,217 +64,217 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00"),
 		0,
 		0,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x01\x02\x03\x04"),
 		[]byte("\x01\x02\x03\x04"),
 		int(16909060),
 		int(16909060),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x80\x00\x00\x00"),
 		[]byte("\x80\x00\x00\x00"),
 		int32(math.MinInt32),
 		int32(math.MinInt32),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff"),
 		int32(math.MaxInt32),
 		int32(math.MaxInt32),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00"),
 		"0",
 		"0",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x01\x02\x03\x04"),
 		[]byte("\x01\x02\x03\x04"),
 		"16909060",
 		"16909060",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x80\x00\x00\x00"),
 		[]byte("\x80\x00\x00\x00"),
 		"-2147483648", // math.MinInt32
 		"-2147483648", // math.MinInt32
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff"),
 		"2147483647", // math.MaxInt32
 		"2147483647", // math.MaxInt32
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		0,
 		0,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		72623859790382856,
 		72623859790382856,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		int64(math.MinInt64),
 		int64(math.MinInt64),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		int64(math.MaxInt64),
 		int64(math.MaxInt64),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
 		"0",
 		"0",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
 		"72623859790382856",
 		"72623859790382856",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
 		"-9223372036854775808", // math.MinInt64
 		"-9223372036854775808", // math.MinInt64
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBigInt},
+		&TypeInfo{Proto: 2, Type: TypeBigInt},
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
 		"9223372036854775807", // math.MaxInt64
 		"9223372036854775807", // math.MaxInt64
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x00"),
 		[]byte("\x00"),
 		false,
 		false,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x01"),
 		[]byte("\x01"),
 		true,
 		true,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte("\x40\x49\x0f\xdb"),
 		[]byte("\x40\x49\x0f\xdb"),
 		float32(3.14159265),
 		float32(3.14159265),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		float64(3.14159265),
 		float64(3.14159265),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x00\x00"),
 		[]byte("\x00\x00\x00\x00\x00"),
 		inf.NewDec(0, 0),
 		inf.NewDec(0, 0),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x00\x64"),
 		[]byte("\x00\x00\x00\x00\x64"),
 		inf.NewDec(100, 0),
 		inf.NewDec(100, 0),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x02\x19"),
 		[]byte("\x00\x00\x00\x02\x19"),
 		decimalize("0.25"),
 		decimalize("0.25"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x13\xD5\a;\x20\x14\xA2\x91"),
 		[]byte("\x00\x00\x00\x13\xD5\a;\x20\x14\xA2\x91"),
 		decimalize("-0.0012095473475870063"), // From the iconara/cql-rb test suite
 		decimalize("-0.0012095473475870063"), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x13*\xF8\xC4\xDF\xEB]o"),
 		[]byte("\x00\x00\x00\x13*\xF8\xC4\xDF\xEB]o"),
 		decimalize("0.0012095473475870063"), // From the iconara/cql-rb test suite
 		decimalize("0.0012095473475870063"), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x12\xF2\xD8\x02\xB6R\x7F\x99\xEE\x98#\x99\xA9V"),
 		[]byte("\x00\x00\x00\x12\xF2\xD8\x02\xB6R\x7F\x99\xEE\x98#\x99\xA9V"),
 		decimalize("-1042342234234.123423435647768234"), // From the iconara/cql-rb test suite
 		decimalize("-1042342234234.123423435647768234"), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\r\nJ\x04\"^\x91\x04\x8a\xb1\x18\xfe"),
 		[]byte("\x00\x00\x00\r\nJ\x04\"^\x91\x04\x8a\xb1\x18\xfe"),
 		decimalize("1243878957943.1234124191998"), // From the datastax/python-driver test suite
 		decimalize("1243878957943.1234124191998"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x06\xe5\xde]\x98Y"),
 		[]byte("\x00\x00\x00\x06\xe5\xde]\x98Y"),
 		decimalize("-112233.441191"), // From the datastax/python-driver test suite
 		decimalize("-112233.441191"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x14\x00\xfa\xce"),
 		[]byte("\x00\x00\x00\x14\x00\xfa\xce"),
 		decimalize("0.00000000000000064206"), // From the datastax/python-driver test suite
 		decimalize("0.00000000000000064206"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\x00\x00\x00\x14\xff\x052"),
 		[]byte("\x00\x00\x00\x14\xff\x052"),
 		decimalize("-0.00000000000000064206"), // From the datastax/python-driver test suite
 		decimalize("-0.00000000000000064206"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDecimal},
+		&TypeInfo{Proto: 2, Type: TypeDecimal},
 		[]byte("\xff\xff\xff\x9c\x00\xfa\xce"),
 		[]byte("\xff\xff\xff\x9c\x00\xfa\xce"),
 		inf.NewDec(64206, -100), // From the datastax/python-driver test suite
 		inf.NewDec(64206, -100), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC),
 		time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		int64(1376387523000),
 		int64(1376387523000),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]int{1, 2},
 		[]int{1, 2},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[2]int{1, 2},
 		[2]int{1, 2},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeSet, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeSet, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]int{1, 2},
 		[]int{1, 2},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeSet, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeSet, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte(nil),
 		[]byte(nil),
 		[]int(nil),
 		[]int(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		},
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		map[string]int{"foo": 1},
 		map[string]int{"foo": 1},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		},
 		[]byte(nil),
 		[]byte(nil),
 		map[string]int(nil),
 		map[string]int(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeVarchar}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeVarchar}},
 		bytes.Join([][]byte{
 		bytes.Join([][]byte{
 			[]byte("\x00\x01\xFF\xFF"),
 			[]byte("\x00\x01\xFF\xFF"),
 			bytes.Repeat([]byte("X"), 65535)}, []byte("")),
 			bytes.Repeat([]byte("X"), 65535)}, []byte("")),
 		[]string{strings.Repeat("X", 65535)},
 		[]string{strings.Repeat("X", 65535)},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeVarchar},
 		},
 		},
 		bytes.Join([][]byte{
 		bytes.Join([][]byte{
 			[]byte("\x00\x01\xFF\xFF"),
 			[]byte("\x00\x01\xFF\xFF"),
@@ -286,82 +286,82 @@ var marshalTests = []struct {
 		},
 		},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x00"),
 		[]byte("\x00"),
 		0,
 		0,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x37\xE2\x3C\xEC"),
 		[]byte("\x37\xE2\x3C\xEC"),
 		int32(937573612),
 		int32(937573612),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x37\xE2\x3C\xEC"),
 		[]byte("\x37\xE2\x3C\xEC"),
 		big.NewInt(937573612),
 		big.NewInt(937573612),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\x03\x9EV \x15\f\x03\x9DK\x18\xCDI\\$?\a["),
 		[]byte("\x03\x9EV \x15\f\x03\x9DK\x18\xCDI\\$?\a["),
 		bigintize("1231312312331283012830129382342342412123"), // From the iconara/cql-rb test suite
 		bigintize("1231312312331283012830129382342342412123"), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("\xC9v\x8D:\x86"),
 		[]byte("\xC9v\x8D:\x86"),
 		big.NewInt(-234234234234), // From the iconara/cql-rb test suite
 		big.NewInt(-234234234234), // From the iconara/cql-rb test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarint},
+		&TypeInfo{Proto: 2, Type: TypeVarint},
 		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
 		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
 		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
 		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		[]byte("\x7F\x00\x00\x01"),
 		net.ParseIP("127.0.0.1").To4(),
 		net.ParseIP("127.0.0.1").To4(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xFF\xFF\xFF\xFF"),
 		[]byte("\xFF\xFF\xFF\xFF"),
 		net.ParseIP("255.255.255.255").To4(),
 		net.ParseIP("255.255.255.255").To4(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		[]byte("\x7F\x00\x00\x01"),
 		"127.0.0.1",
 		"127.0.0.1",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xFF\xFF\xFF\xFF"),
 		[]byte("\xFF\xFF\xFF\xFF"),
 		"255.255.255.255",
 		"255.255.255.255",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		"21da:d3:0:2f3b:2aa:ff:fe28:9c5a",
 		"21da:d3:0:2f3b:2aa:ff:fe28:9c5a",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		"fe80::202:b3ff:fe1e:8329",
 		"fe80::202:b3ff:fe1e:8329",
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
 		net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"),
 		net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
 		net.ParseIP("fe80::202:b3ff:fe1e:8329"),
 		net.ParseIP("fe80::202:b3ff:fe1e:8329"),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte(nil),
 		[]byte(nil),
 		nil,
 		nil,
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte("nullable string"),
 		[]byte("nullable string"),
 		func() *string {
 		func() *string {
 			value := "nullable string"
 			value := "nullable string"
@@ -369,12 +369,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeVarchar},
+		&TypeInfo{Proto: 2, Type: TypeVarchar},
 		[]byte{},
 		[]byte{},
 		(*string)(nil),
 		(*string)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte("\x7f\xff\xff\xff"),
 		[]byte("\x7f\xff\xff\xff"),
 		func() *int {
 		func() *int {
 			var value int = math.MaxInt32
 			var value int = math.MaxInt32
@@ -382,22 +382,22 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeInt},
 		[]byte(nil),
 		[]byte(nil),
 		(*int)(nil),
 		(*int)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		&UUID{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 		&UUID{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimeUUID},
+		&TypeInfo{Proto: 2, Type: TypeTimeUUID},
 		[]byte{},
 		[]byte{},
 		(*UUID)(nil),
 		(*UUID)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
 		func() *time.Time {
 		func() *time.Time {
 			t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC)
 			t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC)
@@ -405,12 +405,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeTimestamp},
+		&TypeInfo{Proto: 2, Type: TypeTimestamp},
 		[]byte(nil),
 		[]byte(nil),
 		(*time.Time)(nil),
 		(*time.Time)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x00"),
 		[]byte("\x00"),
 		func() *bool {
 		func() *bool {
 			b := false
 			b := false
@@ -418,7 +418,7 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte("\x01"),
 		[]byte("\x01"),
 		func() *bool {
 		func() *bool {
 			b := true
 			b := true
@@ -426,12 +426,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeBoolean},
+		&TypeInfo{Proto: 2, Type: TypeBoolean},
 		[]byte(nil),
 		[]byte(nil),
 		(*bool)(nil),
 		(*bool)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte("\x40\x49\x0f\xdb"),
 		[]byte("\x40\x49\x0f\xdb"),
 		func() *float32 {
 		func() *float32 {
 			f := float32(3.14159265)
 			f := float32(3.14159265)
@@ -439,12 +439,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeFloat},
+		&TypeInfo{Proto: 2, Type: TypeFloat},
 		[]byte(nil),
 		[]byte(nil),
 		(*float32)(nil),
 		(*float32)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
 		func() *float64 {
 		func() *float64 {
 			d := float64(3.14159265)
 			d := float64(3.14159265)
@@ -452,12 +452,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeDouble},
+		&TypeInfo{Proto: 2, Type: TypeDouble},
 		[]byte(nil),
 		[]byte(nil),
 		(*float64)(nil),
 		(*float64)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte("\x7F\x00\x00\x01"),
 		[]byte("\x7F\x00\x00\x01"),
 		func() *net.IP {
 		func() *net.IP {
 			ip := net.ParseIP("127.0.0.1").To4()
 			ip := net.ParseIP("127.0.0.1").To4()
@@ -465,12 +465,12 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeInet},
+		&TypeInfo{Proto: 2, Type: TypeInet},
 		[]byte(nil),
 		[]byte(nil),
 		(*net.IP)(nil),
 		(*net.IP)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		[]byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"),
 		func() *[]int {
 		func() *[]int {
 			l := []int{1, 2}
 			l := []int{1, 2}
@@ -478,14 +478,14 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeList, Elem: &TypeInfo{Type: TypeInt}},
+		&TypeInfo{Proto: 2, Type: TypeList, Elem: &TypeInfo{Proto: 2, Type: TypeInt}},
 		[]byte(nil),
 		[]byte(nil),
 		(*[]int)(nil),
 		(*[]int)(nil),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		},
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		[]byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"),
 		func() *map[string]int {
 		func() *map[string]int {
@@ -494,9 +494,9 @@ var marshalTests = []struct {
 		}(),
 		}(),
 	},
 	},
 	{
 	{
-		&TypeInfo{Type: TypeMap,
-			Key:  &TypeInfo{Type: TypeVarchar},
-			Elem: &TypeInfo{Type: TypeInt},
+		&TypeInfo{Proto: 2, Type: TypeMap,
+			Key:  &TypeInfo{Proto: 2, Type: TypeVarchar},
+			Elem: &TypeInfo{Proto: 2, Type: TypeInt},
 		},
 		},
 		[]byte(nil),
 		[]byte(nil),
 		(*map[string]int)(nil),
 		(*map[string]int)(nil),
@@ -610,7 +610,7 @@ func TestMarshalVarint(t *testing.T) {
 	}
 	}
 
 
 	for i, test := range varintTests {
 	for i, test := range varintTests {
-		data, err := Marshal(&TypeInfo{Type: TypeVarint}, test.Value)
+		data, err := Marshal(&TypeInfo{Proto: 2, Type: TypeVarint}, test.Value)
 		if err != nil {
 		if err != nil {
 			t.Errorf("error marshaling varint: %v (test #%d)", err, i)
 			t.Errorf("error marshaling varint: %v (test #%d)", err, i)
 		}
 		}
@@ -620,7 +620,7 @@ func TestMarshalVarint(t *testing.T) {
 		}
 		}
 
 
 		binder := new(big.Int)
 		binder := new(big.Int)
-		err = Unmarshal(&TypeInfo{Type: TypeVarint}, test.Marshaled, binder)
+		err = Unmarshal(&TypeInfo{Proto: 2, Type: TypeVarint}, test.Marshaled, binder)
 		if err != nil {
 		if err != nil {
 			t.Errorf("error unmarshaling varint: %v (test #%d)", err, i)
 			t.Errorf("error unmarshaling varint: %v (test #%d)", err, i)
 		}
 		}
@@ -659,14 +659,15 @@ var typeLookupTest = []struct {
 	{"FloatType", TypeFloat},
 	{"FloatType", TypeFloat},
 	{"Int32Type", TypeInt},
 	{"Int32Type", TypeInt},
 	{"DateType", TypeTimestamp},
 	{"DateType", TypeTimestamp},
+	{"TimestampType", TypeTimestamp},
 	{"UUIDType", TypeUUID},
 	{"UUIDType", TypeUUID},
 	{"UTF8Type", TypeVarchar},
 	{"UTF8Type", TypeVarchar},
 	{"IntegerType", TypeVarint},
 	{"IntegerType", TypeVarint},
 	{"TimeUUIDType", TypeTimeUUID},
 	{"TimeUUIDType", TypeTimeUUID},
 	{"InetAddressType", TypeInet},
 	{"InetAddressType", TypeInet},
 	{"MapType", TypeMap},
 	{"MapType", TypeMap},
-	{"ListType", TypeInet},
-	{"SetType", TypeInet},
+	{"ListType", TypeList},
+	{"SetType", TypeSet},
 	{"unknown", TypeCustom},
 	{"unknown", TypeCustom},
 }
 }
 
 

+ 844 - 0
metadata.go

@@ -0,0 +1,844 @@
+// Copyright (c) 2015 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gocql
+
+import (
+	"encoding/hex"
+	"encoding/json"
+	"fmt"
+	"log"
+	"strconv"
+	"strings"
+	"sync"
+)
+
+// schema metadata for a keyspace
+type KeyspaceMetadata struct {
+	Name            string
+	DurableWrites   bool
+	StrategyClass   string
+	StrategyOptions map[string]interface{}
+	Tables          map[string]*TableMetadata
+}
+
+// schema metadata for a table (a.k.a. column family)
+type TableMetadata struct {
+	Keyspace          string
+	Name              string
+	KeyValidator      string
+	Comparator        string
+	DefaultValidator  string
+	KeyAliases        []string
+	ColumnAliases     []string
+	ValueAlias        string
+	PartitionKey      []*ColumnMetadata
+	ClusteringColumns []*ColumnMetadata
+	Columns           map[string]*ColumnMetadata
+}
+
+// schema metadata for a column
+type ColumnMetadata struct {
+	Keyspace       string
+	Table          string
+	Name           string
+	ComponentIndex int
+	Kind           string
+	Validator      string
+	Type           TypeInfo
+	Order          ColumnOrder
+	Index          ColumnIndexMetadata
+}
+
+// the ordering of the column with regard to its comparator
+type ColumnOrder bool
+
+const (
+	ASC  ColumnOrder = false
+	DESC             = true
+)
+
+type ColumnIndexMetadata struct {
+	Name    string
+	Type    string
+	Options map[string]interface{}
+}
+
+// Column kind values
+const (
+	PARTITION_KEY  = "partition_key"
+	CLUSTERING_KEY = "clustering_key"
+	REGULAR        = "regular"
+)
+
+// default alias values
+const (
+	DEFAULT_KEY_ALIAS    = "key"
+	DEFAULT_COLUMN_ALIAS = "column"
+	DEFAULT_VALUE_ALIAS  = "value"
+)
+
+// queries the cluster for schema information for a specific keyspace
+type schemaDescriber struct {
+	session *Session
+	mu      sync.Mutex
+
+	cache map[string]*KeyspaceMetadata
+}
+
+func newSchemaDescriber(session *Session) *schemaDescriber {
+	return &schemaDescriber{
+		session: session,
+		cache:   map[string]*KeyspaceMetadata{},
+	}
+}
+
+func (s *schemaDescriber) getSchema(keyspaceName string) (*KeyspaceMetadata, error) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	// TODO handle schema change events
+
+	metadata, found := s.cache[keyspaceName]
+	if !found {
+		// refresh the cache for this keyspace
+		err := s.refreshSchema(keyspaceName)
+		if err != nil {
+			return nil, err
+		}
+
+		metadata = s.cache[keyspaceName]
+	}
+
+	return metadata, nil
+}
+
+func (s *schemaDescriber) refreshSchema(keyspaceName string) error {
+	var err error
+
+	// query the system keyspace for schema data
+	// TODO retrieve concurrently
+	keyspace, err := getKeyspaceMetadata(s.session, keyspaceName)
+	if err != nil {
+		return err
+	}
+	tables, err := getTableMetadata(s.session, keyspaceName)
+	if err != nil {
+		return err
+	}
+	columns, err := getColumnMetadata(s.session, keyspaceName)
+	if err != nil {
+		return err
+	}
+
+	// organize the schema data
+	compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns)
+
+	// update the cache
+	s.cache[keyspaceName] = keyspace
+
+	return nil
+}
+
+// "compiles" keyspace, table, and column metadata for a keyspace together
+// linking the metadata objects together and calculating the partition key
+// and clustering key.
+func compileMetadata(
+	protoVersion int,
+	keyspace *KeyspaceMetadata,
+	tables []TableMetadata,
+	columns []ColumnMetadata,
+) {
+	keyspace.Tables = make(map[string]*TableMetadata)
+	for i := range tables {
+		tables[i].Columns = make(map[string]*ColumnMetadata)
+
+		keyspace.Tables[tables[i].Name] = &tables[i]
+	}
+
+	// add columns from the schema data
+	for i := range columns {
+		// decode the validator for TypeInfo and order
+		validatorParsed := parseType(columns[i].Validator)
+		columns[i].Type = validatorParsed.types[0]
+		columns[i].Order = ASC
+		if validatorParsed.reversed[0] {
+			columns[i].Order = DESC
+		}
+
+		table := keyspace.Tables[columns[i].Table]
+		table.Columns[columns[i].Name] = &columns[i]
+	}
+
+	if protoVersion == 1 {
+		compileV1Metadata(tables)
+	} else {
+		compileV2Metadata(tables)
+	}
+}
+
+// V1 protocol does not return as much column metadata as V2+ so determining
+// PartitionKey and ClusterColumns is more complex
+func compileV1Metadata(tables []TableMetadata) {
+	for i := range tables {
+		table := &tables[i]
+
+		// decode the key validator
+		keyValidatorParsed := parseType(table.KeyValidator)
+		// decode the comparator
+		comparatorParsed := parseType(table.Comparator)
+
+		// the partition key length is the same as the number of types in the
+		// key validator
+		table.PartitionKey = make([]*ColumnMetadata, len(keyValidatorParsed.types))
+
+		// V1 protocol only returns "regular" columns from
+		// system.schema_columns (there is no type field for columns)
+		// so the alias information is used to
+		// create the partition key and clustering columns
+
+		// construct the partition key from the alias
+		for i := range table.PartitionKey {
+			var alias string
+			if len(table.KeyAliases) > i {
+				alias = table.KeyAliases[i]
+			} else if i == 0 {
+				alias = DEFAULT_KEY_ALIAS
+			} else {
+				alias = DEFAULT_KEY_ALIAS + strconv.Itoa(i+1)
+			}
+
+			column := &ColumnMetadata{
+				Keyspace:       table.Keyspace,
+				Table:          table.Name,
+				Name:           alias,
+				Type:           keyValidatorParsed.types[i],
+				Kind:           PARTITION_KEY,
+				ComponentIndex: i,
+			}
+
+			table.PartitionKey[i] = column
+			table.Columns[alias] = column
+		}
+
+		// determine the number of clustering columns
+		size := len(comparatorParsed.types)
+		if comparatorParsed.isComposite {
+			if len(comparatorParsed.collections) != 0 ||
+				(len(table.ColumnAliases) == size-1 &&
+					comparatorParsed.types[size-1].Type == TypeVarchar) {
+				size = size - 1
+			}
+		} else {
+			if !(len(table.ColumnAliases) != 0 || len(table.Columns) == 0) {
+				size = 0
+			}
+		}
+
+		table.ClusteringColumns = make([]*ColumnMetadata, size)
+
+		for i := range table.ClusteringColumns {
+			var alias string
+			if len(table.ColumnAliases) > i {
+				alias = table.ColumnAliases[i]
+			} else if i == 0 {
+				alias = DEFAULT_COLUMN_ALIAS
+			} else {
+				alias = DEFAULT_COLUMN_ALIAS + strconv.Itoa(i+1)
+			}
+
+			order := ASC
+			if comparatorParsed.reversed[i] {
+				order = DESC
+			}
+
+			column := &ColumnMetadata{
+				Keyspace:       table.Keyspace,
+				Table:          table.Name,
+				Name:           alias,
+				Type:           comparatorParsed.types[i],
+				Order:          order,
+				Kind:           CLUSTERING_KEY,
+				ComponentIndex: i,
+			}
+
+			table.ClusteringColumns[i] = column
+			table.Columns[alias] = column
+		}
+
+		if size != len(comparatorParsed.types)-1 {
+			alias := DEFAULT_VALUE_ALIAS
+			if len(table.ValueAlias) > 0 {
+				alias = table.ValueAlias
+			}
+			// decode the default validator
+			defaultValidatorParsed := parseType(table.DefaultValidator)
+			column := &ColumnMetadata{
+				Keyspace: table.Keyspace,
+				Table:    table.Name,
+				Name:     alias,
+				Type:     defaultValidatorParsed.types[0],
+				Kind:     REGULAR,
+			}
+			table.Columns[alias] = column
+		}
+	}
+}
+
+// The simpler compile case for V2+ protocol
+func compileV2Metadata(tables []TableMetadata) {
+	for i := range tables {
+		table := &tables[i]
+
+		partitionColumnCount := countColumnsOfKind(table.Columns, PARTITION_KEY)
+		table.PartitionKey = make([]*ColumnMetadata, partitionColumnCount)
+
+		clusteringColumnCount := countColumnsOfKind(table.Columns, CLUSTERING_KEY)
+		table.ClusteringColumns = make([]*ColumnMetadata, clusteringColumnCount)
+
+		for _, column := range table.Columns {
+			if column.Kind == PARTITION_KEY {
+				table.PartitionKey[column.ComponentIndex] = column
+			} else if column.Kind == CLUSTERING_KEY {
+				table.ClusteringColumns[column.ComponentIndex] = column
+			}
+		}
+
+	}
+}
+
+func countColumnsOfKind(columns map[string]*ColumnMetadata, kind string) int {
+	count := 0
+	for _, column := range columns {
+		if column.Kind == kind {
+			count++
+		}
+	}
+	return count
+}
+
+// query only for the keyspace metadata for the specified keyspace
+func getKeyspaceMetadata(
+	session *Session,
+	keyspaceName string,
+) (*KeyspaceMetadata, error) {
+	query := session.Query(
+		`
+		SELECT durable_writes, strategy_class, strategy_options
+		FROM system.schema_keyspaces
+		WHERE keyspace_name = ?
+		`,
+		keyspaceName,
+	)
+
+	keyspace := &KeyspaceMetadata{Name: keyspaceName}
+	var strategyOptionsJSON []byte
+
+	err := query.Scan(
+		&keyspace.DurableWrites,
+		&keyspace.StrategyClass,
+		&strategyOptionsJSON,
+	)
+	if err != nil {
+		return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
+	}
+
+	err = json.Unmarshal(strategyOptionsJSON, &keyspace.StrategyOptions)
+	if err != nil {
+		return nil, fmt.Errorf(
+			"Invalid JSON value '%s' as strategy_options for in keyspace '%s': %v",
+			strategyOptionsJSON, keyspace.Name, err,
+		)
+	}
+
+	return keyspace, nil
+}
+
+// query for only the table metadata in the specified keyspace
+func getTableMetadata(
+	session *Session,
+	keyspaceName string,
+) ([]TableMetadata, error) {
+	query := session.Query(
+		`
+		SELECT
+			columnfamily_name,
+			key_validator,
+			comparator,
+			default_validator,
+			key_aliases,
+			column_aliases,
+			value_alias
+		FROM system.schema_columnfamilies
+		WHERE keyspace_name = ?
+		`,
+		keyspaceName,
+	)
+	iter := query.Iter()
+
+	tables := []TableMetadata{}
+	table := TableMetadata{Keyspace: keyspaceName}
+
+	var keyAliasesJSON []byte
+	var columnAliasesJSON []byte
+	for iter.Scan(
+		&table.Name,
+		&table.KeyValidator,
+		&table.Comparator,
+		&table.DefaultValidator,
+		&keyAliasesJSON,
+		&columnAliasesJSON,
+		&table.ValueAlias,
+	) {
+		var err error
+
+		// decode the key aliases
+		if keyAliasesJSON != nil {
+			table.KeyAliases = []string{}
+			err = json.Unmarshal(keyAliasesJSON, &table.KeyAliases)
+			if err != nil {
+				iter.Close()
+				return nil, fmt.Errorf(
+					"Invalid JSON value '%s' as key_aliases for in table '%s': %v",
+					keyAliasesJSON, table.Name, err,
+				)
+			}
+		}
+
+		// decode the column aliases
+		if columnAliasesJSON != nil {
+			table.ColumnAliases = []string{}
+			err = json.Unmarshal(columnAliasesJSON, &table.ColumnAliases)
+			if err != nil {
+				iter.Close()
+				return nil, fmt.Errorf(
+					"Invalid JSON value '%s' as column_aliases for in table '%s': %v",
+					columnAliasesJSON, table.Name, err,
+				)
+			}
+		}
+
+		tables = append(tables, table)
+		table = TableMetadata{Keyspace: keyspaceName}
+	}
+
+	err := iter.Close()
+	if err != nil && err != ErrNotFound {
+		return nil, fmt.Errorf("Error querying table schema: %v", err)
+	}
+
+	return tables, nil
+}
+
+// query for only the table metadata in the specified keyspace
+func getColumnMetadata(
+	session *Session,
+	keyspaceName string,
+) ([]ColumnMetadata, error) {
+	// Deal with differences in protocol versions
+	var stmt string
+	var scan func(*Iter, *ColumnMetadata, *[]byte) bool
+	if session.cfg.ProtoVersion == 1 {
+		// V1 does not support the type column, and all returned rows are
+		// of kind "regular".
+		stmt = `
+			SELECT
+				columnfamily_name,
+				column_name,
+				component_index,
+				validator,
+				index_name,
+				index_type,
+				index_options
+			FROM system.schema_columns
+			WHERE keyspace_name = ?
+			`
+		scan = func(
+			iter *Iter,
+			column *ColumnMetadata,
+			indexOptionsJSON *[]byte,
+		) bool {
+			// all columns returned by V1 are regular
+			column.Kind = REGULAR
+			return iter.Scan(
+				&column.Table,
+				&column.Name,
+				&column.ComponentIndex,
+				&column.Validator,
+				&column.Index.Name,
+				&column.Index.Type,
+				&indexOptionsJSON,
+			)
+		}
+	} else {
+		// V2+ supports the type column
+		stmt = `
+			SELECT
+				columnfamily_name,
+				column_name,
+				component_index,
+				validator,
+				index_name,
+				index_type,
+				index_options,
+				type
+			FROM system.schema_columns
+			WHERE keyspace_name = ?
+			`
+		scan = func(
+			iter *Iter,
+			column *ColumnMetadata,
+			indexOptionsJSON *[]byte,
+		) bool {
+			return iter.Scan(
+				&column.Table,
+				&column.Name,
+				&column.ComponentIndex,
+				&column.Validator,
+				&column.Index.Name,
+				&column.Index.Type,
+				&indexOptionsJSON,
+				&column.Kind,
+			)
+		}
+	}
+
+	// get the columns metadata
+	columns := []ColumnMetadata{}
+	column := ColumnMetadata{Keyspace: keyspaceName}
+
+	var indexOptionsJSON []byte
+
+	query := session.Query(stmt, keyspaceName)
+	iter := query.Iter()
+
+	for scan(iter, &column, &indexOptionsJSON) {
+		var err error
+
+		// decode the index options
+		if indexOptionsJSON != nil {
+			err = json.Unmarshal(indexOptionsJSON, &column.Index.Options)
+			if err != nil {
+				iter.Close()
+				return nil, fmt.Errorf(
+					"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
+					indexOptionsJSON,
+					column.Name,
+					column.Table,
+					err,
+				)
+			}
+		}
+
+		columns = append(columns, column)
+		column = ColumnMetadata{Keyspace: keyspaceName}
+	}
+
+	err := iter.Close()
+	if err != nil && err != ErrNotFound {
+		return nil, fmt.Errorf("Error querying column schema: %v", err)
+	}
+
+	return columns, nil
+}
+
+// type definition parser state
+type typeParser struct {
+	input string
+	index int
+}
+
+// the type definition parser result
+type typeParserResult struct {
+	isComposite bool
+	types       []TypeInfo
+	reversed    []bool
+	collections map[string]TypeInfo
+}
+
+// Parse the type definition used for validator and comparator schema data
+func parseType(def string) typeParserResult {
+	parser := &typeParser{input: def}
+	return parser.parse()
+}
+
+const (
+	REVERSED_TYPE   = "org.apache.cassandra.db.marshal.ReversedType"
+	COMPOSITE_TYPE  = "org.apache.cassandra.db.marshal.CompositeType"
+	COLLECTION_TYPE = "org.apache.cassandra.db.marshal.ColumnToCollectionType"
+	LIST_TYPE       = "org.apache.cassandra.db.marshal.ListType"
+	SET_TYPE        = "org.apache.cassandra.db.marshal.SetType"
+	MAP_TYPE        = "org.apache.cassandra.db.marshal.MapType"
+)
+
+// represents a class specification in the type def AST
+type typeParserClassNode struct {
+	name   string
+	params []typeParserParamNode
+	// this is the segment of the input string that defined this node
+	input string
+}
+
+// represents a class parameter in the type def AST
+type typeParserParamNode struct {
+	name  *string
+	class typeParserClassNode
+}
+
+func (t *typeParser) parse() typeParserResult {
+	// parse the AST
+	ast, ok := t.parseClassNode()
+	if !ok {
+		// treat this is a custom type
+		return typeParserResult{
+			isComposite: false,
+			types: []TypeInfo{
+				TypeInfo{
+					Type:   TypeCustom,
+					Custom: t.input,
+				},
+			},
+			reversed:    []bool{false},
+			collections: nil,
+		}
+	}
+
+	// interpret the AST
+	if strings.HasPrefix(ast.name, COMPOSITE_TYPE) {
+		count := len(ast.params)
+
+		// look for a collections param
+		last := ast.params[count-1]
+		collections := map[string]TypeInfo{}
+		if strings.HasPrefix(last.class.name, COLLECTION_TYPE) {
+			count--
+
+			for _, param := range last.class.params {
+				// decode the name
+				var name string
+				decoded, err := hex.DecodeString(*param.name)
+				if err != nil {
+					log.Printf(
+						"Error parsing type '%s', contains collection name '%s' with an invalid format: %v",
+						t.input,
+						*param.name,
+						err,
+					)
+					// just use the provided name
+					name = *param.name
+				} else {
+					name = string(decoded)
+				}
+				collections[name] = param.class.asTypeInfo()
+			}
+		}
+
+		types := make([]TypeInfo, count)
+		reversed := make([]bool, count)
+
+		for i, param := range ast.params[:count] {
+			class := param.class
+			reversed[i] = strings.HasPrefix(class.name, REVERSED_TYPE)
+			if reversed[i] {
+				class = class.params[0].class
+			}
+			types[i] = class.asTypeInfo()
+		}
+
+		return typeParserResult{
+			isComposite: true,
+			types:       types,
+			reversed:    reversed,
+			collections: collections,
+		}
+	} else {
+		// not composite, so one type
+		class := *ast
+		reversed := strings.HasPrefix(class.name, REVERSED_TYPE)
+		if reversed {
+			class = class.params[0].class
+		}
+		typeInfo := class.asTypeInfo()
+
+		return typeParserResult{
+			isComposite: false,
+			types:       []TypeInfo{typeInfo},
+			reversed:    []bool{reversed},
+		}
+	}
+}
+
+func (class *typeParserClassNode) asTypeInfo() TypeInfo {
+	if strings.HasPrefix(class.name, LIST_TYPE) {
+		elem := class.params[0].class.asTypeInfo()
+		return TypeInfo{
+			Type: TypeList,
+			Elem: &elem,
+		}
+	}
+	if strings.HasPrefix(class.name, SET_TYPE) {
+		elem := class.params[0].class.asTypeInfo()
+		return TypeInfo{
+			Type: TypeSet,
+			Elem: &elem,
+		}
+	}
+	if strings.HasPrefix(class.name, MAP_TYPE) {
+		key := class.params[0].class.asTypeInfo()
+		elem := class.params[1].class.asTypeInfo()
+		return TypeInfo{
+			Type: TypeMap,
+			Key:  &key,
+			Elem: &elem,
+		}
+	}
+
+	// must be a simple type or custom type
+	info := TypeInfo{Type: getApacheCassandraType(class.name)}
+	if info.Type == TypeCustom {
+		// add the entire class definition
+		info.Custom = class.input
+	}
+	return info
+}
+
+// CLASS := ID [ PARAMS ]
+func (t *typeParser) parseClassNode() (node *typeParserClassNode, ok bool) {
+	t.skipWhitespace()
+
+	startIndex := t.index
+
+	name, ok := t.nextIdentifier()
+	if !ok {
+		return nil, false
+	}
+
+	params, ok := t.parseParamNodes()
+	if !ok {
+		return nil, false
+	}
+
+	endIndex := t.index
+
+	node = &typeParserClassNode{
+		name:   name,
+		params: params,
+		input:  t.input[startIndex:endIndex],
+	}
+	return node, true
+}
+
+// PARAMS := "(" PARAM { "," PARAM } ")"
+// PARAM := [ PARAM_NAME ":" ] CLASS
+// PARAM_NAME := ID
+func (t *typeParser) parseParamNodes() (params []typeParserParamNode, ok bool) {
+	t.skipWhitespace()
+
+	// the params are optional
+	if t.index == len(t.input) || t.input[t.index] != '(' {
+		return nil, true
+	}
+
+	params = []typeParserParamNode{}
+
+	// consume the '('
+	t.index++
+
+	t.skipWhitespace()
+
+	for t.input[t.index] != ')' {
+		// look for a named param, but if no colon, then we want to backup
+		backupIndex := t.index
+
+		// name will be a hex encoded version of a utf-8 string
+		name, ok := t.nextIdentifier()
+		if !ok {
+			return nil, false
+		}
+		hasName := true
+
+		// TODO handle '=>' used for DynamicCompositeType
+
+		t.skipWhitespace()
+
+		if t.input[t.index] == ':' {
+			// there is a name for this parameter
+
+			// consume the ':'
+			t.index++
+
+			t.skipWhitespace()
+		} else {
+			// no name, backup
+			hasName = false
+			t.index = backupIndex
+		}
+
+		// parse the next full parameter
+		classNode, ok := t.parseClassNode()
+		if !ok {
+			return nil, false
+		}
+
+		if hasName {
+			params = append(
+				params,
+				typeParserParamNode{name: &name, class: *classNode},
+			)
+		} else {
+			params = append(
+				params,
+				typeParserParamNode{class: *classNode},
+			)
+		}
+
+		t.skipWhitespace()
+
+		if t.input[t.index] == ',' {
+			// consume the comma
+			t.index++
+
+			t.skipWhitespace()
+		}
+	}
+
+	// consume the ')'
+	t.index++
+
+	return params, true
+}
+
+func (t *typeParser) skipWhitespace() {
+	for t.index < len(t.input) && isWhitespaceChar(t.input[t.index]) {
+		t.index++
+	}
+}
+
+func isWhitespaceChar(c byte) bool {
+	return c == ' ' || c == '\n' || c == '\t'
+}
+
+// ID := LETTER { LETTER }
+// LETTER := "0"..."9" | "a"..."z" | "A"..."Z" | "-" | "+" | "." | "_" | "&"
+func (t *typeParser) nextIdentifier() (id string, found bool) {
+	startIndex := t.index
+	for t.index < len(t.input) && isIdentifierChar(t.input[t.index]) {
+		t.index++
+	}
+	if startIndex == t.index {
+		return "", false
+	}
+	return t.input[startIndex:t.index], true
+}
+
+func isIdentifierChar(c byte) bool {
+	return (c >= '0' && c <= '9') ||
+		(c >= 'a' && c <= 'z') ||
+		(c >= 'A' && c <= 'Z') ||
+		c == '-' ||
+		c == '+' ||
+		c == '.' ||
+		c == '_' ||
+		c == '&'
+}

+ 670 - 0
metadata_test.go

@@ -0,0 +1,670 @@
+// Copyright (c) 2015 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gocql
+
+import (
+	"strconv"
+	"testing"
+)
+
+func TestCompileMetadata(t *testing.T) {
+	// V1 tests - these are all based on real examples from the integration test ccm cluster
+	keyspace := &KeyspaceMetadata{
+		Name: "V1Keyspace",
+	}
+	tables := []TableMetadata{
+		TableMetadata{
+			// This table, found in the system keyspace, has no key aliases or column aliases
+			Keyspace:         "V1Keyspace",
+			Name:             "Schema",
+			KeyValidator:     "org.apache.cassandra.db.marshal.BytesType",
+			Comparator:       "org.apache.cassandra.db.marshal.UTF8Type",
+			DefaultValidator: "org.apache.cassandra.db.marshal.BytesType",
+			KeyAliases:       []string{},
+			ColumnAliases:    []string{},
+			ValueAlias:       "",
+		},
+		TableMetadata{
+			// This table, found in the system keyspace, has key aliases, column aliases, and a value alias.
+			Keyspace:         "V1Keyspace",
+			Name:             "hints",
+			KeyValidator:     "org.apache.cassandra.db.marshal.UUIDType",
+			Comparator:       "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.TimeUUIDType,org.apache.cassandra.db.marshal.Int32Type)",
+			DefaultValidator: "org.apache.cassandra.db.marshal.BytesType",
+			KeyAliases:       []string{"target_id"},
+			ColumnAliases:    []string{"hint_id", "message_version"},
+			ValueAlias:       "mutation",
+		},
+		TableMetadata{
+			// This table, found in the system keyspace, has a comparator with collections, but no column aliases
+			Keyspace:         "V1Keyspace",
+			Name:             "peers",
+			KeyValidator:     "org.apache.cassandra.db.marshal.InetAddressType",
+			Comparator:       "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.ColumnToCollectionType(746f6b656e73:org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)))",
+			DefaultValidator: "org.apache.cassandra.db.marshal.BytesType",
+			KeyAliases:       []string{"peer"},
+			ColumnAliases:    []string{},
+			ValueAlias:       "",
+		},
+		TableMetadata{
+			// This table, found in the system keyspace, has a column alias, but not a composite comparator
+			Keyspace:         "V1Keyspace",
+			Name:             "IndexInfo",
+			KeyValidator:     "org.apache.cassandra.db.marshal.UTF8Type",
+			Comparator:       "org.apache.cassandra.db.marshal.UTF8Type",
+			DefaultValidator: "org.apache.cassandra.db.marshal.BytesType",
+			KeyAliases:       []string{"table_name"},
+			ColumnAliases:    []string{"index_name"},
+			ValueAlias:       "",
+		},
+		TableMetadata{
+			// This table, found in the gocql_test keyspace following an integration test run, has a composite comparator with collections as well as a column alias
+			Keyspace:         "V1Keyspace",
+			Name:             "wiki_page",
+			KeyValidator:     "org.apache.cassandra.db.marshal.UTF8Type",
+			Comparator:       "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.TimeUUIDType,org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.ColumnToCollectionType(74616773:org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type),6174746163686d656e7473:org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.BytesType)))",
+			DefaultValidator: "org.apache.cassandra.db.marshal.BytesType",
+			KeyAliases:       []string{"title"},
+			ColumnAliases:    []string{"revid"},
+			ValueAlias:       "",
+		},
+	}
+	columns := []ColumnMetadata{
+		// Here are the regular columns from the peers table for testing regular columns
+		ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "data_center", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type"},
+		ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "host_id", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType"},
+		ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "rack", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type"},
+		ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "release_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type"},
+		ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "rpc_address", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.InetAddressType"},
+		ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "schema_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType"},
+		ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "tokens", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)"},
+	}
+	compileMetadata(1, keyspace, tables, columns)
+	assertKeyspaceMetadata(
+		t,
+		keyspace,
+		&KeyspaceMetadata{
+			Name: "V1Keyspace",
+			Tables: map[string]*TableMetadata{
+				"Schema": &TableMetadata{
+					PartitionKey: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "key",
+							Type: TypeInfo{Type: TypeBlob},
+						},
+					},
+					ClusteringColumns: []*ColumnMetadata{},
+					Columns: map[string]*ColumnMetadata{
+						"key": &ColumnMetadata{
+							Name: "key",
+							Type: TypeInfo{Type: TypeBlob},
+							Kind: PARTITION_KEY,
+						},
+					},
+				},
+				"hints": &TableMetadata{
+					PartitionKey: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "target_id",
+							Type: TypeInfo{Type: TypeUUID},
+						},
+					},
+					ClusteringColumns: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name:  "hint_id",
+							Type:  TypeInfo{Type: TypeTimeUUID},
+							Order: ASC,
+						},
+						&ColumnMetadata{
+							Name:  "message_version",
+							Type:  TypeInfo{Type: TypeInt},
+							Order: ASC,
+						},
+					},
+					Columns: map[string]*ColumnMetadata{
+						"target_id": &ColumnMetadata{
+							Name: "target_id",
+							Type: TypeInfo{Type: TypeUUID},
+							Kind: PARTITION_KEY,
+						},
+						"hint_id": &ColumnMetadata{
+							Name:  "hint_id",
+							Type:  TypeInfo{Type: TypeTimeUUID},
+							Order: ASC,
+							Kind:  CLUSTERING_KEY,
+						},
+						"message_version": &ColumnMetadata{
+							Name:  "message_version",
+							Type:  TypeInfo{Type: TypeInt},
+							Order: ASC,
+							Kind:  CLUSTERING_KEY,
+						},
+						"mutation": &ColumnMetadata{
+							Name: "mutation",
+							Type: TypeInfo{Type: TypeBlob},
+							Kind: REGULAR,
+						},
+					},
+				},
+				"peers": &TableMetadata{
+					PartitionKey: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "peer",
+							Type: TypeInfo{Type: TypeInet},
+						},
+					},
+					ClusteringColumns: []*ColumnMetadata{},
+					Columns: map[string]*ColumnMetadata{
+						"peer": &ColumnMetadata{
+							Name: "peer",
+							Type: TypeInfo{Type: TypeInet},
+							Kind: PARTITION_KEY,
+						},
+						"data_center":     &ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "data_center", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", Type: TypeInfo{Type: TypeVarchar}},
+						"host_id":         &ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "host_id", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType", Type: TypeInfo{Type: TypeUUID}},
+						"rack":            &ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "rack", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", Type: TypeInfo{Type: TypeVarchar}},
+						"release_version": &ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "release_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", Type: TypeInfo{Type: TypeVarchar}},
+						"rpc_address":     &ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "rpc_address", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.InetAddressType", Type: TypeInfo{Type: TypeInet}},
+						"schema_version":  &ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "schema_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType", Type: TypeInfo{Type: TypeUUID}},
+						"tokens":          &ColumnMetadata{Keyspace: "V1Keyspace", Table: "peers", Kind: REGULAR, Name: "tokens", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)", Type: TypeInfo{Type: TypeSet}},
+					},
+				},
+				"IndexInfo": &TableMetadata{
+					PartitionKey: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "table_name",
+							Type: TypeInfo{Type: TypeVarchar},
+						},
+					},
+					ClusteringColumns: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name:  "index_name",
+							Type:  TypeInfo{Type: TypeVarchar},
+							Order: ASC,
+						},
+					},
+					Columns: map[string]*ColumnMetadata{
+						"table_name": &ColumnMetadata{
+							Name: "table_name",
+							Type: TypeInfo{Type: TypeVarchar},
+							Kind: PARTITION_KEY,
+						},
+						"index_name": &ColumnMetadata{
+							Name: "index_name",
+							Type: TypeInfo{Type: TypeVarchar},
+							Kind: CLUSTERING_KEY,
+						},
+						"value": &ColumnMetadata{
+							Name: "value",
+							Type: TypeInfo{Type: TypeBlob},
+							Kind: REGULAR,
+						},
+					},
+				},
+				"wiki_page": &TableMetadata{
+					PartitionKey: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "title",
+							Type: TypeInfo{Type: TypeVarchar},
+						},
+					},
+					ClusteringColumns: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name:  "revid",
+							Type:  TypeInfo{Type: TypeTimeUUID},
+							Order: ASC,
+						},
+					},
+					Columns: map[string]*ColumnMetadata{
+						"title": &ColumnMetadata{
+							Name: "title",
+							Type: TypeInfo{Type: TypeVarchar},
+							Kind: PARTITION_KEY,
+						},
+						"revid": &ColumnMetadata{
+							Name: "revid",
+							Type: TypeInfo{Type: TypeTimeUUID},
+							Kind: CLUSTERING_KEY,
+						},
+					},
+				},
+			},
+		},
+	)
+
+	// V2 test - V2+ protocol is simpler so here are some toy examples to verify that the mapping works
+	keyspace = &KeyspaceMetadata{
+		Name: "V2Keyspace",
+	}
+	tables = []TableMetadata{
+		TableMetadata{
+			Keyspace: "V2Keyspace",
+			Name:     "Table1",
+		},
+		TableMetadata{
+			Keyspace: "V2Keyspace",
+			Name:     "Table2",
+		},
+	}
+	columns = []ColumnMetadata{
+		ColumnMetadata{
+			Keyspace:  "V2Keyspace",
+			Table:     "Table1",
+			Name:      "Key1",
+			Kind:      PARTITION_KEY,
+			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
+		},
+		ColumnMetadata{
+			Keyspace:  "V2Keyspace",
+			Table:     "Table2",
+			Name:      "Column1",
+			Kind:      PARTITION_KEY,
+			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
+		},
+		ColumnMetadata{
+			Keyspace:  "V2Keyspace",
+			Table:     "Table2",
+			Name:      "Column2",
+			Kind:      CLUSTERING_KEY,
+			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
+		},
+		ColumnMetadata{
+			Keyspace:  "V2Keyspace",
+			Table:     "Table2",
+			Name:      "Column3",
+			Kind:      REGULAR,
+			Validator: "org.apache.cassandra.db.marshal.UTF8Type",
+		},
+	}
+	compileMetadata(2, keyspace, tables, columns)
+	assertKeyspaceMetadata(
+		t,
+		keyspace,
+		&KeyspaceMetadata{
+			Name: "V2Keyspace",
+			Tables: map[string]*TableMetadata{
+				"Table1": &TableMetadata{
+					PartitionKey: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "Key1",
+							Type: TypeInfo{Type: TypeVarchar},
+						},
+					},
+					ClusteringColumns: []*ColumnMetadata{},
+					Columns: map[string]*ColumnMetadata{
+						"Key1": &ColumnMetadata{
+							Name: "Key1",
+							Type: TypeInfo{Type: TypeVarchar},
+							Kind: PARTITION_KEY,
+						},
+					},
+				},
+				"Table2": &TableMetadata{
+					PartitionKey: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "Column1",
+							Type: TypeInfo{Type: TypeVarchar},
+						},
+					},
+					ClusteringColumns: []*ColumnMetadata{
+						&ColumnMetadata{
+							Name: "Column2",
+							Type: TypeInfo{Type: TypeVarchar},
+						},
+					},
+					Columns: map[string]*ColumnMetadata{
+						"Column1": &ColumnMetadata{
+							Name: "Column1",
+							Type: TypeInfo{Type: TypeVarchar},
+							Kind: PARTITION_KEY,
+						},
+						"Column2": &ColumnMetadata{
+							Name: "Column2",
+							Type: TypeInfo{Type: TypeVarchar},
+							Kind: CLUSTERING_KEY,
+						},
+						"Column3": &ColumnMetadata{
+							Name: "Column3",
+							Type: TypeInfo{Type: TypeVarchar},
+							Kind: REGULAR,
+						},
+					},
+				},
+			},
+		},
+	)
+}
+
+func assertKeyspaceMetadata(t *testing.T, actual, expected *KeyspaceMetadata) {
+	if len(expected.Tables) != len(actual.Tables) {
+		t.Errorf("Expected len(%s.Tables) to be %v but was %v", expected.Name, len(expected.Tables), len(actual.Tables))
+	}
+	for keyT := range expected.Tables {
+		et := expected.Tables[keyT]
+		at, found := actual.Tables[keyT]
+
+		if !found {
+			t.Errorf("Expected %s.Tables[%s] but was not found", expected.Name, keyT)
+		} else {
+			if keyT != at.Name {
+				t.Errorf("Expected %s.Tables[%s].Name to be %v but was %v", expected.Name, keyT, keyT, at.Name)
+			}
+			if len(et.PartitionKey) != len(at.PartitionKey) {
+				t.Errorf("Expected len(%s.Tables[%s].PartitionKey) to be %v but was %v", expected.Name, keyT, len(et.PartitionKey), len(at.PartitionKey))
+			} else {
+				for i := range et.PartitionKey {
+					if et.PartitionKey[i].Name != at.PartitionKey[i].Name {
+						t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Name to be '%v' but was '%v'", expected.Name, keyT, i, et.PartitionKey[i].Name, at.PartitionKey[i].Name)
+					}
+					if expected.Name != at.PartitionKey[i].Keyspace {
+						t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Keyspace to be '%v' but was '%v'", expected.Name, keyT, i, expected.Name, at.PartitionKey[i].Keyspace)
+					}
+					if keyT != at.PartitionKey[i].Table {
+						t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Table to be '%v' but was '%v'", expected.Name, keyT, i, keyT, at.PartitionKey[i].Table)
+					}
+					if et.PartitionKey[i].Type.Type != at.PartitionKey[i].Type.Type {
+						t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Type.Type to be %v but was %v", expected.Name, keyT, i, et.PartitionKey[i].Type.Type, at.PartitionKey[i].Type.Type)
+					}
+					if i != at.PartitionKey[i].ComponentIndex {
+						t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].ComponentIndex to be %v but was %v", expected.Name, keyT, i, i, at.PartitionKey[i].ComponentIndex)
+					}
+					if PARTITION_KEY != at.PartitionKey[i].Kind {
+						t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Kind to be '%v' but was '%v'", expected.Name, keyT, i, PARTITION_KEY, at.PartitionKey[i].Kind)
+					}
+				}
+			}
+			if len(et.ClusteringColumns) != len(at.ClusteringColumns) {
+				t.Errorf("Expected len(%s.Tables[%s].ClusteringColumns) to be %v but was %v", expected.Name, keyT, len(et.ClusteringColumns), len(at.ClusteringColumns))
+			} else {
+				for i := range et.ClusteringColumns {
+					if et.ClusteringColumns[i].Name != at.ClusteringColumns[i].Name {
+						t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Name to be '%v' but was '%v'", expected.Name, keyT, i, et.ClusteringColumns[i].Name, at.ClusteringColumns[i].Name)
+					}
+					if expected.Name != at.ClusteringColumns[i].Keyspace {
+						t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Keyspace to be '%v' but was '%v'", expected.Name, keyT, i, expected.Name, at.ClusteringColumns[i].Keyspace)
+					}
+					if keyT != at.ClusteringColumns[i].Table {
+						t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Table to be '%v' but was '%v'", expected.Name, keyT, i, keyT, at.ClusteringColumns[i].Table)
+					}
+					if et.ClusteringColumns[i].Type.Type != at.ClusteringColumns[i].Type.Type {
+						t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Type.Type to be %v but was %v", expected.Name, keyT, i, et.ClusteringColumns[i].Type.Type, at.ClusteringColumns[i].Type.Type)
+					}
+					if i != at.ClusteringColumns[i].ComponentIndex {
+						t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].ComponentIndex to be %v but was %v", expected.Name, keyT, i, i, at.ClusteringColumns[i].ComponentIndex)
+					}
+					if et.ClusteringColumns[i].Order != at.ClusteringColumns[i].Order {
+						t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Order to be %v but was %v", expected.Name, keyT, i, et.ClusteringColumns[i].Order, at.ClusteringColumns[i].Order)
+					}
+					if CLUSTERING_KEY != at.ClusteringColumns[i].Kind {
+						t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Kind to be '%v' but was '%v'", expected.Name, keyT, i, CLUSTERING_KEY, at.ClusteringColumns[i].Kind)
+					}
+				}
+			}
+			if len(et.Columns) != len(at.Columns) {
+				eKeys := make([]string, 0, len(et.Columns))
+				for key := range et.Columns {
+					eKeys = append(eKeys, key)
+				}
+				aKeys := make([]string, 0, len(at.Columns))
+				for key := range at.Columns {
+					aKeys = append(aKeys, key)
+				}
+				t.Errorf("Expected len(%s.Tables[%s].Columns) to be %v (keys:%v) but was %v (keys:%v)", expected.Name, keyT, len(et.Columns), eKeys, len(at.Columns), aKeys)
+			} else {
+				for keyC := range et.Columns {
+					ec := et.Columns[keyC]
+					ac, found := at.Columns[keyC]
+
+					if !found {
+						t.Errorf("Expected %s.Tables[%s].Columns[%s] but was not found", expected.Name, keyT, keyC)
+					} else {
+						if keyC != ac.Name {
+							t.Errorf("Expected %s.Tables[%s].Columns[%s].Name to be '%v' but was '%v'", expected.Name, keyT, keyC, keyC, at.Name)
+						}
+						if expected.Name != ac.Keyspace {
+							t.Errorf("Expected %s.Tables[%s].Columns[%s].Keyspace to be '%v' but was '%v'", expected.Name, keyT, keyC, expected.Name, ac.Keyspace)
+						}
+						if keyT != ac.Table {
+							t.Errorf("Expected %s.Tables[%s].Columns[%s].Table to be '%v' but was '%v'", expected.Name, keyT, keyC, keyT, ac.Table)
+						}
+						if ec.Type.Type != ac.Type.Type {
+							t.Errorf("Expected %s.Tables[%s].Columns[%s].Type.Type to be %v but was %v", expected.Name, keyT, keyC, ec.Type.Type, ac.Type.Type)
+						}
+						if ec.Order != ac.Order {
+							t.Errorf("Expected %s.Tables[%s].Columns[%s].Order to be %v but was %v", expected.Name, keyT, keyC, ec.Order, ac.Order)
+						}
+						if ec.Kind != ac.Kind {
+							t.Errorf("Expected %s.Tables[%s].Columns[%s].Kind to be '%v' but was '%v'", expected.Name, keyT, keyC, ec.Kind, ac.Kind)
+						}
+					}
+				}
+			}
+		}
+	}
+}
+
+func TestTypeParser(t *testing.T) {
+	// native type
+	assertParseNonCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.UTF8Type",
+		assertTypeInfo{Type: TypeVarchar},
+	)
+
+	// reversed
+	assertParseNonCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.ReversedType(org.apache.cassandra.db.marshal.UUIDType)",
+		assertTypeInfo{Type: TypeUUID, Reversed: true},
+	)
+
+	// set
+	assertParseNonCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)",
+		assertTypeInfo{
+			Type: TypeSet,
+			Elem: &assertTypeInfo{Type: TypeInt},
+		},
+	)
+
+	// map
+	assertParseNonCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UUIDType,org.apache.cassandra.db.marshal.BytesType)",
+		assertTypeInfo{
+			Type: TypeMap,
+			Key:  &assertTypeInfo{Type: TypeUUID},
+			Elem: &assertTypeInfo{Type: TypeBlob},
+		},
+	)
+
+	// custom
+	assertParseNonCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.marshal.LexicalUUIDType)",
+		assertTypeInfo{Type: TypeCustom, Custom: "org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.marshal.LexicalUUIDType)"},
+	)
+
+	// composite defs
+	assertParseCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UTF8Type)",
+		[]assertTypeInfo{
+			assertTypeInfo{Type: TypeVarchar},
+		},
+		nil,
+	)
+	assertParseCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.DateType,org.apache.cassandra.db.marshal.UTF8Type)",
+		[]assertTypeInfo{
+			assertTypeInfo{Type: TypeTimestamp},
+			assertTypeInfo{Type: TypeVarchar},
+		},
+		nil,
+	)
+	assertParseCompositeType(
+		t,
+		"org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.ColumnToCollectionType(726f77735f6d6572676564:org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.LongType)))",
+		[]assertTypeInfo{
+			assertTypeInfo{Type: TypeVarchar},
+		},
+		map[string]assertTypeInfo{
+			"rows_merged": assertTypeInfo{
+				Type: TypeMap,
+				Key:  &assertTypeInfo{Type: TypeInt},
+				Elem: &assertTypeInfo{Type: TypeBigInt},
+			},
+		},
+	)
+}
+
+//---------------------------------------
+// some code to assert the parser result
+//---------------------------------------
+
+type assertTypeInfo struct {
+	Type     Type
+	Reversed bool
+	Elem     *assertTypeInfo
+	Key      *assertTypeInfo
+	Custom   string
+}
+
+func assertParseNonCompositeType(
+	t *testing.T,
+	def string,
+	typeExpected assertTypeInfo,
+) {
+
+	result := parseType(def)
+	if len(result.reversed) != 1 {
+		t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed))
+	}
+
+	assertParseNonCompositeTypes(
+		t,
+		def,
+		[]assertTypeInfo{typeExpected},
+		result.types,
+	)
+
+	// expect no composite part of the result
+	if result.isComposite {
+		t.Errorf("%s: Expected not composite", def)
+	}
+	if result.collections != nil {
+		t.Errorf("%s: Expected nil collections: %v", def, result.collections)
+	}
+}
+
+func assertParseCompositeType(
+	t *testing.T,
+	def string,
+	typesExpected []assertTypeInfo,
+	collectionsExpected map[string]assertTypeInfo,
+) {
+
+	result := parseType(def)
+	if len(result.reversed) != len(typesExpected) {
+		t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed))
+	}
+
+	assertParseNonCompositeTypes(
+		t,
+		def,
+		typesExpected,
+		result.types,
+	)
+
+	// expect composite part of the result
+	if !result.isComposite {
+		t.Errorf("%s: Expected composite", def)
+	}
+	if result.collections == nil {
+		t.Errorf("%s: Expected non-nil collections: %v", def, result.collections)
+	}
+
+	for name, typeExpected := range collectionsExpected {
+		// check for an actual type for this name
+		typeActual, found := result.collections[name]
+		if !found {
+			t.Errorf("%s.tcollections: Expected param named %s but there wasn't", def, name)
+		} else {
+			// remove the actual from the collection so we can detect extras
+			delete(result.collections, name)
+
+			// check the type
+			assertParseNonCompositeTypes(
+				t,
+				def+"collections["+name+"]",
+				[]assertTypeInfo{typeExpected},
+				[]TypeInfo{typeActual},
+			)
+		}
+	}
+
+	if len(result.collections) != 0 {
+		t.Errorf("%s.collections: Expected no more types in collections, but there was %v", def, result.collections)
+	}
+}
+
+func assertParseNonCompositeTypes(
+	t *testing.T,
+	context string,
+	typesExpected []assertTypeInfo,
+	typesActual []TypeInfo,
+) {
+	if len(typesActual) != len(typesExpected) {
+		t.Errorf("%s: Expected %d types, but there were %d", context, len(typesExpected), len(typesActual))
+	}
+
+	for i := range typesExpected {
+		typeExpected := typesExpected[i]
+		typeActual := typesActual[i]
+
+		// shadow copy the context for local modification
+		context := context
+		if len(typesExpected) > 1 {
+			context = context + "[" + strconv.Itoa(i) + "]"
+		}
+
+		// check the type
+		if typeActual.Type != typeExpected.Type {
+			t.Errorf("%s: Expected to parse Type to %s but was %s", context, typeExpected.Type, typeActual.Type)
+		}
+		// check the custom
+		if typeActual.Custom != typeExpected.Custom {
+			t.Errorf("%s: Expected to parse Custom %s but was %s", context, typeExpected.Custom, typeActual.Custom)
+		}
+		// check the elem
+		if typeActual.Elem == nil && typeExpected.Elem != nil {
+			t.Errorf("%s: Expected to parse Elem, but was nil ", context)
+		} else if typeExpected.Elem == nil && typeActual.Elem != nil {
+			t.Errorf("%s: Expected to not parse Elem, but was %+v", context, typeActual.Elem)
+		} else if typeActual.Elem != nil && typeExpected.Elem != nil {
+			assertParseNonCompositeTypes(
+				t,
+				context+".Elem",
+				[]assertTypeInfo{*typeExpected.Elem},
+				[]TypeInfo{*typeActual.Elem},
+			)
+		}
+		// check the key
+		if typeActual.Key == nil && typeExpected.Key != nil {
+			t.Errorf("%s: Expected to parse Key, but was nil ", context)
+		} else if typeExpected.Key == nil && typeActual.Key != nil {
+			t.Errorf("%s: Expected to not parse Key, but was %+v", context, typeActual.Key)
+		} else if typeActual.Key != nil && typeExpected.Key != nil {
+			assertParseNonCompositeTypes(
+				t,
+				context+".Key",
+				[]assertTypeInfo{*typeExpected.Key},
+				[]TypeInfo{*typeActual.Key},
+			)
+		}
+	}
+}

+ 30 - 7
session.go

@@ -24,12 +24,13 @@ import (
 // and automatically sets a default consinstency level on all operations
 // and automatically sets a default consinstency level on all operations
 // that do not have a consistency level set.
 // that do not have a consistency level set.
 type Session struct {
 type Session struct {
-	Pool     ConnectionPool
-	cons     Consistency
-	pageSize int
-	prefetch float64
-	trace    Tracer
-	mu       sync.RWMutex
+	Pool            ConnectionPool
+	cons            Consistency
+	pageSize        int
+	prefetch        float64
+	schemaDescriber *schemaDescriber
+	trace           Tracer
+	mu              sync.RWMutex
 
 
 	cfg ClusterConfig
 	cfg ClusterConfig
 
 
@@ -39,7 +40,7 @@ type Session struct {
 
 
 // NewSession wraps an existing Node.
 // NewSession wraps an existing Node.
 func NewSession(p ConnectionPool, c ClusterConfig) *Session {
 func NewSession(p ConnectionPool, c ClusterConfig) *Session {
-	return &Session{Pool: p, cons: Quorum, prefetch: 0.25, cfg: c}
+	return &Session{Pool: p, cons: c.Consistency, prefetch: 0.25, cfg: c}
 }
 }
 
 
 // SetConsistency sets the default consistency level for this session. This
 // SetConsistency sets the default consistency level for this session. This
@@ -163,6 +164,27 @@ func (s *Session) executeQuery(qry *Query) *Iter {
 	return iter
 	return iter
 }
 }
 
 
+// KeyspaceMetadata returns the schema metadata for the keyspace specified.
+func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) {
+	// fail fast
+	if s.Closed() {
+		return nil, ErrSessionClosed
+	}
+
+	if keyspace == "" {
+		return nil, ErrNoKeyspace
+	}
+
+	s.mu.Lock()
+	// lazy-init schemaDescriber
+	if s.schemaDescriber == nil {
+		s.schemaDescriber = newSchemaDescriber(s)
+	}
+	s.mu.Unlock()
+
+	return s.schemaDescriber.getSchema(keyspace)
+}
+
 // ExecuteBatch executes a batch operation and returns nil if successful
 // ExecuteBatch executes a batch operation and returns nil if successful
 // otherwise an error is returned describing the failure.
 // otherwise an error is returned describing the failure.
 func (s *Session) ExecuteBatch(batch *Batch) error {
 func (s *Session) ExecuteBatch(batch *Batch) error {
@@ -659,6 +681,7 @@ var (
 	ErrUseStmt       = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.")
 	ErrUseStmt       = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.")
 	ErrSessionClosed = errors.New("session has been closed")
 	ErrSessionClosed = errors.New("session has been closed")
 	ErrNoConnections = errors.New("no connections available")
 	ErrNoConnections = errors.New("no connections available")
+	ErrNoKeyspace    = errors.New("no keyspace provided")
 )
 )
 
 
 type ErrProtocol struct{ error }
 type ErrProtocol struct{ error }