Browse Source

Merge pull request #411 from go-sql-driver/multistmt

Multistatements and multi results
Julien Schmidt 10 years ago
parent
commit
b4db83c6fa
8 changed files with 174 additions and 17 deletions
  1. 2 0
      AUTHORS
  2. 10 0
      README.md
  3. 81 0
      driver_test.go
  4. 9 0
      dsn.go
  5. 14 13
      dsn_test.go
  6. 50 2
      packets.go
  7. 7 1
      rows.go
  8. 1 1
      statement.go

+ 2 - 0
AUTHORS

@@ -31,6 +31,7 @@ Julien Schmidt <go-sql-driver at julienschmidt.com>
 Kamil Dziedzic <kamil at klecza.pl>
 Kevin Malachowski <kevin at chowski.com>
 Leonardo YongUk Kim <dalinaum at gmail.com>
+Luca Looz <luca.looz92 at gmail.com>
 Lucas Liu <extrafliu at gmail.com>
 Luke Scott <luke at webconnex.com>
 Michael Woolnough <michael.woolnough at gmail.com>
@@ -38,6 +39,7 @@ Nicola Peduzzi <thenikso at gmail.com>
 Runrioter Wung <runrioter at gmail.com>
 Soroush Pour <me at soroushjp.com>
 Stan Putrya <root.vagner at gmail.com>
+Stanley Gunawan <gunawan.stanley at gmail.com>
 Xiaobing Jiang <s7v7nislands at gmail.com>
 Xiuming Chen <cc at cxm.cc>
 

+ 10 - 0
README.md

@@ -219,6 +219,16 @@ Note that this sets the location for time.Time values but does not change MySQL'
 
 Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
 
+##### `multiStatements`
+
+```
+Type:           bool
+Valid Values:   true, false
+Default:        false
+```
+
+Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded.
+
 
 ##### `parseTime`
 

+ 81 - 0
driver_test.go

@@ -76,6 +76,28 @@ type DBTest struct {
 	db *sql.DB
 }
 
+func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
+	if !available {
+		t.Skipf("MySQL server not running on %s", netAddr)
+	}
+
+	dsn += "&multiStatements=true"
+	var db *sql.DB
+	if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
+		db, err = sql.Open("mysql", dsn)
+		if err != nil {
+			t.Fatalf("error connecting: %s", err.Error())
+		}
+		defer db.Close()
+	}
+
+	dbt := &DBTest{t, db}
+	for _, test := range tests {
+		test(dbt)
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
+	}
+}
+
 func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 	if !available {
 		t.Skipf("MySQL server not running on %s", netAddr)
@@ -99,8 +121,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 		defer db2.Close()
 	}
 
+	dsn3 := dsn + "&multiStatements=true"
+	var db3 *sql.DB
+	if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
+		db3, err = sql.Open("mysql", dsn3)
+		if err != nil {
+			t.Fatalf("error connecting: %s", err.Error())
+		}
+		defer db3.Close()
+	}
+
 	dbt := &DBTest{t, db}
 	dbt2 := &DBTest{t, db2}
+	dbt3 := &DBTest{t, db3}
 	for _, test := range tests {
 		test(dbt)
 		dbt.db.Exec("DROP TABLE IF EXISTS test")
@@ -108,6 +141,10 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 			test(dbt2)
 			dbt2.db.Exec("DROP TABLE IF EXISTS test")
 		}
+		if db3 != nil {
+			test(dbt3)
+			dbt3.db.Exec("DROP TABLE IF EXISTS test")
+		}
 	}
 }
 
@@ -237,6 +274,50 @@ func TestCRUD(t *testing.T) {
 	})
 }
 
