浏览代码

dont recycle framers until closing iters

Delay releasing frames until iter.Close so that we can reuse the
underlying buffers and not copy them when we dont need to.
Chris Bannister 10 年之前
父节点
当前提交
d0dbd15082
共有 4 个文件被更改,包括 101 次插入36 次删除
  1. 80 31
      conn.go
  2. 1 1
      frame.go
  3. 6 0
      helpers.go
  4. 14 4
      session.go

+ 80 - 31
conn.go

@@ -240,7 +240,12 @@ func (c *Conn) startup(cfg *ConnConfig) error {
 		m["COMPRESSION"] = c.compressor.Name()
 		m["COMPRESSION"] = c.compressor.Name()
 	}
 	}
 
 
-	frame, err := c.exec(&writeStartupFrame{opts: m}, nil)
+	framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
+	if err != nil {
+		return err
+	}
+
+	frame, err := framer.parseFrame()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -270,7 +275,12 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 	req := &writeAuthResponseFrame{data: resp}
 	req := &writeAuthResponseFrame{data: resp}
 
 
 	for {
 	for {
-		frame, err := c.exec(req, nil)
+		framer, err := c.exec(req, nil)
+		if err != nil {
+			return err
+		}
+
+		frame, err := framer.parseFrame()
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -295,6 +305,8 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 		default:
 		default:
 			return fmt.Errorf("unknown frame response during authentication: %v", v)
 			return fmt.Errorf("unknown frame response during authentication: %v", v)
 		}
 		}
+
+		framerPool.Put(framer)
 	}
 	}
 }
 }
 
 
@@ -385,6 +397,7 @@ func (c *Conn) recv() error {
 		if err := framer.readFrame(&head); err != nil {
 		if err := framer.readFrame(&head); err != nil {
 			return err
 			return err
 		}
 		}
+		defer framerPool.Put(framer)
 
 
 		frame, err := framer.parseFrame()
 		frame, err := framer.parseFrame()
 		if err != nil {
 		if err != nil {
@@ -435,7 +448,6 @@ type callReq struct {
 
 
 func (c *Conn) releaseStream(stream int) {
 func (c *Conn) releaseStream(stream int) {
 	call := &c.calls[stream]
 	call := &c.calls[stream]
-	framerPool.Put(call.framer)
 	call.framer = nil
 	call.framer = nil
 
 
 	select {
 	select {
@@ -450,7 +462,7 @@ func (c *Conn) handleTimeout() {
 	}
 	}
 }
 }
 
 
-func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
+func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 	// TODO: move tracer onto conn
 	// TODO: move tracer onto conn
 	var stream int
 	var stream int
 	select {
 	select {
@@ -512,19 +524,10 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
 		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
 	}
 	}
 
 
-	frame, err := framer.parseFrame()
-	if err != nil {
-		return nil, err
-	}
-
-	if len(framer.traceID) > 0 {
-		tracer.Trace(framer.traceID)
-	}
-
-	return frame, nil
+	return framer, nil
 }
 }
 
 
-func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
+func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error) {
 	stmtsLRU.Lock()
 	stmtsLRU.Lock()
 	if stmtsLRU.lru == nil {
 	if stmtsLRU.lru == nil {
 		initStmtsLRU(defaultMaxPreparedStmts)
 		initStmtsLRU(defaultMaxPreparedStmts)
@@ -548,17 +551,31 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 		statement: stmt,
 		statement: stmt,
 	}
 	}
 
 
-	resp, err := c.exec(prep, trace)
+	framer, err := c.exec(prep, tracer)
 	if err != nil {
 	if err != nil {
 		flight.err = err
 		flight.err = err
 		flight.wg.Done()
 		flight.wg.Done()
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	switch x := resp.(type) {
+	frame, err := framer.parseFrame()
+	if err != nil {
+		flight.err = err
+		flight.wg.Done()
+		return nil, err
+	}
+
+	// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
+	// everytime we need to parse a frame.
+	if len(framer.traceID) > 0 {
+		tracer.Trace(framer.traceID)
+	}
+
+	switch x := frame.(type) {
 	case *resultPreparedFrame:
 	case *resultPreparedFrame:
-		flight.info.Id = make([]byte, len(x.preparedID))
-		copy(flight.info.Id, x.preparedID)
+		// defensivly copy as we will recycle the underlying buffer after we
+		// return.
+		flight.info.Id = copyBytes(x.preparedID)
 		// the type info's should _not_ have a reference to the framers read buffer,
 		// the type info's should _not_ have a reference to the framers read buffer,
 		// therefore we can just copy them directly.
 		// therefore we can just copy them directly.
 		flight.info.Args = x.reqMeta.columns
 		flight.info.Args = x.reqMeta.columns
@@ -577,6 +594,8 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 		stmtsLRU.Unlock()
 		stmtsLRU.Unlock()
 	}
 	}
 
 
+	framerPool.Put(framer)
+
 	return &flight.info, flight.err
 	return &flight.info, flight.err
 }
 }
 
 
@@ -642,18 +661,28 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		}
 		}
 	}
 	}
 
 
