Ver código fonte

Enormous performance optimization.
Socket-IO reading is now buffered.
Less memory allocations
Fixed Issue #4

Julien Schmidt 13 anos atrás
pai
commit
4965942d3e
6 arquivos alterados com 234 adições e 241 exclusões
  1. 14 12
      connection.go
  2. 2 0
      driver.go
  3. 195 197
      packets.go
  4. 3 10
      rows.go
  5. 11 6
      statement.go
  6. 9 16
      utils.go

+ 14 - 12
connection.go

@@ -9,6 +9,7 @@
 package mysql
 
 import (
+	"bufio"
 	"database/sql/driver"
 	"errors"
 	"net"
@@ -20,6 +21,7 @@ type mysqlConn struct {
 	cfg            *config
 	server         *serverSettings
 	netConn        net.Conn
+	bufReader      *bufio.Reader
 	protocol       uint8
 	sequence       uint8
 	affectedRows   uint64
@@ -142,15 +144,17 @@ func (mc *mysqlConn) Close() (e error) {
 		mc.keepaliveTimer.Stop()
 	}
 	mc.writeCommandPacket(COM_QUIT)
+	mc.bufReader = nil
+	mc.netConn.Close()
 	mc.netConn = nil
 	return
 }
 
-func (mc *mysqlConn) Prepare(query string) (ds driver.Stmt, e error) {
+func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	// Send command
-	e = mc.writeCommandPacket(COM_STMT_PREPARE, query)
+	e := mc.writeCommandPacket(COM_STMT_PREPARE, query)
 	if e != nil {
-		return
+		return nil, e
 	}
 
 	stmt := mysqlStmt{new(stmtContent)}
@@ -160,26 +164,24 @@ func (mc *mysqlConn) Prepare(query string) (ds driver.Stmt, e error) {
 	var columnCount uint16
 	columnCount, e = stmt.readPrepareResultPacket()
 	if e != nil {
-		return
+		return nil, e
 	}
 
 	if stmt.paramCount > 0 {
 		stmt.params, e = stmt.mc.readColumns(stmt.paramCount)
 		if e != nil {
-			return
+			return nil, e
 		}
 	}
 
 	if columnCount > 0 {
-		_, e = stmt.mc.readColumns(int(columnCount))
+		_, e = stmt.mc.readUntilEOF()
 		if e != nil {
-			return
+			return nil, e
 		}
 	}
 
-	stmt.query = query
-	ds = stmt
-	return
+	return stmt, e
 }
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
@@ -258,13 +260,13 @@ func (mc *mysqlConn) getSystemVar(name string) (val string, e error) {
 			return
 		}
 
-		var rows []*[]*[]byte
+		var rows []*[][]byte
 		rows, e = mc.readRows(int(n))
 		if e != nil {
 			return
 		}
 
-		val = string(*(*rows[0])[0])
+		val = string((*rows[0])[0])
 	}
 
 	return

+ 2 - 0
driver.go

@@ -9,6 +9,7 @@
 package mysql
 
 import (
+	"bufio"
 	"database/sql"
 	"database/sql/driver"
 	"errors"
@@ -37,6 +38,7 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
 	if e != nil {
 		return nil, e
 	}
+	mc.bufReader = bufio.NewReader(mc.netConn)
 
 	// Reading Handshake Initialization Packet 
 	e = mc.readInitPacket()

+ 195 - 197
packets.go

@@ -12,7 +12,6 @@ import (
 	"database/sql/driver"
 	"errors"
 	"fmt"
-	"io"
 	"reflect"
 	"time"
 )
@@ -21,103 +20,91 @@ import (
 // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
 
 // Read packet to buffer 'data'
-func (mc *mysqlConn) readPacket() (data []byte, e error) {
+func (mc *mysqlConn) readPacket() ([]byte, error) {
 	// Packet Length
 	pktLen, e := mc.readNumber(3)
 	if e != nil {
-		return
+		return nil, e
 	}
 
 	if int(pktLen) == 0 {
-		return
+		return nil, e
 	}
 
 	// Packet Number
 	pktSeq, e := mc.readNumber(1)
 	if e != nil {
-		return
+		return nil, e
 	}
 
 	// Check Packet Sync
 	if uint8(pktSeq) != mc.sequence {
 		e = errors.New("Commands out of sync; you can't run this command now")
-		return
+		return nil, e
 	}
 	mc.sequence++
 
 	// Read rest of packet
-	data = make([]byte, pktLen)
+	data := make([]byte, pktLen)
 	var n, add int
-	n, e = mc.netConn.Read(data)
-	
-	// Read conventionally returns what is available instead of waiting for more
 	for e == nil && n < int(pktLen) {
-		add, e = mc.netConn.Read(data[n:])
+		add, e = mc.bufReader.Read(data[n:])
 		n += add
 	}
-	
-	if e != nil || n != int(pktLen) {
-		errLog.Print(e)
-		e = driver.ErrBadConn
-		return
-	}
-	return data[:pktLen], e // Return without scratch space
-}
-
-// Send Packet with given data
-func (mc *mysqlConn) writePacket(data []byte) (e error) {
-	// Set time BEFORE to avoid possible collisions
-	if mc.server.keepalive > 0 {
-		mc.lastCmdTime = time.Now()
-	}
-
-	pktLen := uint32(len(data))
-	if int(pktLen) == 0 {
-		return
-	}
-
-	// Add the packet header
-	pktData := make([]byte, 0, len(data)+4)
-	pktData = append(pktData, uint24ToBytes(pktLen)...)
-	pktData = append(pktData, mc.sequence)
-	pktData = append(pktData, data...)
-
-	// Write packet
-	n, e := mc.netConn.Write(pktData)
-	if e != nil || n != len(pktData) {
+	if e != nil || n < int(pktLen) {
 		if e == nil {
-			e = errors.New("Length of send data does not match packet length")
+			e = fmt.Errorf("Length of read data (%d) does not match body length (%d)", n, pktLen)
 		}
-		errLog.Print(e)
-		e = driver.ErrBadConn
-		return
+		errLog.Print(`packets:58 `, e)
+		return nil, driver.ErrBadConn
 	}
-
-	mc.sequence++
-	return
+	return data, e
 }
 
 // Read n bytes long number num
-func (mc *mysqlConn) readNumber(n uint8) (num uint64, e error) {
+func (mc *mysqlConn) readNumber(nr uint8) (uint64, error) {
 	// Read bytes into array
-	buf := make([]byte, n)
-
-	nr, err := io.ReadFull(mc.netConn, buf)
-	if err != nil || nr != int(n) {
+	buf := make([]byte, nr)
+	var n, add int
+	var e error
+	for e == nil && n < int(nr) {
+		add, e = mc.bufReader.Read(buf[n:])
+		n += add
+	}
+	if e != nil || n < int(nr) {
 		if e == nil {
-			e = errors.New("Length of read data does not match header length")
+			e = fmt.Errorf("Length of read data (%d) does not match header length (%d)", n, nr)
 		}
-		errLog.Print(e)
-		e = driver.ErrBadConn
-		return
+		errLog.Print(`packets:78 `, e)
+		return 0, driver.ErrBadConn
 	}
 
 	// Convert to uint64
-	num = 0
-	for i := uint8(0); i < n; i++ {
+	var num uint64 = 0
+	for i := uint8(0); i < nr; i++ {
 		num |= uint64(buf[i]) << (i * 8)
 	}
-	return
+	return num, e
+}
+
+func (mc *mysqlConn) writePacket(data *[]byte) error {
+	// Set time BEFORE to avoid possible collisions
+	if mc.server.keepalive > 0 {
+		mc.lastCmdTime = time.Now()
+	}
+	
+	// Write packet
+	n, e := mc.netConn.Write(*data)
+	if e != nil || n != len(*data) {
+		if e == nil {
+			e = errors.New("Length of send data does not match packet length")
+		}
+		errLog.Print(`packets:102 `, e)
+		return driver.ErrBadConn
+	}
+
+	mc.sequence++
+	return nil
 }
 
 /******************************************************************************
@@ -229,8 +216,12 @@ func (mc *mysqlConn) writeAuthPacket() (e error) {
 	scrambleBuff := scramblePassword(mc.server.scrambleBuff, []byte(mc.cfg.passwd))
 
 	// Calculate packet length and make buffer with that size
-	dataLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) + len(mc.cfg.dbname) + 1
-	data := make([]byte, 0, dataLen)
+	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) + len(mc.cfg.dbname) + 1
+	data := make([]byte, 0, pktLen+4)
+	
+	// Add the packet header
+	data = append(data, uint24ToBytes(uint32(pktLen))...)
+	data = append(data, mc.sequence)
 
 	// ClientFlags
 	data = append(data, uint32ToBytes(clientFlags)...)
@@ -265,9 +256,8 @@ func (mc *mysqlConn) writeAuthPacket() (e error) {
 		data = append(data, 0x0)
 	}
 
-	// Send Auth-Packet
-	mc.writePacket(data)
-	return
+	// Send Auth packet
+	return mc.writePacket(&data)
 }
 
 /******************************************************************************
@@ -284,8 +274,7 @@ func (mc *mysqlConn) writeCommandPacket(command commandType, args ...interface{}
 	// Reset Packet Sequence
 	mc.sequence = 0
 
-	// Make slice from command byte
-	data := []byte{byte(command)}
+	var arg []byte
 
 	switch command {
 
@@ -294,26 +283,41 @@ func (mc *mysqlConn) writeCommandPacket(command commandType, args ...interface{}
 		if len(args) > 0 {
 			return fmt.Errorf("Too much arguments (Got: %d Has: 0)", len(args))
 		}
+		arg = []byte{}
 
 	// Commands with 1 arg unterminated string
 	case COM_QUERY, COM_STMT_PREPARE:
 		if len(args) != 1 {
 			return fmt.Errorf("Invalid arguments count (Got: %d Has: 1)", len(args))
 		}
-		data = append(data, []byte(args[0].(string))...)
+		arg = []byte(args[0].(string))
 
 	// Commands with 1 arg 32 bit uint
 	case COM_STMT_CLOSE:
 		if len(args) != 1 {
 			return fmt.Errorf("Invalid arguments count (Got: %d Has: 1)", len(args))
 		}
-		data = append(data, uint32ToBytes(args[0].(uint32))...)
+		arg = uint32ToBytes(args[0].(uint32))
+
 	default:
 		return fmt.Errorf("Unknown command: %d", command)
 	}
+	
+	pktLen := 1 + len(arg)
+	data := make([]byte, 0, pktLen + 4)
+	
+	// Add the packet header
+	data = append(data, uint24ToBytes(uint32(pktLen))...)
+	data = append(data, mc.sequence)
+	
+	// Add command byte
+	data = append(data, byte(command))
+	
+	// Add arg
+	data = append(data, arg...)
 
 	// Send CMD packet
-	return mc.writePacket(data)
+	return mc.writePacket(&data)
 }
 
 /******************************************************************************
@@ -430,7 +434,7 @@ The order of packets for a result set is:
 func (mc *mysqlConn) readResultSetHeaderPacket() (fieldCount int, e error) {
 	data, e := mc.readPacket()
 	if e != nil {
-		errLog.Print(e)
+		errLog.Print(`packets:437 `, e)
 		e = driver.ErrBadConn
 		return
 	}
@@ -454,7 +458,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (fieldCount int, e error) {
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
+func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, e error) {
 	var data []byte
 
 	for {
@@ -551,14 +555,14 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
 		//	defaultVal, _, e = bytesToLengthCodedBinary(data[pos:])
 		//}
 
-		columns = append(columns, &mysqlField{name: string(name), fieldType: fieldType, flags: flags})
+		columns = append(columns, mysqlField{name: string(name), fieldType: fieldType, flags: flags})
 	}
 
 	return
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-func (mc *mysqlConn) readRows(columnsCount int) (rows []*[]*[]byte, e error) {
+func (mc *mysqlConn) readRows(columnsCount int) (rows []*[][]byte, e error) {
 	var data []byte
 	var i, pos, n int
 	var isNull bool
@@ -575,22 +579,18 @@ func (mc *mysqlConn) readRows(columnsCount int) (rows []*[]*[]byte, e error) {
 		}
 
 		// RowSet Packet
-		row := make([]*[]byte, 0, columnsCount)
+		row := make([][]byte, columnsCount)
 		pos = 0
-
 		for i = 0; i < columnsCount; i++ {
 			// Read bytes and convert to string
-			var value []byte
-			value, n, isNull, e = readLengthCodedBinary(data[pos:])
+			row[i], n, isNull, e = readLengthCodedBinary(data[pos:])
 			if e != nil {
 				return
 			}
 
 			// Append nil if field is NULL
 			if isNull {
-				row = append(row, nil)
-			} else {
-				row = append(row, &value)
+				 row[i] = nil
 			}
 			pos += n
 		}
@@ -700,18 +700,101 @@ n*2                  type of parameters
 n                    values for the parameters 
 */
 func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
-	if len(*args) < stmt.paramCount {
+	argsLen := len(*args)
+	if argsLen < stmt.paramCount {
 		return fmt.Errorf(
 			"Not enough Arguments to call STMT_EXEC (Got: %d Has: %d",
-			len(*args),
+			argsLen,
 			stmt.paramCount)
 	}
 
 	// Reset packet-sequence
 	stmt.mc.sequence = 0
-
-	data := make([]byte, 0, 10)
-
+	
+	pktLen := 1 + 4 + 1 + 4 + (stmt.paramCount+7)/8 + 1 + argsLen*2
+	paramValues := make([][]byte, 0, argsLen)
+	paramTypes := make([]byte, 0, argsLen*2)
+	bitMask := uint64(0)
+	var i, valLen int
+	var pv reflect.Value
+	for i = 0; i < stmt.paramCount; i++ {
+		// build nullBitMap	
+		if (*args)[i] == nil {
+			bitMask += 1 << uint(i)
+		}
+		
+		// cache types and values
+		switch (*args)[i].(type) {
+		case nil:
+			paramTypes = append(paramTypes, []byte{
+				byte(FIELD_TYPE_NULL),
+				0x0}...)
+			continue
+		
+		case []byte:
+			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING),0x0}...)
+			val := (*args)[i].([]byte)
+			valLen = len(val)
+			lcb := lengthCodedBinaryToBytes(uint64(valLen))
+			pktLen += len(lcb) + valLen 
+			paramValues = append(paramValues, lcb)
+			paramValues = append(paramValues, val)
+			continue
+		
+		case time.Time:
+			// Format to string for time+date Fields
+			// Data is packed in case reflect.String below
+			(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
+		}
+		
+		pv = reflect.ValueOf((*args)[i])
+		switch pv.Kind() {
+		case reflect.Int64:
+			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_LONGLONG),0x0}...)
+			val := int64ToBytes(pv.Int())
+			pktLen += len(val)
+			paramValues = append(paramValues, val)
+			continue
+		
+		case reflect.Float64:
+			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_DOUBLE),0x0}...)
+			val := float64ToBytes(pv.Float())
+			pktLen += len(val)
+			paramValues = append(paramValues, val)
+			continue
+		
+		case reflect.Bool:
+			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_TINY),0x0}...)
+			val := pv.Bool()
+			pktLen++
+			if val {
+				paramValues = append(paramValues, []byte{byte(1)})
+			} else {
+				paramValues = append(paramValues, []byte{byte(0)})
+			}
+			continue
+		
+		case reflect.String:
+			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING),0x0}...)
+			val := []byte(pv.String())
+			valLen = len(val)
+			lcb := lengthCodedBinaryToBytes(uint64(valLen))
+			pktLen += valLen + len(lcb)
+			paramValues = append(paramValues, lcb)
+			paramValues = append(paramValues, val)
+			continue
+		
+		default:
+			return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
+		}
+	}
+	
+	data := make([]byte, 0, pktLen+4)
+	
+	// Add the packet header
+	data = append(data, uint24ToBytes(uint32(pktLen))...)
+	data = append(data, stmt.mc.sequence)
+	
 	// code [1 byte]
 	data = append(data, byte(COM_STMT_EXECUTE))
 
