|
|
@@ -9,7 +9,6 @@
|
|
|
package mysql
|
|
|
|
|
|
import (
|
|
|
- "database/sql/driver"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"os"
|
|
|
@@ -86,6 +85,16 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
rdr = handler()
|
|
|
if rdr != nil {
|
|
|
data = make([]byte, 4+mc.maxWriteSize)
|
|
|
+
|
|
|
+ if rdc, ok := rdr.(io.ReadCloser); ok {
|
|
|
+ defer func() {
|
|
|
+ if err == nil {
|
|
|
+ err = rdc.Close()
|
|
|
+ } else {
|
|
|
+ rdc.Close()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ }
|
|
|
} else {
|
|
|
err = fmt.Errorf("Reader '%s' is <nil>", name)
|
|
|
}
|
|
|
@@ -99,6 +108,15 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
var fi os.FileInfo
|
|
|
|
|
|
if file, err = os.Open(name); err == nil {
|
|
|
+ defer func() {
|
|
|
+ if err == nil {
|
|
|
+ err = file.Close()
|
|
|
+ } else {
|
|
|
+ file.Close()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // get file size
|
|
|
if fi, err = file.Stat(); err == nil {
|
|
|
rdr = file
|
|
|
if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
|
|
|
@@ -115,45 +133,28 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if rdc, ok := rdr.(io.ReadCloser); ok {
|
|
|
- defer func() {
|
|
|
- if err == nil {
|
|
|
- err = rdc.Close()
|
|
|
- } else {
|
|
|
- rdc.Close()
|
|
|
- }
|
|
|
- }()
|
|
|
- }
|
|
|
-
|
|
|
// send content packets
|
|
|
- var ioErr error
|
|
|
if err == nil {
|
|
|
var n int
|
|
|
- for err == nil && ioErr == nil {
|
|
|
+ for err == nil {
|
|
|
n, err = rdr.Read(data[4:])
|
|
|
if n > 0 {
|
|
|
- ioErr = mc.writePacket(data[:4+n])
|
|
|
+ if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
|
|
|
+ return ioErr
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
if err == io.EOF {
|
|
|
err = nil
|
|
|
}
|
|
|
- if ioErr != nil {
|
|
|
- errLog.Print(ioErr.Error())
|
|
|
- return driver.ErrBadConn
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
// send empty packet (termination)
|
|
|
- ioErr = mc.writePacket([]byte{
|
|
|
- 0x00,
|
|
|
- 0x00,
|
|
|
- 0x00,
|
|
|
- mc.sequence,
|
|
|
- })
|
|
|
- if ioErr != nil {
|
|
|
- errLog.Print(ioErr.Error())
|
|
|
- return driver.ErrBadConn
|
|
|
+ if data == nil {
|
|
|
+ data = make([]byte, 4)
|
|
|
+ }
|
|
|
+ if ioErr := mc.writePacket(data[:4]); ioErr != nil {
|
|
|
+ return ioErr
|
|
|
}
|
|
|
|
|
|
// read OK packet
|