Przeglądaj źródła

Merge pull request #42 from go-sql-driver/readerHandler

register a io.Reader handle func instead
Julien Schmidt 12 lat temu
rodzic
commit
59ce2f9913
3 zmienionych plików z 34 dodań i 23 usunięć
  1. 2 2
      README.md
  2. 8 6
      driver_test.go
  3. 24 15
      infile.go

+ 2 - 2
README.md

@@ -144,9 +144,9 @@ For this feature you need direct access to the package. Therefore you must chang
 import "github.com/go-sql-driver/mysql"
 ```
 
-Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (reccommended) or the whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` (might be insecure).
+Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (reccommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` (might be insecure).
 
-`io.Reader`s must be registered with `mysql.RegisterReader(name, reader)`. They are available with the filepath `Reader::<name>` then.
+To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then.
 
 See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation")
 

+ 8 - 6
driver_test.go

@@ -3,6 +3,7 @@ package mysql
 import (
 	"database/sql"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net"
 	"os"
@@ -657,11 +658,13 @@ func TestLoadData(t *testing.T) {
 	mustExec(t, db, "TRUNCATE TABLE test")
 
 	// Reader
-	file, err = os.Open(file.Name())
-	if err != nil {
-		t.Fatal(err)
-	}
-	RegisterReader("test", file)
+	RegisterReaderHandler("test", func() io.Reader {
+		file, err = os.Open(file.Name())
+		if err != nil {
+			t.Fatal(err)
+		}
+		return file
+	})
 	mustExec(t, db, "LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
 	verifyLoadDataResult(t, db)
 	// negative test
@@ -671,7 +674,6 @@ func TestLoadData(t *testing.T) {
 	} else if err.Error() != "Reader 'doesnotexist' is not registered" {
 		t.Fatal(err.Error())
 	}
-	file.Close()
 
 	mustExec(t, db, "DROP TABLE IF EXISTS test")
 }

+ 24 - 15
infile.go

@@ -19,12 +19,12 @@ import (
 
 var (
 	fileRegister   map[string]bool
-	readerRegister map[string]io.Reader
+	readerRegister map[string]func() io.Reader
 )
 
 func init() {
 	fileRegister = make(map[string]bool)
-	readerRegister = make(map[string]io.Reader)
+	readerRegister = make(map[string]func() io.Reader)
 }
 
 // RegisterLocalFile adds the given file to the file whitelist,
@@ -35,11 +35,13 @@ func RegisterLocalFile(filepath string) {
 	fileRegister[filepath] = true
 }
 
-// RegisterReader registers a io.Reader so that it can be used by
-// "LOAD DATA LOCAL INFILE Reader::<name>".
-// The use of io.Reader in this context is NOT safe for concurrency!
-func RegisterReader(name string, rdr io.Reader) {
-	readerRegister[name] = rdr
+// RegisterReaderHandler registers a handler function which is used
+// to receive a io.Reader.
+// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
+// If the handler returns a io.ReadCloser Close() is called when the
+// request is finished.
+func RegisterReaderHandler(name string, handler func() io.Reader) {
+	readerRegister[name] = handler
 }
 
 func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
@@ -48,8 +50,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 
 	if strings.HasPrefix(name, "Reader::") { // io.Reader
 		name = name[8:]
-		var inMap bool
-		rdr, inMap = readerRegister[name]
+		handler, inMap := readerRegister[name]
+		if handler != nil {
+			rdr = handler()
+		}
 		if rdr == nil {
 			if !inMap {
 				err = fmt.Errorf("Reader '%s' is not registered", name)
@@ -57,19 +61,24 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 				err = fmt.Errorf("Reader '%s' is <nil>", name)
 			}
 		}
-
 	} else { // File
 		if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` {
-			var file *os.File
-			file, err = os.Open(name)
-			defer file.Close()
-
-			rdr = file
+			rdr, err = os.Open(name)
 		} else {
 			err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
 		}
 	}
 
+	if rdc, ok := rdr.(io.ReadCloser); ok {
+		defer func() {
+			if err == nil {
+				err = rdc.Close()
+			} else {
+				rdc.Close()
+			}
+		}()
+	}
+
 	// send content packets
 	var ioErr error
 	if err == nil {