|
|
@@ -9,7 +9,6 @@
|
|
|
package mysql
|
|
|
|
|
|
import (
|
|
|
- "database/sql/driver"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"os"
|
|
|
@@ -21,11 +20,6 @@ var (
|
|
|
readerRegister map[string]func() io.Reader
|
|
|
)
|
|
|
|
|
|
-func init() {
|
|
|
- fileRegister = make(map[string]bool)
|
|
|
- readerRegister = make(map[string]func() io.Reader)
|
|
|
-}
|
|
|
-
|
|
|
// RegisterLocalFile adds the given file to the file whitelist,
|
|
|
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
|
|
|
// Alternatively you can allow the use of all local files with
|
|
|
@@ -38,6 +32,11 @@ func init() {
|
|
|
// ...
|
|
|
//
|
|
|
func RegisterLocalFile(filePath string) {
|
|
|
+ // lazy map init
|
|
|
+ if fileRegister == nil {
|
|
|
+ fileRegister = make(map[string]bool)
|
|
|
+ }
|
|
|
+
|
|
|
fileRegister[strings.Trim(filePath, `"`)] = true
|
|
|
}
|
|
|
|
|
|
@@ -62,6 +61,11 @@ func DeregisterLocalFile(filePath string) {
|
|
|
// ...
|
|
|
//
|
|
|
func RegisterReaderHandler(name string, handler func() io.Reader) {
|
|
|
+ // lazy map init
|
|
|
+ if readerRegister == nil {
|
|
|
+ readerRegister = make(map[string]func() io.Reader)
|
|
|
+ }
|
|
|
+
|
|
|
readerRegister[name] = handler
|
|
|
}
|
|
|
|
|
|
@@ -71,71 +75,81 @@ func DeregisterReaderHandler(name string) {
|
|
|
delete(readerRegister, name)
|
|
|
}
|
|
|
|
|
|
+func deferredClose(err *error, closer io.Closer) {
|
|
|
+ closeErr := closer.Close()
|
|
|
+ if *err == nil {
|
|
|
+ *err = closeErr
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
var rdr io.Reader
|
|
|
- data := make([]byte, 4+mc.maxWriteSize)
|
|
|
+ var data []byte
|
|
|
|
|
|
if strings.HasPrefix(name, "Reader::") { // io.Reader
|
|
|
name = name[8:]
|
|
|
- handler, inMap := readerRegister[name]
|
|
|
- if handler != nil {
|
|
|
+ if handler, inMap := readerRegister[name]; inMap {
|
|
|
rdr = handler()
|
|
|
- }
|
|
|
- if rdr == nil {
|
|
|
- if !inMap {
|
|
|
- err = fmt.Errorf("Reader '%s' is not registered", name)
|
|
|
+ if rdr != nil {
|
|
|
+ data = make([]byte, 4+mc.maxWriteSize)
|
|
|
+
|
|
|
+ if cl, ok := rdr.(io.Closer); ok {
|
|
|
+ defer deferredClose(&err, cl)
|
|
|
+ }
|
|
|
} else {
|
|
|
err = fmt.Errorf("Reader '%s' is <nil>", name)
|
|
|
}
|
|
|
+ } else {
|
|
|
+ err = fmt.Errorf("Reader '%s' is not registered", name)
|
|
|
}
|
|
|
} else { // File
|
|
|
name = strings.Trim(name, `"`)
|
|
|
if mc.cfg.allowAllFiles || fileRegister[name] {
|
|
|
- rdr, err = os.Open(name)
|
|
|
+ var file *os.File
|
|
|
+ var fi os.FileInfo
|
|
|
+
|
|
|
+ if file, err = os.Open(name); err == nil {
|
|
|
+ defer deferredClose(&err, file)
|
|
|
+
|
|
|
+ // get file size
|
|
|
+ if fi, err = file.Stat(); err == nil {
|
|
|
+ rdr = file
|
|
|
+ if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
|
|
|
+ data = make([]byte, 4+fileSize)
|
|
|
+ } else if fileSize <= mc.maxPacketAllowed {
|
|
|
+ data = make([]byte, 4+mc.maxWriteSize)
|
|
|
+ } else {
|
|
|
+ err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
} else {
|
|
|
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if rdc, ok := rdr.(io.ReadCloser); ok {
|
|
|
- defer func() {
|
|
|
- if err == nil {
|
|
|
- err = rdc.Close()
|
|
|
- } else {
|
|
|
- rdc.Close()
|
|
|
- }
|
|
|
- }()
|
|
|
- }
|
|
|
-
|
|
|
// send content packets
|
|
|
- var ioErr error
|
|
|
if err == nil {
|
|
|
var n int
|
|
|
- for err == nil && ioErr == nil {
|
|
|
+ for err == nil {
|
|
|
n, err = rdr.Read(data[4:])
|
|
|
if n > 0 {
|
|
|
- ioErr = mc.writePacket(data[:4+n])
|
|
|
+ if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
|
|
|
+ return ioErr
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
if err == io.EOF {
|
|
|
err = nil
|
|
|
}
|
|
|
- if ioErr != nil {
|
|
|
- errLog.Print(ioErr.Error())
|
|
|
- return driver.ErrBadConn
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
// send empty packet (termination)
|
|
|
- ioErr = mc.writePacket([]byte{
|
|
|
- 0x00,
|
|
|
- 0x00,
|
|
|
- 0x00,
|
|
|
- mc.sequence,
|
|
|
- })
|
|
|
- if ioErr != nil {
|
|
|
- errLog.Print(ioErr.Error())
|
|
|
- return driver.ErrBadConn
|
|
|
+ if data == nil {
|
|
|
+ data = make([]byte, 4)
|
|
|
+ }
|
|
|
+ if ioErr := mc.writePacket(data[:4]); ioErr != nil {
|
|
|
+ return ioErr
|
|
|
}
|
|
|
|
|
|
// read OK packet
|