Explorar el Código

LOAD DATA LOCAL INFILE support

closes #33
Julien Schmidt hace 12 años
padre
commit
5b516b3632
Se han modificado 5 ficheros con 255 adiciones y 19 borrados
  1. 27 12
      README.md
  2. 4 3
      const.go
  3. 99 1
      driver_test.go
  4. 115 0
      infile.go
  5. 10 3
      packets.go

+ 27 - 12
README.md

@@ -13,12 +13,13 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa
   * [Requirements](#requirements)
   * [Installation](#installation)
   * [Usage](#usage)
-  * [DSN (Data Source Name)](#dsn-data-source-name)
-    * [Password](#password)
-    * [Protocol](#protocol)
-    * [Address](#address)
-    * [Parameters](#parameters)
-    * [Examples](#examples)
+    * [DSN (Data Source Name)](#dsn-data-source-name)
+      * [Password](#password)
+      * [Protocol](#protocol)
+      * [Address](#address)
+      * [Parameters](#parameters)
+      * [Examples](#examples)
+    * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) 
   * [Testing / Development](#testing--development)
   * [License](#license)
 
@@ -32,6 +33,7 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa
   * Automatic Connection-Pooling *(by database/sql package)*
   * Supports queries larger than 16MB
   * Intelligent `LONG DATA` handling in prepared statements
+  * Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support
 
 ## Requirements
   * Go 1.0.3 or higher
@@ -62,7 +64,7 @@ All further methods are listed here: http://golang.org/pkg/database/sql
 [Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples").
 
 
-## DSN (Data Source Name)
+### DSN (Data Source Name)
 
 The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets):
 ```
@@ -84,21 +86,21 @@ If you do not want to preselect a database, leave `dbname` empty:
 /
 ```
 
-### Password
+#### Password
 Passwords can consist of any character. Escaping is **not** necessary.
 
-### Protocol
+#### Protocol
 See [net.Dial](http://golang.org/pkg/net/#Dial) for more information which networks are available.
 In general you should use an Unix-socket if available and TCP otherwise for best performance.
 
-### Address
+#### Address
 For TCP and UDP networks, addresses have the form `host:port`.
 If `host` is a literal IPv6 address, it must be enclosed in square brackets.
 The functions [net.JoinHostPort](http://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](http://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form.
 
 For Unix-sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`.
 
-### Parameters
+#### Parameters
 **Parameters are case-sensitive!**
 
 Possible Parameters are:
@@ -113,7 +115,7 @@ All other parameters are interpreted as system variables:
   * `tx_isolation`: *"SET [tx_isolation](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation)='`value`'"*
   * `param`: *"SET `param`=`value`"*
 
-### Examples
+#### Examples
 ```
 user@unix(/path/to/socket)/dbname
 ```
@@ -135,6 +137,19 @@ No Database preselected:
 user:password@/
 ```
 
+### `LOAD DATA LOCAL INFILE` support
+For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
+```go
+import "github.com/go-sql-driver/mysql"
+```
+
+Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (reccommended) or the whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` (might be insecure).
+
+`io.Reader`s must be registered with `mysql.RegisterReader(name, reader)`. They are available with the filepath `Reader::<name>` then.
+
+See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation")
+
+
 ## Testing / Development
 To run the driver tests you may need to adjust the configuration. See [this Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details.
 

+ 4 - 3
const.go

@@ -19,9 +19,10 @@ const (
 // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
 
 const (
-	iOK  byte = 0x00
-	iEOF byte = 0xfe
-	iERR byte = 0xff
+	iOK          byte = 0x00
+	iLocalInFile byte = 0xfb
+	iEOF         byte = 0xfe
+	iERR         byte = 0xff
 )
 
 type clientFlag uint32

+ 99 - 1
driver_test.go

@@ -3,6 +3,7 @@ package mysql
 import (
 	"database/sql"
 	"fmt"
+	"io/ioutil"
 	"net"
 	"os"
 	"strings"
@@ -570,7 +571,6 @@ func TestLongData(t *testing.T) {
 		if rows.Next() {
 			t.Error("LONGBLOB: unexpexted row")
 		}
-		//t.Fatalf("%d %d %d", len(in)+nonDataQueryLen, len(out)+nonDataQueryLen, maxAllowedPacketSize)
 	} else {
 		t.Fatalf("LONGBLOB: no data")
 	}
@@ -578,6 +578,104 @@ func TestLongData(t *testing.T) {
 	mustExec(t, db, "DROP TABLE IF EXISTS test")
 }
 
+func verifyLoadDataResult(t *testing.T, db *sql.DB) {
+	rows, err := db.Query("SELECT * FROM test")
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+
+	i := 0
+	values := [4]string{
+		"a string",
+		"a string containing a \t",
+		"a string containing a \n",
+		"a string containing both \t\n",
+	}
+
+	var id int
+	var value string
+
+	for rows.Next() {
+		i++
+		err = rows.Scan(&id, &value)
+		if err != nil {
+			t.Fatal(err.Error())
+		}
+		if i != id {
+			t.Fatalf("%d != %d", i, id)
+		}
+		if values[i-1] != value {
+			t.Fatalf("%s != %s", values[i-1], value)
+		}
+	}
+	err = rows.Err()
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+
+	if i != 4 {
+		t.Fatalf("Rows count mismatch. Got %d, want 4", i)
+	}
+}
+
+func TestLoadData(t *testing.T) {
+	if !getEnv() {
+		t.Logf("MySQL-Server not running on %s. Skipping TestLoadData", netAddr)
+		return
+	}
+
+	db, err := sql.Open("mysql", dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %v", err)
+	}
+	defer db.Close()
+
+	file, err := ioutil.TempFile("", "gotest")
+	defer os.Remove(file.Name())
+	if err != nil {
+		t.Fatal(err)
+	}
+	file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
+	file.Close()
+
+	mustExec(t, db, "DROP TABLE IF EXISTS test")
+	mustExec(t, db, "CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8 COLLATE utf8_unicode_ci")
+
+	// Local File
+	RegisterLocalFile(file.Name())
+	mustExec(t, db, "LOAD DATA LOCAL INFILE '"+file.Name()+"' INTO TABLE test")
+	verifyLoadDataResult(t, db)
+	// negative test
+	_, err = db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
+	if err == nil {
+		t.Fatal("Load non-existent file didn't fail")
+	} else if err.Error() != "Local File 'doesnotexist' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" {
+		t.Fatal(err.Error())
+	}
+
+	// Empty table
+	mustExec(t, db, "TRUNCATE TABLE test")
+
+	// Reader
+	file, err = os.Open(file.Name())
+	if err != nil {
+		t.Fatal(err)
+	}
+	RegisterReader("test", file)
+	mustExec(t, db, "LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
+	verifyLoadDataResult(t, db)
+	// negative test
+	_, err = db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test")
+	if err == nil {
+		t.Fatal("Load non-existent Reader didn't fail")
+	} else if err.Error() != "Reader 'doesnotexist' is not registered" {
+		t.Fatal(err.Error())
+	}
+	file.Close()
+
+	mustExec(t, db, "DROP TABLE IF EXISTS test")
+}
+
 // Special cases
 
 func TestRowsClose(t *testing.T) {

+ 115 - 0
infile.go

@@ -0,0 +1,115 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 Julien Schmidt. All rights reserved.
+// http://www.julienschmidt.com
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+	"database/sql/driver"
+	"fmt"
+	"io"
+	"os"
+	"strings"
+)
+
+var (
+	fileRegister   map[string]bool
+	readerRegister map[string]io.Reader
+)
+
+func init() {
+	fileRegister = make(map[string]bool)
+	readerRegister = make(map[string]io.Reader)
+}
+
+// RegisterLocalFile adds the given file to the file whitelist,
+// so that it can be used by "LOAD DATA LOCAL INFILE <filepath".
+// Alternatively you can allow the use of all local files with
+// the DSN parameter 'allowAllFiles=true'
+func RegisterLocalFile(filepath string) {
+	fileRegister[filepath] = true
+}
+
+// RegisterReader registers a io.Reader so that it can be used by
+// "LOAD DATA LOCAL INFILE Reader::<name>".
+// The use of io.Reader in this context is NOT safe for concurrency!
+func RegisterReader(name string, rdr io.Reader) {
+	readerRegister[name] = rdr
+}
+
+func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
+	var rdr io.Reader
+	data := make([]byte, 4+mc.maxWriteSize)
+
+	if strings.HasPrefix(name, "Reader::") { // io.Reader
+		name = name[8:]
+		var inMap bool
+		rdr, inMap = readerRegister[name]
+		if rdr == nil {
+			if !inMap {
+				err = fmt.Errorf("Reader '%s' is not registered", name)
+			} else {
+				err = fmt.Errorf("Reader '%s' is <nil>", name)
+			}
+		}
+
+	} else { // File
+		if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` {
+			var file *os.File
+			file, err = os.Open(name)
+			defer file.Close()
+
+			rdr = file
+		} else {
+			err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
+		}
+	}
+
+	// send content packets
+	var ioErr error
+	if err == nil {
+		var n int
+		for err == nil && ioErr == nil {
+			n, err = rdr.Read(data[4:])
+			if n > 0 {
+				data[0] = byte(n)
+				data[1] = byte(n >> 8)
+				data[2] = byte(n >> 16)
+				data[3] = mc.sequence
+				ioErr = mc.writePacket(data[:4+n])
+			}
+		}
+		if err == io.EOF {
+			err = nil
+		}
+		if ioErr != nil {
+			errLog.Print(ioErr.Error())
+			return driver.ErrBadConn
+		}
+	}
+
+	// send empty packet (termination)
+	ioErr = mc.writePacket([]byte{
+		0x00,
+		0x00,
+		0x00,
+		mc.sequence,
+	})
+	if ioErr != nil {
+		errLog.Print(ioErr.Error())
+		return driver.ErrBadConn
+	}
+
+	// read OK packet
+	if err == nil {
+		return mc.readResultOK()
+	} else {
+		mc.readPacket()
+	}
+	return err
+}

+ 10 - 3
packets.go

@@ -206,7 +206,8 @@ func (mc *mysqlConn) writeAuthPacket() error {
 		clientProtocol41 |
 			clientSecureConn |
 			clientLongPassword |
-			clientTransactions,
+			clientTransactions |
+			clientLocalFiles,
 	)
 	if mc.flags&clientLongFlag > 0 {
 		clientFlags |= uint32(clientLongFlag)
@@ -369,11 +370,17 @@ func (mc *mysqlConn) readResultOK() error {
 func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 	data, err := mc.readPacket()
 	if err == nil {
-		if data[0] == iOK {
+		switch data[0] {
+
+		case iOK:
 			mc.handleOkPacket(data)
 			return 0, nil
-		} else if data[0] == iERR {
+
+		case iERR:
 			return 0, mc.handleErrorPacket(data)
+
+		case iLocalInFile:
+			return 0, mc.handleInFileRequest(string(data[1:]))
 		}
 
 		// column count