Browse Source

LONG DATA handling

prerequisite for issue #33
Julien Schmidt 12 years ago
parent
commit
e9a3f83b7e
8 changed files with 302 additions and 54 deletions
  1. 2 0
      README.md
  2. 40 10
      connection.go
  3. 3 5
      const.go
  4. 15 2
      driver.go
  5. 77 0
      driver_test.go
  6. 1 0
      errors.go
  7. 154 37
      packets.go
  8. 10 0
      utils.go

+ 2 - 0
README.md

@@ -30,6 +30,8 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa
   * Connections over TCP/IPv4, TCP/IPv6 or Unix Sockets
   * Automatic handling of broken connections
   * Automatic Connection-Pooling *(by database/sql package)*
+  * Supports queries larger than 16MB
+  * Intelligent `LONG DATA` handling in prepared statements
 
 ## Requirements
   * Go 1.0.3 or higher

+ 40 - 10
connection.go

@@ -17,16 +17,18 @@ import (
 )
 
 type mysqlConn struct {
-	cfg          *config
-	flags        clientFlag
-	charset      byte
-	cipher       []byte
-	netConn      net.Conn
-	buf          *buffer
-	protocol     uint8
-	sequence     uint8
-	affectedRows uint64
-	insertId     uint64
+	cfg              *config
+	flags            clientFlag
+	charset          byte
+	cipher           []byte
+	netConn          net.Conn
+	buf              *buffer
+	protocol         uint8
+	sequence         uint8
+	affectedRows     uint64
+	insertId         uint64
+	maxPacketAllowed int
+	maxWriteSize     int
 }
 
 type config struct {
@@ -192,3 +194,31 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 	// with args, must use prepared stmt
 	return nil, driver.ErrSkip
 }
+
+// Gets the value of the given MySQL System Variable
+func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
+	// Send command
+	err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)
+	if err == nil {
+		// Read Result
+		var resLen int
+		resLen, err = mc.readResultSetHeaderPacket()
+		if err == nil {
+			rows := &mysqlRows{mc, false, nil, false}
+
+			if resLen > 0 {
+				// Columns
+				rows.columns, err = mc.readColumns(resLen)
+			}
+
+			dest := make([]driver.Value, resLen)
+			err = rows.readRow(dest)
+			if err == nil {
+				val = dest[0].([]byte)
+				err = mc.readUntilEOF()
+			}
+		}
+	}
+
+	return
+}

+ 3 - 5
const.go

@@ -11,8 +11,8 @@ package mysql
 
 const (
 	minProtocolVersion byte = 10
-	//maxPacketSize      = 1<<24 - 1
-	timeFormat = "2006-01-02 15:04:05"
+	maxPacketSize           = 1<<24 - 1
+	timeFormat              = "2006-01-02 15:04:05"
 )
 
 // MySQL constants documentation:
@@ -47,10 +47,8 @@ const (
 	clientMultiResults
 )
 
