Browse Source

use the connection buffer for writing

Julien Schmidt 12 years ago
parent
commit
ddf24e6427
3 changed files with 281 additions and 167 deletions
  1. 5 2
      benchmark_test.go
  2. 49 2
      buffer.go
  3. 227 163
      packets.go

+ 5 - 2
benchmark_test.go

@@ -69,23 +69,26 @@ func BenchmarkQuery(b *testing.B) {
 
 	stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
 	defer stmt.Close()
-	b.StartTimer()
 
 	remain := int64(b.N)
 	var wg sync.WaitGroup
 	wg.Add(concurrencyLevel)
 	defer wg.Wait()
+	b.StartTimer()
+
 	for i := 0; i < concurrencyLevel; i++ {
 		go func() {
-			defer wg.Done()
 			for {
 				if atomic.AddInt64(&remain, -1) < 0 {
+					wg.Done()
 					return
 				}
+
 				var got string
 				tb.check(stmt.QueryRow(1).Scan(&got))
 				if got != "one" {
 					b.Errorf("query = %q; want one", got)
+					wg.Done()
 					return
 				}
 			}

+ 49 - 2
buffer.go

@@ -12,7 +12,10 @@ import "io"
 
 const defaultBufSize = 4096
 
-// A read buffer similar to bufio.Reader but zero-copy-ish
+// A buffer which is used for both reading and writing.
+// This is possible since communication on each connection is synchronous.
+// In other words, we can't write and read simultaneously on the same connection.
+// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
 // Also highly optimized for this particular use case.
 type buffer struct {
 	buf    []byte
@@ -37,8 +40,11 @@ func (b *buffer) fill(need int) (err error) {
 	}
 
 	// grow buffer if necessary
+	// TODO: let the buffer shrink again at some point
+	//       Maybe keep the org buf slice and swap back?
 	if need > len(b.buf) {
-		newBuf := make([]byte, need)
+		// Round up to the next multiple of the default size
+		newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
 		copy(newBuf, b.buf)
 		b.buf = newBuf
 	}
@@ -74,3 +80,44 @@ func (b *buffer) readNext(need int) (p []byte, err error) {
 	b.length -= need
 	return
 }
+
+// returns a buffer with the requested size.
+// If possible, a slice from the existing buffer is returned.
+// Otherwise a bigger buffer is made.
+// Only one buffer (total) can be used at a time.
+func (b *buffer) writeBuffer(length int) []byte {
+	if b.length > 0 {
+		return nil
+	}
+
+	// test (cheap) general case first
+	if length <= defaultBufSize || length <= cap(b.buf) {
+		return b.buf[:length]
+	}
+
+	if length < maxPacketSize {
+		b.buf = make([]byte, length)
+		return b.buf
+	}
+	return make([]byte, length)
+}
+
+// shortcut which can be used if the requested buffer is guaranteed to be
+// smaller than defaultBufSize
+// Only one buffer (total) can be used at a time.
+func (b *buffer) smallWriteBuffer(length int) []byte {
+	if b.length == 0 {
+		return b.buf[:length]
+	}
+	return nil
+}
+
+// takeCompleteBuffer returns the complete existing buffer.
+// This can be used if the necessary buffer size is unknown.
+// Only one buffer (total) can be used at a time.
+func (b *buffer) takeCompleteBuffer() []byte {
+	if b.length == 0 {
+		return b.buf
+	}
+	return nil
+}

+ 227 - 163
packets.go

@@ -239,8 +239,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 		pktLen += n + 1
 	}
 
-	// Calculate packet length and make buffer with that size
-	data := make([]byte, pktLen+4)
+	// Calculate packet length and get buffer with that size
+	data := mc.buf.smallWriteBuffer(pktLen + 4)
+	if data == nil {
+		// can not take the buffer. Something must be wrong with the connection
+		errLog.Print("Busy buffer")
+		return driver.ErrBadConn
+	}
 
 	// ClientFlags [32 bit]
 	data[4] = byte(clientFlags)
@@ -249,10 +254,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	data[7] = byte(clientFlags >> 24)
 
 	// MaxPacketSize [32 bit] (none)
-	//data[8] = 0x00
-	//data[9] = 0x00
-	//data[10] = 0x00
-	//data[11] = 0x00
+	data[8] = 0x00
+	data[9] = 0x00
+	data[10] = 0x00
+	data[11] = 0x00
 
 	// Charset [1 byte]
 	data[12] = collation_utf8_general_ci
@@ -293,7 +298,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	if len(mc.cfg.user) > 0 {
 		pos += copy(data[pos:], mc.cfg.user)
 	}
-	//data[pos] = 0x00
+	data[pos] = 0x00
 	pos++
 
 	// ScrambleBuffer [length encoded integer]
@@ -303,7 +308,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	// Databasename [null terminated string]
 	if len(mc.cfg.dbname) > 0 {
 		pos += copy(data[pos:], mc.cfg.dbname)
-		//data[pos] = 0x00
+		data[pos] = 0x00
 	}
 
 	// Send Auth packet
@@ -318,7 +323,12 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
 
 	// Calculate the packet lenght and add a tailing 0
 	pktLen := len(scrambleBuff) + 1
-	data := make([]byte, pktLen+4)
+	data := mc.buf.smallWriteBuffer(pktLen + 4)
+	if data == nil {
+		// can not take the buffer. Something must be wrong with the connection
+		errLog.Print("Busy buffer")
+		return driver.ErrBadConn
+	}
 
 	// Add the packet header  [24bit length + 1 byte sequence]
 	data[0] = byte(pktLen)
@@ -340,17 +350,24 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
 	// Reset Packet Sequence
 	mc.sequence = 0
 
+	data := mc.buf.smallWriteBuffer(4 + 1)
+	if data == nil {
+		// can not take the buffer. Something must be wrong with the connection
+		errLog.Print("Busy buffer")
+		return driver.ErrBadConn
+	}
+
+	// Add the packet header [24bit length + 1 byte sequence]
+	data[0] = 0x01 // 1 byte long
+	data[1] = 0x00
+	data[2] = 0x00
+	data[3] = 0x00 // sequence is always 0
+
+	// Add command byte
+	data[4] = command
+
 	// Send CMD packet
-	return mc.writePacket([]byte{
-		// Add the packet header [24bit length + 1 byte sequence]
-		0x01, // 1 byte long
-		0x00,
-		0x00,
-		0x00, // mc.sequence
-
-		// Add command byte
-		command,
-	})
+	return mc.writePacket(data)
 }
 
 func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
@@ -358,13 +375,18 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
 	mc.sequence = 0
 
 	pktLen := 1 + len(arg)
-	data := make([]byte, pktLen+4)
+	data := mc.buf.writeBuffer(pktLen + 4)
+	if data == nil {
+		// can not take the buffer. Something must be wrong with the connection
+		errLog.Print("Busy buffer")
+		return driver.ErrBadConn
+	}
 
 	// Add the packet header [24bit length + 1 byte sequence]
 	data[0] = byte(pktLen)
 	data[1] = byte(pktLen >> 8)
 	data[2] = byte(pktLen >> 16)
-	//data[3] = mc.sequence
+	data[3] = 0x00 // sequence is always 0
 
 	// Add command byte
 	data[4] = command
@@ -380,23 +402,30 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
 	// Reset Packet Sequence
 	mc.sequence = 0
 
+	data := mc.buf.smallWriteBuffer(4 + 1 + 4)
+	if data == nil {
+		// can not take the buffer. Something must be wrong with the connection
+		errLog.Print("Busy buffer")
+		return driver.ErrBadConn
+	}
+
+	// Add the packet header [24bit length + 1 byte sequence]
+	data[0] = 0x05 // 1 bytes long
+	data[1] = 0x00
+	data[2] = 0x00
+	data[3] = 0x00 // sequence is always 0
+
+	// Add command byte
+	data[4] = command
+
+	// Add arg [32 bit]
+	data[5] = byte(arg)
+	data[6] = byte(arg >> 8)
+	data[7] = byte(arg >> 16)
+	data[8] = byte(arg >> 24)
+
 	// Send CMD packet
-	return mc.writePacket([]byte{
-		// Add the packet header [24bit length + 1 byte sequence]
-		0x05, // 5 bytes long
-		0x00,
-		0x00,
-		0x00, // mc.sequence
-
-		// Add command byte
-		command,
-
-		// Add arg [32 bit]
-		byte(arg),
-		byte(arg >> 8),
-		byte(arg >> 16),
-		byte(arg >> 24),
-	})
+	return mc.writePacket(data)
 }
 
 /******************************************************************************
@@ -599,10 +628,10 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
 // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow
-func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
+func (rows *mysqlRows) readRow(dest []driver.Value) error {
 	data, err := rows.mc.readPacket()
 	if err != nil {
-		return
+		return err
 	}
 
 	// EOF Packet
@@ -641,24 +670,22 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
 				continue
 			}
 		}
-		return // err
+		return err // err != nil
 	}
 
-	return
+	return nil
 }
 
 // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
-func (mc *mysqlConn) readUntilEOF() (err error) {
-	var data []byte
-
+func (mc *mysqlConn) readUntilEOF() error {
 	for {
-		data, err = mc.readPacket()
+		data, err := mc.readPacket()
 
 		// No Err and no EOF Packet
 		if err == nil && data[0] != iEOF {
 			continue
 		}
-		return // Err or EOF
+		return err // Err or EOF
 	}
 }
 
@@ -710,11 +737,16 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
 }
 
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data
-func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) {
+func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 	maxLen := stmt.mc.maxPacketAllowed - 1
 	pktLen := maxLen
 	argLen := len(arg)
+
+	// Can not use the write buffer since
+	// a) the buffer is too small
+	// b) it is in use
 	data := make([]byte, 4+1+4+2+argLen)
+
 	copy(data[4+1+4+2:], arg)
 
 	for argLen > 0 {
@@ -742,7 +774,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error)
 		data[10] = byte(paramID >> 8)
 
 		// Send CMD packet
-		err = stmt.mc.writePacket(data[:4+pktLen])
+		err := stmt.mc.writePacket(data[:4+pktLen])
 		if err == nil {
 			argLen -= pktLen - (1 + 4 + 2)
 			data = data[pktLen-(1+4+2):]
@@ -758,7 +790,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error)
 }
 
 // Execute Prepared Statement
-// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute
+// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
 func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	if len(args) != stmt.paramCount {
 		return fmt.Errorf(
@@ -770,107 +802,32 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	// Reset packet-sequence
 	stmt.mc.sequence = 0
 
-	pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (stmt.paramCount << 1)
-	paramValues := make([][]byte, stmt.paramCount)
-	paramTypes := make([]byte, (stmt.paramCount << 1))
-	bitMask := uint64(0)
-	var i int
-
-	for i = range args {
-		// build NULL-bitmap
-		if args[i] == nil {
-			bitMask += 1 << uint(i)
-			paramTypes[i<<1] = fieldTypeNULL
-			continue
-		}
-
-		// cache types and values
-		switch v := args[i].(type) {
-		case int64:
-			paramTypes[i<<1] = fieldTypeLongLong
-			paramValues[i] = uint64ToBytes(uint64(v))
-			pktLen += 8
-			continue
-
-		case float64:
-			paramTypes[i<<1] = fieldTypeDouble
-			paramValues[i] = uint64ToBytes(math.Float64bits(v))
-			pktLen += 8
-			continue
-
-		case bool:
-			paramTypes[i<<1] = fieldTypeTiny
-			pktLen++
-			if v {
-				paramValues[i] = []byte{0x01}
-			} else {
-				paramValues[i] = []byte{0x00}
-			}
-			continue
-
-		case []byte:
-			paramTypes[i<<1] = fieldTypeString
-			if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 {
-				paramValues[i] = append(
-					lengthEncodedIntegerToBytes(uint64(len(v))),
-					v...,
-				)
-				pktLen += len(paramValues[i])
-				continue
-			} else {
-				err := stmt.writeCommandLongData(i, v)
-				if err == nil {
-					continue
-				}
-				return err
-			}
-
-		case string:
-			paramTypes[i<<1] = fieldTypeString
-			if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 {
-				paramValues[i] = append(
-					lengthEncodedIntegerToBytes(uint64(len(v))),
-					[]byte(v)...,
-				)
-				pktLen += len(paramValues[i])
-				continue
-			} else {
-				err := stmt.writeCommandLongData(i, []byte(v))
-				if err == nil {
-					continue
-				}
-				return err
-			}
-
-		case time.Time:
-			paramTypes[i<<1] = fieldTypeString
-
-			var val []byte
-			if v.IsZero() {
-				val = []byte("0000-00-00")
-			} else {
-				val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat))
-			}
-
-			paramValues[i] = append(
-				lengthEncodedIntegerToBytes(uint64(len(val))),
-				val...,
-			)
-			pktLen += len(paramValues[i])
-			continue
+	var data []byte
 
-		default:
-			return fmt.Errorf("Can't convert type: %T", args[i])
+	if len(args) == 0 {
+		const pktLen = 1 + 4 + 1 + 4
+		data = stmt.mc.buf.writeBuffer(4 + pktLen)
+		if data == nil {
+			// can not take the buffer. Something must be wrong with the connection
+			errLog.Print("Busy buffer")
+			return driver.ErrBadConn
 		}
-	}
 
-	data := make([]byte, pktLen+4)
+		// packet header [4 bytes]
+		data[0] = byte(pktLen)
+		data[1] = byte(pktLen >> 8)
+		data[2] = byte(pktLen >> 16)
+		data[3] = 0x00 // sequence is always 0
+	} else {
+		data = stmt.mc.buf.takeCompleteBuffer()
+		if data == nil {
+			// can not take the buffer. Something must be wrong with the connection
+			errLog.Print("Busy buffer")
+			return driver.ErrBadConn
+		}
 
-	// packet header [4 bytes]
-	data[0] = byte(pktLen)
-	data[1] = byte(pktLen >> 8)
-	data[2] = byte(pktLen >> 16)
-	data[3] = stmt.mc.sequence
+		// header (bytes 0-3) is added after we know the packet size
+	}
 
 	// command [1 byte]
 	data[4] = comStmtExecute
@@ -882,32 +839,139 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	data[8] = byte(stmt.id >> 24)
 
 	// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
-	//data[9] = 0x00
+	data[9] = 0x00
 
 	// iteration_count (uint32(1)) [4 bytes]
 	data[10] = 0x01
-	//data[11] = 0x00
-	//data[12] = 0x00
-	//data[13] = 0x00
-
-	if stmt.paramCount > 0 {
-		// NULL-bitmap [(param_count+7)/8 bytes]
-		pos := 14 + ((stmt.paramCount + 7) >> 3)
-		// Convert bitMask to bytes
-		for i = 14; i < pos; i++ {
-			data[i] = byte(bitMask >> uint((i-14)<<3))
-		}
+	data[11] = 0x00
+	data[12] = 0x00
+	data[13] = 0x00
+
+	if len(args) > 0 {
+		// NULL-bitmap [(len(args)+7)/8 bytes]
+		nullMask := uint64(0)
+
+		pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3)
 
 		// newParameterBoundFlag 1 [1 byte]
 		data[pos] = 0x01
 		pos++
 
-		// type of parameters [param_count*2 bytes]
-		pos += copy(data[pos:], paramTypes)
+		// type of each parameter [len(args)*2 bytes]
+		paramTypes := data[pos:]
+		pos += (len(args) << 1)
+
+		// value of each parameter [n bytes]
+		paramValues := data[pos:pos]
+		valuesCap := cap(paramValues)
+
+		for i := range args {
+			// build NULL-bitmap
+			if args[i] == nil {
+				nullMask += 1 << uint(i)
+				paramTypes[i+i] = fieldTypeNULL
+				continue
+			}
+
+			// cache types and values
+			switch v := args[i].(type) {
+			case int64:
+				paramTypes[i+i] = fieldTypeLongLong
+				if cap(paramValues) <= len(paramValues)+8 {
+					paramValues = paramValues[:len(paramValues)+8]
+					binary.LittleEndian.PutUint64(paramValues, uint64(v))
+				} else {
+					paramValues = append(paramValues,
+						uint64ToBytes(uint64(v))...,
+					)
+				}
+
+			case float64:
+				paramTypes[i+i] = fieldTypeDouble
+				if cap(paramValues) <= len(paramValues)+8 {
+					paramValues = paramValues[:len(paramValues)+8]
+					binary.LittleEndian.PutUint64(paramValues, math.Float64bits(v))
+				} else {
+					paramValues = append(paramValues,
+						uint64ToBytes(math.Float64bits(v))...,
+					)
+				}
+
+			case bool:
+				paramTypes[i+i] = fieldTypeTiny
+				if v {
+					paramValues = append(paramValues, 0x01)
+				} else {
+					paramValues = append(paramValues, 0x00)
+				}
+
+			case []byte:
+				paramTypes[i+i] = fieldTypeString
+				if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+					paramValues = append(paramValues,
+						lengthEncodedIntegerToBytes(uint64(len(v)))...,
+					)
+					paramValues = append(paramValues, v...)
+				} else {
+					if err := stmt.writeCommandLongData(i, v); err != nil {
+						return err
+					}
+				}
+
+			case string:
+				paramTypes[i+i] = fieldTypeString
+				if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+					paramValues = append(paramValues,
+						lengthEncodedIntegerToBytes(uint64(len(v)))...,
+					)
+					paramValues = append(paramValues, v...)
+				} else {
+					if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
+						return err
+					}
+				}
+
+			case time.Time:
+				paramTypes[i+i] = fieldTypeString
+
+				var val []byte
+				if v.IsZero() {
+					val = []byte("0000-00-00")
+				} else {
+					val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat))
+				}
+
+				paramValues = append(paramValues,
+					lengthEncodedIntegerToBytes(uint64(len(val)))...,
+				)
+				paramValues = append(paramValues, val...)
+
+			default:
+				return fmt.Errorf("Can't convert type: %T", args[i])
+			}
+		}
+
+		// Check if param values exceeded the available buffer
+		// In that case we must build the data packet with the new values buffer
+		if valuesCap != cap(paramValues) {
+			data = append(data[:pos], paramValues...)
+			stmt.mc.buf.buf = data
+		}
+
+		pos += len(paramValues)
+		data = data[:pos]
+
+		pktLen := pos - 4
+
+		// packet header [4 bytes]
+		data[0] = byte(pktLen)
+		data[1] = byte(pktLen >> 8)
+		data[2] = byte(pktLen >> 16)
+		data[3] = stmt.mc.sequence
 
-		// values for the parameters [n bytes]
-		for i = range paramValues {
-			pos += copy(data[pos:], paramValues[i])
+		// Convert nullMask to bytes
+		for i, max := 14, 14+((stmt.paramCount+7)>>3); i < max; i++ {
+			data[i] = byte(nullMask >> uint((i-14)<<3))
 		}
 	}