Przeglądaj źródła

register a io.Reader handle func instead

Julien Schmidt 12 lat temu
rodzic
commit
72059433b1
2 zmienionych plików z 16 dodań i 12 usunięć
  1. 8 6
      driver_test.go
  2. 8 6
      infile.go

+ 8 - 6
driver_test.go

@@ -3,6 +3,7 @@ package mysql
 import (
 	"database/sql"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net"
 	"os"
@@ -657,11 +658,13 @@ func TestLoadData(t *testing.T) {
 	mustExec(t, db, "TRUNCATE TABLE test")
 
 	// Reader
-	file, err = os.Open(file.Name())
-	if err != nil {
-		t.Fatal(err)
-	}
-	RegisterReader("test", file)
+	RegisterReaderHandler("test", func() io.Reader {
+		file, err = os.Open(file.Name())
+		if err != nil {
+			t.Fatal(err)
+		}
+		return file
+	})
 	mustExec(t, db, "LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
 	verifyLoadDataResult(t, db)
 	// negative test
@@ -671,7 +674,6 @@ func TestLoadData(t *testing.T) {
 	} else if err.Error() != "Reader 'doesnotexist' is not registered" {
 		t.Fatal(err.Error())
 	}
-	file.Close()
 
 	mustExec(t, db, "DROP TABLE IF EXISTS test")
 }

+ 8 - 6
infile.go

@@ -19,12 +19,12 @@ import (
 
 var (
 	fileRegister   map[string]bool
-	readerRegister map[string]io.Reader
+	readerRegister map[string]func() io.Reader
 )
 
 func init() {
 	fileRegister = make(map[string]bool)
-	readerRegister = make(map[string]io.Reader)
+	readerRegister = make(map[string]func() io.Reader)
 }
 
 // RegisterLocalFile adds the given file to the file whitelist,
@@ -38,8 +38,8 @@ func RegisterLocalFile(filepath string) {
 // RegisterReader registers a io.Reader so that it can be used by
 // "LOAD DATA LOCAL INFILE Reader::<name>".
 // The use of io.Reader in this context is NOT safe for concurrency!
-func RegisterReader(name string, rdr io.Reader) {
-	readerRegister[name] = rdr
+func RegisterReaderHandler(name string, cb func() io.Reader) {
+	readerRegister[name] = cb
 }
 
 func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
@@ -48,8 +48,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 
 	if strings.HasPrefix(name, "Reader::") { // io.Reader
 		name = name[8:]
-		var inMap bool
-		rdr, inMap = readerRegister[name]
+		cb, inMap := readerRegister[name]
+		if cb != nil {
+			rdr = cb()
+		}
 		if rdr == nil {
 			if !inMap {
 				err = fmt.Errorf("Reader '%s' is not registered", name)