Pārlūkot izejas kodu

dont recycle framers until closing iters

Delay releasing frames until iter.Close so that we can reuse the
underlying buffers and not copy them when we dont need to.
Chris Bannister 10 gadi atpakaļ
vecāks
revīzija
d0dbd15082
4 mainītis faili ar 101 papildinājumiem un 36 dzēšanām
  1. 80 31
      conn.go
  2. 1 1
      frame.go
  3. 6 0
      helpers.go
  4. 14 4
      session.go

+ 80 - 31
conn.go

@@ -240,7 +240,12 @@ func (c *Conn) startup(cfg *ConnConfig) error {
 		m["COMPRESSION"] = c.compressor.Name()
 	}
 
-	frame, err := c.exec(&writeStartupFrame{opts: m}, nil)
+	framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
+	if err != nil {
+		return err
+	}
+
+	frame, err := framer.parseFrame()
 	if err != nil {
 		return err
 	}
@@ -270,7 +275,12 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 	req := &writeAuthResponseFrame{data: resp}
 
 	for {
-		frame, err := c.exec(req, nil)
+		framer, err := c.exec(req, nil)
+		if err != nil {
+			return err
+		}
+
+		frame, err := framer.parseFrame()
 		if err != nil {
 			return err
 		}
@@ -295,6 +305,8 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 		default:
 			return fmt.Errorf("unknown frame response during authentication: %v", v)
 		}
+
+		framerPool.Put(framer)
 	}
 }
 
@@ -385,6 +397,7 @@ func (c *Conn) recv() error {
 		if err := framer.readFrame(&head); err != nil {
 			return err
 		}
+		defer framerPool.Put(framer)
 
 		frame, err := framer.parseFrame()
 		if err != nil {
@@ -435,7 +448,6 @@ type callReq struct {
 
 func (c *Conn) releaseStream(stream int) {
 	call := &c.calls[stream]
-	framerPool.Put(call.framer)
 	call.framer = nil
 
 	select {
@@ -450,7 +462,7 @@ func (c *Conn) handleTimeout() {
 	}
 }
 
-func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
+func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 	// TODO: move tracer onto conn
 	var stream int
 	select {
@@ -512,19 +524,10 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
 	}
 
-	frame, err := framer.parseFrame()
-	if err != nil {
-		return nil, err
-	}
-
-	if len(framer.traceID) > 0 {
-		tracer.Trace(framer.traceID)
-	}
-
-	return frame, nil
+	return framer, nil
 }
 
-func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
+func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error) {
 	stmtsLRU.Lock()
 	if stmtsLRU.lru == nil {
 		initStmtsLRU(defaultMaxPreparedStmts)
@@ -548,17 +551,31 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 		statement: stmt,
 	}
 
-	resp, err := c.exec(prep, trace)
+	framer, err := c.exec(prep, tracer)
 	if err != nil {
 		flight.err = err
 		flight.wg.Done()
 		return nil, err
 	}
 
-	switch x := resp.(type) {
+	frame, err := framer.parseFrame()
+	if err != nil {
+		flight.err = err
+		flight.wg.Done()
+		return nil, err
+	}
+
+	// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
+	// everytime we need to parse a frame.
+	if len(framer.traceID) > 0 {
+		tracer.Trace(framer.traceID)
+	}
+
+	switch x := frame.(type) {
 	case *resultPreparedFrame:
-		flight.info.Id = make([]byte, len(x.preparedID))
-		copy(flight.info.Id, x.preparedID)
+		// defensivly copy as we will recycle the underlying buffer after we
+		// return.
+		flight.info.Id = copyBytes(x.preparedID)
 		// the type info's should _not_ have a reference to the framers read buffer,
 		// therefore we can just copy them directly.
 		flight.info.Args = x.reqMeta.columns
@@ -577,6 +594,8 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 		stmtsLRU.Unlock()
 	}
 
+	framerPool.Put(framer)
+
 	return &flight.info, flight.err
 }
 
@@ -642,18 +661,28 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		}
 	}
 
-	resp, err := c.exec(frame, qry.trace)
+	framer, err := c.exec(frame, qry.trace)
 	if err != nil {
 		return &Iter{err: err}
 	}
 
+	resp, err := framer.parseFrame()
+	if err != nil {
+		return &Iter{err: err}
+	}
+
+	if len(framer.traceID) > 0 {
+		qry.trace.Trace(framer.traceID)
+	}
+
 	switch x := resp.(type) {
 	case *resultVoidFrame:
-		return &Iter{}
+		return &Iter{framer: framer}
 	case *resultRowsFrame:
 		iter := &Iter{
-			meta: x.meta,
-			rows: x.rows,
+			meta:   x.meta,
+			rows:   x.rows,
+			framer: framer,
 		}
 
 		if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
@@ -670,7 +699,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 
 		return iter
 	case *resultKeyspaceFrame, *resultSchemaChangeFrame, *schemaChangeKeyspace, *schemaChangeTable:
-		return &Iter{}
+		return &Iter{framer: framer}
 	case *RequestErrUnprepared:
 		stmtsLRU.Lock()
 		stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
@@ -680,11 +709,14 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 			return c.executeQuery(qry)
 		}
 		stmtsLRU.Unlock()
-		return &Iter{err: x}
+		return &Iter{err: x, framer: framer}
 	case error:
-		return &Iter{err: x}
+		return &Iter{err: x, framer: framer}
 	default:
-		return &Iter{err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x)}
+		return &Iter{
+			err:    NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x),
+			framer: framer,
+		}
 	}
 }
 