-type commandType byte
-
 const (
-	comQuit commandType = iota + 1
+	comQuit byte = iota + 1
 	comInitDB
 	comQuery
 	comFieldList

+ 15 - 2
driver.go

@@ -24,8 +24,11 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
 	var err error
 
 	// New mysqlConn
-	mc := new(mysqlConn)
-	mc.cfg = parseDSN(dsn)
+	mc := &mysqlConn{
+		cfg:              parseDSN(dsn),
+		maxPacketAllowed: maxPacketSize,
+		maxWriteSize:     maxPacketSize - 1,
+	}
 
 	// Connect to Server
 	if _, ok := mc.cfg.params["timeout"]; ok { // with timeout
@@ -60,6 +63,16 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
 		return nil, err
 	}
 
+	// Get max allowed packet size
+	maxap, err := mc.getSystemVar("max_allowed_packet")
+	if err != nil {
+		return nil, err
+	}
+	mc.maxPacketAllowed = stringToInt(maxap) - 1
+	if mc.maxPacketAllowed < maxPacketSize {
+		mc.maxWriteSize = mc.maxPacketAllowed
+	}
+
 	// Handle DSN Params
 	err = mc.handleParams()
 	if err != nil {

+ 77 - 0
driver_test.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"net"
 	"os"
+	"strings"
 	"sync"
 	"testing"
 )
@@ -57,6 +58,9 @@ func getEnv() bool {
 func mustExec(t *testing.T, db *sql.DB, query string, args ...interface{}) (res sql.Result) {
 	res, err := db.Exec(query, args...)
 	if err != nil {
+		if len(query) > 300 {
+			query = "[query too large to print]"
+		}
 		t.Fatalf("Error on Exec %q: %v", query, err)
 	}
 	return
@@ -65,6 +69,9 @@ func mustExec(t *testing.T, db *sql.DB, query string, args ...interface{}) (res
 func mustQuery(t *testing.T, db *sql.DB, query string, args ...interface{}) (rows *sql.Rows) {
 	rows, err := db.Query(query, args...)
 	if err != nil {
+		if len(query) > 300 {
+			query = "[query too large to print]"
+		}
 		t.Fatalf("Error on Query %q: %v", query, err)
 	}
 	return
@@ -501,6 +508,76 @@ func TestNULL(t *testing.T) {
 	mustExec(t, db, "DROP TABLE IF EXISTS test")
 }
 
+func TestLongData(t *testing.T) {
+	if !getEnv() {
+		t.Logf("MySQL-Server not running on %s. Skipping TestLongData", netAddr)
+		return
+	}
+
+	db, err := sql.Open("mysql", dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %v", err)
+	}
+	defer db.Close()
+
+	var maxAllowedPacketSize int
+	err = db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
+	if err != nil {
+		t.Fatal(err)
+	}
+	maxAllowedPacketSize--
+
+	// don't get too ambitious
+	if maxAllowedPacketSize > 1<<25 {
+		maxAllowedPacketSize = 1 << 25
+	}
+
+	mustExec(t, db, "DROP TABLE IF EXISTS test")
+	mustExec(t, db, "CREATE TABLE test (value LONGBLOB) CHARACTER SET utf8 COLLATE utf8_unicode_ci")
+
+	in := strings.Repeat(`0`, maxAllowedPacketSize+1)
+	var out string
+	var rows *sql.Rows
+
+	// Long text data
+	const nonDataQueryLen = 28 // length query w/o value
+	inS := in[:maxAllowedPacketSize-nonDataQueryLen]
+	mustExec(t, db, "INSERT INTO test VALUES('"+inS+"')")
+	rows = mustQuery(t, db, "SELECT value FROM test")
+	if rows.Next() {
+		rows.Scan(&out)
+		if inS != out {
+			t.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out))
+		}
+		if rows.Next() {
+			t.Error("LONGBLOB: unexpexted row")
+		}
+	} else {
+		t.Fatalf("LONGBLOB: no data")
+	}
+
+	// Empty table
+	mustExec(t, db, "TRUNCATE TABLE test")
+
+	// Long binary data
+	mustExec(t, db, "INSERT INTO test VALUES(?)", in)
+	rows = mustQuery(t, db, "SELECT value FROM test WHERE 1=?", 1)
+	if rows.Next() {
+		rows.Scan(&out)
+		if in != out {
+			t.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out))
+		}
+		if rows.Next() {
+			t.Error("LONGBLOB: unexpexted row")
+		}
+		//t.Fatalf("%d %d %d", len(in)+nonDataQueryLen, len(out)+nonDataQueryLen, maxAllowedPacketSize)
+	} else {
+		t.Fatalf("LONGBLOB: no data")
+	}
+
+	mustExec(t, db, "DROP TABLE IF EXISTS test")
+}
+
 // Special cases
 
 func TestRowsClose(t *testing.T) {

+ 1 - 0
errors.go

@@ -16,4 +16,5 @@ var (
 	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
 	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")
 	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/Go-SQL-Driver/MySQL/wiki/old_passwords")
+	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 )

+ 154 - 37
packets.go

@@ -55,7 +55,16 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	data = make([]byte, pktLen)
 	err = mc.buf.read(data)
 	if err == nil {
-		return data, nil
+		if pktLen < maxPacketSize {
+			return data, nil
+		}
+
+		// More data
+		var data2 []byte
+		data2, err = mc.readPacket()
+		if err == nil {
+			return append(data, data2...), nil
+		}
 	}
 	errLog.Print(err.Error())
 	return nil, driver.ErrBadConn
@@ -64,19 +73,63 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 // Write packet buffer 'data'
 // The packet header must be already included
 func (mc *mysqlConn) writePacket(data []byte) error {
-	// Write packet
-	n, err := mc.netConn.Write(data)
-	if err == nil && n == len(data) {
-		mc.sequence++
-		return nil
+	if len(data)-4 <= mc.maxWriteSize { // Can send data at once
+		// Write packet
+		n, err := mc.netConn.Write(data)
+		if err == nil && n == len(data) {
+			mc.sequence++
+			return nil
+		}
+
+		// Handle error
+		if err == nil { // n != len(data)
+			errLog.Print(errMalformPkt.Error())
+		} else {
+			errLog.Print(err.Error())
+		}
+		return driver.ErrBadConn
 	}
 
-	if err == nil { // n != len(data)
-		errLog.Print(errMalformPkt.Error())
-	} else {
-		errLog.Print(err.Error())
+	// Must split packet
+	return mc.splitPacket(data)
+}
+
+func (mc *mysqlConn) splitPacket(data []byte) (err error) {
+	pktLen := len(data) - 4
+
+	if pktLen > mc.maxPacketAllowed {
+		return errPktTooLarge
+	}
+
+	for pktLen >= maxPacketSize {
+		data[0] = 0xff
+		data[1] = 0xff
+		data[2] = 0xff
+		data[3] = mc.sequence
+
+		// Write packet
+		n, err := mc.netConn.Write(data[:4+maxPacketSize])
+		if err == nil && n == 4+maxPacketSize {
+			mc.sequence++
+			data = data[maxPacketSize:]
+			pktLen -= maxPacketSize
+			continue
+		}
+
+		// Handle error
+		if err == nil { // n != len(data)
+			errLog.Print(errMalformPkt.Error())
+		} else {
+			errLog.Print(err.Error())
+		}
+		return driver.ErrBadConn
 	}
-	return driver.ErrBadConn
+
+	data[0] = byte(pktLen)
+	data[1] = byte(pktLen >> 8)
+	data[2] = byte(pktLen >> 16)
+	data[3] = mc.sequence
+	return mc.writePacket(data)
 }
 
 /******************************************************************************
@@ -186,10 +239,10 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	data[6] = byte(clientFlags >> 16)
 	data[7] = byte(clientFlags >> 24)
 
-	// MaxPacketSize [32 bit] (1<<24 - 1)
-	data[8] = 0xff
-	data[9] = 0xff
-	data[10] = 0xff
+	// MaxPacketSize [32 bit] (none)
+	//data[8] = 0x00
+	//data[9] = 0x00
+	//data[10] = 0x00
 	//data[11] = 0x00
 
 	// Charset [1 byte]
@@ -223,7 +276,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 *                             Command Packets                                 *
 ******************************************************************************/
 
-func (mc *mysqlConn) writeCommandPacket(command commandType) error {
+func (mc *mysqlConn) writeCommandPacket(command byte) error {
 	// Reset Packet Sequence
 	mc.sequence = 0
 
@@ -233,14 +286,14 @@ func (mc *mysqlConn) writeCommandPacket(command commandType) error {
 		0x05, // 5 bytes long
 		0x00,
 		0x00,
-		mc.sequence,
+		0x00, // mc.sequence
 
 		// Add command byte
-		byte(command),
+		command,
 	})
 }
 
-func (mc *mysqlConn) writeCommandPacketStr(command commandType, arg string) error {
+func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
 	// Reset Packet Sequence
 	mc.sequence = 0
 
@@ -251,10 +304,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command commandType, arg string) erro
 	data[0] = byte(pktLen)
 	data[1] = byte(pktLen >> 8)
 	data[2] = byte(pktLen >> 16)
-	data[3] = mc.sequence
+	//data[3] = mc.sequence
 
 	// Add command byte
-	data[4] = byte(command)
+	data[4] = command
 
 	// Add arg
 	copy(data[5:], arg)
@@ -263,7 +316,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command commandType, arg string) erro
 	return mc.writePacket(data)
 }
 
-func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) error {
+func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
 	// Reset Packet Sequence
 	mc.sequence = 0
 
@@ -273,10 +326,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) e
 		0x05, // 5 bytes long
 		0x00,
 		0x00,
-		mc.sequence,
+		0x00, // mc.sequence
 
 		// Add command byte
-		byte(command),
+		command,
 
 		// Add arg [32 bit]
 		byte(arg),
@@ -556,6 +609,54 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
 	return
 }
 
+// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data
+func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) (err error) {
+	maxLen := stmt.mc.maxPacketAllowed - 1
+	pktLen := maxLen
+	argLen := len(arg)
+	data := make([]byte, 4+1+4+2+argLen)
+	copy(data[4+1+4+2:], arg)
+
+	for argLen > 0 {
+		if 1+4+2+argLen < maxLen {
+			pktLen = 1 + 4 + 2 + argLen
+		}
+
+		// Add the packet header [24bit length + 1 byte sequence]
+		data[0] = byte(pktLen)
+		data[1] = byte(pktLen >> 8)
+		data[2] = byte(pktLen >> 16)
+		data[3] = 0x00 // mc.sequence
+
+		// Add command byte [1 byte]
+		data[4] = comStmtSendLongData
+
+		// Add stmtID [32 bit]
+		data[5] = byte(stmt.id)
+		data[6] = byte(stmt.id >> 8)
+		data[7] = byte(stmt.id >> 16)
+		data[8] = byte(stmt.id >> 24)
+
+		// Add paramID [16 bit]
+		data[9] = byte(paramID)
+		data[10] = byte(paramID >> 8)
+
+		// Send CMD packet
+		err = stmt.mc.writePacket(data[:4+pktLen])
+		if err == nil {
+			argLen -= pktLen - (1 + 4 + 2)
+			data = data[pktLen-(1+4+2):]
+			continue
+		}
+		return err
+
+	}
+
+	// Reset Packet Sequence
+	stmt.mc.sequence = 0
+	return nil
+}
+
 // Execute Prepared Statement
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-execute
 func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
