|
|
@@ -55,7 +55,16 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
|
|
|
data = make([]byte, pktLen)
|
|
|
err = mc.buf.read(data)
|
|
|
if err == nil {
|
|
|
- return data, nil
|
|
|
+ if pktLen < maxPacketSize {
|
|
|
+ return data, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // More data
|
|
|
+ var data2 []byte
|
|
|
+ data2, err = mc.readPacket()
|
|
|
+ if err == nil {
|
|
|
+ return append(data, data2...), nil
|
|
|
+ }
|
|
|
}
|
|
|
errLog.Print(err.Error())
|
|
|
return nil, driver.ErrBadConn
|
|
|
@@ -64,19 +73,63 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
|
|
|
// Write packet buffer 'data'
|
|
|
// The packet header must be already included
|
|
|
func (mc *mysqlConn) writePacket(data []byte) error {
|
|
|
- // Write packet
|
|
|
- n, err := mc.netConn.Write(data)
|
|
|
- if err == nil && n == len(data) {
|
|
|
- mc.sequence++
|
|
|
- return nil
|
|
|
+ if len(data)-4 <= mc.maxWriteSize { // Can send data at once
|
|
|
+ // Write packet
|
|
|
+ n, err := mc.netConn.Write(data)
|
|
|
+ if err == nil && n == len(data) {
|
|
|
+ mc.sequence++
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // Handle error
|
|
|
+ if err == nil { // n != len(data)
|
|
|
+ errLog.Print(errMalformPkt.Error())
|
|
|
+ } else {
|
|
|
+ errLog.Print(err.Error())
|
|
|
+ }
|
|
|
+ return driver.ErrBadConn
|
|
|
}
|
|
|
|
|
|
- if err == nil { // n != len(data)
|
|
|
- errLog.Print(errMalformPkt.Error())
|
|
|
- } else {
|
|
|
- errLog.Print(err.Error())
|
|
|
+ // Must split packet
|
|
|
+ return mc.splitPacket(data)
|
|
|
+}
|
|
|
+
|
|
|
+func (mc *mysqlConn) splitPacket(data []byte) (err error) {
|
|
|
+ pktLen := len(data) - 4
|
|
|
+
|
|
|
+ if pktLen > mc.maxPacketAllowed {
|
|
|
+ return errPktTooLarge
|
|
|
+ }
|
|
|
+
|
|
|
+ for pktLen >= maxPacketSize {
|
|
|
+ data[0] = 0xff
|
|
|
+ data[1] = 0xff
|
|
|
+ data[2] = 0xff
|
|
|
+ data[3] = mc.sequence
|
|
|
+
|
|
|
+ // Write packet
|
|
|
+ n, err := mc.netConn.Write(data[:4+maxPacketSize])
|
|
|
+ if err == nil && n == 4+maxPacketSize {
|
|
|
+ mc.sequence++
|
|
|
+ data = data[maxPacketSize:]
|
|
|
+ pktLen -= maxPacketSize
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // Handle error
|
|
|
+ if err == nil { // n != len(data)
|
|
|
+ errLog.Print(errMalformPkt.Error())
|
|
|
+ } else {
|
|
|
+ errLog.Print(err.Error())
|
|
|
+ }
|
|
|
+ return driver.ErrBadConn
|
|
|
}
|
|
|
- return driver.ErrBadConn
|
|
|
+
|
|
|
+ data[0] = byte(pktLen)
|
|
|
+ data[1] = byte(pktLen >> 8)
|
|
|
+ data[2] = byte(pktLen >> 16)
|
|
|
+ data[3] = mc.sequence
|
|
|
+ return mc.writePacket(data)
|
|
|
}
|
|
|
|
|
|
/******************************************************************************
|
|
|
@@ -186,10 +239,10 @@ func (mc *mysqlConn) writeAuthPacket() error {
|
|
|
data[6] = byte(clientFlags >> 16)
|
|
|
data[7] = byte(clientFlags >> 24)
|
|
|
|
|
|
- // MaxPacketSize [32 bit] (1<<24 - 1)
|
|
|
- data[8] = 0xff
|
|
|
- data[9] = 0xff
|
|
|
- data[10] = 0xff
|
|
|
+ // MaxPacketSize [32 bit] (none)
|
|
|
+ //data[8] = 0x00
|
|
|
+ //data[9] = 0x00
|
|
|
+ //data[10] = 0x00
|
|
|
//data[11] = 0x00
|
|
|
|
|
|
// Charset [1 byte]
|
|
|
@@ -223,7 +276,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
|
|
|
* Command Packets *
|
|
|
******************************************************************************/
|
|
|
|
|
|
-func (mc *mysqlConn) writeCommandPacket(command commandType) error {
|
|
|
+func (mc *mysqlConn) writeCommandPacket(command byte) error {
|
|
|
// Reset Packet Sequence
|
|
|
mc.sequence = 0
|
|
|
|
|
|
@@ -233,14 +286,14 @@ func (mc *mysqlConn) writeCommandPacket(command commandType) error {
|
|
|
0x05, // 5 bytes long
|
|
|
0x00,
|
|
|
0x00,
|
|
|
- mc.sequence,
|
|
|
+ 0x00, // mc.sequence
|
|
|
|
|
|
// Add command byte
|
|
|
- byte(command),
|
|
|
+ command,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
-func (mc *mysqlConn) writeCommandPacketStr(command commandType, arg string) error {
|
|
|
+func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
|
|
|
// Reset Packet Sequence
|
|
|
mc.sequence = 0
|
|
|
|
|
|
@@ -251,10 +304,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command commandType, arg string) erro
|
|
|
data[0] = byte(pktLen)
|
|
|
data[1] = byte(pktLen >> 8)
|
|
|
data[2] = byte(pktLen >> 16)
|
|
|
- data[3] = mc.sequence
|
|
|
+ //data[3] = mc.sequence
|
|
|
|
|
|
// Add command byte
|
|
|
- data[4] = byte(command)
|
|
|
+ data[4] = command
|
|
|
|
|
|
// Add arg
|
|
|
copy(data[5:], arg)
|
|
|
@@ -263,7 +316,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command commandType, arg string) erro
|
|
|
return mc.writePacket(data)
|
|
|
}
|
|
|
|
|
|
-func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) error {
|
|
|
+func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
|
|
// Reset Packet Sequence
|
|
|
mc.sequence = 0
|
|
|
|
|
|
@@ -273,10 +326,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) e
|
|
|
0x05, // 5 bytes long
|
|
|
0x00,
|
|
|
0x00,
|
|
|
- mc.sequence,
|
|
|
+ 0x00, // mc.sequence
|
|
|
|
|
|
// Add command byte
|
|
|
- byte(command),
|
|
|
+ command,
|
|
|
|
|
|
// Add arg [32 bit]
|
|
|
byte(arg),
|
|
|
@@ -556,6 +609,54 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data
|
|
|
+func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) {
|
|
|
+ maxLen := stmt.mc.maxPacketAllowed - 1
|
|
|
+ pktLen := maxLen
|
|
|
+ argLen := len(arg)
|
|
|
+ data := make([]byte, 4+1+4+2+argLen)
|
|
|
+ copy(data[4+1+4+2:], arg)
|
|
|
+
|
|
|
+ for argLen > 0 {
|
|
|
+ if 1+4+2+argLen < maxLen {
|
|
|
+ pktLen = 1 + 4 + 2 + argLen
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add the packet header [24bit length + 1 byte sequence]
|
|
|
+ data[0] = byte(pktLen)
|
|
|
+ data[1] = byte(pktLen >> 8)
|
|
|
+ data[2] = byte(pktLen >> 16)
|
|
|
+ data[3] = 0x00 // mc.sequence
|
|
|
+
|
|
|
+ // Add command byte [1 byte]
|
|
|
+ data[4] = comStmtSendLongData
|
|
|
+
|
|
|
+ // Add stmtID [32 bit]
|
|
|
+ data[5] = byte(stmt.id)
|
|
|
+ data[6] = byte(stmt.id >> 8)
|
|
|
+ data[7] = byte(stmt.id >> 16)
|
|
|
+ data[8] = byte(stmt.id >> 24)
|
|
|
+
|
|
|
+ // Add paramID [16 bit]
|
|
|
+ data[9] = byte(paramID)
|
|
|
+ data[10] = byte(paramID >> 8)
|
|
|
+
|
|
|
+ // Send CMD packet
|
|
|
+ err = stmt.mc.writePacket(data[:4+pktLen])
|
|
|
+ if err == nil {
|
|
|
+ argLen -= pktLen - (1 + 4 + 2)
|
|
|
+ data = data[pktLen-(1+4+2):]
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ return err
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ // Reset Packet Sequence
|
|
|
+ stmt.mc.sequence = 0
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
// Execute Prepared Statement
|
|
|
// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute
|
|
|
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
|
@@ -609,21 +710,37 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
|
|
|
|
case []byte:
|
|
|
paramTypes[i<<1] = fieldTypeString
|
|
|
- paramValues[i] = append(
|
|
|
- lengthEncodedIntegerToBytes(uint64(len(v))),
|
|
|
- v...,
|
|
|
- )
|
|
|
- pktLen += len(paramValues[i])
|
|
|
- continue
|
|
|
+ if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 {
|
|
|
+ paramValues[i] = append(
|
|
|
+ lengthEncodedIntegerToBytes(uint64(len(v))),
|
|
|
+ v...,
|
|
|
+ )
|
|
|
+ pktLen += len(paramValues[i])
|
|
|
+ continue
|
|
|
+ } else {
|
|
|
+ err := stmt.writeCommandLongData(i, v)
|
|
|
+ if err == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ return err
|
|
|
+ }
|
|
|
|
|
|
case string:
|
|
|
paramTypes[i<<1] = fieldTypeString
|
|
|
- paramValues[i] = append(
|
|
|
- lengthEncodedIntegerToBytes(uint64(len(v))),
|
|
|
- []byte(v)...,
|
|
|
- )
|
|
|
- pktLen += len(paramValues[i])
|
|
|
- continue
|
|
|
+ if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 {
|
|
|
+ paramValues[i] = append(
|
|
|
+ lengthEncodedIntegerToBytes(uint64(len(v))),
|
|
|
+ []byte(v)...,
|
|
|
+ )
|
|
|
+ pktLen += len(paramValues[i])
|
|
|
+ continue
|
|
|
+ } else {
|
|
|
+ err := stmt.writeCommandLongData(i, []byte(v))
|
|
|
+ if err == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ return err
|
|
|
+ }
|
|
|
|
|
|
case time.Time:
|
|
|
paramTypes[i<<1] = fieldTypeString
|
|
|
@@ -649,7 +766,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|
|
data[3] = stmt.mc.sequence
|
|
|
|
|
|
// command [1 byte]
|
|
|
- data[4] = byte(comStmtExecute)
|
|
|
+ data[4] = comStmtExecute
|
|
|
|
|
|
// statement_id [4 bytes]
|
|
|
data[5] = byte(stmt.id)
|