@@ -711,7 +743,12 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
 	q.params.consistency = Any
 
-	resp, err := c.exec(q, nil)
+	framer, err := c.exec(q, nil)
+	if err != nil {
+		return err
+	}
+
+	resp, err := framer.parseFrame()
 	if err != nil {
 		return err
 	}
@@ -788,13 +825,19 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 	}
 
 	// TODO: should batch support tracing?
-	resp, err := c.exec(req, nil)
+	framer, err := c.exec(req, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	resp, err := framer.parseFrame()
 	if err != nil {
 		return nil, err
 	}
 
 	switch x := resp.(type) {
 	case *resultVoidFrame:
+		framerPool.Put(framer)
 		return nil, nil
 	case *RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
@@ -803,6 +846,9 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 			stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
 			stmtsLRU.Unlock()
 		}
+
+		framerPool.Put(framer)
+
 		if found {
 			return c.executeBatch(batch)
 		} else {
@@ -810,14 +856,17 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 		}
 	case *resultRowsFrame:
 		iter := &Iter{
-			meta: x.meta,
-			rows: x.rows,
+			meta:   x.meta,
+			rows:   x.rows,
+			framer: framer,
 		}
 
 		return iter, nil
 	case error:
+		framerPool.Put(framer)
 		return nil, x
 	default:
+		framerPool.Put(framer)
 		return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
 	}
 }

+ 1 - 1
frame.go

@@ -526,7 +526,7 @@ func (f *framer) parseErrorFrame() frame {
 		stmtId := f.readShortBytes()
 		return &RequestErrUnprepared{
 			errorFrame:  errD,
-			StatementId: stmtId,
+			StatementId: copyBytes(stmtId), // defensivly copy
 		}
 	case errReadFailure:
 		res := &RequestErrReadFailure{

+ 6 - 0
helpers.go

@@ -180,3 +180,9 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool {
 	}
 	return false
 }
+
+func copyBytes(p []byte) []byte {
+	b := make([]byte, len(p))
+	copy(b, p)
+	return b
+}

+ 14 - 4
session.go

@@ -151,9 +151,9 @@ func (s *Session) Query(stmt string, values ...interface{}) *Query {
 }
 
 type QueryInfo struct {
-	Id   []byte
-	Args []ColumnInfo
-	Rval []ColumnInfo
+	Id          []byte
+	Args        []ColumnInfo
+	Rval        []ColumnInfo
 	PKeyColumns []int
 }
 
@@ -287,7 +287,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	s.routingKeyInfoCache.mu.Unlock()
 
 	var (
-		info     *QueryInfo
+		info         *QueryInfo
 		partitionKey []*ColumnMetadata
 	)
 
@@ -764,6 +764,9 @@ type Iter struct {
 	rows [][][]byte
 	meta resultMetadata
 	next *nextIter
+
+	framer *framer
+	once   sync.Once
 }
 
 // Columns returns the name and type of the selected columns.
@@ -837,6 +840,13 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 // Close closes the iterator and returns any errors that happened during
 // the query or the iteration.
 func (iter *Iter) Close() error {
+	iter.once.Do(func() {
+		if iter.framer != nil {
+			framerPool.Put(iter.framer)
+			iter.framer = nil
+		}
+	})
+
 	return iter.err
 }