@@ -723,103 +806,30 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
 
 	// iteration_count [4 bytes]
 	data = append(data, uint32ToBytes(1)...)
-
-	if stmt.paramCount > 0 {
-		var i int
-
-		// build nullBitMap
-		nullBitMap := make([]byte, (stmt.paramCount+7)/8)
-		bitMask := uint64(0)
-
-		// Check for NULL fields
-		for i = 0; i < stmt.paramCount; i++ {
-			if (*args)[i] == nil {
-				bitMask += 1 << uint(i)
-			}
-		}
+	
+	// append nullBitMap [(param_count+7)/8 bytes]
+	if stmt.paramCount > 0 {	
 		// Convert bitMask to bytes
+		nullBitMap := make([]byte, (stmt.paramCount+7)/8)
 		for i = 0; i < len(nullBitMap); i++ {
 			nullBitMap[i] = byte(bitMask >> uint(i*8))
 		}
-
-		// append nullBitMap [(param_count+7)/8 bytes]
+		
 		data = append(data, nullBitMap...)
+	}
 
-		// newParameterBoundFlag 1 [1 byte]
-		data = append(data, byte(1))
-
-		// append types and cache values
-		paramValues := make([]byte, 0)
-		var pv reflect.Value
-		for i = 0; i < stmt.paramCount; i++ {
-			switch (*args)[i].(type) {
-			case nil:
-				data = append(data, []byte{
-					byte(FIELD_TYPE_NULL),
-					0x0}...)
-				continue
-
-			case []byte:
-				data = append(data, []byte{
-					byte(FIELD_TYPE_STRING),
-					0x0}...)
-				val := (*args)[i].([]byte)
-				paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
-				paramValues = append(paramValues, val...)
-				continue
-
-			case time.Time:
-				// Format to string for time+date Fields
-				// Data is packed in case reflect.String below
-				(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
-			}
-
-			pv = reflect.ValueOf((*args)[i])
-			switch pv.Kind() {
-			case reflect.Int64:
-				data = append(data, []byte{
-					byte(FIELD_TYPE_LONGLONG),
-					0x0}...)
-				paramValues = append(paramValues, int64ToBytes(pv.Int())...)
-				continue
-
-			case reflect.Float64:
-				data = append(data, []byte{
-					byte(FIELD_TYPE_DOUBLE),
-					0x0}...)
-				paramValues = append(paramValues, float64ToBytes(pv.Float())...)
-				continue
-
-			case reflect.Bool:
-				data = append(data, []byte{
-					byte(FIELD_TYPE_TINY),
-					0x0}...)
-				val := pv.Bool()
-				if val {
-					paramValues = append(paramValues, byte(1))
-				} else {
-					paramValues = append(paramValues, byte(0))
-				}
-				continue
-
-			case reflect.String:
-				data = append(data, []byte{
-					byte(FIELD_TYPE_STRING),
-					0x0}...)
-				val := pv.String()
-				paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
-				paramValues = append(paramValues, []byte(val)...)
-				continue
-
-			default:
-				return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
-			}
-		}
-
-		// append cached values
-		data = append(data, paramValues...)
+	// newParameterBoundFlag 1 [1 byte]
+	data = append(data, byte(1))
+	
+	// type of parameters [n*2 byte]
+	data = append(data, paramTypes...)
+	
+	// values for the parameters [n byte]
+	for _, paramValue := range paramValues {
+		data = append(data, paramValue...)
 	}
-	return stmt.mc.writePacket(data)
+	
+	return stmt.mc.writePacket(&data)
 }
 
 func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
@@ -844,7 +854,7 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 		pos++
 
 		// BinaryRowSet Packet
-		row := make([]*[]byte, columnsCount)
+		row := make([][]byte, columnsCount)
 
 		nullBitMap = data[pos : pos+(columnsCount+7+2)/8]
 		pos += (columnsCount + 7 + 2) / 8
@@ -905,16 +915,13 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				pos += 8
 
 			case FIELD_TYPE_DECIMAL, FIELD_TYPE_NEWDECIMAL:
-				var tmp []byte
-				tmp, n, isNull, e = readLengthCodedBinary(data[pos:])
+				row[i], n, isNull, e = readLengthCodedBinary(data[pos:])
 				if e != nil {
 					return
 				}
 
 				if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
 					row[i] = nil
-				} else {
-					row[i] = &tmp
 				}
 				pos += n
 
