فهرست منبع

Merge pull request #607 from Zariel/skip-metadata

Skip metadata
Chris Bannister 10 سال پیش
والد
کامیت
8bb0dcb0c6
5فایلهای تغییر یافته به همراه137 افزوده شده و 94 حذف شده
  1. 23 31
      cassandra_test.go
  2. 9 0
      cluster.go
  3. 75 36
      conn.go
  4. 23 23
      frame.go
  5. 7 4
      session.go

+ 23 - 31
cassandra_test.go

@@ -461,11 +461,7 @@ func TestTooManyQueryArgs(t *testing.T) {
 	_, err := session.Query(`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2).Iter().SliceMap()
 
 	if err == nil {
-		t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an ErrQueryArgLength")
-	}
-
-	if err != ErrQueryArgLength {
-		t.Fatalf("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an ErrQueryArgLength, but returned: %s", err)
+		t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error")
 	}
 
 	batch := session.NewBatch(UnloggedBatch)
@@ -473,12 +469,10 @@ func TestTooManyQueryArgs(t *testing.T) {
 	err = session.ExecuteBatch(batch)
 
 	if err == nil {
-		t.Fatal("'`INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an ErrQueryArgLength")
+		t.Fatal("'`INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an error")
 	}
 
-	if err != ErrQueryArgLength {
-		t.Fatalf("'INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an ErrQueryArgLength, but returned: %s", err)
-	}
+	// TODO: should indicate via an error code that it is an invalid arg?
 
 }
 
@@ -498,11 +492,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
 	_, err := session.Query(`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1).Iter().SliceMap()
 
 	if err == nil {
-		t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an ErrQueryArgLength")
-	}
-
-	if err != ErrQueryArgLength {
-		t.Fatalf("'`SELECT * FROM too_few_query_args WHERE id = ? and cluster = ?`, 1' should return an ErrQueryArgLength, but returned: %s", err)
+		t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error")
 	}
 
 	batch := session.NewBatch(UnloggedBatch)
@@ -510,11 +500,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
 	err = session.ExecuteBatch(batch)
 
 	if err == nil {
-		t.Fatal("'`INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an ErrQueryArgLength")
-	}
-
-	if err != ErrQueryArgLength {
-		t.Fatalf("'INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an ErrQueryArgLength, but returned: %s", err)
+		t.Fatal("'`INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an error")
 	}
 }
 
@@ -1011,15 +997,21 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	session.stmtsLRU.Lock()
 	session.stmtsLRU.lru.Add(conn.addr+stmt, flight)
 	session.stmtsLRU.Unlock()
