|
|
@@ -25,7 +25,7 @@ import (
|
|
|
|
|
|
// Read packet to buffer 'data'
|
|
|
func (mc *mysqlConn) readPacket() (data []byte, err error) {
|
|
|
- // Read header
|
|
|
+ // Read packet header
|
|
|
data = make([]byte, 4)
|
|
|
err = mc.buf.read(data)
|
|
|
if err != nil {
|
|
|
@@ -63,6 +63,8 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
|
|
|
+// Write packet buffer 'data'
|
|
|
+// The packet header must be already included
|
|
|
func (mc *mysqlConn) writePacket(data []byte) error {
|
|
|
// Write packet
|
|
|
n, err := mc.netConn.Write(data)
|
|
|
@@ -137,17 +139,8 @@ func (mc *mysqlConn) readInitPacket() (err error) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-/* Client Authentication Packet
|
|
|
-Bytes Name
|
|
|
------ ----
|
|
|
-4 client_flags
|
|
|
-4 max_packet_size
|
|
|
-1 charset_number
|
|
|
-23 (filler) always 0x00...
|
|
|
-n (Null-Terminated String) user
|
|
|
-n (Length Coded Binary) scramble_buff (1 + x bytes)
|
|
|
-n (Null-Terminated String) databasename (optional)
|
|
|
-*/
|
|
|
+// Client Authentication Packet
|
|
|
+// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
|
|
|
func (mc *mysqlConn) writeAuthPacket() error {
|
|
|
// Adjust client flags based on server support
|
|
|
clientFlags := uint32(
|
|
|
@@ -162,6 +155,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
|
|
|
|
|
|
// User Password
|
|
|
scrambleBuff := scramblePassword(mc.scrambleBuff, []byte(mc.cfg.passwd))
|
|
|
+ mc.scrambleBuff = nil
|
|
|
|
|
|
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
|
|
|
|
|
|
@@ -172,43 +166,47 @@ func (mc *mysqlConn) writeAuthPacket() error {
|
|
|
}
|
|
|
|
|
|
// Calculate packet length and make buffer with that size
|
|
|
- data := make([]byte, 0, pktLen+4)
|
|
|
+ data := make([]byte, pktLen+4)
|
|
|
|
|
|
- // Add the packet header
|
|
|
- data = append(data, uint24ToBytes(uint32(pktLen))...)
|
|
|
- data = append(data, mc.sequence)
|
|
|
+ // Add the packet header [24bit length + 1 byte sequence]
|
|
|
+ data[0] = byte(pktLen)
|
|
|
+ data[1] = byte(pktLen >> 8)
|
|
|
+ data[2] = byte(pktLen >> 16)
|
|
|
+ data[3] = mc.sequence
|
|
|
|
|
|
- // ClientFlags
|
|
|
- data = append(data, uint32ToBytes(clientFlags)...)
|
|
|
+ // ClientFlags [32 bit]
|
|
|
+ data[4] = byte(clientFlags)
|
|
|
+ data[5] = byte(clientFlags >> 8)
|
|
|
+ data[6] = byte(clientFlags >> 16)
|
|
|
+ data[7] = byte(clientFlags >> 24)
|
|
|
|
|
|
- // MaxPacketSize
|
|
|
- data = append(data, uint32ToBytes(MAX_PACKET_SIZE)...)
|
|
|
+ // MaxPacketSize [32 bit] (1<<24 - 1)
|
|
|
+ data[8] = 0xff
|
|
|
+ data[9] = 0xff
|
|
|
+ data[10] = 0xff
|
|
|
+ //data[11] = 0x00
|
|
|
|
|
|
- // Charset
|
|
|
- data = append(data, mc.charset)
|
|
|
+ // Charset [1 byte]
|
|
|
+ data[12] = mc.charset
|
|
|
|
|
|
- // Filler
|
|
|
- data = append(data, make([]byte, 23)...)
|
|
|
+ // Filler [23 byte] (all 0x00)
|
|
|
+ pos := 13 + 23
|
|
|
|
|
|
- // User
|
|
|
+ // User [null terminated string]
|
|
|
if len(mc.cfg.user) > 0 {
|
|
|
- data = append(data, []byte(mc.cfg.user)...)
|
|
|
+ pos += copy(data[pos:], mc.cfg.user)
|
|
|
}
|
|
|
+ //data[pos] = 0x00
|
|
|
+ pos++
|
|
|
|
|
|
- // Null-Terminator
|
|
|
- data = append(data, 0x0)
|
|
|
-
|
|
|
- // ScrambleBuffer
|
|
|
- data = append(data, byte(len(scrambleBuff)))
|
|
|
- if len(scrambleBuff) > 0 {
|
|
|
- data = append(data, scrambleBuff...)
|
|
|
- }
|
|
|
+ // ScrambleBuffer [length encoded integer]
|
|
|
+ data[pos] = byte(len(scrambleBuff))
|
|
|
+ pos += 1 + copy(data[pos+1:], scrambleBuff)
|
|
|
|
|
|
- // Databasename
|
|
|
+ // Databasename [null terminated string]
|
|
|
if len(mc.cfg.dbname) > 0 {
|
|
|
- data = append(data, []byte(mc.cfg.dbname)...)
|
|
|
- // Null-Terminator
|
|
|
- data = append(data, 0x0)
|
|
|
+ pos += copy(data[pos:], mc.cfg.dbname)
|
|
|
+ //data[pos] = 0x00
|
|
|
}
|
|
|
|
|
|
// Send Auth packet
|
|
|
@@ -219,62 +217,69 @@ func (mc *mysqlConn) writeAuthPacket() error {
|
|
|
* Command Packets *
|
|
|
******************************************************************************/
|
|
|
|
|
|
-/* Command Packet
|
|
|
-Bytes Name
|
|
|
------ ----
|
|
|
-1 command
|
|
|
-n arg
|
|
|
-*/
|
|
|
-func (mc *mysqlConn) writeCommandPacket(command commandType, args ...interface{}) error {
|
|
|
+func (mc *mysqlConn) writeCommandPacket(command commandType) error {
|
|
|
// Reset Packet Sequence
|
|
|
mc.sequence = 0
|
|
|
|
|
|
- var arg []byte
|
|
|
-
|
|
|
- switch command {
|
|
|
-
|
|
|
- // Commands without args
|
|
|
- case COM_QUIT, COM_PING:
|
|
|
- if len(args) > 0 {
|
|
|
- return fmt.Errorf("Too much arguments (Got: %d Has: 0)", len(args))
|
|
|
- }
|
|
|
- arg = []byte{}
|
|
|
-
|
|
|
- // Commands with 1 arg unterminated string
|
|
|
- case COM_QUERY, COM_STMT_PREPARE:
|
|
|
- if len(args) != 1 {
|
|
|
- return fmt.Errorf("Invalid arguments count (Got: %d Has: 1)", len(args))
|
|
|
- }
|
|
|
- arg = []byte(args[0].(string))
|
|
|
-
|
|
|
- // Commands with 1 arg 32 bit uint
|
|
|
- case COM_STMT_CLOSE:
|
|
|
- if len(args) != 1 {
|
|
|
- return fmt.Errorf("Invalid arguments count (Got: %d Has: 1)", len(args))
|
|
|
- }
|
|
|
- arg = uint32ToBytes(args[0].(uint32))
|
|
|
+ // Send CMD packet
|
|
|
+ return mc.writePacket([]byte{
|
|
|
+ // Add the packet header [24bit length + 1 byte sequence]
|
|
|
+ 0x05, // 5 bytes long
|
|
|
+ 0x00,
|
|
|
+ 0x00,
|
|
|
+ mc.sequence,
|
|
|
+
|
|
|
+ // Add command byte
|
|
|
+ byte(command),
|
|
|
+ })
|
|
|
+}
|
|
|
|
|
|
- default:
|
|
|
- return fmt.Errorf("Unknown command: %d", command)
|
|
|
- }
|
|
|
+func (mc *mysqlConn) writeCommandPacketStr(command commandType, arg string) error {
|
|
|
+ // Reset Packet Sequence
|
|
|
+ mc.sequence = 0
|
|
|
|
|
|
pktLen := 1 + len(arg)
|
|
|
- data := make([]byte, 0, pktLen+4)
|
|
|
+ data := make([]byte, pktLen+4)
|
|
|
|
|
|
- // Add the packet header
|
|
|
- data = append(data, uint24ToBytes(uint32(pktLen))...)
|
|
|
- data = append(data, mc.sequence)
|
|
|
+ // Add the packet header [24bit length + 1 byte sequence]
|
|
|
+ data[0] = byte(pktLen)
|
|
|
+ data[1] = byte(pktLen >> 8)
|
|
|
+ data[2] = byte(pktLen >> 16)
|
|
|
+ data[3] = mc.sequence
|
|
|
|
|
|
// Add command byte
|
|
|
- data = append(data, byte(command))
|
|
|
+ data[4] = byte(command)
|
|
|
|
|
|
// Add arg
|
|
|
- data = append(data, arg...)
|
|
|
+ copy(data[5:], arg)
|
|
|
|
|
|
// Send CMD packet
|
|
|
return mc.writePacket(data)
|
|
|
}
|
|
|
|
|
|
+func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) error {
|
|
|
+ // Reset Packet Sequence
|
|
|
+ mc.sequence = 0
|
|
|
+
|
|
|
+ // Send CMD packet
|
|
|
+ return mc.writePacket([]byte{
|
|
|
+ // Add the packet header [24bit length + 1 byte sequence]
|
|
|
+ 0x05, // 5 bytes long
|
|
|
+ 0x00,
|
|
|
+ 0x00,
|
|
|
+ mc.sequence,
|
|
|
+
|
|
|
+ // Add command byte
|
|
|
+ byte(command),
|
|
|
+
|
|
|
+ // Add arg [32 bit]
|
|
|
+ byte(arg),
|
|
|
+ byte(arg >> 8),
|
|
|
+ byte(arg >> 16),
|
|
|
+ byte(arg >> 24),
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
/******************************************************************************
|
|
|
* Result Packets *
|
|
|
******************************************************************************/
|
|
|
@@ -289,28 +294,49 @@ func (mc *mysqlConn) readResultOK() error {
|
|
|
switch data[0] {
|
|
|
// OK
|
|
|
case 0:
|
|
|
- return mc.handleOkPacket(data)
|
|
|
+ mc.handleOkPacket(data)
|
|
|
+ return nil
|
|
|
// EOF, someone is using old_passwords
|
|
|
case 254:
|
|
|
return errOldPassword
|
|
|
+ }
|
|
|
// ERROR
|
|
|
- case 255:
|
|
|
- return mc.handleErrorPacket(data)
|
|
|
+ return mc.handleErrorPacket(data)
|
|
|
+}
|
|
|
+
|
|
|
+// Result Set Header Packet
|
|
|
+// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
|
|
|
+func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
|
|
|
+ data, err := mc.readPacket()
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if data[0] == 0 {
|
|
|
+ mc.handleOkPacket(data)
|
|
|
+ return 0, nil
|
|
|
+ } else if data[0] == 255 {
|
|
|
+ return 0, mc.handleErrorPacket(data)
|
|
|
}
|
|
|
|
|
|
- return errMalformPkt
|
|
|
+ // column count
|
|
|
+ num, _, n := readLengthEncodedInteger(data)
|
|
|
+ if n-len(data) == 0 {
|
|
|
+ return int(num), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return 0, errMalformPkt
|
|
|
}
|
|
|
|
|
|
-/* Error Packet
|
|
|
-Bytes Name
|
|
|
------ ----
|
|
|
-1 field_count, always = 0xff
|
|
|
-2 errno
|
|
|
-1 (sqlstate marker), always '#'
|
|
|
-5 sqlstate (5 characters)
|
|
|
-n message
|
|
|
-*/
|
|
|
+// Error Packet
|
|
|
+// http://dev.mysql.com/doc/internals/en/overview.html#packet-ERR_Packet
|
|
|
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
|
|
+ if data[0] != 255 {
|
|
|
+ return errMalformPkt
|
|
|
+ }
|
|
|
+
|
|
|
+ // 0xff [1 byte]
|
|
|
+
|
|
|
// Error Number [16 bit uint]
|
|
|
errno := binary.LittleEndian.Uint16(data[1:3])
|
|
|
|
|
|
@@ -321,80 +347,33 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
|
|
return fmt.Errorf("Error %d: %s", errno, string(data[9:]))
|
|
|
}
|
|
|
|
|
|
-/* Ok Packet
|
|
|
-Bytes Name
|
|
|
------ ----
|
|
|
-1 (Length Coded Binary) field_count, always = 0
|
|
|
-1-9 (Length Coded Binary) affected_rows
|
|
|
-1-9 (Length Coded Binary) insert_id
|
|
|
-2 server_status
|
|
|
-2 warning_count
|
|
|
-n (until end of packet) message
|
|
|
-*/
|
|
|
-func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
|
|
|
+// Ok Packet
|
|
|
+// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
|
|
|
+func (mc *mysqlConn) handleOkPacket(data []byte) {
|
|
|
var n int
|
|
|
|
|
|
+ // 0x00 [1 byte]
|
|
|
+
|
|
|
// Affected rows [Length Coded Binary]
|
|
|
- mc.affectedRows, _, n, err = readLengthEncodedInteger(data[1:])
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
+ mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
|
|
|
|
|
|
// Insert id [Length Coded Binary]
|
|
|
- mc.insertId, _, _, err = readLengthEncodedInteger(data[1+n:])
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
+ mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
|
|
|
|
|
|
- // Skip remaining data
|
|
|
- return
|
|
|
-}
|
|
|
-
|
|
|
-/* Result Set Header Packet
|
|
|
- Bytes Name
|
|
|
- ----- ----
|
|
|
- 1-9 (Length-Coded-Binary) field_count
|
|
|
- 1-9 (Length-Coded-Binary) extra
|
|
|
-
|
|
|
-The order of packets for a result set is:
|
|
|
- (Result Set Header Packet) the number of columns
|
|
|
- (Field Packets) column descriptors
|
|
|
- (EOF Packet) marker: end of Field Packets
|
|
|
- (Row Data Packets) row contents
|
|
|
- (EOF Packet) marker: end of Data Packets
|
|
|
-*/
|
|
|
-func (mc *mysqlConn) readResultSetHeaderPacket() (fieldCount int, err error) {
|
|
|
- data, err := mc.readPacket()
|
|
|
- if err != nil {
|
|
|
- errLog.Print(err)
|
|
|
- err = driver.ErrBadConn
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- if data[0] == 0 {
|
|
|
- err = mc.handleOkPacket(data)
|
|
|
- return
|
|
|
- } else if data[0] == 255 {
|
|
|
- err = mc.handleErrorPacket(data)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- num, _, n, err := readLengthEncodedInteger(data)
|
|
|
- if err != nil || (n-len(data)) != 0 {
|
|
|
- err = errors.New("Malformed Packet")
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- fieldCount = int(num)
|
|
|
- return
|
|
|
+ // server_status [2 bytes]
|
|
|
+ // warning count [2 bytes]
|
|
|
+ // message [until end of packet]
|
|
|
}
|
|
|
|
|
|
// Read Packets as Field Packets until EOF-Packet or an Error appears
|
|
|
+// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41
|
|
|
func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
|
|
|
var data []byte
|
|
|
- var pos, n int
|
|
|
+ var i, pos, n int
|
|
|
var name []byte
|
|
|
|
|
|
+ columns = make([]mysqlField, count)
|
|
|
+
|
|
|
for {
|
|
|
data, err = mc.readPacket()
|
|
|
if err != nil {
|
|
|
@@ -403,34 +382,34 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
|
|
|
|
|
|
// EOF Packet
|
|
|
if data[0] == 254 && len(data) == 5 {
|
|
|
- if len(columns) != count {
|
|
|
+ if i != count {
|
|
|
err = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// Catalog
|
|
|
- pos, err = readAndDropLengthEnodedString(data)
|
|
|
+ pos, err = skipLengthEnodedString(data)
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// Database [len coded string]
|
|
|
- n, err = readAndDropLengthEnodedString(data[pos:])
|
|
|
+ n, err = skipLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
pos += n
|
|
|
|
|
|
// Table [len coded string]
|
|
|
- n, err = readAndDropLengthEnodedString(data[pos:])
|
|
|
+ n, err = skipLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
pos += n
|
|
|
|
|
|
// Original table [len coded string]
|
|
|
- n, err = readAndDropLengthEnodedString(data[pos:])
|
|
|
+ n, err = skipLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
@@ -441,10 +420,11 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
+ columns[i].name = string(name)
|
|
|
pos += n
|
|
|
|
|
|
// Original name [len coded string]
|
|
|
- n, err = readAndDropLengthEnodedString(data[pos:])
|
|
|
+ n, err = skipLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
@@ -455,11 +435,11 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
|
|
|
pos += n + 1 + 2 + 4
|
|
|
|
|
|
// Field type [byte]
|
|
|
- fieldType := FieldType(data[pos])
|
|
|
+ columns[i].fieldType = FieldType(data[pos])
|
|
|
pos++
|
|
|
|
|
|
// Flags [16 bit uint]
|
|
|
- flags := FieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
+ columns[i].flags = FieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
//pos += 2
|
|
|
|
|
|
// Decimals [8 bit uint]
|
|
|
@@ -470,13 +450,14 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
|
|
|
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
|
|
|
//}
|
|
|
|
|
|
- columns = append(columns, mysqlField{name: string(name), fieldType: fieldType, flags: flags})
|
|
|
+ i++
|
|
|
}
|
|
|
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// Read Packets as Field Packets until EOF-Packet or an Error appears
|
|
|
+// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow
|
|
|
func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
|
|
|
data, err := rows.mc.readPacket()
|
|
|
if err != nil {
|
|
|
@@ -526,34 +507,8 @@ func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
|
|
|
* Prepared Statements *
|
|
|
******************************************************************************/
|
|
|
|
|
|
-/* Prepare Result Packets
|
|
|
- Type Of Result Packet Hexadecimal Value Of First Byte (field_count)
|
|
|
- --------------------- ---------------------------------------------
|
|
|
-
|
|
|
- Prepare OK Packet 00
|
|
|
- Error Packet ff
|
|
|
-
|
|
|
-Prepare OK Packet
|
|
|
- Bytes Name
|
|
|
- ----- ----
|
|
|
- 1 0 - marker for OK packet
|
|
|
- 4 statement_handler_id
|
|
|
- 2 number of columns in result set
|
|
|
- 2 number of parameters in query
|
|
|
- 1 filler (always 0)
|
|
|
- 2 warning count
|
|
|
-
|
|
|
- It is made up of:
|
|
|
-
|
|
|
- a PREPARE_OK packet
|
|
|
- if "number of parameters" > 0
|
|
|
- (field packets) as in a Result Set Header Packet
|
|
|
- (EOF packet)
|
|
|
- if "number of columns" > 0
|
|
|
- (field packets) as in a Result Set Header Packet
|
|
|
- (EOF packet)
|
|
|
-
|
|
|
-*/
|
|
|
+// Prepare Result Packets
|
|
|
+// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
|
|
|
func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
|
|
|
data, err := stmt.mc.readPacket()
|
|
|
if err != nil {
|
|
|
@@ -563,12 +518,14 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
|
|
|
// Position
|
|
|
pos := 0
|
|
|
|
|
|
- if data[pos] != 0 {
|
|
|
+ // packet marker [1 byte]
|
|
|
+ if data[pos] != 0 { // not OK (0) ?
|
|
|
err = stmt.mc.handleErrorPacket(data)
|
|
|
return
|
|
|
}
|
|
|
pos++
|
|
|
|
|
|
+ // statement id [4 bytes]
|
|
|
stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
|
|
|
pos += 4
|
|
|
|
|
|
@@ -586,103 +543,85 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-/* Command Packet
|
|
|
-Bytes Name
|
|
|
------ ----
|
|
|
-1 code
|
|
|
-4 statement_id
|
|
|
-1 flags
|
|
|
-4 iteration_count
|
|
|
- if param_count > 0:
|
|
|
-(param_count+7)/8 null_bit_map
|
|
|
-1 new_parameter_bound_flag
|
|
|
- if new_params_bound == 1:
|
|
|
-n*2 type of parameters
|
|
|
-n values for the parameters
|
|
|
-*/
|
|
|
-func (stmt *mysqlStmt) buildExecutePacket(args []driver.Value) error {
|
|
|
- argsLen := len(args)
|
|
|
- if argsLen != stmt.paramCount {
|
|
|
+// Execute Prepared Statement
|
|
|
+// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute
|
|
|
+func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
|
+ if len(args) != stmt.paramCount {
|
|
|
return fmt.Errorf(
|
|
|
"Arguments count mismatch (Got: %d Has: %d",
|
|
|
- argsLen,
|
|
|
+ len(args),
|
|
|
stmt.paramCount)
|
|
|
}
|
|
|
|
|
|
// Reset packet-sequence
|
|
|
stmt.mc.sequence = 0
|
|
|
|
|
|
- pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (argsLen << 1)
|
|
|
- paramValues := make([][]byte, 0, argsLen)
|
|
|
- paramTypes := make([]byte, 0, (argsLen << 1))
|
|
|
+ pktLen := 1 + 4 + 1 + 4 + ((stmt.paramCount + 7) >> 3) + 1 + (stmt.paramCount << 1)
|
|
|
+ paramValues := make([][]byte, stmt.paramCount)
|
|
|
+ paramTypes := make([]byte, (stmt.paramCount << 1))
|
|
|
bitMask := uint64(0)
|
|
|
- var i, valLen int
|
|
|
+ var i int
|
|
|
+
|
|
|
for i = range args {
|
|
|
- // build nullBitMap
|
|
|
+ // build NULL-bitmap
|
|
|
if args[i] == nil {
|
|
|
bitMask += 1 << uint(i)
|
|
|
+ paramTypes[i<<1] = byte(FIELD_TYPE_NULL)
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
// cache types and values
|
|
|
switch args[i].(type) {
|
|
|
- case nil:
|
|
|
- paramTypes = append(paramTypes, []byte{
|
|
|
- byte(FIELD_TYPE_NULL),
|
|
|
- 0x0}...)
|
|
|
- continue
|
|
|
-
|
|
|
case int64:
|
|
|
- paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_LONGLONG), 0x0}...)
|
|
|
- val := uint64ToBytes(uint64(args[i].(int64)))
|
|
|
- pktLen += len(val)
|
|
|
- paramValues = append(paramValues, val)
|
|
|
+ paramTypes[i<<1] = byte(FIELD_TYPE_LONGLONG)
|
|
|
+ paramValues[i] = uint64ToBytes(uint64(args[i].(int64)))
|
|
|
+ pktLen += 8
|
|
|
continue
|
|
|
|
|
|
case float64:
|
|
|
- paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_DOUBLE), 0x0}...)
|
|
|
- val := uint64ToBytes(math.Float64bits(args[i].(float64)))
|
|
|
- pktLen += len(val)
|
|
|
- paramValues = append(paramValues, val)
|
|
|
+ paramTypes[i<<1] = byte(FIELD_TYPE_DOUBLE)
|
|
|
+ paramValues[i] = uint64ToBytes(math.Float64bits(args[i].(float64)))
|
|
|
+ pktLen += 8
|
|
|
continue
|
|
|
|
|
|
case bool:
|
|
|
- paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_TINY), 0x0}...)
|
|
|
+ paramTypes[i<<1] = byte(FIELD_TYPE_TINY)
|
|
|
pktLen++
|
|
|
if args[i].(bool) {
|
|
|
- paramValues = append(paramValues, []byte{byte(1)})
|
|
|
+ paramValues[i] = []byte{0x01}
|
|
|
} else {
|
|
|
- paramValues = append(paramValues, []byte{byte(0)})
|
|
|
+ paramValues[i] = []byte{0x00}
|
|
|
}
|
|
|
continue
|
|
|
|
|
|
case []byte:
|
|
|
- paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
|
|
|
+ paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
|
|
|
val := args[i].([]byte)
|
|
|
- valLen = len(val)
|
|
|
- lcb := lengthEncodedIntegerToBytes(uint64(valLen))
|
|
|
- pktLen += len(lcb) + valLen
|
|
|
- paramValues = append(paramValues, lcb)
|
|
|
- paramValues = append(paramValues, val)
|
|
|
+ paramValues[i] = append(
|
|
|
+ lengthEncodedIntegerToBytes(uint64(len(val))),
|
|
|
+ val...,
|
|
|
+ )
|
|
|
+ pktLen += len(paramValues[i])
|
|
|
continue
|
|
|
|
|
|
case string:
|
|
|
- paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
|
|
|
+ paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
|
|
|
val := []byte(args[i].(string))
|
|
|
- valLen = len(val)
|
|
|
- lcb := lengthEncodedIntegerToBytes(uint64(valLen))
|
|
|
- pktLen += valLen + len(lcb)
|
|
|
- paramValues = append(paramValues, lcb)
|
|
|
- paramValues = append(paramValues, val)
|
|
|
+ paramValues[i] = append(
|
|
|
+ lengthEncodedIntegerToBytes(uint64(len(val))),
|
|
|
+ val...,
|
|
|
+ )
|
|
|
+ pktLen += len(paramValues[i])
|
|
|
continue
|
|
|
|
|
|
case time.Time:
|
|
|
- paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
|
|
|
+ paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
|
|
|
val := []byte(args[i].(time.Time).Format(TIME_FORMAT))
|
|
|
- valLen = len(val)
|
|
|
- lcb := lengthEncodedIntegerToBytes(uint64(valLen))
|
|
|
- pktLen += valLen + len(lcb)
|
|
|
- paramValues = append(paramValues, lcb)
|
|
|
- paramValues = append(paramValues, val)
|
|
|
+ paramValues[i] = append(
|
|
|
+ lengthEncodedIntegerToBytes(uint64(len(val))),
|
|
|
+ val...,
|
|
|
+ )
|
|
|
+ pktLen += len(paramValues[i])
|
|
|
continue
|
|
|
|
|
|
default:
|
|
|
@@ -690,44 +629,51 @@ func (stmt *mysqlStmt) buildExecutePacket(args []driver.Value) error {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- data := make([]byte, 0, pktLen+4)
|
|
|
+ data := make([]byte, pktLen+4)
|
|
|
|
|
|
- // Add the packet header
|
|
|
- data = append(data, uint24ToBytes(uint32(pktLen))...)
|
|
|
- data = append(data, stmt.mc.sequence)
|
|
|
+ // packet header [4 bytes]
|
|
|
+ data[0] = byte(pktLen)
|
|
|
+ data[1] = byte(pktLen >> 8)
|
|
|
+ data[2] = byte(pktLen >> 16)
|
|
|
+ data[3] = stmt.mc.sequence
|
|
|
|
|
|
- // code [1 byte]
|
|
|
- data = append(data, byte(COM_STMT_EXECUTE))
|
|
|
+ // command [1 byte]
|
|
|
+ data[4] = byte(COM_STMT_EXECUTE)
|
|
|
|
|
|
// statement_id [4 bytes]
|
|
|
- data = append(data, uint32ToBytes(stmt.id)...)
|
|
|
+ data[5] = byte(stmt.id)
|
|
|
+ data[6] = byte(stmt.id >> 8)
|
|
|
+ data[7] = byte(stmt.id >> 16)
|
|
|
+ data[8] = byte(stmt.id >> 24)
|
|
|
|
|
|
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
|
|
|
- data = append(data, byte(0))
|
|
|
+ //data[9] = 0x00
|
|
|
|
|
|
- // iteration_count [4 bytes]
|
|
|
- data = append(data, uint32ToBytes(1)...)
|
|
|
+ // iteration_count (uint32(1)) [4 bytes]
|
|
|
+ data[10] = 0x01
|
|
|
+ //data[11] = 0x00
|
|
|
+ //data[12] = 0x00
|
|
|
+ //data[13] = 0x00
|
|
|
|
|
|
- // append nullBitMap [(param_count+7)/8 bytes]
|
|
|
if stmt.paramCount > 0 {
|
|
|
+ // NULL-bitmap [(param_count+7)/8 bytes]
|
|
|
+ pos := 14 + ((stmt.paramCount + 7) >> 3)
|
|
|
// Convert bitMask to bytes
|
|
|
- nullBitMap := make([]byte, (stmt.paramCount+7)/8)
|
|
|
- for i = 0; i < len(nullBitMap); i++ {
|
|
|
- nullBitMap[i] = byte(bitMask >> uint(i*8))
|
|
|
+ for i = 14; i < pos; i++ {
|
|
|
+ data[i] = byte(bitMask >> uint(i<<3))
|
|
|
}
|
|
|
|
|
|
- data = append(data, nullBitMap...)
|
|
|
- }
|
|
|
-
|
|
|
- // newParameterBoundFlag 1 [1 byte]
|
|
|
- data = append(data, byte(1))
|
|
|
+ // newParameterBoundFlag 1 [1 byte]
|
|
|
+ data[pos] = 0x01
|
|
|
+ pos++
|
|
|
|
|
|
- // type of parameters [n*2 byte]
|
|
|
- data = append(data, paramTypes...)
|
|
|
+ // type of parameters [param_count*2 byte]
|
|
|
+ pos += copy(data[pos:], paramTypes)
|
|
|
|
|
|
- // values for the parameters [n byte]
|
|
|
- for _, paramValue := range paramValues {
|
|
|
- data = append(data, paramValue...)
|
|
|
+ // values for the parameters [n byte]
|
|
|
+ for i = range paramValues {
|
|
|
+ pos += copy(data[pos:], paramValues[i])
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
return stmt.mc.writePacket(data)
|
|
|
@@ -740,20 +686,28 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // EOF Packet
|
|
|
- if data[0] == 254 && len(data) == 5 {
|
|
|
- return io.EOF
|
|
|
+ // packet header [1 byte]
|
|
|
+ if data[0] != 0x00 {
|
|
|
+ // EOF Packet
|
|
|
+ if data[0] == 254 && len(data) == 5 {
|
|
|
+ return io.EOF
|
|
|
+ } else {
|
|
|
+ // Error otherwise
|
|
|
+ return rc.mc.handleErrorPacket(data)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- // BinaryRowSet Packet
|
|
|
+ // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
|
|
|
pos := 1 + (len(dest)+7+2)>>3
|
|
|
nullBitMap := data[1:pos]
|
|
|
|
|
|
+ // values [rest]
|
|
|
var n int
|
|
|
var unsigned bool
|
|
|
for i := range dest {
|
|
|
// Field is NULL
|
|
|
- if (nullBitMap[(i+2)>>3] >> uint((i+2)&7) & 1) == 1 {
|
|
|
+ // (byte >> bit-pos) % 2 == 1
|
|
|
+ if ((nullBitMap[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
|
|
|
dest[i] = nil
|
|
|
continue
|
|
|
}
|
|
|
@@ -830,19 +784,16 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
|
|
|
case FIELD_TYPE_DATE, FIELD_TYPE_NEWDATE:
|
|
|
var num uint64
|
|
|
var isNull bool
|
|
|
- num, isNull, n, err = readLengthEncodedInteger(data[pos:])
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
+ num, isNull, n = readLengthEncodedInteger(data[pos:])
|
|
|
|
|
|
if num == 0 {
|
|
|
if isNull {
|
|
|
dest[i] = nil
|
|
|
- pos++ // n = 1
|
|
|
+ pos++ // always n=1
|
|
|
continue
|
|
|
} else {
|
|
|
dest[i] = []byte("0000-00-00")
|
|
|
- pos++ // n = 1
|
|
|
+ pos += n
|
|
|
continue
|
|
|
}
|
|
|
} else {
|
|
|
@@ -858,19 +809,16 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
|
|
|
case FIELD_TYPE_TIME:
|
|
|
var num uint64
|
|
|
var isNull bool
|
|
|
- num, isNull, n, err = readLengthEncodedInteger(data[pos:])
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
+ num, isNull, n = readLengthEncodedInteger(data[pos:])
|
|
|
|
|
|
if num == 0 {
|
|
|
if isNull {
|
|
|
dest[i] = nil
|
|
|
- pos++ // n = 1
|
|
|
+ pos++ // always n=1
|
|
|
continue
|
|
|
} else {
|
|
|
dest[i] = []byte("00:00:00")
|
|
|
- pos++ // n = 1
|
|
|
+ pos += n
|
|
|
continue
|
|
|
}
|
|
|
}
|
|
|
@@ -912,19 +860,16 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
|
|
|
case FIELD_TYPE_TIMESTAMP, FIELD_TYPE_DATETIME:
|
|
|
var num uint64
|
|
|
var isNull bool
|
|
|
- num, isNull, n, err = readLengthEncodedInteger(data[pos:])
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
+ num, isNull, n = readLengthEncodedInteger(data[pos:])
|
|
|
|
|
|
if num == 0 {
|
|
|
if isNull {
|
|
|
dest[i] = nil
|
|
|
- pos++ // n = 1
|
|
|
+ pos++ // always n=1
|
|
|
continue
|
|
|
} else {
|
|
|
dest[i] = []byte("0000-00-00 00:00:00")
|
|
|
- pos++ // n = 1
|
|
|
+ pos += n
|
|
|
continue
|
|
|
}
|
|
|
}
|
|
|
@@ -939,7 +884,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
|
|
|
data[pos+2],
|
|
|
data[pos+3],
|
|
|
))
|
|
|
- pos += 5
|
|
|
+ pos += 4
|
|
|
continue
|
|
|
case 7:
|
|
|
dest[i] = []byte(fmt.Sprintf(
|