Selaa lähdekoodia

refactoring and micro-optimisation

Julien Schmidt 12 vuotta sitten
vanhempi
commit
4e1d8236e3
7 muutettua tiedostoa jossa 248 lisäystä ja 323 poistoa
  1. 3 10
      connection.go
  2. 1 1
      driver_test.go
  3. 19 0
      errors.go
  4. 167 207
      packets.go
  5. 4 7
      rows.go
  6. 2 2
      statement.go
  7. 52 96
      utils.go

+ 3 - 10
connection.go

@@ -18,7 +18,9 @@ import (
 
 type mysqlConn struct {
 	cfg          *config
-	server       *serverSettings
+	flags        ClientFlag
+	charset      byte
+	scrambleBuff []byte
 	netConn      net.Conn
 	buf          *buffer
 	protocol     uint8
@@ -36,15 +38,6 @@ type config struct {
 	params map[string]string
 }
 
-type serverSettings struct {
-	protocol     byte
-	version      string
-	flags        ClientFlag
-	charset      uint8
-	scrambleBuff []byte
-	threadID     uint32
-}
-
 // Handles parameters set in DSN
 func (mc *mysqlConn) handleParams() (err error) {
 	for param, val := range mc.cfg.params {

+ 1 - 1
driver_test.go

@@ -301,7 +301,7 @@ func TestString(t *testing.T) {
 		if rows.Next() {
 			rows.Scan(&out)
 			if in != out {
-				t.Errorf("%s: %d != %d", v, in, out)
+				t.Errorf("%s: %s != %s", v, in, out)
 			}
 		} else {
 			t.Errorf("%s: no data", v)

+ 19 - 0
errors.go

@@ -0,0 +1,19 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 Julien Schmidt. All rights reserved.
+// http://www.julienschmidt.com
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import "errors"
+
+var (
+	errMalformPkt  = errors.New("Malformed Packet")
+	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
+	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")
+	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/Go-SQL-Driver/MySQL/wiki/old_passwords")
+)

+ 167 - 207
packets.go

@@ -10,12 +10,13 @@
 package mysql
 
 import (
+	"bytes"
 	"database/sql/driver"
 	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
-	"reflect"
+	"math"
 	"time"
 )
 
@@ -45,31 +46,29 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	// Check Packet Sync
 	if data[3] != mc.sequence {
 		if data[3] > mc.sequence {
-			err = errors.New("Commands out of sync. Did you run multiple statements at once?")
+			return nil, errPktSyncMul
 		} else {
-			err = errors.New("Commands out of sync; you can't run this command now")
+			return nil, errPktSync
 		}
-		return nil, err
 	}
 	mc.sequence++
 
-	// Read rest of packet
+	// Read packet body
 	data = make([]byte, pktLen)
 	err = mc.buf.read(data)
-	if err != nil {
-		errLog.Print(err)
-		return nil, driver.ErrBadConn
+	if err == nil {
+		return data, nil
 	}
-
-	return data, err
+	errLog.Print(err)
+	return nil, driver.ErrBadConn
 }
 
-func (mc *mysqlConn) writePacket(data *[]byte) error {
+func (mc *mysqlConn) writePacket(data []byte) error {
 	// Write packet
-	n, err := mc.netConn.Write(*data)
-	if err != nil || n != len(*data) {
+	n, err := mc.netConn.Write(data)
+	if err != nil || n != len(data) {
 		if err == nil {
-			err = errors.New("Length of send data does not match packet length")
+			errLog.Print(errMalformPkt)
 		}
 		errLog.Print(err)
 		return driver.ErrBadConn
@@ -83,76 +82,56 @@ func (mc *mysqlConn) writePacket(data *[]byte) error {
 *                           Initialisation Process                            *
 ******************************************************************************/
 
-/* Handshake Initialization Packet
- Bytes                        Name
- -----                        ----
- 1                            protocol_version
- n (Null-Terminated String)   server_version
- 4                            thread_id
- 8                            scramble_buff
- 1                            (filler) always 0x00
- 2                            server_capabilities
- 1                            server_language
- 2                            server_status
- 2                            server capabilities (two upper bytes)
- 1                            length of the scramble
-10                            (filler)  always 0
- n                            rest of the plugin provided data (at least 12 bytes)
- 1                            \0 byte, terminating the second part of a scramble
-*/
+// Handshake Initialization Packet
+// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
 func (mc *mysqlConn) readInitPacket() (err error) {
 	data, err := mc.readPacket()
 	if err != nil {
 		return
 	}
 
-	mc.server = new(serverSettings)
-
-	// Position
-	pos := 0
-
-	// Protocol version [8 bit uint]
-	mc.server.protocol = data[pos]
-	if mc.server.protocol < MIN_PROTOCOL_VERSION {
+	// protocol version [1 byte]
+	if data[0] < MIN_PROTOCOL_VERSION {
 		err = fmt.Errorf(
 			"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
-			mc.server.protocol,
+			data[0],
 			MIN_PROTOCOL_VERSION)
 	}
-	pos++
 
-	// Server version [null terminated string]
-	slice, err := readSlice(data[pos:], 0x00)
-	if err != nil {
-		return
-	}
-	mc.server.version = string(slice)
-	pos += len(slice) + 1
+	// server version [null terminated string]
+	// connection id [4 bytes]
+	pos := 1 + (bytes.IndexByte(data[1:], 0x00) + 1) + 4
 
-	// Thread id [32 bit uint]
-	mc.server.threadID = binary.LittleEndian.Uint32(data[pos : pos+4])
-	pos += 4
+	// first part of scramble buffer [8 bytes]
+	mc.scrambleBuff = data[pos : pos+8]
 
-	// First part of scramble buffer [8 bytes]
-	mc.server.scrambleBuff = make([]byte, 8)
-	mc.server.scrambleBuff = data[pos : pos+8]
-	pos += 9
+	// (filler) always 0x00 [1 byte]
+	pos += 8 + 1
 
-	// Server capabilities [16 bit uint]
-	mc.server.flags = ClientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
-	if mc.server.flags&CLIENT_PROTOCOL_41 == 0 {
+	// capability flags (lower 2 bytes) [2 bytes]
+	mc.flags = ClientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+	if mc.flags&CLIENT_PROTOCOL_41 == 0 {
 		err = errors.New("MySQL-Server does not support required Protocol 41+")
 	}
 	pos += 2
 
-	// Server language [8 bit uint]
-	mc.server.charset = data[pos]
-	pos++
+	if len(data) > pos {
+		// character set [1 byte]
+		mc.charset = data[pos]
 
-	// Server status [16 bit uint]
-	pos += 15
+		// status flags [2 bytes]
+		// capability flags (upper 2 bytes) [2 bytes]
+		// length of auth-plugin-data [1 byte]
+		// reserved (all [00]) [10 byte]
+		pos += 1 + 2 + 2 + 1 + 10
 
-	mc.server.scrambleBuff = append(mc.server.scrambleBuff, data[pos:pos+12]...)
+		mc.scrambleBuff = append(mc.scrambleBuff, data[pos:len(data)-1]...)
+
+		if data[len(data)-1] == 0 {
+			return
+		}
+		return errMalformPkt
+	}
 
 	return
 }
@@ -170,18 +149,18 @@ n (Null-Terminated String)   databasename (optional)
 */
 func (mc *mysqlConn) writeAuthPacket() error {
 	// Adjust client flags based on server support
-	clientFlags := uint32(CLIENT_MULTI_STATEMENTS |
-		// CLIENT_MULTI_RESULTS |
+	clientFlags := uint32(
 		CLIENT_PROTOCOL_41 |
-		CLIENT_SECURE_CONN |
-		CLIENT_LONG_PASSWORD |
-		CLIENT_TRANSACTIONS)
-	if mc.server.flags&CLIENT_LONG_FLAG > 0 {
+			CLIENT_SECURE_CONN |
+			CLIENT_LONG_PASSWORD |
+			CLIENT_TRANSACTIONS,
+	)
+	if mc.flags&CLIENT_LONG_FLAG > 0 {
 		clientFlags |= uint32(CLIENT_LONG_FLAG)
 	}
 
 	// User Password
-	scrambleBuff := scramblePassword(mc.server.scrambleBuff, []byte(mc.cfg.passwd))
+	scrambleBuff := scramblePassword(mc.scrambleBuff, []byte(mc.cfg.passwd))
 
 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
 
@@ -205,7 +184,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	data = append(data, uint32ToBytes(MAX_PACKET_SIZE)...)
 
 	// Charset
-	data = append(data, mc.server.charset)
+	data = append(data, mc.charset)
 
 	// Filler
 	data = append(data, make([]byte, 23)...)
@@ -232,7 +211,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	}
 
 	// Send Auth packet
-	return mc.writePacket(&data)
+	return mc.writePacket(data)
 }
 
 /******************************************************************************
@@ -292,7 +271,7 @@ func (mc *mysqlConn) writeCommandPacket(command commandType, args ...interface{}
 	data = append(data, arg...)
 
 	// Send CMD packet
-	return mc.writePacket(&data)
+	return mc.writePacket(data)
 }
 
 /******************************************************************************
@@ -312,13 +291,13 @@ func (mc *mysqlConn) readResultOK() error {
 		return mc.handleOkPacket(data)
 	// EOF, someone is using old_passwords
 	case 254:
-		return errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/Go-SQL-Driver/MySQL/wiki/old_passwords")
+		return errOldPassword
 	// ERROR
 	case 255:
 		return mc.handleErrorPacket(data)
 	}
 
-	return errors.New("Invalid Result Packet-Type")
+	return errMalformPkt
 }
 
 /* Error Packet
@@ -331,24 +310,14 @@ Bytes                       Name
 n                           message
 */
 func (mc *mysqlConn) handleErrorPacket(data []byte) error {
-	if data[0] != 255 {
-		return errors.New("Wrong Packet-Type: Not an Error-Packet")
-	}
-
-	pos := 1
-
 	// Error Number [16 bit uint]
-	errno := binary.LittleEndian.Uint16(data[pos : pos+2])
-	pos += 2
+	errno := binary.LittleEndian.Uint16(data[1:3])
 
 	// SQL State [# + 5bytes string]
 	//sqlstate := string(data[pos : pos+6])
-	pos += 6
 
 	// Error Message [string]
-	message := string(data[pos:])
-
-	return fmt.Errorf("Error %d: %s", errno, message)
+	return fmt.Errorf("Error %d: %s", errno, string(data[9:]))
 }
 
 /* Ok Packet
@@ -362,31 +331,21 @@ Bytes                       Name
 n   (until end of packet)   message
 */
 func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
-	if data[0] != 0 {
-		err = errors.New("Wrong Packet-Type: Not an OK-Packet")
-		return
-	}
-
-	// Position
-	pos := 1
-
 	var n int
 
 	// Affected rows [Length Coded Binary]
-	mc.affectedRows, _, n, err = bytesToLengthEncodedInteger(data[pos:])
+	mc.affectedRows, _, n, err = readLengthEncodedInteger(data[1:])
 	if err != nil {
 		return
 	}
-	pos += n
 
 	// Insert id [Length Coded Binary]
-	mc.insertId, _, n, err = bytesToLengthEncodedInteger(data[pos:])
+	mc.insertId, _, _, err = readLengthEncodedInteger(data[1+n:])
 	if err != nil {
 		return
 	}
 
 	// Skip remaining data
-
 	return
 }
 
@@ -411,15 +370,15 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (fieldCount int, err error) {
 		return
 	}
 
-	if data[0] == 255 {
-		err = mc.handleErrorPacket(data)
-		return
-	} else if data[0] == 0 {
+	if data[0] == 0 {
 		err = mc.handleOkPacket(data)
 		return
+	} else if data[0] == 255 {
+		err = mc.handleErrorPacket(data)
+		return
 	}
 
-	num, _, n, err := bytesToLengthEncodedInteger(data)
+	num, _, n, err := readLengthEncodedInteger(data)
 	if err != nil || (n-len(data)) != 0 {
 		err = errors.New("Malformed Packet")
 		return
@@ -449,14 +408,11 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 			return
 		}
 
-		pos = 0
-
 		// Catalog
-		n, err = readAndDropLengthEnodedString(data)
+		pos, err = readAndDropLengthEnodedString(data)
 		if err != nil {
 			return
 		}
-		pos += n
 
 		// Database [len coded string]
 		n, err = readAndDropLengthEnodedString(data[pos:])
@@ -480,7 +436,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		pos += n
 
 		// Name [len coded string]
-		name, _, n, err = readLengthEnodedString(data[pos:])
+		name, n, err = readLengthEnodedString(data[pos:])
 		if err != nil {
 			return
 		}
@@ -491,16 +447,11 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		if err != nil {
 			return
 		}
-		pos += n
-
-		// Filler
-		pos++
 
+		// Filler [1 byte]
 		// Charset [16 bit uint]
-		pos += 2
-
 		// Length [32 bit uint]
-		pos += 4
+		pos += n + 1 + 2 + 4
 
 		// Field type [byte]
 		fieldType := FieldType(data[pos])
@@ -525,7 +476,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-func (rows *mysqlRows) readRow(dest *[]driver.Value) (err error) {
+func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
 	data, err := rows.mc.readPacket()
 	if err != nil {
 		return
@@ -538,16 +489,16 @@ func (rows *mysqlRows) readRow(dest *[]driver.Value) (err error) {
 
 	// RowSet Packet
 	var n int
-	columnsCount := len(*dest)
 	pos := 0
 
-	for i := 0; i < columnsCount; i++ {
+	for i := range dest {
 		// Read bytes and convert to string
-		(*dest)[i], _, n, err = readLengthEnodedString(data[pos:])
-		if err != nil {
-			return
-		}
+		dest[i], n, err = readLengthEnodedString(data[pos:])
 		pos += n
+		if err == nil {
+			continue
+		}
+		return // err
 	}
 
 	return
@@ -559,12 +510,9 @@ func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
 
 	for {
 		data, err = mc.readPacket()
-		if err != nil {
-			return
-		}
 
-		// EOF Packet
-		if data[0] == 254 && len(data) == 5 {
+		// Err or EOF Packet
+		if err != nil || (data[0] == 254 && len(data) == 5) {
 			return
 		}
 
@@ -651,11 +599,11 @@ Bytes                Name
 n*2                  type of parameters
 n                    values for the parameters
 */
-func (stmt *mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
-	argsLen := len(*args)
-	if argsLen < stmt.paramCount {
+func (stmt *mysqlStmt) buildExecutePacket(args []driver.Value) error {
+	argsLen := len(args)
+	if argsLen != stmt.paramCount {
 		return fmt.Errorf(
-			"Not enough Arguments to call STMT_EXEC (Got: %d Has: %d",
+			"Arguments count mismatch (Got: %d Has: %d",
 			argsLen,
 			stmt.paramCount)
 	}
@@ -668,67 +616,67 @@ func (stmt *mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
 	paramTypes := make([]byte, 0, (argsLen << 1))
 	bitMask := uint64(0)
 	var i, valLen int
-	var pv reflect.Value
-	for i = 0; i < stmt.paramCount; i++ {
+	for i = range args {
 		// build nullBitMap
-		if (*args)[i] == nil {
+		if args[i] == nil {
 			bitMask += 1 << uint(i)
 		}
 
 		// cache types and values
-		switch (*args)[i].(type) {
+		switch args[i].(type) {
 		case nil:
 			paramTypes = append(paramTypes, []byte{
 				byte(FIELD_TYPE_NULL),
 				0x0}...)
 			continue
 
-		case []byte:
-			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
-			val := (*args)[i].([]byte)
-			valLen = len(val)
-			lcb := lengthEncodedIntegerToBytes(uint64(valLen))
-			pktLen += len(lcb) + valLen
-			paramValues = append(paramValues, lcb)
-			paramValues = append(paramValues, val)
-			continue
-
-		case time.Time:
-			// Format to string for time+date Fields
-			// Data is packed in case reflect.String below
-			(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
-		}
-
-		pv = reflect.ValueOf((*args)[i])
-		switch pv.Kind() {
-		case reflect.Int64:
+		case int64:
 			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_LONGLONG), 0x0}...)
-			val := int64ToBytes(pv.Int())
+			val := uint64ToBytes(uint64(args[i].(int64)))
 			pktLen += len(val)
 			paramValues = append(paramValues, val)
 			continue
 
-		case reflect.Float64:
+		case float64:
 			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_DOUBLE), 0x0}...)
-			val := float64ToBytes(pv.Float())
+			val := uint64ToBytes(math.Float64bits(args[i].(float64)))
 			pktLen += len(val)
 			paramValues = append(paramValues, val)
 			continue
 
-		case reflect.Bool:
+		case bool:
 			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_TINY), 0x0}...)
-			val := pv.Bool()
 			pktLen++
-			if val {
+			if args[i].(bool) {
 				paramValues = append(paramValues, []byte{byte(1)})
 			} else {
 				paramValues = append(paramValues, []byte{byte(0)})
 			}
 			continue
 
-		case reflect.String:
+		case []byte:
 			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
-			val := []byte(pv.String())
+			val := args[i].([]byte)
+			valLen = len(val)
+			lcb := lengthEncodedIntegerToBytes(uint64(valLen))
+			pktLen += len(lcb) + valLen
+			paramValues = append(paramValues, lcb)
+			paramValues = append(paramValues, val)
+			continue
+
+		case string:
+			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
+			val := []byte(args[i].(string))
+			valLen = len(val)
+			lcb := lengthEncodedIntegerToBytes(uint64(valLen))
+			pktLen += valLen + len(lcb)
+			paramValues = append(paramValues, lcb)
+			paramValues = append(paramValues, val)
+			continue
+
+		case time.Time:
+			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
+			val := []byte(args[i].(time.Time).Format(TIME_FORMAT))
 			valLen = len(val)
 			lcb := lengthEncodedIntegerToBytes(uint64(valLen))
 			pktLen += valLen + len(lcb)
@@ -737,7 +685,7 @@ func (stmt *mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
 			continue
 
 		default:
-			return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
+			return fmt.Errorf("Can't convert type: %T", args[i])
 		}
 	}
 
@@ -781,36 +729,31 @@ func (stmt *mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
 		data = append(data, paramValue...)
 	}
 
-	return stmt.mc.writePacket(&data)
+	return stmt.mc.writePacket(data)
 }
 
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
-func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
+func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 	data, err := rc.mc.readPacket()
 	if err != nil {
 		return
 	}
 
-	pos := 0
-
 	// EOF Packet
-	if data[pos] == 254 && len(data) == 5 {
+	if data[0] == 254 && len(data) == 5 {
 		return io.EOF
 	}
-	pos++
 
 	// BinaryRowSet Packet
-	columnsCount := len(*dest)
-
-	nullBitMap := data[pos : pos+(columnsCount+7+2)>>3]
-	pos += (columnsCount + 7 + 2) >> 3
+	pos := 1 + (len(dest)+7+2)>>3
+	nullBitMap := data[1:pos]
 
 	var n int
 	var unsigned bool
-	for i := 0; i < columnsCount; i++ {
+	for i := range dest {
 		// Field is NULL
 		if (nullBitMap[(i+2)>>3] >> uint((i+2)&7) & 1) == 1 {
-			(*dest)[i] = nil
+			dest[i] = nil
 			continue
 		}
 
@@ -819,48 +762,55 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
 		// Convert to byte-coded string
 		switch rc.columns[i].fieldType {
 		case FIELD_TYPE_NULL:
-			(*dest)[i] = nil
+			dest[i] = nil
+			continue
 
 		// Numeric Typs
 		case FIELD_TYPE_TINY:
 			if unsigned {
-				(*dest)[i] = uint64(data[pos])
+				dest[i] = uint64(data[pos])
 			} else {
-				(*dest)[i] = int64(int8(data[pos]))
+				dest[i] = int64(int8(data[pos]))
 			}
 			pos++
+			continue
 
 		case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
 			if unsigned {
-				(*dest)[i] = uint64(binary.LittleEndian.Uint16(data[pos : pos+2]))
+				dest[i] = uint64(binary.LittleEndian.Uint16(data[pos : pos+2]))
 			} else {
-				(*dest)[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
+				dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
 			}
 			pos += 2
+			continue
 
 		case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
 			if unsigned {
-				(*dest)[i] = uint64(binary.LittleEndian.Uint32(data[pos : pos+4]))
+				dest[i] = uint64(binary.LittleEndian.Uint32(data[pos : pos+4]))
 			} else {
-				(*dest)[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
+				dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
 			}
 			pos += 4
+			continue
 
 		case FIELD_TYPE_LONGLONG:
 			if unsigned {
-				(*dest)[i] = binary.LittleEndian.Uint64(data[pos : pos+8])
+				dest[i] = binary.LittleEndian.Uint64(data[pos : pos+8])
 			} else {
-				(*dest)[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
+				dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
 			}
 			pos += 8
+			continue
 
 		case FIELD_TYPE_FLOAT:
-			(*dest)[i] = bytesToFloat32(data[pos : pos+4])
+			dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
 			pos += 4
+			continue
 
 		case FIELD_TYPE_DOUBLE:
-			(*dest)[i] = bytesToFloat64(data[pos : pos+8])
+			dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
 			pos += 8
+			continue
 
 		// Length coded Binary Strings
 		case FIELD_TYPE_DECIMAL, FIELD_TYPE_NEWDECIMAL, FIELD_TYPE_VARCHAR,
@@ -868,82 +818,92 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
 			FIELD_TYPE_TINY_BLOB, FIELD_TYPE_MEDIUM_BLOB, FIELD_TYPE_LONG_BLOB,
 			FIELD_TYPE_BLOB, FIELD_TYPE_VAR_STRING, FIELD_TYPE_STRING,
 			FIELD_TYPE_GEOMETRY:
-			(*dest)[i], _, n, err = readLengthEnodedString(data[pos:])
-			if err != nil {
-				return
-			}
+			dest[i], n, err = readLengthEnodedString(data[pos:])
 			pos += n
+			if err == nil {
+				continue
+			}
+			return // err
 
 		// Date YYYY-MM-DD
 		case FIELD_TYPE_DATE, FIELD_TYPE_NEWDATE:
 			var num uint64
 			// TODO(js): allow nil values
-			num, _, n, err = bytesToLengthEncodedInteger(data[pos:])
+			num, _, n, err = readLengthEncodedInteger(data[pos:])
 			if err != nil {
 				return
 			}
-			pos += n
 
 			if num == 0 {
-				(*dest)[i] = []byte("0000-00-00")
+				dest[i] = []byte("0000-00-00")
+				pos += n
+				continue
 			} else {
-				(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d",
+				dest[i] = []byte(fmt.Sprintf("%04d-%02d-%02d",
 					binary.LittleEndian.Uint16(data[pos:pos+2]),
 					data[pos+2],
 					data[pos+3]))
+				pos += n + int(num)
+				continue
 			}
-			pos += int(num)
 
 		// Time HH:MM:SS
 		case FIELD_TYPE_TIME:
 			var num uint64
 			// TODO(js): allow nil values
-			num, _, n, err = bytesToLengthEncodedInteger(data[pos:])
+			num, _, n, err = readLengthEncodedInteger(data[pos:])
 			if err != nil {
 				return
 			}
 
 			if num == 0 {
-				(*dest)[i] = []byte("00:00:00")
+				dest[i] = []byte("00:00:00")
+				pos += n
+				continue
 			} else {
-				(*dest)[i] = []byte(fmt.Sprintf("%02d:%02d:%02d",
+				dest[i] = []byte(fmt.Sprintf("%02d:%02d:%02d",
 					data[pos+6],
 					data[pos+7],
 					data[pos+8]))
+				pos += n + int(num)
+				continue
 			}
-			pos += n + int(num)
 
 		// Timestamp YYYY-MM-DD HH:MM:SS
 		case FIELD_TYPE_TIMESTAMP, FIELD_TYPE_DATETIME:
 			var num uint64
 			// TODO(js): allow nil values
-			num, _, n, err = bytesToLengthEncodedInteger(data[pos:])
+			num, _, n, err = readLengthEncodedInteger(data[pos:])
 			if err != nil {
 				return
 			}
-			pos += n
 
 			switch num {
 			case 0:
-				(*dest)[i] = []byte("0000-00-00 00:00:00")
+				dest[i] = []byte("0000-00-00 00:00:00")
+				pos += n
+				continue
 			case 4:
-				(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
+				dest[i] = []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
 					binary.LittleEndian.Uint16(data[pos:pos+2]),
 					data[pos+2],
 					data[pos+3]))
+				pos += n + int(num)
+				continue
 			default:
 				if num < 7 {
 					return fmt.Errorf("Invalid datetime-packet length %d", num)
 				}
-				(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
+				dest[i] = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
 					binary.LittleEndian.Uint16(data[pos:pos+2]),
 					data[pos+2],
 					data[pos+3],
 					data[pos+4],
 					data[pos+5],
 					data[pos+6]))
+				pos += n + int(num)
+				continue
 			}
-			pos += int(num)
 
 		// Please report if this happens!
 		default:

+ 4 - 7
rows.go

@@ -30,7 +30,7 @@ type mysqlRows struct {
 
 func (rows *mysqlRows) Columns() (columns []string) {
 	columns = make([]string, len(rows.columns))
-	for i := 0; i < cap(columns); i++ {
+	for i := range columns {
 		columns[i] = rows.columns[i].name
 	}
 	return
@@ -48,12 +48,9 @@ func (rows *mysqlRows) Close() (err error) {
 		}
 
 		_, err = rows.mc.readUntilEOF()
-		if err != nil {
-			return
-		}
 	}
 
-	return nil
+	return
 }
 
 func (rows *mysqlRows) Next(dest []driver.Value) error {
@@ -68,9 +65,9 @@ func (rows *mysqlRows) Next(dest []driver.Value) error {
 	// Fetch next row from stream
 	var err error
 	if rows.binary {
-		err = rows.readBinaryRow(&dest)
+		err = rows.readBinaryRow(dest)
 	} else {
-		err = rows.readRow(&dest)
+		err = rows.readRow(dest)
 	}
 
 	if err == io.EOF {

+ 2 - 2
statement.go

@@ -39,7 +39,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 	stmt.mc.insertId = 0
 
 	// Send command
-	err := stmt.buildExecutePacket(&args)
+	err := stmt.buildExecutePacket(args)
 	if err != nil {
 		return nil, err
 	}
@@ -80,7 +80,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 	}
 
 	// Send command
-	err := stmt.buildExecutePacket(&args)
+	err := stmt.buildExecutePacket(args)
 	if err != nil {
 		return nil, err
 	}

+ 52 - 96
utils.go

@@ -10,15 +10,11 @@
 package mysql
 
 import (
-	"bytes"
 	"crypto/sha1"
-	"encoding/binary"
 	"io"
 	"log"
-	"math"
 	"os"
 	"regexp"
-	"strconv"
 	"strings"
 )
 
@@ -85,9 +81,9 @@ func parseDSN(dsn string) *config {
 
 // Encrypt password using 4.1+ method
 // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
-func scramblePassword(scramble, password []byte) (result []byte) {
+func scramblePassword(scramble, password []byte) []byte {
 	if len(password) == 0 {
-		return
+		return nil
 	}
 
 	// stage1Hash = SHA1(password)
@@ -108,105 +104,81 @@ func scramblePassword(scramble, password []byte) (result []byte) {
 	scrambleHash = crypt.Sum(nil)
 
 	// token = scrambleHash XOR stage1Hash
-	result = make([]byte, 20)
+	result := make([]byte, 20)
 	for i := range result {
 		result[i] = scrambleHash[i] ^ stage1Hash[i]
 	}
-	return
+	return result[0:]
 }
 
 /******************************************************************************
-*                       Read data-types from bytes                            *
+*                       Convert from and to bytes                             *
 ******************************************************************************/
 
-// Read a slice from the data slice
-func readSlice(data []byte, delim byte) (slice []byte, err error) {
-	pos := bytes.IndexByte(data, delim)
-	if pos > -1 {
-		slice = data[:pos]
-	} else {
-		slice = data
-		err = io.EOF
+func uint24ToBytes(n uint32) []byte {
+	return []byte{
+		byte(n),
+		byte(n >> 8),
+		byte(n >> 16),
 	}
-	return
 }
 
-func readLengthEnodedString(data []byte) ([]byte, bool, int, error) {
-	// Get length
-	num, isNull, n, err := bytesToLengthEncodedInteger(data)
-	if err != nil || isNull {
-		return nil, isNull, n, err
+func uint32ToBytes(n uint32) []byte {
+	return []byte{
+		byte(n),
+		byte(n >> 8),
+		byte(n >> 16),
+		byte(n >> 24),
 	}
+}
 
-	// Check data length
-	if len(data) < n+int(num) {
-		return nil, true, n, io.EOF
+func uint64ToBytes(n uint64) []byte {
+	return []byte{
+		byte(n),
+		byte(n >> 8),
+		byte(n >> 16),
+		byte(n >> 24),
+		byte(n >> 32),
+		byte(n >> 40),
+		byte(n >> 48),
+		byte(n >> 56),
 	}
-
-	return data[n : n+int(num)], isNull, n + int(num), err
 }
 
-func readAndDropLengthEnodedString(data []byte) (n int, err error) {
+func readLengthEnodedString(b []byte) ([]byte, int, error) {
 	// Get length
-	num, _, n, err := bytesToLengthEncodedInteger(data)
+	num, _, n, err := readLengthEncodedInteger(b)
 	if err != nil || num < 1 {
-		return n, err
+		return nil, n, err
 	}
 
+	n += int(num)
+
 	// Check data length
-	if len(data) < n+int(num) {
-		return n, io.EOF
+	if len(b) >= n {
+		return b[n-int(num) : n], n, err
 	}
-
-	return n + int(num), err
+	return nil, n, io.EOF
 }
 
-/******************************************************************************
-*                       Convert from and to bytes                             *
-******************************************************************************/
-
-func uint24ToBytes(n uint32) (b []byte) {
-	b = make([]byte, 3)
-	for i := uint8(0); i < 3; i++ {
-		b[i] = byte(n >> (i << 3))
+func readAndDropLengthEnodedString(b []byte) (n int, err error) {
+	// Get length
+	num, _, n, err := readLengthEncodedInteger(b)
+	if err != nil || num < 1 {
+		return
 	}
-	return
-}
 
-func uint32ToBytes(n uint32) (b []byte) {
-	b = make([]byte, 4)
-	for i := uint8(0); i < 4; i++ {
-		b[i] = byte(n >> (i << 3))
-	}
-	return
-}
+	n += int(num)
 
-func uint64ToBytes(n uint64) (b []byte) {
-	b = make([]byte, 8)
-	for i := uint8(0); i < 8; i++ {
-		b[i] = byte(n >> (i << 3))
+	// Check data length
+	if len(b) >= n {
+		return
 	}
-	return
-}
-
-func int64ToBytes(n int64) []byte {
-	return uint64ToBytes(uint64(n))
+	return n, io.EOF
 }
 
-func bytesToFloat32(b []byte) float32 {
-	return math.Float32frombits(binary.LittleEndian.Uint32(b))
-}
-
-func bytesToFloat64(b []byte) float64 {
-	return math.Float64frombits(binary.LittleEndian.Uint64(b))
-}
-
-func float64ToBytes(f float64) []byte {
-	return uint64ToBytes(math.Float64bits(f))
-}
-
-func bytesToLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int, err error) {
-	switch b[0] {
+func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int, err error) {
+	switch (b)[0] {
 
 	// 251: NULL
 	case 0xfb:
@@ -248,32 +220,16 @@ func bytesToLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int, err
 	return
 }
 
-func lengthEncodedIntegerToBytes(n uint64) (b []byte) {
+func lengthEncodedIntegerToBytes(n uint64) []byte {
 	switch {
 	case n <= 250:
-		b = []byte{byte(n)}
+		return []byte{byte(n)}
 
 	case n <= 0xffff:
-		b = []byte{0xfc, byte(n), byte(n >> 8)}
+		return []byte{0xfc, byte(n), byte(n >> 8)}
 
 	case n <= 0xffffff:
-		b = []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
+		return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
 	}
-	return
-}
-
-func intToByteStr(i int64) (b []byte) {
-	return strconv.AppendInt(b, i, 10)
-}
-
-func uintToByteStr(u uint64) (b []byte) {
-	return strconv.AppendUint(b, u, 10)
-}
-
-func float32ToByteStr(f float32) (b []byte) {
-	return strconv.AppendFloat(b, float64(f), 'f', -1, 32)
-}
-
-func float64ToByteStr(f float64) (b []byte) {
-	return strconv.AppendFloat(b, f, 'f', -1, 64)
+	return nil
 }