Browse Source

Improve buffer handling (#890)

* Eliminate redundant size test in takeBuffer.
* Change buffer takeXXX functions to return an error to make it explicit that they can fail.
* Add missing error check in handleAuthResult.
* Add buffer.store(..) method which can be used by external buffer consumers to update the raw buffer.
* Fix some typos and unnecessary UTF-8 characters in comments.
* Improve buffer function docs.
* Add comments to explain some non-obvious behavior around buffer handling.
Steven Hartland 7 years ago
parent
commit
6be42e0ff9
6 changed files with 72 additions and 49 deletions
  1. 2 0
      AUTHORS
  2. 5 3
      auth.go
  3. 31 18
      buffer.go
  4. 3 3
      connection.go
  5. 1 1
      driver.go
  6. 30 24
      packets.go

+ 2 - 0
AUTHORS

@@ -73,6 +73,7 @@ Shuode Li <elemount at qq.com>
 Soroush Pour <me at soroushjp.com>
 Stan Putrya <root.vagner at gmail.com>
 Stanley Gunawan <gunawan.stanley at gmail.com>
+Steven Hartland <steven.hartland at multiplay.co.uk>
 Thomas Wodarek <wodarekwebpage at gmail.com>
 Tom Jenkinson <tom at tjenkinson.me>
 Xiangyu Hu <xiangyu.hu at outlook.com>
@@ -90,3 +91,4 @@ Keybase Inc.
 Percona LLC
 Pivotal Inc.
 Stripe Inc.
+Multiplay Ltd.

+ 5 - 3
auth.go

@@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
 					pubKey := mc.cfg.pubKey
 					if pubKey == nil {
 						// request public key from server
-						data := mc.buf.takeSmallBuffer(4 + 1)
+						data, err := mc.buf.takeSmallBuffer(4 + 1)
+						if err != nil {
+							return err
+						}
 						data[4] = cachingSha2PasswordRequestPublicKey
 						mc.writePacket(data)
 
 						// parse public key
-						data, err := mc.readPacket()
-						if err != nil {
+						if data, err = mc.readPacket(); err != nil {
 							return err
 						}
 

+ 31 - 18
buffer.go

@@ -22,17 +22,17 @@ const defaultBufSize = 4096
 // The buffer is similar to bufio.Reader / Writer but zero-copy-ish
 // Also highly optimized for this particular use case.
 type buffer struct {
-	buf     []byte
+	buf     []byte // buf is a byte buffer who's length and capacity are equal.
 	nc      net.Conn
 	idx     int
 	length  int
 	timeout time.Duration
 }
 
+// newBuffer allocates and returns a new buffer.
 func newBuffer(nc net.Conn) buffer {
-	var b [defaultBufSize]byte
 	return buffer{
-		buf: b[:],
+		buf: make([]byte, defaultBufSize),
 		nc:  nc,
 	}
 }
@@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
 	return b.buf[offset:b.idx], nil
 }
 
-// returns a buffer with the requested size.
+// takeBuffer returns a buffer with the requested size.
 // If possible, a slice from the existing buffer is returned.
 // Otherwise a bigger buffer is made.
 // Only one buffer (total) can be used at a time.
-func (b *buffer) takeBuffer(length int) []byte {
+func (b *buffer) takeBuffer(length int) ([]byte, error) {
 	if b.length > 0 {
-		return nil
+		return nil, ErrBusyBuffer
 	}
 
 	// test (cheap) general case first
-	if length <= defaultBufSize || length <= cap(b.buf) {
-		return b.buf[:length]
+	if length <= cap(b.buf) {
+		return b.buf[:length], nil
 	}
 
 	if length < maxPacketSize {
 		b.buf = make([]byte, length)
-		return b.buf
+		return b.buf, nil
 	}
-	return make([]byte, length)
+
+	// buffer is larger than we want to store.
+	return make([]byte, length), nil
 }
 
-// shortcut which can be used if the requested buffer is guaranteed to be
-// smaller than defaultBufSize
+// takeSmallBuffer is shortcut which can be used if length is
+// known to be smaller than defaultBufSize.
 // Only one buffer (total) can be used at a time.
-func (b *buffer) takeSmallBuffer(length int) []byte {
+func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
 	if b.length > 0 {
-		return nil
+		return nil, ErrBusyBuffer
 	}
-	return b.buf[:length]
+	return b.buf[:length], nil
 }
 
 // takeCompleteBuffer returns the complete existing buffer.
 // This can be used if the necessary buffer size is unknown.
+// cap and len of the returned buffer will be equal.
 // Only one buffer (total) can be used at a time.
-func (b *buffer) takeCompleteBuffer() []byte {
+func (b *buffer) takeCompleteBuffer() ([]byte, error) {
+	if b.length > 0 {
+		return nil, ErrBusyBuffer
+	}
+	return b.buf, nil
+}
+
+// store stores buf, an updated buffer, if its suitable to do so.
+func (b *buffer) store(buf []byte) error {
 	if b.length > 0 {
-		return nil
+		return ErrBusyBuffer
+	} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
+		b.buf = buf[:cap(buf)]
 	}
-	return b.buf
+	return nil
 }

+ 3 - 3
connection.go

@@ -182,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
 		return "", driver.ErrSkip
 	}
 
-	buf := mc.buf.takeCompleteBuffer()
-	if buf == nil {
+	buf, err := mc.buf.takeCompleteBuffer()
+	if err != nil {
 		// can not take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
+		errLog.Print(err)
 		return "", ErrInvalidConn
 	}
 	buf = buf[:0]

+ 1 - 1
driver.go

@@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {
 
 // Open new Connection.
 // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
-// the DSN string is formated
+// the DSN string is formatted
 func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	var err error
 

+ 30 - 24
packets.go

@@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 		mc.sequence++
 
 		// packets with length 0 terminate a previous packet which is a
-		// multiple of (2^24)1 bytes long
+		// multiple of (2^24)-1 bytes long
 		if pktLen == 0 {
 			// there was no previous packet
 			if prevData == nil {
@@ -286,10 +286,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
 	}
 
 	// Calculate packet length and get buffer with that size
-	data := mc.buf.takeSmallBuffer(pktLen + 4)
-	if data == nil {
+	data, err := mc.buf.takeSmallBuffer(pktLen + 4)
+	if err != nil {
 		// cannot take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
+		errLog.Print(err)
 		return errBadConnNoWrite
 	}
 
@@ -367,10 +367,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
 func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
 	pktLen := 4 + len(authData)
-	data := mc.buf.takeSmallBuffer(pktLen)
-	if data == nil {
+	data, err := mc.buf.takeSmallBuffer(pktLen)
+	if err != nil {
 		// cannot take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
+		errLog.Print(err)
 		return errBadConnNoWrite
 	}
 
@@ -387,10 +387,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
 	// Reset Packet Sequence
 	mc.sequence = 0
 
-	data := mc.buf.takeSmallBuffer(4 + 1)
-	if data == nil {
+	data, err := mc.buf.takeSmallBuffer(4 + 1)
+	if err != nil {
 		// cannot take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
+		errLog.Print(err)
 		return errBadConnNoWrite
 	}
 
@@ -406,10 +406,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
 	mc.sequence = 0
 
 	pktLen := 1 + len(arg)
-	data := mc.buf.takeBuffer(pktLen + 4)
-	if data == nil {
+	data, err := mc.buf.takeBuffer(pktLen + 4)
+	if err != nil {
 		// cannot take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
+		errLog.Print(err)
 		return errBadConnNoWrite
 	}
 
@@ -427,10 +427,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
 	// Reset Packet Sequence
 	mc.sequence = 0
 
-	data := mc.buf.takeSmallBuffer(4 + 1 + 4)
-	if data == nil {
+	data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
+	if err != nil {
 		// cannot take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
+		errLog.Print(err)
 		return errBadConnNoWrite
 	}
 
@@ -883,7 +883,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	const minPktLen = 4 + 1 + 4 + 1 + 4
 	mc := stmt.mc
 
-	// Determine threshould dynamically to avoid packet size shortage.
+	// Determine threshold dynamically to avoid packet size shortage.
 	longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
 	if longDataSize < 64 {
 		longDataSize = 64
@@ -893,15 +893,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	mc.sequence = 0
 
 	var data []byte
+	var err error
 
 	if len(args) == 0 {
-		data = mc.buf.takeBuffer(minPktLen)
+		data, err = mc.buf.takeBuffer(minPktLen)
 	} else {
-		data = mc.buf.takeCompleteBuffer()
+		data, err = mc.buf.takeCompleteBuffer()
+		// In this case the len(data) == cap(data) which is used to optimise the flow below.
 	}
-	if data == nil {
+	if err != nil {
 		// cannot take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
+		errLog.Print(err)
 		return errBadConnNoWrite
 	}
 
@@ -927,7 +929,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		pos := minPktLen
 
 		var nullMask []byte
-		if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
+		if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
 			// buffer has to be extended but we don't know by how much so
 			// we depend on append after all data with known sizes fit.
 			// We stop at that because we deal with a lot of columns here
@@ -936,10 +938,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 			copy(tmp[:pos], data[:pos])
 			data = tmp
 			nullMask = data[pos : pos+maskLen]
+			// No need to clean nullMask as make ensures that.
 			pos += maskLen
 		} else {
 			nullMask = data[pos : pos+maskLen]
-			for i := 0; i < maskLen; i++ {
+			for i := range nullMask {
 				nullMask[i] = 0
 			}
 			pos += maskLen
@@ -1076,7 +1079,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		// In that case we must build the data packet with the new values buffer
 		if valuesCap != cap(paramValues) {
 			data = append(data[:pos], paramValues...)
-			mc.buf.buf = data
+			if err = mc.buf.store(data); err != nil {
+				errLog.Print(err)
+				return errBadConnNoWrite
+			}
 		}
 
 		pos += len(paramValues)