Prechádzať zdrojové kódy

io.ReadCloser + doc

Julien Schmidt 12 rokov pred
rodič
commit
539863490e
2 zmenil súbory, kde vykonal 23 pridanie a 16 odobranie
  1. 2 2
      README.md
  2. 21 14
      infile.go

+ 2 - 2
README.md

@@ -144,9 +144,9 @@ For this feature you need direct access to the package. Therefore you must chang
 import "github.com/go-sql-driver/mysql"
 ```
 
-Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (reccommended) or the whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` (might be insecure).
+Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (reccommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` (might be insecure).
 
-`io.Reader`s must be registered with `mysql.RegisterReader(name, reader)`. They are available with the filepath `Reader::<name>` then.
+To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then.
 
 See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation")
 

+ 21 - 14
infile.go

@@ -35,11 +35,13 @@ func RegisterLocalFile(filepath string) {
 	fileRegister[filepath] = true
 }
 
-// RegisterReader registers a io.Reader so that it can be used by
-// "LOAD DATA LOCAL INFILE Reader::<name>".
-// The use of io.Reader in this context is NOT safe for concurrency!
-func RegisterReaderHandler(name string, cb func() io.Reader) {
-	readerRegister[name] = cb
+// RegisterReaderHandler registers a handler function which is used
+// to receive a io.Reader.
+// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
+// If the handler returns a io.ReadCloser Close() is called when the
+// request is finished.
+func RegisterReaderHandler(name string, handler func() io.Reader) {
+	readerRegister[name] = handler
 }
 
 func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
@@ -48,9 +50,9 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 
 	if strings.HasPrefix(name, "Reader::") { // io.Reader
 		name = name[8:]
-		cb, inMap := readerRegister[name]
-		if cb != nil {
-			rdr = cb()
+		handler, inMap := readerRegister[name]
+		if handler != nil {
+			rdr = handler()
 		}
 		if rdr == nil {
 			if !inMap {
@@ -59,19 +61,24 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 				err = fmt.Errorf("Reader '%s' is <nil>", name)
 			}
 		}
-
 	} else { // File
 		if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` {
-			var file *os.File
-			file, err = os.Open(name)
-			defer file.Close()
-
-			rdr = file
+			rdr, err = os.Open(name)
 		} else {
 			err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
 		}
 	}
 
+	if rdc, ok := rdr.(io.ReadCloser); ok {
+		defer func() {
+			if err == nil {
+				err = rdc.Close()
+			} else {
+				rdc.Close()
+			}
+		}()
+	}
+
 	// send content packets
 	var ioErr error
 	if err == nil {