Browse Source

Add Multi-Results support (#537)

Jacek Szwec 8 years ago
parent
commit
ffa70d4981
6 changed files with 377 additions and 86 deletions
  1. 1 0
      AUTHORS
  2. 24 10
      connection.go
  3. 190 0
      driver_go18_test.go
  4. 27 33
      packets.go
  5. 99 23
      rows.go
  6. 36 20
      statement.go

+ 1 - 0
AUTHORS

@@ -25,6 +25,7 @@ Hanno Braun <mail at hannobraun.com>
 Henri Yandell <flamefew at gmail.com>
 Hirotaka Yamamoto <ymmt2005 at gmail.com>
 INADA Naoki <songofacandy at gmail.com>
+Jacek Szwec <szwec.jacek at gmail.com>
 James Harr <james.harr at gmail.com>
 Jian Zhen <zhenjl at gmail.com>
 Joshua Prunier <joshua.prunier at gmail.com>

+ 24 - 10
connection.go

@@ -10,6 +10,7 @@ package mysql
 
 import (
 	"database/sql/driver"
+	"io"
 	"net"
 	"strconv"
 	"strings"
@@ -289,22 +290,29 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
 // Internal function to execute commands
 func (mc *mysqlConn) exec(query string) error {
 	// Send command
-	err := mc.writeCommandPacketStr(comQuery, query)
-	if err != nil {
+	if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
 		return err
 	}
 
 	// Read Result
 	resLen, err := mc.readResultSetHeaderPacket()
-	if err == nil && resLen > 0 {
-		if err = mc.readUntilEOF(); err != nil {
+	if err != nil {
+		return err
+	}
+
+	if resLen > 0 {
+		// columns
+		if err := mc.readUntilEOF(); err != nil {
 			return err
 		}
 
-		err = mc.readUntilEOF()
+		// rows
+		if err := mc.readUntilEOF(); err != nil {
+			return err
+		}
 	}
 
-	return err
+	return mc.discardResults()
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@@ -335,11 +343,17 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 			rows.mc = mc
 
 			if resLen == 0 {
-				// no columns, no more data
-				return emptyRows{}, nil
+				rows.rs.done = true
+
+				switch err := rows.NextResultSet(); err {
+				case nil, io.EOF:
+					return rows, nil
+				default:
+					return nil, err
+				}
 			}
 			// Columns
-			rows.columns, err = mc.readColumns(resLen)
+			rows.rs.columns, err = mc.readColumns(resLen)
 			return rows, err
 		}
 	}
@@ -359,7 +373,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 	if err == nil {
 		rows := new(textRows)
 		rows.mc = mc
-		rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
+		rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
 
 		if resLen > 0 {
 			// Columns

+ 190 - 0
driver_go18_test.go

@@ -0,0 +1,190 @@
+// +build go1.8
+
+package mysql
+
+import (
+	"database/sql"
+	"fmt"
+	"reflect"
+	"testing"
+)
+
+func TestMultiResultSet(t *testing.T) {
+	type result struct {
+		values  [][]int
+		columns []string
+	}
+
+	// checkRows is a helper test function to validate rows containing 3 result
+	// sets with specific values and columns. The basic query would look like this:
+	//
+	// SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
+	// SELECT 0 UNION SELECT 1;
+	// SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
+	//
+	// to distinguish test cases the first string argument is put in front of
+	// every error or fatal message.
+	checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) {
+		expected := []result{
+			{
+				values:  [][]int{{1, 2}, {3, 4}},
+				columns: []string{"col1", "col2"},
+			},
+			{
+				values:  [][]int{{1, 2, 3}, {4, 5, 6}},
+				columns: []string{"col1", "col2", "col3"},
+			},
+		}
+
+		var res1 result
+		for rows.Next() {
+			var res [2]int
+			if err := rows.Scan(&res[0], &res[1]); err != nil {
+				dbt.Fatal(err)
+			}
+			res1.values = append(res1.values, res[:])
+		}
+
+		cols, err := rows.Columns()
+		if err != nil {
+			dbt.Fatal(desc, err)
+		}
+		res1.columns = cols
+
+		if !reflect.DeepEqual(expected[0], res1) {
+			dbt.Error(desc, "want =", expected[0], "got =", res1)
+		}
+
+		if !rows.NextResultSet() {
+			dbt.Fatal(desc, "expected next result set")
+		}
+
+		// ignoring one result set
+
+		if !rows.NextResultSet() {
+			dbt.Fatal(desc, "expected next result set")
+		}
+
+		var res2 result
+		cols, err = rows.Columns()
+		if err != nil {
+			dbt.Fatal(desc, err)
+		}
+		res2.columns = cols
+
+		for rows.Next() {
+			var res [3]int
+			if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
+				dbt.Fatal(desc, err)
+			}
+			res2.values = append(res2.values, res[:])
+		}
+
+		if !reflect.DeepEqual(expected[1], res2) {
+			dbt.Error(desc, "want =", expected[1], "got =", res2)
+		}
+
+		if rows.NextResultSet() {
+			dbt.Error(desc, "unexpected next result set")
+		}
+
+		if err := rows.Err(); err != nil {
+			dbt.Error(desc, err)
+		}
+	}
+
+	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+		rows := dbt.mustQuery(`DO 1;
+		SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
+		DO 1;
+		SELECT 0 UNION SELECT 1;
+		SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`)
+		defer rows.Close()
+		checkRows("query: ", rows, dbt)
+	})
+
+	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+		queries := []string{
+			`
+			DROP PROCEDURE IF EXISTS test_mrss;
+			CREATE PROCEDURE test_mrss()
+			BEGIN
+				DO 1;
+				SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
+				DO 1;
+				SELECT 0 UNION SELECT 1;
+				SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
+			END
+		`,
+			`
+			DROP PROCEDURE IF EXISTS test_mrss;
+			CREATE PROCEDURE test_mrss()
+			BEGIN
+				SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
+				SELECT 0 UNION SELECT 1;
+				SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
+			END
+		`,
+		}
+
+		defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss")
+
+		for i, query := range queries {
+			dbt.mustExec(query)
+
+			stmt, err := dbt.db.Prepare("CALL test_mrss()")
+			if err != nil {
+				dbt.Fatalf("%v (i=%d)", err, i)
+			}
+			defer stmt.Close()
+
+			for j := 0; j < 2; j++ {
+				rows, err := stmt.Query()
+				if err != nil {
+					dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j)
+				}
+				checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt)
+			}
+		}
+	})
+}
+
+func TestMultiResultSetNoSelect(t *testing.T) {
+	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+		rows := dbt.mustQuery("DO 1; DO 2;")
+		defer rows.Close()
+
+		if rows.Next() {
+			dbt.Error("unexpected row")
+		}
+
+		if rows.NextResultSet() {
+			dbt.Error("unexpected next result set")
+		}
+
+		if err := rows.Err(); err != nil {
+			dbt.Error("expected nil; got ", err)
+		}
+	})
+}
+
+// tests if rows are set in a proper state if some results were ignored before
+// calling rows.NextResultSet.
+func TestSkipResults(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		rows := dbt.mustQuery("SELECT 1, 2")
+		defer rows.Close()
+
+		if !rows.Next() {
+			dbt.Error("expected row")
+		}
+
+		if rows.NextResultSet() {
+			dbt.Error("unexpected next result set")
+		}
+
+		if err := rows.Err(); err != nil {
+			dbt.Error("expected nil; got ", err)
+		}
+	})
+}