-	resp, err := c.exec(frame, qry.trace)
+	framer, err := c.exec(frame, qry.trace)
 	if err != nil {
 	if err != nil {
 		return &Iter{err: err}
 		return &Iter{err: err}
 	}
 	}
 
 
+	resp, err := framer.parseFrame()
+	if err != nil {
+		return &Iter{err: err}
+	}
+
+	if len(framer.traceID) > 0 {
+		qry.trace.Trace(framer.traceID)
+	}
+
 	switch x := resp.(type) {
 	switch x := resp.(type) {
 	case *resultVoidFrame:
 	case *resultVoidFrame:
-		return &Iter{}
+		return &Iter{framer: framer}
 	case *resultRowsFrame:
 	case *resultRowsFrame:
 		iter := &Iter{
 		iter := &Iter{
-			meta: x.meta,
-			rows: x.rows,
+			meta:   x.meta,
+			rows:   x.rows,
+			framer: framer,
 		}
 		}
 
 
 		if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
 		if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
@@ -670,7 +699,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 
 
 		return iter
 		return iter
 	case *resultKeyspaceFrame, *resultSchemaChangeFrame, *schemaChangeKeyspace, *schemaChangeTable:
 	case *resultKeyspaceFrame, *resultSchemaChangeFrame, *schemaChangeKeyspace, *schemaChangeTable:
-		return &Iter{}
+		return &Iter{framer: framer}
 	case *RequestErrUnprepared:
 	case *RequestErrUnprepared:
 		stmtsLRU.Lock()
 		stmtsLRU.Lock()
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
@@ -680,11 +709,14 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 			return c.executeQuery(qry)
 			return c.executeQuery(qry)
 		}
 		}
 		stmtsLRU.Unlock()
 		stmtsLRU.Unlock()
-		return &Iter{err: x}
+		return &Iter{err: x, framer: framer}
 	case error:
 	case error:
-		return &Iter{err: x}
+		return &Iter{err: x, framer: framer}
 	default:
 	default:
-		return &Iter{err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x)}
+		return &Iter{
+			err:    NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x),
+			framer: framer,
+		}
 	}
 	}
 }
 }
 
 
@@ -711,7 +743,12 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
 	q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
 	q.params.consistency = Any
 	q.params.consistency = Any
 
 
-	resp, err := c.exec(q, nil)
+	framer, err := c.exec(q, nil)
+	if err != nil {
+		return err
+	}
+
+	resp, err := framer.parseFrame()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -788,13 +825,19 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 	}
 	}
 
 
 	// TODO: should batch support tracing?
 	// TODO: should batch support tracing?
