|
|
@@ -35,11 +35,13 @@ func RegisterLocalFile(filepath string) {
|
|
|
fileRegister[filepath] = true
|
|
|
}
|
|
|
|
|
|
-// RegisterReader registers a io.Reader so that it can be used by
|
|
|
-// "LOAD DATA LOCAL INFILE Reader::<name>".
|
|
|
-// The use of io.Reader in this context is NOT safe for concurrency!
|
|
|
-func RegisterReaderHandler(name string, cb func() io.Reader) {
|
|
|
- readerRegister[name] = cb
|
|
|
+// RegisterReaderHandler registers a handler function which is used
|
|
|
+// to receive a io.Reader.
|
|
|
+// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
|
|
|
+// If the handler returns a io.ReadCloser Close() is called when the
|
|
|
+// request is finished.
|
|
|
+func RegisterReaderHandler(name string, handler func() io.Reader) {
|
|
|
+ readerRegister[name] = handler
|
|
|
}
|
|
|
|
|
|
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
@@ -48,9 +50,9 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
|
|
|
if strings.HasPrefix(name, "Reader::") { // io.Reader
|
|
|
name = name[8:]
|
|
|
- cb, inMap := readerRegister[name]
|
|
|
- if cb != nil {
|
|
|
- rdr = cb()
|
|
|
+ handler, inMap := readerRegister[name]
|
|
|
+ if handler != nil {
|
|
|
+ rdr = handler()
|
|
|
}
|
|
|
if rdr == nil {
|
|
|
if !inMap {
|
|
|
@@ -59,19 +61,24 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
err = fmt.Errorf("Reader '%s' is <nil>", name)
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
} else { // File
|
|
|
if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` {
|
|
|
- var file *os.File
|
|
|
- file, err = os.Open(name)
|
|
|
- defer file.Close()
|
|
|
-
|
|
|
- rdr = file
|
|
|
+ rdr, err = os.Open(name)
|
|
|
} else {
|
|
|
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ if rdc, ok := rdr.(io.ReadCloser); ok {
|
|
|
+ defer func() {
|
|
|
+ if err == nil {
|
|
|
+ err = rdc.Close()
|
|
|
+ } else {
|
|
|
+ rdc.Close()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ }
|
|
|
+
|
|
|
// send content packets
|
|
|
var ioErr error
|
|
|
if err == nil {
|