-	flight.info = QueryInfo{
-		Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
-		Args: []ColumnInfo{
-			{
-				Keyspace: "gocql_test",
-				Table:    table,
-				Name:     "foo",
-				TypeInfo: NativeType{
-					typ: TypeVarchar,
+	flight.preparedStatment = &preparedStatment{
+		id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
+		request: preparedMetadata{
+			resultMetadata: resultMetadata{
+				colCount:       1,
+				actualColCount: 1,
+				columns: []ColumnInfo{
+					{
+						Keyspace: "gocql_test",
+						Table:    table,
+						Name:     "foo",
+						TypeInfo: NativeType{
+							typ: TypeVarchar,
+						},
+					},
 				},
 			},
 		},
@@ -1085,12 +1077,12 @@ func TestQueryInfo(t *testing.T) {
 		t.Fatalf("Failed to execute query for preparing statement: %v", err)
 	}
 
-	if x := len(info.Args); x != 1 {
+	if x := len(info.request.columns); x != 1 {
 		t.Fatalf("Was not expecting meta data for %d query arguments, but got %d\n", 1, x)
 	}
 
 	if *flagProto > 1 {
-		if x := len(info.Rval); x != 2 {
+		if x := len(info.response.columns); x != 2 {
 			t.Fatalf("Was not expecting meta data for %d result columns, but got %d\n", 2, x)
 		}
 	}
@@ -1982,7 +1974,7 @@ func TestManualQueryPaging(t *testing.T) {
 	}
 
 	if fetched != rowsToInsert {
-		t.Fatalf("expected to fetch %d rows got %d", fetched, rowsToInsert)
+		t.Fatalf("expected to fetch %d rows got %d", rowsToInsert, fetched)
 	}
 }
 

+ 9 - 0
cluster.go

@@ -32,6 +32,15 @@ func (p *preparedLRU) max(max int) {
 	p.lru.MaxEntries = max
 }
 
+func (p *preparedLRU) clear() {
+	p.Lock()
+	defer p.Unlock()
+
+	for p.lru.Len() > 0 {
+		p.lru.RemoveOldest()
+	}
+}
+
 // PoolConfig configures the connection pool used by the driver, it defaults to
 // using a round robbin host selection policy and a round robbin connection selection
 // policy for each host.

+ 75 - 36
conn.go

@@ -579,14 +579,27 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 	return framer, nil
 }
 
-func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error) {
+type preparedStatment struct {
+	id       []byte
+	request  preparedMetadata
+	response resultMetadata
+}
+
+type inflightPrepare struct {
+	wg  sync.WaitGroup
+	err error
+
+	preparedStatment *preparedStatment
+}
+
+func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment, error) {
 	c.session.stmtsLRU.Lock()
 	stmtCacheKey := c.addr + c.currentKeyspace + stmt
 	if val, ok := c.session.stmtsLRU.lru.Get(stmtCacheKey); ok {
 		c.session.stmtsLRU.Unlock()
 		flight := val.(*inflightPrepare)
 		flight.wg.Wait()
-		return &flight.info, flight.err
+		return flight.preparedStatment, flight.err
 	}
 
 	flight := new(inflightPrepare)
@@ -620,14 +633,15 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error)
 
 	switch x := frame.(type) {
 	case *resultPreparedFrame:
-		// defensivly copy as we will recycle the underlying buffer after we
-		// return.
-		flight.info.Id = copyBytes(x.preparedID)
-		// the type info's should _not_ have a reference to the framers read buffer,
-		// therefore we can just copy them directly.
-		flight.info.Args = x.reqMeta.columns
-		flight.info.PKeyColumns = x.reqMeta.pkeyColumns
-		flight.info.Rval = x.respMeta.columns
+		flight.preparedStatment = &preparedStatment{
+			// defensivly copy as we will recycle the underlying buffer after we
+			// return.
+			id: copyBytes(x.preparedID),
+			// the type info's should _not_ have a reference to the framers read buffer,
+			// therefore we can just copy them directly.
+			request:  x.reqMeta,
+			response: x.respMeta,
+		}
 	case error:
 		flight.err = x
 	default:
@@ -643,7 +657,7 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error)
 
 	framerPool.Put(framer)
 
-	return &flight.info, flight.err
+	return flight.preparedStatment, flight.err
 }
 
 func (c *Conn) executeQuery(qry *Query) *Iter {
@@ -662,10 +676,15 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		params.pageSize = qry.pageSize
 	}
 
-	var frame frameWriter
+	var (
+		frame frameWriter
+		info  *preparedStatment
+	)
+
 	if qry.shouldPrepare() {
 		// Prepare all DML queries. Other queries can not be prepared.
-		info, err := c.prepareStatement(qry.stmt, qry.trace)
+		var err error
+		info, err = c.prepareStatement(qry.stmt, qry.trace)
 		if err != nil {
 			return &Iter{err: err}
 		}
@@ -675,19 +694,25 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		if qry.binding == nil {
 			values = qry.values
 		} else {
-			values, err = qry.binding(info)
+			values, err = qry.binding(&QueryInfo{
+				Id:          info.id,
+				Args:        info.request.columns,
+				Rval:        info.response.columns,
+				PKeyColumns: info.request.pkeyColumns,
+			})
+
 			if err != nil {
 				return &Iter{err: err}
 			}
 		}
 
-		if len(values) != len(info.Args) {
-			return &Iter{err: ErrQueryArgLength}
+		if len(values) != info.request.actualColCount {
+			return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))}
 		}
 
 		params.values = make([]queryValues, len(values))
 		for i := 0; i < len(values); i++ {
-			val, err := Marshal(info.Args[i].TypeInfo, values[i])
+			val, err := Marshal(info.request.columns[i].TypeInfo, values[i])
 			if err != nil {
 				return &Iter{err: err}
 			}
@@ -697,8 +722,10 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 			// TODO: handle query binding names
 		}
 
+		params.skipMeta = !qry.isCAS
+
 		frame = &writeExecuteFrame{
-			preparedID: info.Id,
+			preparedID: info.id,
 			params:     params,
 		}
 	} else {
@@ -727,18 +754,28 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		return &Iter{framer: framer}
 	case *resultRowsFrame:
 		iter := &Iter{
-			meta:   x.meta,
 			rows:   x.rows,
 			framer: framer,
 		}
 
+		if params.skipMeta {
+			if info != nil {
+				iter.meta = info.response
+				iter.meta.pagingState = x.meta.pagingState
+			} else {
+				return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")}
+			}
+		} else {
+			iter.meta = x.meta
+		}
+
 		if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
 			iter.next = &nextIter{
 				qry: *qry,
 				pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
 			}
 
-			iter.next.qry.pageState = x.meta.pagingState
+			iter.next.qry.pageState = copyBytes(x.meta.pagingState)
 			if iter.next.pos < 1 {
 				iter.next.pos = 1
 			}
@@ -748,6 +785,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case *resultKeyspaceFrame:
 		return &Iter{framer: framer}
 	case *resultSchemaChangeFrame, *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction:
+		// Clear the statments cache so that we dont use stale table info for requests.
+		// TODO: only reset a specific table/keyapce and only when it is changed.
+		c.session.stmtsLRU.clear()
 		iter := &Iter{framer: framer}
 		if err := c.awaitSchemaAgreement(); err != nil {
 			// TODO: should have this behind a flag
@@ -848,27 +888,32 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 				return &Iter{err: err}
 			}
 
-			var args []interface{}
+			var values []interface{}
 			if entry.binding == nil {
-				args = entry.Args
+				values = entry.Args
 			} else {
-				args, err = entry.binding(info)
+				values, err = entry.binding(&QueryInfo{
+					Id:          info.id,
+					Args:        info.request.columns,
+					Rval:        info.response.columns,
+					PKeyColumns: info.request.pkeyColumns,
+				})
 				if err != nil {
 					return &Iter{err: err}
 				}
 			}
 
-			if len(args) != len(info.Args) {
-				return &Iter{err: ErrQueryArgLength}
+			if len(values) != info.request.actualColCount {
+				return &Iter{err: fmt.Errorf("gocql: batch statment %d expected %d values send got %d", i, info.request.actualColCount, len(values))}
 			}
 
-			b.preparedID = info.Id
-			stmts[string(info.Id)] = entry.Stmt
+			b.preparedID = info.id
+			stmts[string(info.id)] = entry.Stmt
 
-			b.values = make([]queryValues, len(info.Args))
+			b.values = make([]queryValues, info.request.actualColCount)
 
-			for j := 0; j < len(info.Args); j++ {
-				val, err := Marshal(info.Args[j].TypeInfo, args[j])
+			for j := 0; j < info.request.actualColCount; j++ {
+				val, err := Marshal(info.request.columns[j].TypeInfo, values[j])
 				if err != nil {
 					return &Iter{err: err}
 				}
@@ -1000,12 +1045,6 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 	return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
 }
 
-type inflightPrepare struct {
-	info QueryInfo
-	err  error
-	wg   sync.WaitGroup
-}
-
 var (
 	ErrQueryArgLength    = errors.New("gocql: query argument length mismatch")
 	ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")

+ 23 - 23
frame.go

@@ -758,19 +758,19 @@ type preparedMetadata struct {
 }
 
 func (r preparedMetadata) String() string {
-	return fmt.Sprintf("[paging_metadata flags=0x%x pkey=%q paging_state=% X columns=%v]", r.flags, r.pkeyColumns, r.pagingState, r.columns)
+	return fmt.Sprintf("[prepared flags=0x%x pkey=%v paging_state=% X columns=%v col_count=%d actual_col_count=%d]", r.flags, r.pkeyColumns, r.pagingState, r.columns, r.colCount, r.actualColCount)
 }
 
 func (f *framer) parsePreparedMetadata() preparedMetadata {
 	// TODO: deduplicate this from parseMetadata
 	meta := preparedMetadata{}
-	meta.flags = f.readInt()
 
-	colCount := f.readInt()
-	if colCount < 0 {
-		panic(fmt.Errorf("received negative column count: %d", colCount))
+	meta.flags = f.readInt()
+	meta.colCount = f.readInt()
+	if meta.colCount < 0 {
+		panic(fmt.Errorf("received negative column count: %d", meta.colCount))
 	}
-	meta.actualColCount = colCount
+	meta.actualColCount = meta.colCount
 
 	if f.proto >= protoVersion4 {
 		pkeyCount := f.readInt()
@@ -797,16 +797,16 @@ func (f *framer) parsePreparedMetadata() preparedMetadata {
 	}
 
 	var cols []ColumnInfo
-	if colCount < 1000 {
+	if meta.colCount < 1000 {
 		// preallocate columninfo to avoid excess copying
-		cols = make([]ColumnInfo, colCount)
-		for i := 0; i < colCount; i++ {
+		cols = make([]ColumnInfo, meta.colCount)
+		for i := 0; i < meta.colCount; i++ {
 			f.readCol(&cols[i], &meta.resultMetadata, globalSpec, keyspace, table)
 		}
 	} else {
 		// use append, huge number of columns usually indicates a corrupt frame or
 		// just a huge row.
-		for i := 0; i < colCount; i++ {
+		for i := 0; i < meta.colCount; i++ {
 			var col ColumnInfo
 			f.readCol(&col, &meta.resultMetadata, globalSpec, keyspace, table)
 			cols = append(cols, col)
@@ -824,7 +824,8 @@ type resultMetadata struct {
 	// only if flagPageState
 	pagingState []byte
 
-	columns []ColumnInfo
+	columns  []ColumnInfo
+	colCount int
 
 	// this is a count of the total number of columns which can be scanned,
 	// it is at minimum len(columns) but may be larger, for instance when a column
@@ -856,15 +857,14 @@ func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool,
 }
 
 func (f *framer) parseResultMetadata() resultMetadata {
-	meta := resultMetadata{
-		flags: f.readInt(),
-	}
+	var meta resultMetadata
 
-	colCount := f.readInt()
-	if colCount < 0 {
-		panic(fmt.Errorf("received negative column count: %d", colCount))
+	meta.flags = f.readInt()
+	meta.colCount = f.readInt()
+	if meta.colCount < 0 {
+		panic(fmt.Errorf("received negative column count: %d", meta.colCount))
 	}
-	meta.actualColCount = colCount
+	meta.actualColCount = meta.colCount
 
 	if meta.flags&flagHasMorePages == flagHasMorePages {
 		meta.pagingState = f.readBytes()
@@ -882,17 +882,17 @@ func (f *framer) parseResultMetadata() resultMetadata {
 	}
 
 	var cols []ColumnInfo
-	if colCount < 1000 {
+	if meta.colCount < 1000 {
 		// preallocate columninfo to avoid excess copying
-		cols = make([]ColumnInfo, colCount)
-		for i := 0; i < colCount; i++ {
+		cols = make([]ColumnInfo, meta.colCount)
+		for i := 0; i < meta.colCount; i++ {
 			f.readCol(&cols[i], &meta, globalSpec, keyspace, table)
 		}
 
 	} else {
 		// use append, huge number of columns usually indicates a corrupt frame or
 		// just a huge row.
-		for i := 0; i < colCount; i++ {
+		for i := 0; i < meta.colCount; i++ {
 			var col ColumnInfo
 			f.readCol(&col, &meta, globalSpec, keyspace, table)
 			cols = append(cols, col)
@@ -950,7 +950,7 @@ func (f *framer) parseResultRows() frame {
 		panic(fmt.Errorf("invalid row_count in result frame: %d", numRows))
 	}
 
-	colCount := len(meta.columns)
+	colCount := meta.colCount
 
 	rows := make([][][]byte, numRows)
 	for i := 0; i < numRows; i++ {

+ 7 - 4
session.go

@@ -363,7 +363,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	s.routingKeyInfoCache.mu.Unlock()
 
 	var (
-		info         *QueryInfo
+		info         *preparedStatment
 		partitionKey []*ColumnMetadata
 	)
 
@@ -388,13 +388,13 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	// Mark host as OK
 	host.Mark(nil)
 
-	if len(info.Args) == 0 {
+	if info.request.colCount == 0 {
 		// no arguments, no routing key, and no error
 		return nil, nil
 	}
 
 	// get the table metadata
-	table := info.Args[0].Table
+	table := info.request.columns[0].Table
 
 	var keyspaceMetadata *KeyspaceMetadata
 	keyspaceMetadata, inflight.err = s.KeyspaceMetadata(s.cfg.Keyspace)
@@ -427,7 +427,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 		routingKeyInfo.indexes[keyIndex] = -1
 
 		// find the column in the query info
-		for argIndex, boundColumn := range info.Args {
+		for argIndex, boundColumn := range info.request.columns {
 			if keyColumn.Name == boundColumn.Name {
 				// there may be many such bound columns, pick the first
 				routingKeyInfo.indexes[keyIndex] = argIndex
@@ -576,6 +576,7 @@ type Query struct {
 	totalLatency     int64
 	serialCons       SerialConsistency
 	defaultTimestamp bool
+	isCAS            bool
 
 	disableAutoPage bool
 }
@@ -819,6 +820,7 @@ func (q *Query) Scan(dest ...interface{}) error {
 // the existing values did not match, the previous values will be stored
 // in dest.
 func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
+	q.isCAS = true
 	iter := q.Iter()
 	if err := iter.checkErrAndNotFound(); err != nil {
 		return false, err
@@ -841,6 +843,7 @@ func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
 // SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
 // column mismatching. MapScanCAS is added to capture them safely.
 func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error) {
+	q.isCAS = true
 	iter := q.Iter()
 	if err := iter.checkErrAndNotFound(); err != nil {
 		return false, err