@@ -923,16 +930,13 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				FIELD_TYPE_SET, FIELD_TYPE_TINY_BLOB, FIELD_TYPE_MEDIUM_BLOB,
 				FIELD_TYPE_LONG_BLOB, FIELD_TYPE_BLOB, FIELD_TYPE_VAR_STRING,
 				FIELD_TYPE_STRING, FIELD_TYPE_GEOMETRY:
-				var tmp []byte
-				tmp, n, isNull, e = readLengthCodedBinary(data[pos:])
+				row[i], n, isNull, e = readLengthCodedBinary(data[pos:])
 				if e != nil {
 					return
 				}
 
 				if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
 					row[i] = nil
-				} else {
-					row[i] = &tmp
 				}
 				pos += n
 
@@ -945,16 +949,14 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				}
 				pos += n
 
-				var tmp []byte
 				if num == 0 {
-					tmp = []byte("0000-00-00")
+					row[i] = []byte("0000-00-00")
 				} else {
-					tmp = []byte(fmt.Sprintf("%04d-%02d-%02d",
+					row[i] = []byte(fmt.Sprintf("%04d-%02d-%02d",
 						bytesToUint16(data[pos:pos+2]),
 						data[pos+2],
 						data[pos+3]))
 				}
-				row[i] = &tmp
 				pos += int(num)
 
 			// Time HH:MM:SS
@@ -965,16 +967,14 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 					return
 				}
 
-				var tmp []byte
 				if num == 0 {
-					tmp = []byte("00:00:00")
+					row[i] = []byte("00:00:00")
 				} else {
-					tmp = []byte(fmt.Sprintf("%02d:%02d:%02d",
+					row[i] = []byte(fmt.Sprintf("%02d:%02d:%02d",
 						data[pos+6],
 						data[pos+7],
 						data[pos+8]))
 				}
-				row[i] = &tmp
 				pos += n + int(num)
 
 			// Timestamp YYYY-MM-DD HH:MM:SS
@@ -986,11 +986,10 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				}
 				pos += n
 
-				var tmp []byte
 				if num == 0 {
-					tmp = []byte("0000-00-00 00:00:00")
+					row[i] = []byte("0000-00-00 00:00:00")
 				} else {
-					tmp = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
+					row[i] = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
 						bytesToUint16(data[pos:pos+2]),
 						data[pos+2],
 						data[pos+3],
@@ -998,7 +997,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 						data[pos+5],
 						data[pos+6]))
 				}
