Kaynağa Gözat

Merge pull request #463 from Zariel/proto-v4

Proto v4
Chris Bannister 10 yıl önce
ebeveyn
işleme
3bc659aecb
9 değiştirilmiş dosya ile 349 ekleme ve 116 silme
  1. 2 2
      .travis.yml
  2. 0 1
      cass1batch_test.go
  3. 39 22
      cassandra_test.go
  4. 3 3
      conn.go
  5. 33 15
      errors.go
  6. 176 14
      frame.go
  7. 4 0
      integration.sh
  8. 49 21
      metadata.go
  9. 43 38
      wiki_test.go

+ 2 - 2
.travis.yml

@@ -13,10 +13,10 @@ env:
   global:
     - GOMAXPROCS=2
   matrix:
-    - CASS=1.2.19 AUTH=false
-    - CASS=2.0.14 AUTH=false
+    - CASS=2.0.16 AUTH=false
     - CASS=2.1.5  AUTH=false
     - CASS=2.1.5  AUTH=true
+    - CASS=2.2.0  AUTH=false
 
 go:
   - 1.4

+ 0 - 1
cass1batch_test.go

@@ -22,7 +22,6 @@ func TestProto1BatchInsert(t *testing.T) {
 	if err := session.Query(fullQuery, args...).Consistency(Quorum).Exec(); err != nil {
 		t.Fatal(err)
 	}
-
 }
 
 func TestShouldPrepareFunction(t *testing.T) {

+ 39 - 22
cassandra_test.go

@@ -87,21 +87,34 @@ func createCluster() *ClusterConfig {
 }
 
 func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
-	session, err := cluster.CreateSession()
+	c := *cluster
+	c.Keyspace = "system"
+	session, err := c.CreateSession()
 	if err != nil {
 		tb.Fatal("createSession:", err)
 	}
-	defer session.Close()
-	if err = session.Query(`DROP KEYSPACE IF EXISTS ` + keyspace).Exec(); err != nil {
-		tb.Log("drop keyspace:", err)
+
+	// should reuse the same conn apparently
+	conn := session.Pool.Pick(nil)
+	if conn == nil {
+		tb.Fatal("no connections available in the pool")
+	}
+
+	err = conn.executeQuery(session.Query(`DROP KEYSPACE IF EXISTS ` + keyspace).Consistency(All)).Close()
+	if err != nil {
+		tb.Fatal(err)
 	}
-	if err := session.Query(fmt.Sprintf(`CREATE KEYSPACE %s
+
+	query := session.Query(fmt.Sprintf(`CREATE KEYSPACE %s
 	WITH replication = {
 		'class' : 'SimpleStrategy',
 		'replication_factor' : %d
-	}`, keyspace, *flagRF)).Consistency(All).Exec(); err != nil {
-		tb.Fatalf("error creating keyspace %s: %v", keyspace, err)
+	}`, keyspace, *flagRF)).Consistency(All)
+
+	if err = conn.executeQuery(query).Close(); err != nil {
+		tb.Fatal(err)
 	}
+
 	tb.Logf("Created keyspace %s", keyspace)
 }
 
@@ -1072,14 +1085,16 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	stmtsLRU.Unlock()
 	flight.info = &resultPreparedFrame{
 		preparedID: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
-		reqMeta: resultMetadata{
-			columns: []ColumnInfo{
-				{
-					Keyspace: "gocql_test",
-					Table:    table,
-					Name:     "foo",
-					TypeInfo: NativeType{
-						typ: TypeVarchar,
+		reqMeta: preparedMetadata{
+			resultMetadata: resultMetadata{
+				columns: []ColumnInfo{
+					{
+						Keyspace: "gocql_test",
+						Table:    table,
+						Name:     "foo",
+						TypeInfo: NativeType{
+							typ: TypeVarchar,
+						},
 					},
 				},
 			}},
@@ -1638,13 +1653,15 @@ func TestGetTableMetadata(t *testing.T) {
 	if testTable.DefaultValidator != "org.apache.cassandra.db.marshal.BytesType" {
 		t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.BytesType' but was '%s'", testTable.DefaultValidator)
 	}
-	expectedKeyAliases := []string{"first_id"}
-	if !reflect.DeepEqual(testTable.KeyAliases, expectedKeyAliases) {
-		t.Errorf("Expected key aliases %v but was %v", expectedKeyAliases, testTable.KeyAliases)
-	}
-	expectedColumnAliases := []string{"second_id"}
-	if !reflect.DeepEqual(testTable.ColumnAliases, expectedColumnAliases) {
-		t.Errorf("Expected key aliases %v but was %v", expectedColumnAliases, testTable.ColumnAliases)
+	if *flagProto < protoVersion4 {
+		expectedKeyAliases := []string{"first_id"}
+		if !reflect.DeepEqual(testTable.KeyAliases, expectedKeyAliases) {
+			t.Errorf("Expected key aliases %v but was %v", expectedKeyAliases, testTable.KeyAliases)
+		}
+		expectedColumnAliases := []string{"second_id"}
+		if !reflect.DeepEqual(testTable.ColumnAliases, expectedColumnAliases) {
+			t.Errorf("Expected key aliases %v but was %v", expectedColumnAliases, testTable.ColumnAliases)
+		}
 	}
 	if testTable.ValueAlias != "" {
 		t.Errorf("Expected value alias '' but was '%s'", testTable.ValueAlias)

+ 3 - 3
conn.go

@@ -144,7 +144,7 @@ func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn,
 	}
 
 	// going to default to proto 2
-	if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
+	if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion4 {
 		log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
 		cfg.ProtoVersion = 2
 	}
@@ -653,7 +653,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		}
 
 		return iter
-	case *resultKeyspaceFrame, *resultSchemaChangeFrame:
+	case *resultKeyspaceFrame, *resultSchemaChangeFrame, *schemaChangeKeyspace, *schemaChangeTable:
 		return &Iter{}
 	case *RequestErrUnprepared:
 		stmtsLRU.Lock()
@@ -668,7 +668,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case error:
 		return &Iter{err: x}
 	default:
-		return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
+		return &Iter{err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x)}
 	}
 }
 

+ 33 - 15
errors.go

@@ -3,21 +3,23 @@ package gocql
 import "fmt"
 
 const (
-	errServer        = 0x0000
-	errProtocol      = 0x000A
-	errCredentials   = 0x0100
-	errUnavailable   = 0x1000
-	errOverloaded    = 0x1001
-	errBootstrapping = 0x1002
-	errTruncate      = 0x1003
-	errWriteTimeout  = 0x1100
-	errReadTimeout   = 0x1200
-	errSyntax        = 0x2000
-	errUnauthorized  = 0x2100
-	errInvalid       = 0x2200
-	errConfig        = 0x2300
-	errAlreadyExists = 0x2400
-	errUnprepared    = 0x2500
+	errServer          = 0x0000
+	errProtocol        = 0x000A
+	errCredentials     = 0x0100
+	errUnavailable     = 0x1000
+	errOverloaded      = 0x1001
+	errBootstrapping   = 0x1002
+	errTruncate        = 0x1003
+	errWriteTimeout    = 0x1100
+	errReadTimeout     = 0x1200
+	errReadFailure     = 0x1300
+	errFunctionFailure = 0x1400
+	errSyntax          = 0x2000
+	errUnauthorized    = 0x2100
+	errInvalid         = 0x2200
+	errConfig          = 0x2300
+	errAlreadyExists   = 0x2400
+	errUnprepared      = 0x2500
 )
 
 type RequestError interface {
@@ -86,3 +88,19 @@ type RequestErrUnprepared struct {
 	errorFrame
 	StatementId []byte
 }
+
+type RequestErrReadFailure struct {
+	errorFrame
+	Consistency Consistency
+	Received    int
+	BlockFor    int
+	NumFailures int
+	DataPresent bool
+}
+
+type RequestErrFunctionFailure struct {
+	errorFrame
+	Keyspace string
+	Function string
+	ArgTypes []string
+}

+ 176 - 14
frame.go

@@ -9,6 +9,7 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"log"
 	"net"
 	"runtime"
 	"sync"
@@ -21,6 +22,7 @@ const (
 	protoVersion1      = 0x01
 	protoVersion2      = 0x02
 	protoVersion3      = 0x03
+	protoVersion4      = 0x04
 
 	maxFrameSize = 256 * 1024 * 1024
 )
@@ -132,8 +134,10 @@ const (
 	flagWithNameValues             = 0x40
 
 	// header flags
-	flagCompress byte = 0x01
-	flagTracing       = 0x02
+	flagCompress      byte = 0x01
+	flagTracing       byte = 0x02
+	flagCustomPayload byte = 0x04
+	flagWarning       byte = 0x08
 )
 
 type Consistency uint16
@@ -315,8 +319,8 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, err error) {
 
 	version := p[0] & protoVersionMask
 
-	if version < protoVersion1 || version > protoVersion3 {
-		err = fmt.Errorf("invalid version: %x", version)
+	if version < protoVersion1 || version > protoVersion4 {
+		err = fmt.Errorf("gocql: invalid version: %x", version)
 		return
 	}
 
@@ -408,6 +412,14 @@ func (f *framer) parseFrame() (frame frame, err error) {
 		f.readTrace()
 	}
 
+	if f.header.flags&flagWarning == flagWarning {
+		warnings := f.readStringList()
+		// what to do with warnings?
+		for _, v := range warnings {
+			log.Println(v)
+		}
+	}
+
 	// asumes that the frame body has been read into rbuf
 	switch f.header.op {
 	case opError:
@@ -490,6 +502,23 @@ func (f *framer) parseErrorFrame() frame {
 			errorFrame:  errD,
 			StatementId: stmtId,
 		}
+	case errReadFailure:
+		res := &RequestErrReadFailure{
+			errorFrame: errD,
+		}
+		res.Consistency = f.readConsistency()
+		res.Received = f.readInt()
+		res.BlockFor = f.readInt()
+		res.DataPresent = f.readByte() != 0
+		return res
+	case errFunctionFailure:
+		res := RequestErrFunctionFailure{
+			errorFrame: errD,
+		}
+		res.Keyspace = f.readString()
+		res.Function = f.readString()
+		res.ArgTypes = f.readStringList()
+		return res
 	default:
 		return &errD
 	}
@@ -600,6 +629,10 @@ type writeStartupFrame struct {
 	opts map[string]string
 }
 
+func (w writeStartupFrame) String() string {
+	return fmt.Sprintf("[startup opts=%+v]", w.opts)
+}
+
 func (w *writeStartupFrame) writeFrame(framer *framer, streamID int) error {
 	return framer.writeStartupFrame(streamID, w.opts)
 }
@@ -689,6 +722,74 @@ func (f *framer) readTypeInfo() TypeInfo {
 	return simple
 }
 
+type preparedMetadata struct {
+	resultMetadata
+
+	// proto v4+
+	pkeyColumns []int
+}
+
+func (r preparedMetadata) String() string {
+	return fmt.Sprintf("[paging_metadata flags=0x%x pkey=%q paging_state=% X columns=%v]", r.flags, r.pkeyColumns, r.pagingState, r.columns)
+}
+
+func (f *framer) parsePreparedMetadata() preparedMetadata {
+	// TODO: deduplicate this from parseMetadata
+	meta := preparedMetadata{}
+	meta.flags = f.readInt()
+
+	colCount := f.readInt()
+	if colCount < 0 {
+		panic(fmt.Errorf("received negative column count: %d", colCount))
+	}
+	meta.actualColCount = colCount
+
+	if f.proto >= protoVersion4 {
+		pkeyCount := f.readInt()
+		pkeys := make([]int, pkeyCount)
+		for i := 0; i < pkeyCount; i++ {
+			pkeys[i] = int(f.readShort())
+		}
+		meta.pkeyColumns = pkeys
+	}
+
+	if meta.flags&flagHasMorePages == flagHasMorePages {
+		meta.pagingState = f.readBytes()
+	}
+
+	if meta.flags&flagNoMetaData == flagNoMetaData {
+		return meta
+	}
+
+	var keyspace, table string
+	globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec
+	if globalSpec {
+		keyspace = f.readString()
+		table = f.readString()
+	}
+
+	var cols []ColumnInfo
+	if colCount < 1000 {
+		// preallocate columninfo to avoid excess copying
+		cols = make([]ColumnInfo, colCount)
+		for i := 0; i < colCount; i++ {
+			f.readCol(&cols[i], &meta.resultMetadata, globalSpec, keyspace, table)
+		}
+	} else {
+		// use append, huge number of columns usually indicates a corrupt frame or
+		// just a huge row.
+		for i := 0; i < colCount; i++ {
+			var col ColumnInfo
+			f.readCol(&col, &meta.resultMetadata, globalSpec, keyspace, table)
+			cols = append(cols, col)
+		}
+	}
+
+	meta.columns = cols
+
+	return meta
+}
+
 type resultMetadata struct {
 	flags int
 
@@ -858,7 +959,7 @@ type resultPreparedFrame struct {
 	frameHeader
 
 	preparedID []byte
-	reqMeta    resultMetadata
+	reqMeta    preparedMetadata
 	respMeta   resultMetadata
 }
 
@@ -866,7 +967,7 @@ func (f *framer) parseResultPrepared() frame {
 	frame := &resultPreparedFrame{
 		frameHeader: *f.header,
 		preparedID:  f.readShortBytes(),
-		reqMeta:     f.parseResultMetadata(),
+		reqMeta:     f.parsePreparedMetadata(),
 	}
 
 	if f.proto < protoVersion2 {
@@ -890,29 +991,90 @@ func (s *resultSchemaChangeFrame) String() string {
 	return fmt.Sprintf("[result_schema_change change=%s keyspace=%s table=%s]", s.change, s.keyspace, s.table)
 }
 
+type schemaChangeKeyspace struct {
+	frameHeader
+
+	change   string
+	keyspace string
+}
+
+func (f schemaChangeKeyspace) String() string {
+	return fmt.Sprintf("[event schema_change_keyspace change=%q keyspace=%q]", f.change, f.keyspace)
+}
+
+type schemaChangeTable struct {
+	frameHeader
+
+	change   string
+	keyspace string
+	object   string
+}
+
+func (f schemaChangeTable) String() string {
+	return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.keyspace, f.object)
+}
+
+type schemaChangeFunction struct {
+	frameHeader
+
+	change   string
+	keyspace string
+	name     string
+	args     []string
+}
+
 func (f *framer) parseResultSchemaChange() frame {
-	frame := &resultSchemaChangeFrame{
-		frameHeader: *f.header,
-	}
+	if f.proto <= protoVersion2 {
+		frame := &resultSchemaChangeFrame{
+			frameHeader: *f.header,
+		}
 
-	if f.proto < protoVersion3 {
 		frame.change = f.readString()
 		frame.keyspace = f.readString()
 		frame.table = f.readString()
+
+		return frame
 	} else {
-		// TODO: improve type representation of this
-		frame.change = f.readString()
+		change := f.readString()
 		target := f.readString()
+
+		// TODO: could just use a seperate type for each target
 		switch target {
 		case "KEYSPACE":
+			frame := &schemaChangeKeyspace{
+				frameHeader: *f.header,
+				change:      change,
+			}
+
 			frame.keyspace = f.readString()
+
+			return frame
 		case "TABLE", "TYPE":
+			frame := &schemaChangeTable{
+				frameHeader: *f.header,
+				change:      change,
+			}
+
 			frame.keyspace = f.readString()
-			frame.table = f.readString()
+			frame.object = f.readString()
+
+			return frame
+		case "FUNCTION", "AGGREGATE":
+			frame := &schemaChangeFunction{
+				frameHeader: *f.header,
+				change:      change,
+			}
+
+			frame.keyspace = f.readString()
+			frame.name = f.readString()
+			frame.args = f.readStringList()
+
+			return frame
+		default:
+			panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change))
 		}
 	}
 
-	return frame
 }
 
 type authenticateFrame struct {

+ 4 - 0
integration.sh

@@ -47,8 +47,12 @@ function run_tests() {
 	local proto=2
 	if [[ $version == 1.2.* ]]; then
 		proto=1
+	elif [[ $version == 2.0.* ]]; then
+		proto=2
 	elif [[ $version == 2.1.* ]]; then
 		proto=3
+	elif [[ $version == 2.2.* ]]; then
+		proto=4
 	fi
 
 	if [ "$auth" = true ]

+ 49 - 21
metadata.go

@@ -375,12 +375,20 @@ func getKeyspaceMetadata(
 }
 
 // query for only the table metadata in the specified keyspace from system.schema_columnfamilies
-func getTableMetadata(
-	session *Session,
-	keyspaceName string,
-) ([]TableMetadata, error) {
-	query := session.Query(
-		`
+func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, error) {
+
+	var (
+		scan func(iter *Iter, table *TableMetadata) bool
+		stmt string
+
+		keyAliasesJSON    []byte
+		columnAliasesJSON []byte
+	)
+
+	if session.cfg.ProtoVersion < protoVersion4 {
+		// we have key aliases
+		// TODO: Do we need key_aliases?
+		stmt = `
 		SELECT
 			columnfamily_name,
 			key_validator,
@@ -390,29 +398,49 @@ func getTableMetadata(
 			column_aliases,
 			value_alias
 		FROM system.schema_columnfamilies
-		WHERE keyspace_name = ?
-		`,
-		keyspaceName,
-	)
+		WHERE keyspace_name = ?`
+
+		scan = func(iter *Iter, table *TableMetadata) bool {
+			return iter.Scan(
+				&table.Name,
+				&table.KeyValidator,
+				&table.Comparator,
+				&table.DefaultValidator,
+				&keyAliasesJSON,
+				&columnAliasesJSON,
+				&table.ValueAlias,
+			)
+		}
+	} else {
+		stmt = `
+		SELECT
+			columnfamily_name,
+			key_validator,
+			comparator,
+			default_validator
+		FROM system.schema_columnfamilies
+		WHERE keyspace_name = ?`
+
+		scan = func(iter *Iter, table *TableMetadata) bool {
+			return iter.Scan(
+				&table.Name,
+				&table.KeyValidator,
+				&table.Comparator,
+				&table.DefaultValidator,
+			)
+		}
+	}
+
 	// Set a routing key to avoid GetRoutingKey from computing the routing key
 	// TODO use a separate connection (pool) for system keyspace queries.
+	query := session.Query(stmt, keyspaceName)
 	query.RoutingKey([]byte{})
 	iter := query.Iter()
 
 	tables := []TableMetadata{}
 	table := TableMetadata{Keyspace: keyspaceName}
 
-	var keyAliasesJSON []byte
-	var columnAliasesJSON []byte
-	for iter.Scan(
-		&table.Name,
-		&table.KeyValidator,
-		&table.Comparator,
-		&table.DefaultValidator,
-		&keyAliasesJSON,
-		&columnAliasesJSON,
-		&table.ValueAlias,
-	) {
+	for scan(iter, &table) {
 		var err error
 
 		// decode the key aliases

+ 43 - 38
wiki_test.go

@@ -6,9 +6,10 @@ import (
 	"fmt"
 	"reflect"
 	"sort"
-	"gopkg.in/inf.v0"
 	"testing"
 	"time"
+
+	"gopkg.in/inf.v0"
 )
 
 type WikiPage struct {
@@ -49,14 +50,17 @@ var wikiTestData = []*WikiPage{
 type WikiTest struct {
 	session *Session
 	tb      testing.TB
-}
 
-func (w *WikiTest) CreateSchema() {
+	table string
+}
 
-	if err := w.session.Query(`DROP TABLE wiki_page`).Exec(); err != nil && err.Error() != "unconfigured columnfamily wiki_page" {
-		w.tb.Fatal("CreateSchema:", err)
+func CreateSchema(session *Session, tb testing.TB, table string) *WikiTest {
+	table = "wiki_" + table
+	if err := session.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s", table)).Exec(); err != nil {
+		tb.Fatal("CreateSchema:", err)
 	}
-	err := createTable(w.session, `CREATE TABLE wiki_page (
+
+	err := createTable(session, fmt.Sprintf(`CREATE TABLE %s (
 			title       varchar,
 			revid       timeuuid,
 			body        varchar,
@@ -67,13 +71,16 @@ func (w *WikiTest) CreateSchema() {
 			tags        set<varchar>,
 			attachments map<varchar, blob>,
 			PRIMARY KEY (title, revid)
-		)`)
-	if *clusterSize > 1 {
-		// wait for table definition to propogate
-		time.Sleep(250 * time.Millisecond)
-	}
+		)`, table))
+
 	if err != nil {
-		w.tb.Fatal("CreateSchema:", err)
+		tb.Fatal("CreateSchema:", err)
+	}
+
+	return &WikiTest{
+		session: session,
+		tb:      tb,
+		table:   table,
 	}
 }
 
@@ -92,17 +99,17 @@ func (w *WikiTest) CreatePages(n int) {
 }
 
 func (w *WikiTest) InsertPage(page *WikiPage) error {
-	return w.session.Query(`INSERT INTO wiki_page
+	return w.session.Query(fmt.Sprintf(`INSERT INTO %s
 		(title, revid, body, views, protected, modified, rating, tags, attachments)
-		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, w.table),
 		page.Title, page.RevId, page.Body, page.Views, page.Protected,
 		page.Modified, page.Rating, page.Tags, page.Attachments).Exec()
 }
 
 func (w *WikiTest) SelectPage(page *WikiPage, title string, revid UUID) error {
-	return w.session.Query(`SELECT title, revid, body, views, protected,
+	return w.session.Query(fmt.Sprintf(`SELECT title, revid, body, views, protected,
 		modified,tags, attachments, rating
-		FROM wiki_page WHERE title = ? AND revid = ? LIMIT 1`,
+		FROM %s WHERE title = ? AND revid = ? LIMIT 1`, w.table),
 		title, revid).Scan(&page.Title, &page.RevId,
 		&page.Body, &page.Views, &page.Protected, &page.Modified, &page.Tags,
 		&page.Attachments, &page.Rating)
@@ -110,7 +117,7 @@ func (w *WikiTest) SelectPage(page *WikiPage, title string, revid UUID) error {
 
 func (w *WikiTest) GetPageCount() int {
 	var count int
-	if err := w.session.Query(`SELECT COUNT(*) FROM wiki_page`).Scan(&count); err != nil {
+	if err := w.session.Query(fmt.Sprintf(`SELECT COUNT(*) FROM %s`, w.table)).Scan(&count); err != nil {
 		w.tb.Error("GetPageCount", err)
 	}
 	return count
@@ -120,8 +127,7 @@ func TestWikiCreateSchema(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 
-	w := WikiTest{session, t}
-	w.CreateSchema()
+	CreateSchema(session, t, "create")
 }
 
 func BenchmarkWikiCreateSchema(b *testing.B) {
@@ -131,11 +137,10 @@ func BenchmarkWikiCreateSchema(b *testing.B) {
 		b.StopTimer()
 		session.Close()
 	}()
-	w := WikiTest{session, b}
-	b.StartTimer()
 
+	b.StartTimer()
 	for i := 0; i < b.N; i++ {
-		w.CreateSchema()
+		CreateSchema(session, b, "bench_create")
 	}
 }
 
@@ -143,8 +148,8 @@ func TestWikiCreatePages(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 
-	w := WikiTest{session, t}
-	w.CreateSchema()
+	w := CreateSchema(session, t, "create_pages")
+
 	numPages := 5
 	w.CreatePages(numPages)
 	if count := w.GetPageCount(); count != numPages {
@@ -159,8 +164,9 @@ func BenchmarkWikiCreatePages(b *testing.B) {
 		b.StopTimer()
 		session.Close()
 	}()
-	w := WikiTest{session, b}
-	w.CreateSchema()
+
+	w := CreateSchema(session, b, "bench_create_pages")
+
 	b.StartTimer()
 
 	w.CreatePages(b.N)
@@ -173,16 +179,16 @@ func BenchmarkWikiSelectAllPages(b *testing.B) {
 		b.StopTimer()
 		session.Close()
 	}()
-	w := WikiTest{session, b}
-	w.CreateSchema()
+	w := CreateSchema(session, b, "bench_select_all")
+
 	w.CreatePages(100)
 	b.StartTimer()
 
 	var page WikiPage
 	for i := 0; i < b.N; i++ {
-		iter := session.Query(`SELECT title, revid, body, views, protected,
+		iter := session.Query(fmt.Sprintf(`SELECT title, revid, body, views, protected,
 			modified, tags, attachments, rating
-			FROM wiki_page`).Iter()
+			FROM %s`, w.table)).Iter()
 		for iter.Scan(&page.Title, &page.RevId, &page.Body, &page.Views,
 			&page.Protected, &page.Modified, &page.Tags, &page.Attachments,
 			&page.Rating) {
@@ -201,11 +207,10 @@ func BenchmarkWikiSelectSinglePage(b *testing.B) {
 		b.StopTimer()
 		session.Close()
 	}()
-	w := WikiTest{session, b}
-	w.CreateSchema()
+	w := CreateSchema(session, b, "bench_select_single")
 	pages := make([]WikiPage, 100)
 	w.CreatePages(len(pages))
-	iter := session.Query(`SELECT title, revid FROM wiki_page`).Iter()
+	iter := session.Query(fmt.Sprintf(`SELECT title, revid FROM %s`, w.table)).Iter()
 	for i := 0; i < len(pages); i++ {
 		if !iter.Scan(&pages[i].Title, &pages[i].RevId) {
 			pages = pages[:i]
@@ -233,9 +238,9 @@ func BenchmarkWikiSelectPageCount(b *testing.B) {
 		b.StopTimer()
 		session.Close()
 	}()
-	w := WikiTest{session, b}
-	w.CreateSchema()
-	numPages := 10
+
+	w := CreateSchema(session, b, "bench_page_count")
+	const numPages = 10
 	w.CreatePages(numPages)
 	b.StartTimer()
 	for i := 0; i < b.N; i++ {
@@ -249,8 +254,8 @@ func TestWikiTypicalCRUD(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
 
-	w := WikiTest{session, t}
-	w.CreateSchema()
+	w := CreateSchema(session, t, "crud")
+
 	for _, page := range wikiTestData {
 		if err := w.InsertPage(page); err != nil {
 			t.Error("InsertPage:", err)