-	resp, err := c.exec(req, nil)
+	framer, err := c.exec(req, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	resp, err := framer.parseFrame()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	switch x := resp.(type) {
 	switch x := resp.(type) {
 	case *resultVoidFrame:
 	case *resultVoidFrame:
+		framerPool.Put(framer)
 		return nil, nil
 		return nil, nil
 	case *RequestErrUnprepared:
 	case *RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		stmt, found := stmts[string(x.StatementId)]
@@ -803,6 +846,9 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
 			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
 			stmtsLRU.Unlock()
 			stmtsLRU.Unlock()
 		}
 		}
+
+		framerPool.Put(framer)
+
 		if found {
 		if found {
 			return c.executeBatch(batch)
 			return c.executeBatch(batch)
 		} else {
 		} else {
@@ -810,14 +856,17 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 		}
 		}
 	case *resultRowsFrame:
 	case *resultRowsFrame:
 		iter := &Iter{
 		iter := &Iter{
-			meta: x.meta,
-			rows: x.rows,
+			meta:   x.meta,
+			rows:   x.rows,
+			framer: framer,
 		}
 		}
 
 
 		return iter, nil
 		return iter, nil
 	case error:
 	case error:
+		framerPool.Put(framer)
 		return nil, x
 		return nil, x
 	default:
 	default:
+		framerPool.Put(framer)
 		return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
 		return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
 	}
 	}
 }
 }

+ 1 - 1
frame.go

@@ -526,7 +526,7 @@ func (f *framer) parseErrorFrame() frame {
 		stmtId := f.readShortBytes()
 		stmtId := f.readShortBytes()
 		return &RequestErrUnprepared{
 		return &RequestErrUnprepared{
 			errorFrame:  errD,
 			errorFrame:  errD,
-			StatementId: stmtId,
+			StatementId: copyBytes(stmtId), // defensivly copy
 		}
 		}
 	case errReadFailure:
 	case errReadFailure:
 		res := &RequestErrReadFailure{
 		res := &RequestErrReadFailure{

+ 6 - 0
helpers.go

@@ -180,3 +180,9 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool {
 	}
 	}
 	return false
 	return false
 }
 }
+
+func copyBytes(p []byte) []byte {
+	b := make([]byte, len(p))
+	copy(b, p)
+	return b
+}

+ 14 - 4
session.go

@@ -151,9 +151,9 @@ func (s *Session) Query(stmt string, values ...interface{}) *Query {
 }
 }
 
 
 type QueryInfo struct {
 type QueryInfo struct {
-	Id   []byte
-	Args []ColumnInfo
-	Rval []ColumnInfo
+	Id          []byte
+	Args        []ColumnInfo
+	Rval        []ColumnInfo
 	PKeyColumns []int
 	PKeyColumns []int
 }
 }
 
 
@@ -287,7 +287,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	s.routingKeyInfoCache.mu.Unlock()
 	s.routingKeyInfoCache.mu.Unlock()
 
 
 	var (
 	var (
-		info     *QueryInfo
+		info         *QueryInfo
 		partitionKey []*ColumnMetadata
 		partitionKey []*ColumnMetadata
 	)
 	)
 
 
@@ -764,6 +764,9 @@ type Iter struct {
 	rows [][][]byte
 	rows [][][]byte
 	meta resultMetadata
 	meta resultMetadata
 	next *nextIter
 	next *nextIter
+
+	framer *framer
+	once   sync.Once
 }
 }
 
 
 // Columns returns the name and type of the selected columns.
 // Columns returns the name and type of the selected columns.
@@ -837,6 +840,13 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 // Close closes the iterator and returns any errors that happened during
 // Close closes the iterator and returns any errors that happened during
 // the query or the iteration.
 // the query or the iteration.
 func (iter *Iter) Close() error {
 func (iter *Iter) Close() error {
+	iter.once.Do(func() {
+		if iter.framer != nil {
+			framerPool.Put(iter.framer)
+			iter.framer = nil
+		}
+	})
+
 	return iter.err
 	return iter.err
 }
 }