|
|
@@ -29,7 +29,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
|
|
|
data = make([]byte, 4)
|
|
|
err = mc.buf.read(data)
|
|
|
if err != nil {
|
|
|
- errLog.Print(err)
|
|
|
+ errLog.Print(err.Error())
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
|
|
|
@@ -40,7 +40,8 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
|
|
|
pktLen |= uint32(data[2]) << 16
|
|
|
|
|
|
if pktLen == 0 {
|
|
|
- return nil, err
|
|
|
+ errLog.Print(errMalformPkt.Error())
|
|
|
+ return nil, driver.ErrBadConn
|
|
|
}
|
|
|
|
|
|
// Check Packet Sync
|
|
|
@@ -59,7 +60,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
|
|
|
if err == nil {
|
|
|
return data, nil
|
|
|
}
|
|
|
- errLog.Print(err)
|
|
|
+ errLog.Print(err.Error())
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
|
|
|
@@ -74,9 +75,9 @@ func (mc *mysqlConn) writePacket(data []byte) error {
|
|
|
}
|
|
|
|
|
|
if err == nil { // n != len(data)
|
|
|
- errLog.Print(errMalformPkt)
|
|
|
+ errLog.Print(errMalformPkt.Error())
|
|
|
} else {
|
|
|
- errLog.Print(err)
|
|
|
+ errLog.Print(err.Error())
|
|
|
}
|
|
|
return driver.ErrBadConn
|
|
|
}
|
|
|
@@ -103,7 +104,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
|
|
|
|
|
|
// server version [null terminated string]
|
|
|
// connection id [4 bytes]
|
|
|
- pos := 1 + (bytes.IndexByte(data[1:], 0x00) + 1) + 4
|
|
|
+ pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
|
|
|
|
|
|
// first part of scramble buffer [8 bytes]
|
|
|
mc.scrambleBuff = data[pos : pos+8]
|
|
|
@@ -287,45 +288,43 @@ func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) e
|
|
|
// Returns error if Packet is not an 'Result OK'-Packet
|
|
|
func (mc *mysqlConn) readResultOK() error {
|
|
|
data, err := mc.readPacket()
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
- switch data[0] {
|
|
|
- // OK
|
|
|
- case 0:
|
|
|
- mc.handleOkPacket(data)
|
|
|
- return nil
|
|
|
- // EOF, someone is using old_passwords
|
|
|
- case 254:
|
|
|
- return errOldPassword
|
|
|
+ if err == nil {
|
|
|
+ switch data[0] {
|
|
|
+ // OK
|
|
|
+ case 0:
|
|
|
+ mc.handleOkPacket(data)
|
|
|
+ return nil
|
|
|
+ // EOF, someone is using old_passwords
|
|
|
+ case 254:
|
|
|
+ return errOldPassword
|
|
|
+ }
|
|
|
+ // ERROR
|
|
|
+ return mc.handleErrorPacket(data)
|
|
|
}
|
|
|
- // ERROR
|
|
|
- return mc.handleErrorPacket(data)
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
// Result Set Header Packet
|
|
|
// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
|
|
|
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
|
|
|
data, err := mc.readPacket()
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
+ if err == nil {
|
|
|
+ if data[0] == 0 {
|
|
|
+ mc.handleOkPacket(data)
|
|
|
+ return 0, nil
|
|
|
+ } else if data[0] == 255 {
|
|
|
+ return 0, mc.handleErrorPacket(data)
|
|
|
+ }
|
|
|
|
|
|
- if data[0] == 0 {
|
|
|
- mc.handleOkPacket(data)
|
|
|
- return 0, nil
|
|
|
- } else if data[0] == 255 {
|
|
|
- return 0, mc.handleErrorPacket(data)
|
|
|
- }
|
|
|
+ // column count
|
|
|
+ num, _, n := readLengthEncodedInteger(data)
|
|
|
+ if n-len(data) == 0 {
|
|
|
+ return int(num), nil
|
|
|
+ }
|
|
|
|
|
|
- // column count
|
|
|
- num, _, n := readLengthEncodedInteger(data)
|
|
|
- if n-len(data) == 0 {
|
|
|
- return int(num), nil
|
|
|
+ return 0, errMalformPkt
|
|
|
}
|
|
|
-
|
|
|
- return 0, errMalformPkt
|
|
|
+ return 0, err
|
|
|
}
|
|
|
|
|
|
// Error Packet
|
|
|
@@ -487,18 +486,17 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
|
|
|
}
|
|
|
|
|
|
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
|
|
|
-func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
|
|
|
+func (mc *mysqlConn) readUntilEOF() (err error) {
|
|
|
var data []byte
|
|
|
|
|
|
for {
|
|
|
data, err = mc.readPacket()
|
|
|
|
|
|
- // Err or EOF Packet
|
|
|
- if err != nil || (data[0] == 254 && len(data) == 5) {
|
|
|
- return
|
|
|
+ // No Err and no EOF Packet
|
|
|
+ if err == nil && (data[0] != 254 || len(data) != 5) {
|
|
|
+ continue
|
|
|
}
|
|
|
-
|
|
|
- count++
|
|
|
+ return
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
@@ -511,35 +509,32 @@ func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
|
|
|
// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
|
|
|
func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
|
|
|
data, err := stmt.mc.readPacket()
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- // Position
|
|
|
- pos := 0
|
|
|
-
|
|
|
- // packet marker [1 byte]
|
|
|
- if data[pos] != 0 { // not OK (0) ?
|
|
|
- err = stmt.mc.handleErrorPacket(data)
|
|
|
- return
|
|
|
- }
|
|
|
- pos++
|
|
|
+ if err == nil {
|
|
|
+ // Position
|
|
|
+ pos := 0
|
|
|
|
|
|
- // statement id [4 bytes]
|
|
|
- stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
|
|
|
- pos += 4
|
|
|
+ // packet marker [1 byte]
|
|
|
+ if data[pos] != 0 { // not OK (0) ?
|
|
|
+ err = stmt.mc.handleErrorPacket(data)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ pos++
|
|
|
|
|
|
- // Column count [16 bit uint]
|
|
|
- columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
|
|
|
- pos += 2
|
|
|
+ // statement id [4 bytes]
|
|
|
+ stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
|
|
|
+ pos += 4
|
|
|
|
|
|
- // Param count [16 bit uint]
|
|
|
- stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
- pos += 2
|
|
|
+ // Column count [16 bit uint]
|
|
|
+ columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
|
|
|
+ pos += 2
|
|
|
|
|
|
- // Warning count [16 bit uint]
|
|
|
- // bytesToUint16(data[pos : pos+2])
|
|
|
+ // Param count [16 bit uint]
|
|
|
+ stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
+ pos += 2
|
|
|
|
|
|
+ // Warning count [16 bit uint]
|
|
|
+ // bytesToUint16(data[pos : pos+2])
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
|