|
|
@@ -11,6 +11,7 @@ package mysql
|
|
|
|
|
|
import (
|
|
|
"database/sql/driver"
|
|
|
+ "encoding/binary"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
@@ -22,26 +23,28 @@ import (
|
|
|
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
|
|
|
|
|
// Read packet to buffer 'data'
|
|
|
-func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|
|
- // Packet Length
|
|
|
- pktLen, err := mc.readNumber(3)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+func (mc *mysqlConn) readPacket() (data []byte, err error) {
|
|
|
+ // Read header
|
|
|
+ data = make([]byte, 4)
|
|
|
+ var n, add int
|
|
|
+ for err == nil && n < 4 {
|
|
|
+ add, err = mc.bufReader.Read(data[n:])
|
|
|
+ n += add
|
|
|
}
|
|
|
|
|
|
- if int(pktLen) == 0 {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ // Packet Length
|
|
|
+ var pktLen uint32
|
|
|
+ pktLen |= uint32(data[0])
|
|
|
+ pktLen |= uint32(data[1]) << 8
|
|
|
+ pktLen |= uint32(data[2]) << 16
|
|
|
|
|
|
- // Packet Number
|
|
|
- pktSeq, err := mc.readNumber(1)
|
|
|
- if err != nil {
|
|
|
+ if pktLen == 0 {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// Check Packet Sync
|
|
|
- if uint8(pktSeq) != mc.sequence {
|
|
|
- if uint8(pktSeq) > mc.sequence {
|
|
|
+ if data[3] != mc.sequence {
|
|
|
+ if data[3] > mc.sequence {
|
|
|
err = errors.New("Commands out of sync. Did you run multiple statements at once?")
|
|
|
} else {
|
|
|
err = errors.New("Commands out of sync; you can't run this command now")
|
|
|
@@ -51,8 +54,8 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|
|
mc.sequence++
|
|
|
|
|
|
// Read rest of packet
|
|
|
- data := make([]byte, pktLen)
|
|
|
- var n, add int
|
|
|
+ data = make([]byte, pktLen)
|
|
|
+ n = 0
|
|
|
for err == nil && n < int(pktLen) {
|
|
|
add, err = mc.bufReader.Read(data[n:])
|
|
|
n += add
|
|
|
@@ -68,32 +71,6 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|
|
return data, err
|
|
|
}
|
|
|
|
|
|
-// Read n bytes long number num
|
|
|
-func (mc *mysqlConn) readNumber(nr uint8) (uint64, error) {
|
|
|
- // Read bytes into array
|
|
|
- buf := make([]byte, nr)
|
|
|
- var n, add int
|
|
|
- var err error
|
|
|
- for err == nil && n < int(nr) {
|
|
|
- add, err = mc.bufReader.Read(buf[n:])
|
|
|
- n += add
|
|
|
- }
|
|
|
- if err != nil || n < int(nr) {
|
|
|
- if err == nil {
|
|
|
- err = fmt.Errorf("Length of read data (%d) does not match header length (%d)", n, nr)
|
|
|
- }
|
|
|
- errLog.Print(err)
|
|
|
- return 0, driver.ErrBadConn
|
|
|
- }
|
|
|
-
|
|
|
- // Convert to uint64
|
|
|
- var num uint64 = 0
|
|
|
- for i := uint8(0); i < nr; i++ {
|
|
|
- num |= uint64(buf[i]) << (i * 8)
|
|
|
- }
|
|
|
- return num, err
|
|
|
-}
|
|
|
-
|
|
|
func (mc *mysqlConn) writePacket(data *[]byte) error {
|
|
|
// Write packet
|
|
|
n, err := mc.netConn.Write(*data)
|
|
|
@@ -160,7 +137,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
|
|
|
pos += len(slice) + 1
|
|
|
|
|
|
// Thread id [32 bit uint]
|
|
|
- mc.server.threadID = bytesToUint32(data[pos : pos+4])
|
|
|
+ mc.server.threadID = binary.LittleEndian.Uint32(data[pos : pos+4])
|
|
|
pos += 4
|
|
|
|
|
|
// First part of scramble buffer [8 bytes]
|
|
|
@@ -169,7 +146,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
|
|
|
pos += 9
|
|
|
|
|
|
// Server capabilities [16 bit uint]
|
|
|
- mc.server.flags = ClientFlag(bytesToUint16(data[pos : pos+2]))
|
|
|
+ mc.server.flags = ClientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
if mc.server.flags&CLIENT_PROTOCOL_41 == 0 {
|
|
|
err = errors.New("MySQL-Server does not support required Protocol 41+")
|
|
|
}
|
|
|
@@ -368,7 +345,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
|
|
pos := 1
|
|
|
|
|
|
// Error Number [16 bit uint]
|
|
|
- errno := bytesToUint16(data[pos : pos+2])
|
|
|
+ errno := binary.LittleEndian.Uint16(data[pos : pos+2])
|
|
|
pos += 2
|
|
|
|
|
|
// SQL State [# + 5bytes string]
|
|
|
@@ -403,14 +380,14 @@ func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
|
|
|
var n int
|
|
|
|
|
|
// Affected rows [Length Coded Binary]
|
|
|
- mc.affectedRows, n, err = bytesToLengthCodedBinary(data[pos:])
|
|
|
+ mc.affectedRows, _, n, err = bytesToLengthEncodedInteger(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
pos += n
|
|
|
|
|
|
// Insert id [Length Coded Binary]
|
|
|
- mc.insertId, n, err = bytesToLengthCodedBinary(data[pos:])
|
|
|
+ mc.insertId, _, n, err = bytesToLengthEncodedInteger(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
@@ -449,7 +426,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (fieldCount int, err error) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- num, n, err := bytesToLengthCodedBinary(data)
|
|
|
+ num, _, n, err := bytesToLengthEncodedInteger(data)
|
|
|
if err != nil || (n-len(data)) != 0 {
|
|
|
err = errors.New("Malformed Packet")
|
|
|
return
|
|
|
@@ -460,8 +437,10 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (fieldCount int, err error) {
|
|
|
}
|
|
|
|
|
|
// Read Packets as Field Packets until EOF-Packet or an Error appears
|
|
|
-func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, err error) {
|
|
|
+func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
|
|
|
var data []byte
|
|
|
+ var pos, n int
|
|
|
+ var name []byte
|
|
|
|
|
|
for {
|
|
|
data, err = mc.readPacket()
|
|
|
@@ -471,59 +450,51 @@ func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, err error) {
|
|
|
|
|
|
// EOF Packet
|
|
|
if data[0] == 254 && len(data) == 5 {
|
|
|
- if len(columns) != n {
|
|
|
- err = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", n, len(columns))
|
|
|
+ if len(columns) != count {
|
|
|
+ err = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- var pos, n int
|
|
|
- var name *[]byte
|
|
|
- //var catalog, database, table, orgTable, name, orgName []byte
|
|
|
- //var defaultVal uint64
|
|
|
+ pos = 0
|
|
|
|
|
|
// Catalog
|
|
|
- //catalog, n, _, err = readLengthCodedBinary(data)
|
|
|
- n, err = readAndDropLengthCodedBinary(data)
|
|
|
+ n, err = readAndDropLengthEnodedString(data)
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
pos += n
|
|
|
|
|
|
// Database [len coded string]
|
|
|
- //database, n, _, err = readLengthCodedBinary(data[pos:])
|
|
|
- n, err = readAndDropLengthCodedBinary(data[pos:])
|
|
|
+ n, err = readAndDropLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
pos += n
|
|
|
|
|
|
// Table [len coded string]
|
|
|
- //table, n, _, err = readLengthCodedBinary(data[pos:])
|
|
|
- n, err = readAndDropLengthCodedBinary(data[pos:])
|
|
|
+ n, err = readAndDropLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
pos += n
|
|
|
|
|
|
// Original table [len coded string]
|
|
|
- //orgTable, n, _, err = readLengthCodedBinary(data[pos:])
|
|
|
- n, err = readAndDropLengthCodedBinary(data[pos:])
|
|
|
+ n, err = readAndDropLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
pos += n
|
|
|
|
|
|
// Name [len coded string]
|
|
|
- name, n, _, err = readLengthCodedBinary(data[pos:])
|
|
|
+ name, _, n, err = readLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
pos += n
|
|
|
|
|
|
// Original name [len coded string]
|
|
|
- //orgName, n, _, err = readLengthCodedBinary(data[pos:])
|
|
|
- n, err = readAndDropLengthCodedBinary(data[pos:])
|
|
|
+ n, err = readAndDropLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
@@ -533,11 +504,9 @@ func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, err error) {
|
|
|
pos++
|
|
|
|
|
|
// Charset [16 bit uint]
|
|
|
- //charsetNumber := bytesToUint16(data[pos : pos+2])
|
|
|
pos += 2
|
|
|
|
|
|
// Length [32 bit uint]
|
|
|
- //length := bytesToUint32(data[pos : pos+4])
|
|
|
pos += 4
|
|
|
|
|
|
// Field type [byte]
|
|
|
@@ -545,11 +514,10 @@ func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, err error) {
|
|
|
pos++
|
|
|
|
|
|
// Flags [16 bit uint]
|
|
|
- flags := FieldFlag(bytesToUint16(data[pos : pos+2]))
|
|
|
+ flags := FieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
//pos += 2
|
|
|
|
|
|
// Decimals [8 bit uint]
|
|
|
- //decimals := data[pos]
|
|
|
//pos++
|
|
|
|
|
|
// Default value [len coded binary]
|
|
|
@@ -557,7 +525,7 @@ func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, err error) {
|
|
|
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
|
|
|
//}
|
|
|
|
|
|
- columns = append(columns, mysqlField{name: string(*name), fieldType: fieldType, flags: flags})
|
|
|
+ columns = append(columns, mysqlField{name: string(name), fieldType: fieldType, flags: flags})
|
|
|
}
|
|
|
|
|
|
return
|
|
|
@@ -577,24 +545,15 @@ func (rows *mysqlRows) readRow(dest *[]driver.Value) (err error) {
|
|
|
|
|
|
// RowSet Packet
|
|
|
var n int
|
|
|
- var isNull bool
|
|
|
- var val *[]byte
|
|
|
columnsCount := len(*dest)
|
|
|
pos := 0
|
|
|
|
|
|
for i := 0; i < columnsCount; i++ {
|
|
|
// Read bytes and convert to string
|
|
|
- val, n, isNull, err = readLengthCodedBinary(data[pos:])
|
|
|
+ (*dest)[i], _, n, err = readLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
-
|
|
|
- // nil if field is NULL
|
|
|
- if isNull {
|
|
|
- (*dest)[i] = nil
|
|
|
- } else {
|
|
|
- (*dest)[i] = *val
|
|
|
- }
|
|
|
pos += n
|
|
|
}
|
|
|
|
|
|
@@ -668,15 +627,15 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
|
|
|
}
|
|
|
pos++
|
|
|
|
|
|
- stmt.id = bytesToUint32(data[pos : pos+4])
|
|
|
+ stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
|
|
|
pos += 4
|
|
|
|
|
|
// Column count [16 bit uint]
|
|
|
- columnCount = bytesToUint16(data[pos : pos+2])
|
|
|
+ columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
|
|
|
pos += 2
|
|
|
|
|
|
// Param count [16 bit uint]
|
|
|
- stmt.paramCount = int(bytesToUint16(data[pos : pos+2]))
|
|
|
+ stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
pos += 2
|
|
|
|
|
|
// Warning count [16 bit uint]
|
|
|
@@ -735,7 +694,7 @@ func (stmt *mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
|
|
|
paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
|
|
|
val := (*args)[i].([]byte)
|
|
|
valLen = len(val)
|
|
|
- lcb := lengthCodedBinaryToBytes(uint64(valLen))
|
|
|
+ lcb := lengthEncodedIntegerToBytes(uint64(valLen))
|
|
|
pktLen += len(lcb) + valLen
|
|
|
paramValues = append(paramValues, lcb)
|
|
|
paramValues = append(paramValues, val)
|
|
|
@@ -778,7 +737,7 @@ func (stmt *mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
|
|
|
paramTypes = append(paramTypes, []byte{byte(FIELD_TYPE_STRING), 0x0}...)
|
|
|
val := []byte(pv.String())
|
|
|
valLen = len(val)
|
|
|
- lcb := lengthCodedBinaryToBytes(uint64(valLen))
|
|
|
+ lcb := lengthEncodedIntegerToBytes(uint64(valLen))
|
|
|
pktLen += valLen + len(lcb)
|
|
|
paramValues = append(paramValues, lcb)
|
|
|
paramValues = append(paramValues, val)
|
|
|
@@ -848,13 +807,13 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
|
|
|
pos++
|
|
|
|
|
|
// BinaryRowSet Packet
|
|
|
- columnsCount := len(rc.columns)
|
|
|
+ columnsCount := len(*dest)
|
|
|
|
|
|
nullBitMap := data[pos : pos+(columnsCount+7+2)>>3]
|
|
|
pos += (columnsCount + 7 + 2) >> 3
|
|
|
|
|
|
var n int
|
|
|
- var unsigned, isNull bool
|
|
|
+ var unsigned bool
|
|
|
for i := 0; i < columnsCount; i++ {
|
|
|
// Field is NULL
|
|
|
if (nullBitMap[(i+2)>>3] >> uint((i+2)&7) & 1) == 1 {
|
|
|
@@ -872,42 +831,42 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
|
|
|
// Numeric Typs
|
|
|
case FIELD_TYPE_TINY:
|
|
|
if unsigned {
|
|
|
- (*dest)[i] = uintToByteStr(uint64(byteToUint8(data[pos])))
|
|
|
+ (*dest)[i] = uint64(data[pos])
|
|
|
} else {
|
|
|
- (*dest)[i] = intToByteStr(int64(int8(byteToUint8(data[pos]))))
|
|
|
+ (*dest)[i] = int64(int8(data[pos]))
|
|
|
}
|
|
|
pos++
|
|
|
|
|
|
case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
|
|
|
if unsigned {
|
|
|
- (*dest)[i] = uintToByteStr(uint64(bytesToUint16(data[pos : pos+2])))
|
|
|
+ (*dest)[i] = uint64(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
} else {
|
|
|
- (*dest)[i] = intToByteStr(int64(int16(bytesToUint16(data[pos : pos+2]))))
|
|
|
+ (*dest)[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
|
|
|
}
|
|
|
pos += 2
|
|
|
|
|
|
case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
|
|
|
if unsigned {
|
|
|
- (*dest)[i] = uintToByteStr(uint64(bytesToUint32(data[pos : pos+4])))
|
|
|
+ (*dest)[i] = uint64(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
|
|
} else {
|
|
|
- (*dest)[i] = intToByteStr(int64(int32(bytesToUint32(data[pos : pos+4]))))
|
|
|
+ (*dest)[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
|
|
}
|
|
|
pos += 4
|
|
|
|
|
|
case FIELD_TYPE_LONGLONG:
|
|
|
if unsigned {
|
|
|
- (*dest)[i] = uintToByteStr(bytesToUint64(data[pos : pos+8]))
|
|
|
+ (*dest)[i] = binary.LittleEndian.Uint64(data[pos : pos+8])
|
|
|
} else {
|
|
|
- (*dest)[i] = intToByteStr(int64(bytesToUint64(data[pos : pos+8])))
|
|
|
+ (*dest)[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
|
|
|
}
|
|
|
pos += 8
|
|
|
|
|
|
case FIELD_TYPE_FLOAT:
|
|
|
- (*dest)[i] = float32ToByteStr(bytesToFloat32(data[pos : pos+4]))
|
|
|
+ (*dest)[i] = bytesToFloat32(data[pos : pos+4])
|
|
|
pos += 4
|
|
|
|
|
|
case FIELD_TYPE_DOUBLE:
|
|
|
- (*dest)[i] = float64ToByteStr(bytesToFloat64(data[pos : pos+8]))
|
|
|
+ (*dest)[i] = bytesToFloat64(data[pos : pos+8])
|
|
|
pos += 8
|
|
|
|
|
|
// Length coded Binary Strings
|
|
|
@@ -916,24 +875,17 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
|
|
|
FIELD_TYPE_TINY_BLOB, FIELD_TYPE_MEDIUM_BLOB, FIELD_TYPE_LONG_BLOB,
|
|
|
FIELD_TYPE_BLOB, FIELD_TYPE_VAR_STRING, FIELD_TYPE_STRING,
|
|
|
FIELD_TYPE_GEOMETRY:
|
|
|
- var val *[]byte
|
|
|
- val, n, isNull, err = readLengthCodedBinary(data[pos:])
|
|
|
+ (*dest)[i], _, n, err = readLengthEnodedString(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
-
|
|
|
- if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
|
|
|
- (*dest)[i] = nil
|
|
|
- } else {
|
|
|
- (*dest)[i] = *val
|
|
|
- }
|
|
|
-
|
|
|
pos += n
|
|
|
|
|
|
// Date YYYY-MM-DD
|
|
|
case FIELD_TYPE_DATE, FIELD_TYPE_NEWDATE:
|
|
|
var num uint64
|
|
|
- num, n, err = bytesToLengthCodedBinary(data[pos:])
|
|
|
+ // TODO(js): allow nil values
|
|
|
+ num, _, n, err = bytesToLengthEncodedInteger(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
@@ -943,7 +895,7 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
|
|
|
(*dest)[i] = []byte("0000-00-00")
|
|
|
} else {
|
|
|
(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d",
|
|
|
- bytesToUint16(data[pos:pos+2]),
|
|
|
+ binary.LittleEndian.Uint16(data[pos:pos+2]),
|
|
|
data[pos+2],
|
|
|
data[pos+3]))
|
|
|
}
|
|
|
@@ -952,7 +904,8 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
|
|
|
// Time HH:MM:SS
|
|
|
case FIELD_TYPE_TIME:
|
|
|
var num uint64
|
|
|
- num, n, err = bytesToLengthCodedBinary(data[pos:])
|
|
|
+ // TODO(js): allow nil values
|
|
|
+ num, _, n, err = bytesToLengthEncodedInteger(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
@@ -970,7 +923,8 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
|
|
|
// Timestamp YYYY-MM-DD HH:MM:SS
|
|
|
case FIELD_TYPE_TIMESTAMP, FIELD_TYPE_DATETIME:
|
|
|
var num uint64
|
|
|
- num, n, err = bytesToLengthCodedBinary(data[pos:])
|
|
|
+ // TODO(js): allow nil values
|
|
|
+ num, _, n, err = bytesToLengthEncodedInteger(data[pos:])
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
@@ -981,7 +935,7 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
|
|
|
(*dest)[i] = []byte("0000-00-00 00:00:00")
|
|
|
case 4:
|
|
|
(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
|
|
|
- bytesToUint16(data[pos:pos+2]),
|
|
|
+ binary.LittleEndian.Uint16(data[pos:pos+2]),
|
|
|
data[pos+2],
|
|
|
data[pos+3]))
|
|
|
default:
|
|
|
@@ -989,7 +943,7 @@ func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
|
|
|
return fmt.Errorf("Invalid datetime-packet length %d", num)
|
|
|
}
|
|
|
(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
|
|
|
- bytesToUint16(data[pos:pos+2]),
|
|
|
+ binary.LittleEndian.Uint16(data[pos:pos+2]),
|
|
|
data[pos+2],
|
|
|
data[pos+3],
|
|
|
data[pos+4],
|