@@ -609,21 +710,37 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 
 		case []byte:
 			paramTypes[i<<1] = fieldTypeString
-			paramValues[i] = append(
-				lengthEncodedIntegerToBytes(uint64(len(v))),
-				v...,
-			)
-			pktLen += len(paramValues[i])
-			continue
+			if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 {
+				paramValues[i] = append(
+					lengthEncodedIntegerToBytes(uint64(len(v))),
+					v...,
+				)
+				pktLen += len(paramValues[i])
+				continue
+			} else {
+				err := stmt.writeCommandLongData(i, v)
+				if err == nil {
+					continue
+				}
+				return err
+			}
 
 		case string:
 			paramTypes[i<<1] = fieldTypeString
-			paramValues[i] = append(
-				lengthEncodedIntegerToBytes(uint64(len(v))),
-				[]byte(v)...,
-			)
-			pktLen += len(paramValues[i])
-			continue
+			if len(v) < stmt.mc.maxPacketAllowed-pktLen-(stmt.paramCount-(i+1))*64 {
+				paramValues[i] = append(
+					lengthEncodedIntegerToBytes(uint64(len(v))),
+					[]byte(v)...,
+				)
+				pktLen += len(paramValues[i])
+				continue
+			} else {
+				err := stmt.writeCommandLongData(i, []byte(v))
+				if err == nil {
+					continue
+				}
+				return err
+			}
 
 		case time.Time:
 			paramTypes[i<<1] = fieldTypeString
@@ -649,7 +766,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	data[3] = stmt.mc.sequence
 
 	// command [1 byte]
-	data[4] = byte(comStmtExecute)
+	data[4] = comStmtExecute
 
 	// statement_id [4 bytes]
 	data[5] = byte(stmt.id)

+ 10 - 0
utils.go

@@ -149,6 +149,16 @@ func uint64ToString(n uint64) []byte {
 	return a[i:]
 }
 
+// treats string value as unsigned integer representation
+func stringToInt(b []byte) int {
+	val := 0
+	for i := range b {
+		val *= 10
+		val += int(b[i] - 0x30)
+	}
+	return val
+}
+
 func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
 	// Get length
 	num, isNull, n := readLengthEncodedInteger(b)