Parcourir la source

Defensivly copy byte slices from the read buffer

When parsing frames which contains [bytes] copy the underlying
bytes into a new slice so that the owner as full control of the
lifecycle of the slice. This has the downside that it increases
our allocations by about a factor of 2 in the benchmarks but it
overcomplicates the assumptions that need to be made about the
underlying memory which will more likely than not lead to a subtle
memory corruption bug.
Chris Bannister il y a 10 ans
Parent
commit
61f9761bea
2 fichiers modifiés avec 17 ajouts et 20 suppressions
  1. 4 5
      conn.go
  2. 13 15
      frame.go

+ 4 - 5
conn.go

@@ -377,6 +377,9 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
 		tracer.Trace(framer.traceID)
 		tracer.Trace(framer.traceID)
 	}
 	}
 
 
+	framerPool.Put(framer)
+	call.framer = nil
+
 	return frame, nil
 	return frame, nil
 }
 }
 
 
@@ -410,7 +413,6 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*resultPreparedFrame
 		flight.wg.Done()
 		flight.wg.Done()
 		return nil, err
 		return nil, err
 	}
 	}
-	defer resp.release()
 
 
 	switch x := resp.(type) {
 	switch x := resp.(type) {
 	case *resultPreparedFrame:
 	case *resultPreparedFrame:
@@ -499,7 +501,6 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	if err != nil {
 	if err != nil {
 		return &Iter{err: err}
 		return &Iter{err: err}
 	}
 	}
-	defer resp.release()
 
 
 	switch x := resp.(type) {
 	switch x := resp.(type) {
 	case *resultVoidFrame:
 	case *resultVoidFrame:
@@ -584,14 +585,13 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	defer resp.release()
 
 
 	switch x := resp.(type) {
 	switch x := resp.(type) {
 	case *resultKeyspaceFrame:
 	case *resultKeyspaceFrame:
 	case error:
 	case error:
 		return x
 		return x
 	default:
 	default:
-		return NewErrProtocol("Unknown type in response to USE: %s", x)
+		return NewErrProtocol("unknown frame in response to USE: %v", x)
 	}
 	}
 
 
 	c.currentKeyspace = keyspace
 	c.currentKeyspace = keyspace
@@ -665,7 +665,6 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	defer resp.release()
 
 
 	switch x := resp.(type) {
 	switch x := resp.(type) {
 	case *resultVoidFrame:
 	case *resultVoidFrame:

+ 13 - 15
frame.go

@@ -209,15 +209,6 @@ type frameHeader struct {
 	stream  int
 	stream  int
 	op      frameOp
 	op      frameOp
 	length  int
 	length  int
-
-	framer *framer
-}
-
-func (f *frameHeader) release() {
-	if f.framer != nil {
-		framerPool.Put(f.framer)
-		f.framer = nil
-	}
 }
 }
 
 
 func (f frameHeader) String() string {
 func (f frameHeader) String() string {
@@ -296,7 +287,6 @@ func newFramer(r io.Reader, w io.Writer, compressor Compressor, version byte) *f
 
 
 type frame interface {
 type frame interface {
 	Header() frameHeader
 	Header() frameHeader
-	release()
 }
 }
 
 
 func readHeader(r io.Reader, p []byte) (head frameHeader, err error) {
 func readHeader(r io.Reader, p []byte) (head frameHeader, err error) {
@@ -1165,17 +1155,24 @@ func (f *framer) readBytes() []byte {
 		return nil
 		return nil
 	}
 	}
 
 
-	l := f.rbuf[:size]
+	// we cant make assumptions about the length of the life of the supplied byte
+	// slice so we defensivly copy it out of the underlying buffer. This has the
+	// downside of increasing allocs per read but will provide much greater memory
+	// safety. The allocs can hopefully be improved in the future.
+	// TODO: dont copy into a new slice
+	l := make([]byte, size)
+	copy(l, f.rbuf[:size])
 	f.rbuf = f.rbuf[size:]
 	f.rbuf = f.rbuf[size:]
 
 
 	return l
 	return l
 }
 }
 
 
 func (f *framer) readShortBytes() []byte {
 func (f *framer) readShortBytes() []byte {
-	n := f.readShort()
+	size := f.readShort()
 
 
-	l := f.rbuf[:n]
-	f.rbuf = f.rbuf[n:]
+	l := make([]byte, size)
+	copy(l, f.rbuf[:size])
+	f.rbuf = f.rbuf[size:]
 
 
 	return l
 	return l
 }
 }
@@ -1188,7 +1185,8 @@ func (f *framer) readInet() (net.IP, int) {
 		panic(fmt.Sprintf("invalid IP size: %d", size))
 		panic(fmt.Sprintf("invalid IP size: %d", size))
 	}
 	}
 
 
-	ip := f.rbuf[:size]
+	ip := make([]byte, size)
+	copy(ip, f.rbuf[:size])
 	f.rbuf = f.rbuf[size:]
 	f.rbuf = f.rbuf[size:]
 
 
 	port := f.readInt()
 	port := f.readInt()