瀏覽代碼

optimized command execution

Also inlined packet docs
Julien Schmidt 12 年之前
父節點
當前提交
80ad073efe
共有 5 個文件被更改,包括 328 次插入406 次删除
  1. 69 67
      connection.go
  2. 2 2
      const.go
  3. 241 296
      packets.go
  4. 6 14
      statement.go
  5. 10 27
      utils.go

+ 69 - 67
connection.go

@@ -47,13 +47,10 @@ func (mc *mysqlConn) handleParams() (err error) {
 			charsets := strings.Split(val, ",")
 			for _, charset := range charsets {
 				err = mc.exec("SET NAMES " + charset)
-				if err == nil {
-					break
+				if err != nil {
+					return
 				}
 			}
-			if err != nil {
-				return
-			}
 
 		// TLS-Encryption
 		case "tls":
@@ -78,11 +75,11 @@ func (mc *mysqlConn) handleParams() (err error) {
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
 	err := mc.exec("START TRANSACTION")
-	if err != nil {
-		return nil, err
+	if err == nil {
+		return &mysqlTx{mc}, err
 	}
 
-	return &mysqlTx{mc}, err
+	return nil, err
 }
 
 func (mc *mysqlConn) Close() (err error) {
@@ -96,7 +93,7 @@ func (mc *mysqlConn) Close() (err error) {
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	// Send command
-	err := mc.writeCommandPacket(COM_STMT_PREPARE, query)
+	err := mc.writeCommandPacketStr(COM_STMT_PREPARE, query)
 	if err != nil {
 		return nil, err
 	}
@@ -106,52 +103,54 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	}
 
 	// Read Result
-	var columnCount uint16
-	columnCount, err = stmt.readPrepareResultPacket()
-	if err != nil {
-		return nil, err
-	}
-
-	if stmt.paramCount > 0 {
-		stmt.params, err = stmt.mc.readColumns(stmt.paramCount)
-		if err != nil {
-			return nil, err
+	columnCount, err := stmt.readPrepareResultPacket()
+	if err == nil {
+		if stmt.paramCount > 0 {
+			stmt.params, err = stmt.mc.readColumns(stmt.paramCount)
+			if err != nil {
+				return nil, err
+			}
 		}
-	}
 
-	if columnCount > 0 {
-		_, err = stmt.mc.readUntilEOF()
-		if err != nil {
-			return nil, err
+		if columnCount > 0 {
+			_, err = stmt.mc.readUntilEOF()
 		}
 	}
 
 	return stmt, err
 }
 
-func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
-	if len(args) > 0 {
-		return nil, driver.ErrSkip
-	}
-
-	mc.affectedRows = 0
-	mc.insertId = 0
-
-	err := mc.exec(query)
-	if err != nil {
-		return nil, err
+func (mc *mysqlConn) Exec(query string, args []driver.Value) (_ driver.Result, err error) {
+	if len(args) > 0 { // with args, must use prepared stmt
+		var res driver.Result
+		var stmt driver.Stmt
+		stmt, err = mc.Prepare(query)
+		if err == nil {
+			res, err = stmt.Exec(args)
+			if err == nil {
+				return res, stmt.Close()
+			}
+		}
+	} else { // no args, fastpath
+		mc.affectedRows = 0
+		mc.insertId = 0
+
+		err = mc.exec(query)
+		if err == nil {
+			return &mysqlResult{
+				affectedRows: int64(mc.affectedRows),
+				insertId:     int64(mc.insertId),
+			}, err
+		}
 	}
+	return nil, err
 
-	return &mysqlResult{
-		affectedRows: int64(mc.affectedRows),
-		insertId:     int64(mc.insertId),
-	}, err
 }
 
 // Internal function to execute commands
 func (mc *mysqlConn) exec(query string) (err error) {
 	// Send command
-	err = mc.writeCommandPacket(COM_QUERY, query)
+	err = mc.writeCommandPacketStr(COM_QUERY, query)
 	if err != nil {
 		return
 	}
@@ -175,39 +174,42 @@ func (mc *mysqlConn) exec(query string) (err error) {
 		}
 
 		mc.affectedRows, err = mc.readUntilEOF()
-		return
 	}
 
 	return
 }
 
-func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
-	if len(args) > 0 {
-		return nil, driver.ErrSkip
-	}
-
-	// Send command
-	err := mc.writeCommandPacket(COM_QUERY, query)
-	if err != nil {
-		return nil, err
-	}
-
-	// Read Result
-	var resLen int
-	resLen, err = mc.readResultSetHeaderPacket()
-	if err != nil {
-		return nil, err
-	}
-
-	rows := &mysqlRows{mc, false, nil, false}
-
-	if resLen > 0 {
-		// Columns
-		rows.columns, err = mc.readColumns(resLen)
-		if err != nil {
-			return nil, err
+func (mc *mysqlConn) Query(query string, args []driver.Value) (_ driver.Rows, err error) {
+	if len(args) > 0 { // with args, must use prepared stmt
+		var rows driver.Rows
+		var stmt driver.Stmt
+		stmt, err = mc.Prepare(query)
+		if err == nil {
+			rows, err = stmt.Query(args)
+			if err == nil {
+				return rows, stmt.Close()
+			}
+		}
+		return
+	} else { // no args, fastpath
+		var rows *mysqlRows
+		// Send command
+		err = mc.writeCommandPacketStr(COM_QUERY, query)
+		if err == nil {
+			// Read Result
+			var resLen int
+			resLen, err = mc.readResultSetHeaderPacket()
+			if err == nil {
+				rows = &mysqlRows{mc, false, nil, false}
+
+				if resLen > 0 {
+					// Columns
+					rows.columns, err = mc.readColumns(resLen)
+				}
+				return rows, err
+			}
 		}
 	}
 
-	return rows, err
+	return nil, err
 }

+ 2 - 2
const.go

@@ -11,8 +11,8 @@ package mysql
 
 const (
 	MIN_PROTOCOL_VERSION = 10
-	MAX_PACKET_SIZE      = 1<<24 - 1
-	TIME_FORMAT          = "2006-01-02 15:04:05"
+	//MAX_PACKET_SIZE      = 1<<24 - 1
+	TIME_FORMAT = "2006-01-02 15:04:05"
 )
 
 // MySQL constants documentation:

+ 241 - 296
packets.go

@@ -25,7 +25,7 @@ import (
 
 // Read packet to buffer 'data'
 func (mc *mysqlConn) readPacket() (data []byte, err error) {
-	// Read header
+	// Read packet header
 	data = make([]byte, 4)
 	err = mc.buf.read(data)
 	if err != nil {
@@ -63,6 +63,8 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	return nil, driver.ErrBadConn
 }
 
+// Write packet buffer 'data'
+// The packet header must be already included
 func (mc *mysqlConn) writePacket(data []byte) error {
 	// Write packet
 	n, err := mc.netConn.Write(data)
@@ -137,17 +139,8 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	return
 }
 
-/* Client Authentication Packet
-Bytes                        Name
------                        ----
-4                            client_flags
-4                            max_packet_size
-1                            charset_number
-23                           (filler) always 0x00...
-n (Null-Terminated String)   user
-n (Length Coded Binary)      scramble_buff (1 + x bytes)
-n (Null-Terminated String)   databasename (optional)
-*/
+// Client Authentication Packet
+// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
 func (mc *mysqlConn) writeAuthPacket() error {
 	// Adjust client flags based on server support
 	clientFlags := uint32(
@@ -162,6 +155,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 
 	// User Password
 	scrambleBuff := scramblePassword(mc.scrambleBuff, []byte(mc.cfg.passwd))
+	mc.scrambleBuff = nil
 
 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
 
@@ -172,43 +166,47 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	}
 
 	// Calculate packet length and make buffer with that size
-	data := make([]byte, 0, pktLen+4)
+	data := make([]byte, pktLen+4)
 
-	// Add the packet header
-	data = append(data, uint24ToBytes(uint32(pktLen))...)
-	data = append(data, mc.sequence)
+	// Add the packet header  [24bit length + 1 byte sequence]
+	data[0] = byte(pktLen)
+	data[1] = byte(pktLen >> 8)
+	data[2] = byte(pktLen >> 16)
+	data[3] = mc.sequence
 
-	// ClientFlags
-	data = append(data, uint32ToBytes(clientFlags)...)
+	// ClientFlags [32 bit]
+	data[4] = byte(clientFlags)
+	data[5] = byte(clientFlags >> 8)
+	data[6] = byte(clientFlags >> 16)
+	data[7] = byte(clientFlags >> 24)
 
-	// MaxPacketSize
-	data = append(data, uint32ToBytes(MAX_PACKET_SIZE)...)
+	// MaxPacketSize [32 bit] (1<<24 - 1)
+	data[8] = 0xff
+	data[9] = 0xff
+	data[10] = 0xff
+	//data[11] = 0x00
 
-	// Charset
-	data = append(data, mc.charset)
+	// Charset [1 byte]
+	data[12] = mc.charset
 
-	// Filler
-	data = append(data, make([]byte, 23)...)
+	// Filler [23 byte] (all 0x00)
+	pos := 13 + 23
 
-	// User
+	// User [null terminated string]
 	if len(mc.cfg.user) > 0 {
-		data = append(data, []byte(mc.cfg.user)...)
+		pos += copy(data[pos:], mc.cfg.user)
 	}
+	//data[pos] = 0x00
+	pos++
 
-	// Null-Terminator
-	data = append(data, 0x0)
-
-	// ScrambleBuffer
-	data = append(data, byte(len(scrambleBuff)))
-	if len(scrambleBuff) > 0 {
-		data = append(data, scrambleBuff...)
-	}
+	// ScrambleBuffer [length encoded integer]
+	data[pos] = byte(len(scrambleBuff))
+	pos += 1 + copy(data[pos+1:], scrambleBuff)
 
-	// Databasename
+	// Databasename [null terminated string]
 	if len(mc.cfg.dbname) > 0 {
-		data = append(data, []byte(mc.cfg.dbname)...)
-		// Null-Terminator
-		data = append(data, 0x0)
+		pos += copy(data[pos:], mc.cfg.dbname)
+		//data[pos] = 0x00
 	}
 
 	// Send Auth packet
@@ -219,62 +217,69 @@ func (mc *mysqlConn) writeAuthPacket() error {
 *                             Command Packets                                 *
 ******************************************************************************/
 
-/* Command Packet
-Bytes                        Name
------                        ----
-1                            command
-n                            arg
-*/
-func (mc *mysqlConn) writeCommandPacket(command commandType, args ...interface{}) error {
+func (mc *mysqlConn) writeCommandPacket(command commandType) error {
 	// Reset Packet Sequence
 	mc.sequence = 0
 
-	var arg []byte
-
-	switch command {
-
-	// Commands without args
-	case COM_QUIT, COM_PING:
-		if len(args) > 0 {
-			return fmt.Errorf("Too much arguments (Got: %d Has: 0)", len(args))
-		}
-		arg = []byte{}
-
-	// Commands with 1 arg unterminated string
-	case COM_QUERY, COM_STMT_PREPARE:
-		if len(args) != 1 {
-			return fmt.Errorf("Invalid arguments count (Got: %d Has: 1)", len(args))
-		}
-		arg = []byte(args[0].(string))
-
-	// Commands with 1 arg 32 bit uint
-	case COM_STMT_CLOSE:
-		if len(args) != 1 {
-			return fmt.Errorf("Invalid arguments count (Got: %d Has: 1)", len(args))
-		}
-		arg = uint32ToBytes(args[0].(uint32))
+	// Send CMD packet
+	return mc.writePacket([]byte{
+		// Add the packet header [24bit length + 1 byte sequence]
+		0x05, // 5 bytes long
+		0x00,
+		0x00,
+		mc.sequence,
+
+		// Add command byte
+		byte(command),
+	})
+}
 
-	default:
-		return fmt.Errorf("Unknown command: %d", command)
-	}
+func (mc *mysqlConn) writeCommandPacketStr(command commandType, arg string) error {
+	// Reset Packet Sequence
+	mc.sequence = 0
 
 	pktLen := 1 + len(arg)
-	data := make([]byte, 0, pktLen+4)
+	data := make([]byte, pktLen+4)
 
-	// Add the packet header
-	data = append(data, uint24ToBytes(uint32(pktLen))...)
-	data = append(data, mc.sequence)
+	// Add the packet header [24bit length + 1 byte sequence]
+	data[0] = byte(pktLen)
+	data[1] = byte(pktLen >> 8)
+	data[2] = byte(pktLen >> 16)
+	data[3] = mc.sequence
 
 	// Add command byte
-	data = append(data, byte(command))
+	data[4] = byte(command)
 
 	// Add arg
-	data = append(data, arg...)
+	copy(data[5:], arg)
 
 	// Send CMD packet
 	return mc.writePacket(data)
 }
 
+func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) error {
+	// Reset Packet Sequence
+	mc.sequence = 0
+
+	// Send CMD packet
+	return mc.writePacket([]byte{
+		// Add the packet header [24bit length + 1 byte sequence]
+		0x05, // 5 bytes long
+		0x00,
+		0x00,
+		mc.sequence,
+
+		// Add command byte
+		byte(command),
+
+		// Add arg [32 bit]
+		byte(arg),
+		byte(arg >> 8),
+		byte(arg >> 16),
+		byte(arg >> 24),
+	})
+}
+
 /******************************************************************************
 *                              Result Packets                                 *
 ******************************************************************************/
@@ -289,28 +294,49 @@ func (mc *mysqlConn) readResultOK() error {
 	switch data[0] {
 	// OK
 	case 0:
-		return mc.handleOkPacket(data)
+		mc.handleOkPacket(data)
+		return nil
 	// EOF, someone is using old_passwords
 	case 254:
 		return errOldPassword
+	}
 	// ERROR
-	case 255:
-		return mc.handleErrorPacket(data)
+	return mc.handleErrorPacket(data)
+}
+
+// Result Set Header Packet
+// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
+func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
+	data, err := mc.readPacket()
+	if err != nil {
+		return 0, err
+	}
+
+	if data[0] == 0 {
+		mc.handleOkPacket(data)
+		return 0, nil
+	} else if data[0] == 255 {
+		return 0, mc.handleErrorPacket(data)
 	}
 
-	return errMalformPkt
+	// column count
+	num, _, n := readLengthEncodedInteger(data)
+	if n-len(data) == 0 {
+		return int(num), nil
+	}
+
+	return 0, errMalformPkt
 }
 
-/* Error Packet
-Bytes                       Name
------                       ----
-1                           field_count, always = 0xff
-2                           errno
-1                           (sqlstate marker), always '#'
-5                           sqlstate (5 characters)
-n                           message
-*/
+// Error Packet
+// http://dev.mysql.com/doc/internals/en/overview.html#packet-ERR_Packet
 func (mc *mysqlConn) handleErrorPacket(data []byte) error {
+	if data[0] != 255 {
+		return errMalformPkt
+	}
+
+	// 0xff [1 byte]
+
 	// Error Number [16 bit uint]
 	errno := binary.LittleEndian.Uint16(data[1:3])
 
@@ -321,80 +347,33 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 	return fmt.Errorf("Error %d: %s", errno, string(data[9:]))
 }
 
-/* Ok Packet
-Bytes                       Name
------                       ----
-1   (Length Coded Binary)   field_count, always = 0
-1-9 (Length Coded Binary)   affected_rows
-1-9 (Length Coded Binary)   insert_id
-2                           server_status
-2                           warning_count
-n   (until end of packet)   message
-*/
-func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
+// Ok Packet
+// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
+func (mc *mysqlConn) handleOkPacket(data []byte) {
 	var n int
 
+	// 0x00 [1 byte]
+
 	// Affected rows [Length Coded Binary]
-	mc.affectedRows, _, n, err = readLengthEncodedInteger(data[1:])
-	if err != nil {
-		return
-	}
+	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
 
 	// Insert id [Length Coded Binary]
-	mc.insertId, _, _, err = readLengthEncodedInteger(data[1+n:])
-	if err != nil {
-		return
-	}
+	mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
 
-	// Skip remaining data
-	return
-}
-
-/* Result Set Header Packet
- Bytes                        Name
- -----                        ----
- 1-9   (Length-Coded-Binary)  field_count
- 1-9   (Length-Coded-Binary)  extra
-
-The order of packets for a result set is:
-  (Result Set Header Packet)  the number of columns
-  (Field Packets)             column descriptors
-  (EOF Packet)                marker: end of Field Packets
-  (Row Data Packets)          row contents
-  (EOF Packet)                marker: end of Data Packets
-*/
-func (mc *mysqlConn) readResultSetHeaderPacket() (fieldCount int, err error) {
-	data, err := mc.readPacket()
-	if err != nil {
-		errLog.Print(err)
-		err = driver.ErrBadConn
-		return
-	}
-
-	if data[0] == 0 {
-		err = mc.handleOkPacket(data)
-		return
-	} else if data[0] == 255 {
-		err = mc.handleErrorPacket(data)
-		return
-	}
-
-	num, _, n, err := readLengthEncodedInteger(data)
-	if err != nil || (n-len(data)) != 0 {
-		err = errors.New("Malformed Packet")
-		return
-	}
-
-	fieldCount = int(num)
-	return
+	// server_status [2 bytes]
+	// warning count [2 bytes]
+	// message [until end of packet]
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
+// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41
 func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 	var data []byte
-	var pos, n int
+	var i, pos, n int
 	var name []byte
 
+	columns = make([]mysqlField, count)
+
 	for {
 		data, err = mc.readPacket()
 		if err != nil {
@@ -403,34 +382,34 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 
 		// EOF Packet
 		if data[0] == 254 && len(data) == 5 {
-			if len(columns) != count {
+			if i != count {
 				err = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
 			}
 			return
 		}
 
 		// Catalog
-		pos, err = readAndDropLengthEnodedString(data)
+		pos, err = skipLengthEnodedString(data)
 		if err != nil {
 			return
 		}
 
 		// Database [len coded string]
-		n, err = readAndDropLengthEnodedString(data[pos:])
+		n, err = skipLengthEnodedString(data[pos:])
 		if err != nil {
 			return
 		}
 		pos += n
 
 		// Table [len coded string]
-		n, err = readAndDropLengthEnodedString(data[pos:])
+		n, err = skipLengthEnodedString(data[pos:])
 		if err != nil {
 			return
 		}
 		pos += n
 
 		// Original table [len coded string]
-		n, err = readAndDropLengthEnodedString(data[pos:])
+		n, err = skipLengthEnodedString(data[pos:])
 		if err != nil {
 			return
 		}
@@ -441,10 +420,11 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		if err != nil {
 			return
 		}
+		columns[i].name = string(name)
 		pos += n
 
 		// Original name [len coded string]
-		n, err = readAndDropLengthEnodedString(data[pos:])
+		n, err = skipLengthEnodedString(data[pos:])
 		if err != nil {
 			return
 		}
@@ -455,11 +435,11 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		pos += n + 1 + 2 + 4
 
 		// Field type [byte]
-		fieldType := FieldType(data[pos])
+		columns[i].fieldType = FieldType(data[pos])
 		pos++
 
 		// Flags [16 bit uint]
-		flags := FieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+		columns[i].flags = FieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
 		//pos += 2
 
 		// Decimals [8 bit uint]
@@ -470,13 +450,14 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		//	defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
 		//}
 
-		columns = append(columns, mysqlField{name: string(name), fieldType: fieldType, flags: flags})
+		i++
 	}
 
 	return
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
+// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow
 func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
 	data, err := rows.mc.readPacket()
 	if err != nil {
@@ -526,34 +507,8 @@ func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
 *                           Prepared Statements                               *
 ******************************************************************************/
 
-/* Prepare Result Packets
- Type Of Result Packet       Hexadecimal Value Of First Byte (field_count)
- ---------------------       ---------------------------------------------
-
- Prepare OK Packet           00
- Error Packet                ff
-
-Prepare OK Packet
- Bytes              Name
- -----              ----
- 1                  0 - marker for OK packet
- 4                  statement_handler_id
- 2                  number of columns in result set
- 2                  number of parameters in query
- 1                  filler (always 0)
- 2                  warning count
-
- It is made up of:
-
-    a PREPARE_OK packet
-    if "number of parameters" > 0
-        (field packets) as in a Result Set Header Packet
-        (EOF packet)
-    if "number of columns" > 0
-        (field packets) as in a Result Set Header Packet
-        (EOF packet)
-
-*/
+// Prepare Result Packets
+// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
 func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
 	data, err := stmt.mc.readPacket()
 	if err != nil {
@@ -563,12 +518,14 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
 	// Position
 	pos := 0
 
-	if data[pos] != 0 {
+	// packet marker [1 byte]
+	if data[pos] != 0 { // not OK (0) ?
 		err = stmt.mc.handleErrorPacket(data)
 		return
 	}
 	pos++
 
+	// statement id [4 bytes]
 	stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
 	pos += 4
 
@@ -586,103 +543,85 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
 	return
 }
 
-/* Command Packet
-Bytes                Name
------                ----
-1                    code
-4                    statement_id
-1                    flags
-4                    iteration_count
-  if param_count > 0:
-(param_count+7)/8    null_bit_map
-1                    new_parameter_bound_flag
-  if new_params_bound == 1:
-n*2                  type of parameters
-n                    values for the parameters
-*/
-func (stmt *mysqlStmt) buildExecutePacket(args []driver.Value) error {
-	argsLen := len(args)
-	if argsLen != stmt.paramCount {
+// Execute Prepared Statement
+// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute
+func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
+	if len(args) != stmt.paramCount {
 		return fmt.Errorf(
 			"Arguments count mismatch (Got: %d Has: %d",
-			argsLen,
+			len(args),
 			stmt.paramCount)
 	}
 
 	// Reset packet-sequence
 	stmt.mc.sequence = 0
 
-	pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (argsLen << 1)
-	paramValues := make([][]byte, 0, argsLen)
-	paramTypes := make([]byte, 0, (argsLen << 1))
+	pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (stmt.paramCount << 1)
+	paramValues := make([][]byte, stmt.paramCount)
+	paramTypes := make([]byte, (stmt.paramCount << 1))
 	bitMask := uint64(0)
-	var i, valLen int
+	var i int
+
 	for i = range args {
-		// build nullBitMap
+		// build NULL-bitmap
 		if args[i] == nil {
 			bitMask += 1 << uint(i)
+			paramTypes[i<<1] = byte(FIELD_TYPE_NULL)
+			continue
 		}
 
 		// cache types and values
 		switch args[i].(type) {
-		case nil:
-			paramTypes = append(paramTypes, []byte{
-				byte(FIELD_TYPE_NULL),
-				0x0}...)
-			continue
-
 		case int64:
-			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_LONGLONG), 0x0}...)
-			val := uint64ToBytes(uint64(args[i].(int64)))
-			pktLen += len(val)
-			paramValues = append(paramValues, val)
+			paramTypes[i<<1] = byte(FIELD_TYPE_LONGLONG)
+			paramValues[i] = uint64ToBytes(uint64(args[i].(int64)))
+			pktLen += 8
 			continue
 
 		case float64:
-			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_DOUBLE), 0x0}...)
-			val := uint64ToBytes(math.Float64bits(args[i].(float64)))
-			pktLen += len(val)
-			paramValues = append(paramValues, val)
+			paramTypes[i<<1] = byte(FIELD_TYPE_DOUBLE)
+			paramValues[i] = uint64ToBytes(math.Float64bits(args[i].(float64)))
+			pktLen += 8
 			continue
 
 		case bool:
-			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_TINY), 0x0}...)
+			paramTypes[i<<1] = byte(FIELD_TYPE_TINY)
 			pktLen++
 			if args[i].(bool) {
-				paramValues = append(paramValues, []byte{byte(1)})
+				paramValues[i] = []byte{0x01}
 			} else {
-				paramValues = append(paramValues, []byte{byte(0)})
+				paramValues[i] = []byte{0x00}
 			}
 			continue
 
 		case []byte:
-			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
+			paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
 			val := args[i].([]byte)
-			valLen = len(val)
-			lcb := lengthEncodedIntegerToBytes(uint64(valLen))
-			pktLen += len(lcb) + valLen
-			paramValues = append(paramValues, lcb)
-			paramValues = append(paramValues, val)
+			paramValues[i] = append(
+				lengthEncodedIntegerToBytes(uint64(len(val))),
+				val...,
+			)
+			pktLen += len(paramValues[i])
 			continue
 
 		case string:
-			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
+			paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
 			val := []byte(args[i].(string))
-			valLen = len(val)
-			lcb := lengthEncodedIntegerToBytes(uint64(valLen))
-			pktLen += valLen + len(lcb)
-			paramValues = append(paramValues, lcb)
-			paramValues = append(paramValues, val)
+			paramValues[i] = append(
+				lengthEncodedIntegerToBytes(uint64(len(val))),
+				val...,
+			)
+			pktLen += len(paramValues[i])
 			continue
 
 		case time.Time:
-			paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
+			paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
 			val := []byte(args[i].(time.Time).Format(TIME_FORMAT))
-			valLen = len(val)
-			lcb := lengthEncodedIntegerToBytes(uint64(valLen))
-			pktLen += valLen + len(lcb)
-			paramValues = append(paramValues, lcb)
-			paramValues = append(paramValues, val)
+			paramValues[i] = append(
+				lengthEncodedIntegerToBytes(uint64(len(val))),
+				val...,
+			)
+			pktLen += len(paramValues[i])
 			continue
 
 		default:
@@ -690,44 +629,51 @@ func (stmt *mysqlStmt) buildExecutePacket(args []driver.Value) error {
 		}
 	}
 
-	data := make([]byte, 0, pktLen+4)
+	data := make([]byte, pktLen+4)
 
-	// Add the packet header
-	data = append(data, uint24ToBytes(uint32(pktLen))...)
-	data = append(data, stmt.mc.sequence)
+	// packet header [4 bytes]
+	data[0] = byte(pktLen)
+	data[1] = byte(pktLen >> 8)
+	data[2] = byte(pktLen >> 16)
+	data[3] = stmt.mc.sequence
 
-	// code [1 byte]
-	data = append(data, byte(COM_STMT_EXECUTE))
+	// command [1 byte]
+	data[4] = byte(COM_STMT_EXECUTE)
 
 	// statement_id [4 bytes]
-	data = append(data, uint32ToBytes(stmt.id)...)
+	data[5] = byte(stmt.id)
+	data[6] = byte(stmt.id >> 8)
+	data[7] = byte(stmt.id >> 16)
+	data[8] = byte(stmt.id >> 24)
 
 	// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
-	data = append(data, byte(0))
+	//data[9] = 0x00
 
-	// iteration_count [4 bytes]
-	data = append(data, uint32ToBytes(1)...)
+	// iteration_count (uint32(1)) [4 bytes]
+	data[10] = 0x01
+	//data[11] = 0x00
+	//data[12] = 0x00
+	//data[13] = 0x00
 
-	// append nullBitMap [(param_count+7)/8 bytes]
 	if stmt.paramCount > 0 {
+		// NULL-bitmap [(param_count+7)/8 bytes]
+		pos := 14 + ((stmt.paramCount + 7) >> 3)
 		// Convert bitMask to bytes
-		nullBitMap := make([]byte, (stmt.paramCount+7)/8)
-		for i = 0; i < len(nullBitMap); i++ {
-			nullBitMap[i] = byte(bitMask >> uint(i*8))
+		for i = 14; i < pos; i++ {
+			data[i] = byte(bitMask >> uint(i<<3))
 		}
 
-		data = append(data, nullBitMap...)
-	}
-
-	// newParameterBoundFlag 1 [1 byte]
-	data = append(data, byte(1))
+		// newParameterBoundFlag 1 [1 byte]
+		data[pos] = 0x01
+		pos++
 
-	// type of parameters [n*2 byte]
-	data = append(data, paramTypes...)
+		// type of parameters [param_count*2 byte]
+		pos += copy(data[pos:], paramTypes)
 
-	// values for the parameters [n byte]
-	for _, paramValue := range paramValues {
-		data = append(data, paramValue...)
+		// values for the parameters [n byte]
+		for i = range paramValues {
+			pos += copy(data[pos:], paramValues[i])
+		}
 	}
 
 	return stmt.mc.writePacket(data)
@@ -740,20 +686,28 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 		return
 	}
 
-	// EOF Packet
-	if data[0] == 254 && len(data) == 5 {
-		return io.EOF
+	// packet header [1 byte]
+	if data[0] != 0x00 {
+		// EOF Packet
+		if data[0] == 254 && len(data) == 5 {
+			return io.EOF
+		} else {
+			// Error otherwise
+			return rc.mc.handleErrorPacket(data)
+		}
 	}
 
-	// BinaryRowSet Packet
+	// NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
 	pos := 1 + (len(dest)+7+2)>>3
 	nullBitMap := data[1:pos]
 
+	// values [rest]
 	var n int
 	var unsigned bool
 	for i := range dest {
 		// Field is NULL
-		if (nullBitMap[(i+2)>>3] >> uint((i+2)&7) & 1) == 1 {
+		// (byte >> bit-pos) % 2 == 1
+		if ((nullBitMap[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
 			dest[i] = nil
 			continue
 		}
@@ -830,19 +784,16 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 		case FIELD_TYPE_DATE, FIELD_TYPE_NEWDATE:
 			var num uint64
 			var isNull bool
-			num, isNull, n, err = readLengthEncodedInteger(data[pos:])
-			if err != nil {
-				return
-			}
+			num, isNull, n = readLengthEncodedInteger(data[pos:])
 
 			if num == 0 {
 				if isNull {
 					dest[i] = nil
-					pos++ // n = 1
+					pos++ // always n=1
 					continue
 				} else {
 					dest[i] = []byte("0000-00-00")
-					pos++ // n = 1
+					pos += n
 					continue
 				}
 			} else {
@@ -858,19 +809,16 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 		case FIELD_TYPE_TIME:
 			var num uint64
 			var isNull bool
-			num, isNull, n, err = readLengthEncodedInteger(data[pos:])
-			if err != nil {
-				return
-			}
+			num, isNull, n = readLengthEncodedInteger(data[pos:])
 
 			if num == 0 {
 				if isNull {
 					dest[i] = nil
-					pos++ // n = 1
+					pos++ // always n=1
 					continue
 				} else {
 					dest[i] = []byte("00:00:00")
-					pos++ // n = 1
+					pos += n
 					continue
 				}
 			}
@@ -912,19 +860,16 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 		case FIELD_TYPE_TIMESTAMP, FIELD_TYPE_DATETIME:
 			var num uint64
 			var isNull bool
-			num, isNull, n, err = readLengthEncodedInteger(data[pos:])
-			if err != nil {
-				return
-			}
+			num, isNull, n = readLengthEncodedInteger(data[pos:])
 
 			if num == 0 {
 				if isNull {
 					dest[i] = nil
-					pos++ // n = 1
+					pos++ // always n=1
 					continue
 				} else {
 					dest[i] = []byte("0000-00-00 00:00:00")
-					pos++ // n = 1
+					pos += n
 					continue
 				}
 			}
@@ -939,7 +884,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 					data[pos+2],
 					data[pos+3],
 				))
-				pos += 5
+				pos += 4
 				continue
 			case 7:
 				dest[i] = []byte(fmt.Sprintf(

+ 6 - 14
statement.go

@@ -11,7 +11,6 @@ package mysql
 
 import (
 	"database/sql/driver"
-	"errors"
 )
 
 type mysqlStmt struct {
@@ -22,7 +21,7 @@ type mysqlStmt struct {
 }
 
 func (stmt *mysqlStmt) Close() (err error) {
-	err = stmt.mc.writeCommandPacket(COM_STMT_CLOSE, stmt.id)
+	err = stmt.mc.writeCommandPacketUint32(COM_STMT_CLOSE, stmt.id)
 	stmt.mc = nil
 	return
 }
@@ -32,14 +31,11 @@ func (stmt *mysqlStmt) NumInput() int {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
-	if stmt.mc == nil {
-		return nil, errors.New(`Invalid Statement`)
-	}
 	stmt.mc.affectedRows = 0
 	stmt.mc.insertId = 0
 
 	// Send command
-	err := stmt.buildExecutePacket(args)
+	err := stmt.writeExecutePacket(args)
 	if err != nil {
 		return nil, err
 	}
@@ -69,18 +65,14 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 	}
 
 	return &mysqlResult{
-			affectedRows: int64(stmt.mc.affectedRows),
-			insertId:     int64(stmt.mc.insertId)},
-		nil
+		affectedRows: int64(stmt.mc.affectedRows),
+		insertId:     int64(stmt.mc.insertId),
+	}, nil
 }
 
 func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
-	if stmt.mc == nil {
-		return nil, errors.New(`Invalid Statement`)
-	}
-
 	// Send command
-	err := stmt.buildExecutePacket(args)
+	err := stmt.writeExecutePacket(args)
 	if err != nil {
 		return nil, err
 	}

+ 10 - 27
utils.go

@@ -115,23 +115,6 @@ func scramblePassword(scramble, password []byte) []byte {
 *                       Convert from and to bytes                             *
 ******************************************************************************/
 
-func uint24ToBytes(n uint32) []byte {
-	return []byte{
-		byte(n),
-		byte(n >> 8),
-		byte(n >> 16),
-	}
-}
-
-func uint32ToBytes(n uint32) []byte {
-	return []byte{
-		byte(n),
-		byte(n >> 8),
-		byte(n >> 16),
-		byte(n >> 24),
-	}
-}
-
 func uint64ToBytes(n uint64) []byte {
 	return []byte{
 		byte(n),
@@ -147,37 +130,37 @@ func uint64ToBytes(n uint64) []byte {
 
 func readLengthEnodedString(b []byte) ([]byte, int, error) {
 	// Get length
-	num, _, n, err := readLengthEncodedInteger(b)
-	if err != nil || num < 1 {
-		return nil, n, err
+	num, _, n := readLengthEncodedInteger(b)
+	if num < 1 {
+		return nil, n, nil
 	}
 
 	n += int(num)
 
 	// Check data length
 	if len(b) >= n {
-		return b[n-int(num) : n], n, err
+		return b[n-int(num) : n], n, nil
 	}
 	return nil, n, io.EOF
 }
 
-func readAndDropLengthEnodedString(b []byte) (n int, err error) {
+func skipLengthEnodedString(b []byte) (int, error) {
 	// Get length
-	num, _, n, err := readLengthEncodedInteger(b)
-	if err != nil || num < 1 {
-		return
+	num, _, n := readLengthEncodedInteger(b)
+	if num < 1 {
+		return n, nil
 	}
 
 	n += int(num)
 
 	// Check data length
 	if len(b) >= n {
-		return
+		return n, nil
 	}
 	return n, io.EOF
 }
 
-func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int, err error) {
+func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) {
 	switch (b)[0] {
 
 	// 251: NULL