Browse Source

Merge pull request #174 from go-sql-driver/infile

infile: refactoring
Julien Schmidt 12 years ago
parent
commit
5d25a76e20
1 changed files with 54 additions and 40 deletions
  1. 54 40
      infile.go

+ 54 - 40
infile.go

@@ -9,7 +9,6 @@
 package mysql
 
 import (
-	"database/sql/driver"
 	"fmt"
 	"io"
 	"os"
@@ -21,11 +20,6 @@ var (
 	readerRegister map[string]func() io.Reader
 )
 
-func init() {
-	fileRegister = make(map[string]bool)
-	readerRegister = make(map[string]func() io.Reader)
-}
-
 // RegisterLocalFile adds the given file to the file whitelist,
 // so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
 // Alternatively you can allow the use of all local files with
@@ -38,6 +32,11 @@ func init() {
 //  ...
 //
 func RegisterLocalFile(filePath string) {
+	// lazy map init
+	if fileRegister == nil {
+		fileRegister = make(map[string]bool)
+	}
+
 	fileRegister[strings.Trim(filePath, `"`)] = true
 }
 
@@ -62,6 +61,11 @@ func DeregisterLocalFile(filePath string) {
 //  ...
 //
 func RegisterReaderHandler(name string, handler func() io.Reader) {
+	// lazy map init
+	if readerRegister == nil {
+		readerRegister = make(map[string]func() io.Reader)
+	}
+
 	readerRegister[name] = handler
 }
 
@@ -71,71 +75,81 @@ func DeregisterReaderHandler(name string) {
 	delete(readerRegister, name)
 }
 
+func deferredClose(err *error, closer io.Closer) {
+	closeErr := closer.Close()
+	if *err == nil {
+		*err = closeErr
+	}
+}
+
 func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 	var rdr io.Reader
-	data := make([]byte, 4+mc.maxWriteSize)
+	var data []byte
 
 	if strings.HasPrefix(name, "Reader::") { // io.Reader
 		name = name[8:]
-		handler, inMap := readerRegister[name]
-		if handler != nil {
+		if handler, inMap := readerRegister[name]; inMap {
 			rdr = handler()
-		}
-		if rdr == nil {
-			if !inMap {
-				err = fmt.Errorf("Reader '%s' is not registered", name)
+			if rdr != nil {
+				data = make([]byte, 4+mc.maxWriteSize)
+
+				if cl, ok := rdr.(io.Closer); ok {
+					defer deferredClose(&err, cl)
+				}
 			} else {
 				err = fmt.Errorf("Reader '%s' is <nil>", name)
 			}
+		} else {
+			err = fmt.Errorf("Reader '%s' is not registered", name)
 		}
 	} else { // File
 		name = strings.Trim(name, `"`)
 		if mc.cfg.allowAllFiles || fileRegister[name] {
-			rdr, err = os.Open(name)
+			var file *os.File
+			var fi os.FileInfo
+
+			if file, err = os.Open(name); err == nil {
+				defer deferredClose(&err, file)
+
+				// get file size
+				if fi, err = file.Stat(); err == nil {
+					rdr = file
+					if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
+						data = make([]byte, 4+fileSize)
+					} else if fileSize <= mc.maxPacketAllowed {
+						data = make([]byte, 4+mc.maxWriteSize)
+					} else {
+						err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed)
+					}
+				}
+			}
 		} else {
 			err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
 		}
 	}
 
-	if rdc, ok := rdr.(io.ReadCloser); ok {
-		defer func() {
-			if err == nil {
-				err = rdc.Close()
-			} else {
-				rdc.Close()
-			}
-		}()
-	}
-
 	// send content packets
-	var ioErr error
 	if err == nil {
 		var n int
-		for err == nil && ioErr == nil {
+		for err == nil {
 			n, err = rdr.Read(data[4:])
 			if n > 0 {
-				ioErr = mc.writePacket(data[:4+n])
+				if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
+					return ioErr
+				}
 			}
 		}
 		if err == io.EOF {
 			err = nil
 		}
-		if ioErr != nil {
-			errLog.Print(ioErr.Error())
-			return driver.ErrBadConn
-		}
 	}
 
 	// send empty packet (termination)
-	ioErr = mc.writePacket([]byte{
-		0x00,
-		0x00,
-		0x00,
-		mc.sequence,
-	})
-	if ioErr != nil {
-		errLog.Print(ioErr.Error())
-		return driver.ErrBadConn
+	if data == nil {
+		data = make([]byte, 4)
+	}
+	if ioErr := mc.writePacket(data[:4]); ioErr != nil {
+		return ioErr
 	}
 
 	// read OK packet