+func TestMultiQuery(t *testing.T) {
+	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+		// Create Table
+		dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
+
+		// Create Data
+		res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
+		count, err := res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 1 {
+			dbt.Fatalf("expected 1 affected row, got %d", count)
+		}
+
+		// Update
+		res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
+		count, err = res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 1 {
+			dbt.Fatalf("expected 1 affected row, got %d", count)
+		}
+
+		// Read
+		var out int
+		rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
+		if rows.Next() {
+			rows.Scan(&out)
+			if 5 != out {
+				dbt.Errorf("5 != %t", out)
+			}
+
+			if rows.Next() {
+				dbt.Error("unexpected data")
+			}
+		} else {
+			dbt.Error("no data")
+		}
+
+	})
+}
+
 func TestInt(t *testing.T) {
 	runTests(t, dsn, func(dbt *DBTest) {
 		types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}

+ 9 - 0
dsn.go

@@ -46,6 +46,7 @@ type Config struct {
 	ClientFoundRows         bool // Return number of matching rows instead of rows changed
 	ColumnsWithAlias        bool // Prepend table alias to column names
 	InterpolateParams       bool // Interpolate placeholders into query string
+	MultiStatements         bool // Allow multiple statements in one query
 	ParseTime               bool // Parse time values to time.Time
 	Strict                  bool // Return warnings as errors
 }
@@ -235,6 +236,14 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 				return
 			}
 
+		// multiple statements in one query
+		case "multiStatements":
+			var isBool bool
+			cfg.MultiStatements, isBool = readBool(value)
+			if !isBool {
+				return errors.New("invalid bool value: " + value)
+			}
+
 		// time.Time parsing
 		case "parseTime":
 			var isBool bool

+ 14 - 13
dsn_test.go

@@ -19,19 +19,20 @@ var testDSNs = []struct {
 	in  string
 	out string
 }{
-	{"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false ParseTime:false Strict:false}"},
-	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:30s ReadTimeout:1s WriteTimeout:1s Collation:224 AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Loc:Local TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
-	{"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
+	{"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:true ParseTime:false Strict:false}"},
+	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:30s ReadTimeout:1s WriteTimeout:1s Collation:224 AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Loc:Local TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
 }
 
 func TestDSNParser(t *testing.T) {

+ 50 - 2
packets.go

@@ -224,6 +224,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 		clientTransactions |
 		clientLocalFiles |
 		clientPluginAuth |
+		clientMultiResults |
 		mc.flags&clientLongFlag
 
 	if mc.cfg.ClientFoundRows {
@@ -235,6 +236,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 		clientFlags |= clientSSL
 	}
 
+	if mc.cfg.MultiStatements {
+		clientFlags |= clientMultiStatements
+	}
+
 	// User Password
 	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
 
@@ -519,6 +524,10 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 	}
 }
 
+func readStatus(b []byte) statusFlag {
+	return statusFlag(b[0]) | statusFlag(b[1])<<8
+}
+
 // Ok Packet
 // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
 func (mc *mysqlConn) handleOkPacket(data []byte) error {
@@ -533,7 +542,10 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 	// server_status [2 bytes]
-	mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
+	mc.status = readStatus(data[1+n+m : 1+n+m+2])
+	if err := mc.discardResults(); err != nil {
+		return err
+	}
 
 	// warning count [2 bytes]
 	if !mc.strict {
@@ -652,6 +664,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
 
 	// EOF Packet
 	if data[0] == iEOF && len(data) == 5 {
+		// server_status [2 bytes]
+		rows.mc.status = readStatus(data[3:])
+		if err := rows.mc.discardResults(); err != nil {
+			return err
+		}
 		rows.mc = nil
 		return io.EOF
 	}
@@ -709,6 +726,10 @@ func (mc *mysqlConn) readUntilEOF() error {
 		if err == nil && data[0] != iEOF {
 			continue
 		}
+		if err == nil && data[0] == iEOF && len(data) == 5 {
+			mc.status = readStatus(data[3:])
+		}
+
 		return err // Err or EOF
 	}
 }
@@ -1013,6 +1034,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	return mc.writePacket(data)
 }
 
+func (mc *mysqlConn) discardResults() error {
+	for mc.status&statusMoreResultsExists != 0 {
+		resLen, err := mc.readResultSetHeaderPacket()
+		if err != nil {
+			return err
+		}
+		if resLen > 0 {
+			// columns
+			if err := mc.readUntilEOF(); err != nil {
+				return err
+			}
+			// rows
+			if err := mc.readUntilEOF(); err != nil {
+				return err
+			}
+		} else {
+			mc.status &^= statusMoreResultsExists
+		}
+	}
+	return nil
+}
+
 // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
 func (rows *binaryRows) readRow(dest []driver.Value) error {
 	data, err := rows.mc.readPacket()
@@ -1022,11 +1065,16 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 
 	// packet indicator [1 byte]
 	if data[0] != iOK {
-		rows.mc = nil
 		// EOF Packet
 		if data[0] == iEOF && len(data) == 5 {
+			rows.mc.status = readStatus(data[3:])
+			if err := rows.mc.discardResults(); err != nil {
+				return err
+			}
+			rows.mc = nil
 			return io.EOF
 		}
+		rows.mc = nil
 
 		// Error otherwise
 		return rows.mc.handleErrorPacket(data)

+ 7 - 1
rows.go

@@ -38,7 +38,7 @@ type emptyRows struct{}
 
 func (rows *mysqlRows) Columns() []string {
 	columns := make([]string, len(rows.columns))
-	if rows.mc.cfg.ColumnsWithAlias {
+	if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
 		for i := range columns {
 			if tableName := rows.columns[i].tableName; len(tableName) > 0 {
 				columns[i] = tableName + "." + rows.columns[i].name
@@ -65,6 +65,12 @@ func (rows *mysqlRows) Close() error {
 
 	// Remove unread packets from stream
 	err := mc.readUntilEOF()
+	if err == nil {
+		if err = mc.discardResults(); err != nil {
+			return err
+		}
+	}
+
 	rows.mc = nil
 	return err
 }

+ 1 - 1
statement.go

@@ -101,9 +101,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 	}
 
 	rows := new(binaryRows)
-	rows.mc = mc
 
 	if resLen > 0 {
+		rows.mc = mc
 		// Columns
 		// If not cached, read them and cache them
 		if stmt.columns == nil {