浏览代码

infile: refactoring

Julien Schmidt 12 年之前
父节点
当前提交
d7e2ac4160
共有 1 个文件被更改,包括 28 次插入27 次删除
  1. 28 27
      infile.go

+ 28 - 27
infile.go

@@ -9,7 +9,6 @@
 package mysql
 
 import (
-	"database/sql/driver"
 	"fmt"
 	"io"
 	"os"
@@ -86,6 +85,16 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 			rdr = handler()
 			if rdr != nil {
 				data = make([]byte, 4+mc.maxWriteSize)
+
+				if rdc, ok := rdr.(io.ReadCloser); ok {
+					defer func() {
+						if err == nil {
+							err = rdc.Close()
+						} else {
+							rdc.Close()
+						}
+					}()
+				}
 			} else {
 				err = fmt.Errorf("Reader '%s' is <nil>", name)
 			}
@@ -99,6 +108,15 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 			var fi os.FileInfo
 
 			if file, err = os.Open(name); err == nil {
+				defer func() {
+					if err == nil {
+						err = file.Close()
+					} else {
+						file.Close()
+					}
+				}()
+
+				// get file size
 				if fi, err = file.Stat(); err == nil {
 					rdr = file
 					if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
@@ -115,45 +133,28 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 		}
 	}
 
-	if rdc, ok := rdr.(io.ReadCloser); ok {
-		defer func() {
-			if err == nil {
-				err = rdc.Close()
-			} else {
-				rdc.Close()
-			}
-		}()
-	}
-
 	// send content packets
-	var ioErr error
 	if err == nil {
 		var n int
-		for err == nil && ioErr == nil {
+		for err == nil {
 			n, err = rdr.Read(data[4:])
 			if n > 0 {
-				ioErr = mc.writePacket(data[:4+n])
+				if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
+					return ioErr
+				}
 			}
 		}
 		if err == io.EOF {
 			err = nil
 		}
-		if ioErr != nil {
-			errLog.Print(ioErr.Error())
-			return driver.ErrBadConn
-		}
 	}
 
 	// send empty packet (termination)
-	ioErr = mc.writePacket([]byte{
-		0x00,
-		0x00,
-		0x00,
-		mc.sequence,
-	})
-	if ioErr != nil {
-		errLog.Print(ioErr.Error())
-		return driver.ErrBadConn
+	if data == nil {
+		data = make([]byte, 4)
+	}
+	if ioErr := mc.writePacket(data[:4]); ioErr != nil {
+		return ioErr
 	}
 
 	// read OK packet