-				row[i] = &tmp
 				pos += int(num)
 
 			// Please report if this happens!

+ 3 - 10
rows.go

@@ -20,8 +20,8 @@ type mysqlField struct {
 }
 
 type rowsContent struct {
-	columns []*mysqlField
-	rows    []*[]*[]byte
+	columns []mysqlField
+	rows    []*[][]byte
 }
 
 type mysqlRows struct {
@@ -47,15 +47,8 @@ func (rows mysqlRows) Close() error {
 // unnecessary conversions.
 func (rows mysqlRows) Next(dest []driver.Value) error {
 	if len(rows.content.rows) > 0 {
-		var value *[]byte
 		for i := 0; i < cap(dest); i++ {
-			value = (*rows.content.rows[0])[i]
-
-			if value == nil {
-				dest[i] = nil
-			} else {
-				dest[i] = *value
-			}
+			dest[i] = (*rows.content.rows[0])[i]
 		}
 		rows.content.rows = rows.content.rows[1:]
 	} else {

+ 11 - 6
statement.go

@@ -10,14 +10,14 @@ package mysql
 
 import (
 	"database/sql/driver"
+	"errors"
 )
 
 type stmtContent struct {
 	mc             *mysqlConn
 	id             uint32
-	query          string
 	paramCount     int
-	params         []*mysqlField
+	params         []mysqlField
 }
 
 type mysqlStmt struct {
@@ -26,7 +26,6 @@ type mysqlStmt struct {
 
 func (stmt mysqlStmt) Close() error {
 	e := stmt.mc.writeCommandPacket(COM_STMT_CLOSE, stmt.id)
-	stmt.params = nil
 	stmt.mc = nil
 	return e
 }
@@ -36,6 +35,9 @@ func (stmt mysqlStmt) NumInput() int {
 }
 
 func (stmt mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+	if stmt.mc == nil {
+		return nil, errors.New(`Invalid Statement`)
+	}
 	stmt.mc.affectedRows = 0
 	stmt.mc.insertId = 0
 
@@ -73,13 +75,17 @@ func (stmt mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 		return driver.ResultNoRows, nil
 	}
 
-	return &mysqlResult{
+	return mysqlResult{
 			affectedRows: int64(stmt.mc.affectedRows),
 			insertId:     int64(stmt.mc.insertId)},
 		nil
 }
 
 func (stmt mysqlStmt) Query(args []driver.Value) (dr driver.Rows, e error) {
+	if stmt.mc == nil {
+		return nil, errors.New(`Invalid Statement`)
+	}
+	
 	// Send command
 	e = stmt.buildExecutePacket(&args)
 	if e != nil {
@@ -88,8 +94,7 @@ func (stmt mysqlStmt) Query(args []driver.Value) (dr driver.Rows, e error) {
 
 	// Get Result
 	var resLen int
-	rows := new(mysqlRows)
-	rows.content = new(rowsContent)
+	rows := mysqlRows{new(rowsContent)}
 	resLen, e = stmt.mc.readResultSetHeaderPacket()
 	if e != nil {
 		return nil, e

+ 9 - 16
utils.go

@@ -298,26 +298,19 @@ func lengthCodedBinaryToBytes(n uint64) (b []byte) {
 	return
 }
 
-func intToByteStr(i int64) (d *[]byte) {
-	tmp := make([]byte, 0)
-	tmp = strconv.AppendInt(tmp, i, 10)
-	return &tmp
+func intToByteStr(i int64) (b []byte) {
+	//tmp := make([]byte, 0)
+	return strconv.AppendInt(b, i, 10)
 }
 
-func uintToByteStr(u uint64) (d *[]byte) {
-	tmp := make([]byte, 0)
-	tmp = strconv.AppendUint(tmp, u, 10)
-	return &tmp
+func uintToByteStr(u uint64) (b []byte) {
+	return strconv.AppendUint(b, u, 10)
 }
 
-func float32ToByteStr(f float32) (d *[]byte) {
-	tmp := make([]byte, 0)
-	tmp = strconv.AppendFloat(tmp, float64(f), 'f', -1, 32)
-	return &tmp
+func float32ToByteStr(f float32) (b []byte) { 
+	return strconv.AppendFloat(b, float64(f), 'f', -1, 32)
 }
 
-func float64ToByteStr(f float64) (d *[]byte) {
-	tmp := make([]byte, 0)
-	tmp = strconv.AppendFloat(tmp, f, 'f', -1, 64)
-	return &tmp
+func float64ToByteStr(f float64) (b []byte) {
+	return strconv.AppendFloat(b, f, 'f', -1, 64)
 }