+ 27 - 33
packets.go

@@ -584,8 +584,8 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 
 	// server_status [2 bytes]
 	mc.status = readStatus(data[1+n+m : 1+n+m+2])
-	if err := mc.discardResults(); err != nil {
-		return err
+	if mc.status&statusMoreResultsExists != 0 {
+		return nil
 	}
 
 	// warning count [2 bytes]
@@ -698,6 +698,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
 func (rows *textRows) readRow(dest []driver.Value) error {
 	mc := rows.mc
 
+	if rows.rs.done {
+		return io.EOF
+	}
+
 	data, err := mc.readPacket()
 	if err != nil {
 		return err
@@ -707,15 +711,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
 	if data[0] == iEOF && len(data) == 5 {
 		// server_status [2 bytes]
 		rows.mc.status = readStatus(data[3:])
-		err = rows.mc.discardResults()
-		if err == nil {
-			err = io.EOF
-		} else {
-			// connection unusable
-			rows.mc.Close()
+		rows.rs.done = true
+		if !rows.HasNextResultSet() {
+			rows.mc = nil
 		}
-		rows.mc = nil
-		return err
+		return io.EOF
 	}
 	if data[0] == iERR {
 		rows.mc = nil
@@ -736,7 +736,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
 				if !mc.parseTime {
 					continue
 				} else {
-					switch rows.columns[i].fieldType {
+					switch rows.rs.columns[i].fieldType {
 					case fieldTypeTimestamp, fieldTypeDateTime,
 						fieldTypeDate, fieldTypeNewDate:
 						dest[i], err = parseDateTime(
@@ -1097,8 +1097,6 @@ func (mc *mysqlConn) discardResults() error {
 			if err := mc.readUntilEOF(); err != nil {
 				return err
 			}
-		} else {
-			mc.status &^= statusMoreResultsExists
 		}
 	}
 	return nil
@@ -1116,15 +1114,11 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 		// EOF Packet
 		if data[0] == iEOF && len(data) == 5 {
 			rows.mc.status = readStatus(data[3:])
-			err = rows.mc.discardResults()
-			if err == nil {
-				err = io.EOF
-			} else {
-				// connection unusable
-				rows.mc.Close()
+			rows.rs.done = true
+			if !rows.HasNextResultSet() {
+				rows.mc = nil
 			}
-			rows.mc = nil
-			return err
+			return io.EOF
 		}
 		rows.mc = nil
 
@@ -1145,14 +1139,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 		}
 
 		// Convert to byte-coded string
-		switch rows.columns[i].fieldType {
+		switch rows.rs.columns[i].fieldType {
 		case fieldTypeNULL:
 			dest[i] = nil
 			continue
 
 		// Numeric Types
 		case fieldTypeTiny:
-			if rows.columns[i].flags&flagUnsigned != 0 {
+			if rows.rs.columns[i].flags&flagUnsigned != 0 {
 				dest[i] = int64(data[pos])
 			} else {
 				dest[i] = int64(int8(data[pos]))
@@ -1161,7 +1155,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 			continue
 
 		case fieldTypeShort, fieldTypeYear:
-			if rows.columns[i].flags&flagUnsigned != 0 {
+			if rows.rs.columns[i].flags&flagUnsigned != 0 {
 				dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
 			} else {
 				dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
@@ -1170,7 +1164,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 			continue
 
 		case fieldTypeInt24, fieldTypeLong:
-			if rows.columns[i].flags&flagUnsigned != 0 {
+			if rows.rs.columns[i].flags&flagUnsigned != 0 {
 				dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
 			} else {
 				dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
@@ -1179,7 +1173,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 			continue
 
 		case fieldTypeLongLong:
-			if rows.columns[i].flags&flagUnsigned != 0 {
+			if rows.rs.columns[i].flags&flagUnsigned != 0 {
 				val := binary.LittleEndian.Uint64(data[pos : pos+8])
 				if val > math.MaxInt64 {
 					dest[i] = uint64ToString(val)
@@ -1233,10 +1227,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 			case isNull:
 				dest[i] = nil
 				continue
-			case rows.columns[i].fieldType == fieldTypeTime:
+			case rows.rs.columns[i].fieldType == fieldTypeTime:
 				// database/sql does not support an equivalent to TIME, return a string
 				var dstlen uint8
-				switch decimals := rows.columns[i].decimals; decimals {
+				switch decimals := rows.rs.columns[i].decimals; decimals {
 				case 0x00, 0x1f:
 					dstlen = 8
 				case 1, 2, 3, 4, 5, 6:
@@ -1244,7 +1238,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 				default:
 					return fmt.Errorf(
 						"protocol error, illegal decimals value %d",
-						rows.columns[i].decimals,
+						rows.rs.columns[i].decimals,
 					)
 				}
 				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
@@ -1252,10 +1246,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 				dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
 			default:
 				var dstlen uint8
-				if rows.columns[i].fieldType == fieldTypeDate {
+				if rows.rs.columns[i].fieldType == fieldTypeDate {
 					dstlen = 10
 				} else {
-					switch decimals := rows.columns[i].decimals; decimals {
+					switch decimals := rows.rs.columns[i].decimals; decimals {
 					case 0x00, 0x1f:
 						dstlen = 19
 					case 1, 2, 3, 4, 5, 6:
@@ -1263,7 +1257,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 					default:
 						return fmt.Errorf(
 							"protocol error, illegal decimals value %d",
-							rows.columns[i].decimals,
+							rows.rs.columns[i].decimals,
 						)
 					}
 				}
@@ -1279,7 +1273,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 
 		// Please report if this happens!
 		default:
-			return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
+			return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
 		}
 	}
 

+ 99 - 23
rows.go

@@ -21,40 +21,49 @@ type mysqlField struct {
 	decimals  byte
 }
 
-type mysqlRows struct {
-	mc      *mysqlConn
+type resultSet struct {
 	columns []mysqlField
+	done    bool
+}
+
+type mysqlRows struct {
+	mc *mysqlConn
+	rs resultSet
 }
 
 type binaryRows struct {
 	mysqlRows
+	// stmtCols is a pointer to the statement's cached columns for different
+	// result sets.
+	stmtCols *[][]mysqlField
+	// i is a number of the current result set. It is used to fetch proper
+	// columns from stmtCols.
+	i int
 }
 
 type textRows struct {
 	mysqlRows
 }
 
-type emptyRows struct{}
-
 func (rows *mysqlRows) Columns() []string {
-	columns := make([]string, len(rows.columns))
+	columns := make([]string, len(rows.rs.columns))
 	if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
 		for i := range columns {
-			if tableName := rows.columns[i].tableName; len(tableName) > 0 {
-				columns[i] = tableName + "." + rows.columns[i].name
+			if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
+				columns[i] = tableName + "." + rows.rs.columns[i].name
 			} else {
-				columns[i] = rows.columns[i].name
+				columns[i] = rows.rs.columns[i].name
 			}
 		}
 	} else {
 		for i := range columns {
-			columns[i] = rows.columns[i].name
+			columns[i] = rows.rs.columns[i].name
 		}
 	}
 	return columns
 }
 
-func (rows *mysqlRows) Close() error {
+func (rows *mysqlRows) Close() (err error) {
 	mc := rows.mc
 	if mc == nil {
 		return nil
@@ -64,7 +73,9 @@ func (rows *mysqlRows) Close() error {
 	}
 
 	// Remove unread packets from stream
-	err := mc.readUntilEOF()
+	if !rows.rs.done {
+		err = mc.readUntilEOF()
+	}
 	if err == nil {
 		if err = mc.discardResults(); err != nil {
 			return err
@@ -75,6 +86,73 @@ func (rows *mysqlRows) Close() error {
 	return err
 }
 
+func (rows *mysqlRows) HasNextResultSet() (b bool) {
+	if rows.mc == nil {
+		return false
+	}
+	return rows.mc.status&statusMoreResultsExists != 0
+}
+
+func (rows *mysqlRows) nextResultSet() (int, error) {
+	if rows.mc == nil {
+		return 0, io.EOF
+	}
+	if rows.mc.netConn == nil {
+		return 0, ErrInvalidConn
+	}
+
+	// Remove unread packets from stream
+	if !rows.rs.done {
+		if err := rows.mc.readUntilEOF(); err != nil {
+			return 0, err
+		}
+		rows.rs.done = true
+	}
+
+	if !rows.HasNextResultSet() {
+		rows.mc = nil
+		return 0, io.EOF
+	}
+	rows.rs = resultSet{}
+	return rows.mc.readResultSetHeaderPacket()
+}
+
+func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
+	for {
+		resLen, err := rows.nextResultSet()
+		if err != nil {
+			return 0, err
+		}
+
+		if resLen > 0 {
+			return resLen, nil
+		}
+
+		rows.rs.done = true
+	}
+}
+
+func (rows *binaryRows) NextResultSet() (err error) {
+	resLen, err := rows.nextNotEmptyResultSet()
+	if err != nil {
+		return err
+	}
+
+	// get columns, if not cached, read them and cache them.
+	if rows.i >= len(*rows.stmtCols) {
+		rows.rs.columns, err = rows.mc.readColumns(resLen)
+		*rows.stmtCols = append(*rows.stmtCols, rows.rs.columns)
+	} else {
+		rows.rs.columns = (*rows.stmtCols)[rows.i]
+		if err := rows.mc.readUntilEOF(); err != nil {
+			return err
+		}
+	}
+
+	rows.i++
+	return nil
+}
+
 func (rows *binaryRows) Next(dest []driver.Value) error {
 	if mc := rows.mc; mc != nil {
 		if mc.netConn == nil {
@@ -87,6 +165,16 @@ func (rows *binaryRows) Next(dest []driver.Value) error {
 	return io.EOF
 }
 
+func (rows *textRows) NextResultSet() (err error) {
+	resLen, err := rows.nextNotEmptyResultSet()
+	if err != nil {
+		return err
+	}
+
+	rows.rs.columns, err = rows.mc.readColumns(resLen)
+	return err
+}
+
 func (rows *textRows) Next(dest []driver.Value) error {
 	if mc := rows.mc; mc != nil {
 		if mc.netConn == nil {
@@ -98,15 +186,3 @@ func (rows *textRows) Next(dest []driver.Value) error {
 	}
 	return io.EOF
 }
-
-func (rows emptyRows) Columns() []string {
-	return nil
-}
-
-func (rows emptyRows) Close() error {
-	return nil
-}
-
-func (rows emptyRows) Next(dest []driver.Value) error {
-	return io.EOF
-}

+ 36 - 20
statement.go

@@ -11,6 +11,7 @@ package mysql
 import (
 	"database/sql/driver"
 	"fmt"
+	"io"
 	"reflect"
 	"strconv"
 )
@@ -19,7 +20,7 @@ type mysqlStmt struct {
 	mc         *mysqlConn
 	id         uint32
 	paramCount int
-	columns    []mysqlField // cached from the first query
+	columns    [][]mysqlField // cached from the first query
 }
 
 func (stmt *mysqlStmt) Close() error {
@@ -62,26 +63,30 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 
 	// Read Result
 	resLen, err := mc.readResultSetHeaderPacket()
-	if err == nil {
-		if resLen > 0 {
-			// Columns
-			err = mc.readUntilEOF()
-			if err != nil {
-				return nil, err
-			}
+	if err != nil {
+		return nil, err
+	}
 
-			// Rows
-			err = mc.readUntilEOF()
+	if resLen > 0 {
+		// Columns
+		if err = mc.readUntilEOF(); err != nil {
+			return nil, err
 		}
-		if err == nil {
-			return &mysqlResult{
-				affectedRows: int64(mc.affectedRows),
-				insertId:     int64(mc.insertId),
-			}, nil
+
+		// Rows
+		if err := mc.readUntilEOF(); err != nil {
+			return nil, err
 		}
 	}
 
-	return nil, err
+	if err := mc.discardResults(); err != nil {
+		return nil, err
+	}
+
+	return &mysqlResult{
+		affectedRows: int64(mc.affectedRows),
+		insertId:     int64(mc.insertId),
+	}, nil
 }
 
 func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
@@ -104,18 +109,29 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 	}
 
 	rows := new(binaryRows)
+	rows.stmtCols = &stmt.columns
 
 	if resLen > 0 {
 		rows.mc = mc
+		rows.i++
 		// Columns
 		// If not cached, read them and cache them
-		if stmt.columns == nil {
-			rows.columns, err = mc.readColumns(resLen)
-			stmt.columns = rows.columns
+		if len(stmt.columns) == 0 {
+			rows.rs.columns, err = mc.readColumns(resLen)
+			stmt.columns = append(stmt.columns, rows.rs.columns)
 		} else {
-			rows.columns = stmt.columns
+			rows.rs.columns = stmt.columns[0]
 			err = mc.readUntilEOF()
 		}
+	} else {
+		rows.rs.done = true
+
+		switch err := rows.NextResultSet(); err {
+		case nil, io.EOF:
+			return rows, nil
+		default:
+			return nil, err
+		}
 	}
 
 	return rows, err