Browse Source

Enable Multi Results support and discard additional results

- packets.go: flag clientMultiResults, update status when receiving an
EOF packet, discard additional results on readRow when EOF is reached
- statement.go: currently a nil rows.mc is used as an eof, don’t set it
if there are no columns to avoid that Next() waits indefinitely
- rows.go: discard additional results on close and avoid panic on
Columns()
Idhor 10 years ago
parent
commit
8cbeffa8f6
3 changed files with 51 additions and 4 deletions
  1. 43 2
      packets.go
  2. 7 1
      rows.go
  3. 1 1
      statement.go

+ 43 - 2
packets.go

@@ -224,6 +224,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 		clientTransactions |
 		clientLocalFiles |
 		clientPluginAuth |
+		clientMultiResults |
 		mc.flags&clientLongFlag
 
 	if mc.cfg.ClientFoundRows {
@@ -519,6 +520,10 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 	}
 }
 
+func readStatus(b []byte) statusFlag {
+	return statusFlag(b[0]) | statusFlag(b[1])<<8
+}
+
 // Ok Packet
 // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
 func (mc *mysqlConn) handleOkPacket(data []byte) error {
@@ -533,7 +538,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 	// server_status [2 bytes]
-	mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
+	mc.status = readStatus(data[1+n+m : 1+n+m+2])
 
 	// warning count [2 bytes]
 	if !mc.strict {
@@ -652,6 +657,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
 
 	// EOF Packet
 	if data[0] == iEOF && len(data) == 5 {
+		// server_status [2 bytes]
+		rows.mc.status = readStatus(data[3:])
+		if err := rows.mc.discardMoreResultsIfExists(); err != nil {
+			return err
+		}
 		rows.mc = nil
 		return io.EOF
 	}
@@ -709,6 +719,10 @@ func (mc *mysqlConn) readUntilEOF() error {
 		if err == nil && data[0] != iEOF {
 			continue
 		}
+		if err == nil && data[0] == iEOF && len(data) == 5 {
+			mc.status = readStatus(data[3:])
+		}
+
 		return err // Err or EOF
 	}
 }
@@ -1013,6 +1027,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	return mc.writePacket(data)
 }
 
+func (mc *mysqlConn) discardMoreResultsIfExists() error {
+	for mc.status&statusMoreResultsExists != 0 {
+		resLen, err := mc.readResultSetHeaderPacket()
+		if err != nil {
+			return err
+		}
+		if resLen > 0 {
+			// columns
+			if err := mc.readUntilEOF(); err != nil {
+				return err
+			}
+			// rows
+			if err := mc.readUntilEOF(); err != nil {
+				return err
+			}
+		} else {
+			mc.status &^= statusMoreResultsExists
+		}
+	}
+	return nil
+}
+
 // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
 func (rows *binaryRows) readRow(dest []driver.Value) error {
 	data, err := rows.mc.readPacket()
@@ -1022,11 +1058,16 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 
 	// packet indicator [1 byte]
 	if data[0] != iOK {
-		rows.mc = nil
 		// EOF Packet
 		if data[0] == iEOF && len(data) == 5 {
+			rows.mc.status = readStatus(data[3:])
+			if err := rows.mc.discardMoreResultsIfExists(); err != nil {
+				return err
+			}
+			rows.mc = nil
 			return io.EOF
 		}
+		rows.mc = nil
 
 		// Error otherwise
 		return rows.mc.handleErrorPacket(data)

+ 7 - 1
rows.go

@@ -38,7 +38,7 @@ type emptyRows struct{}
 
 func (rows *mysqlRows) Columns() []string {
 	columns := make([]string, len(rows.columns))
-	if rows.mc.cfg.ColumnsWithAlias {
+	if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
 		for i := range columns {
 			if tableName := rows.columns[i].tableName; len(tableName) > 0 {
 				columns[i] = tableName + "." + rows.columns[i].name
@@ -65,6 +65,12 @@ func (rows *mysqlRows) Close() error {
 
 	// Remove unread packets from stream
 	err := mc.readUntilEOF()
+	if err == nil {
+		if err = mc.discardMoreResultsIfExists(); err != nil {
+			return err
+		}
+	}
+
 	rows.mc = nil
 	return err
 }

+ 1 - 1
statement.go

@@ -101,9 +101,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 	}
 
 	rows := new(binaryRows)
-	rows.mc = mc
 
 	if resLen > 0 {
+		rows.mc = mc
 		// Columns
 		// If not cached, read them and cache them
 		if stmt.columns == nil {