|
|
@@ -19,12 +19,12 @@ import (
|
|
|
|
|
|
var (
|
|
|
fileRegister map[string]bool
|
|
|
- readerRegister map[string]io.Reader
|
|
|
+ readerRegister map[string]func() io.Reader
|
|
|
)
|
|
|
|
|
|
func init() {
|
|
|
fileRegister = make(map[string]bool)
|
|
|
- readerRegister = make(map[string]io.Reader)
|
|
|
+ readerRegister = make(map[string]func() io.Reader)
|
|
|
}
|
|
|
|
|
|
// RegisterLocalFile adds the given file to the file whitelist,
|
|
|
@@ -38,8 +38,8 @@ func RegisterLocalFile(filepath string) {
|
|
|
// RegisterReader registers a io.Reader so that it can be used by
|
|
|
// "LOAD DATA LOCAL INFILE Reader::<name>".
|
|
|
// The use of io.Reader in this context is NOT safe for concurrency!
|
|
|
-func RegisterReader(name string, rdr io.Reader) {
|
|
|
- readerRegister[name] = rdr
|
|
|
+func RegisterReaderHandler(name string, cb func() io.Reader) {
|
|
|
+ readerRegister[name] = cb
|
|
|
}
|
|
|
|
|
|
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
@@ -48,8 +48,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|
|
|
|
|
if strings.HasPrefix(name, "Reader::") { // io.Reader
|
|
|
name = name[8:]
|
|
|
- var inMap bool
|
|
|
- rdr, inMap = readerRegister[name]
|
|
|
+ cb, inMap := readerRegister[name]
|
|
|
+ if cb != nil {
|
|
|
+ rdr = cb()
|
|
|
+ }
|
|
|
if rdr == nil {
|
|
|
if !inMap {
|
|
|
err = fmt.Errorf("Reader '%s' is not registered", name)
|