Просмотр исходного кода

Merge pull request #149 from go-sql-driver/em_rows

Make BinaryRows and TextRows distinct types
Julien Schmidt 12 лет назад
Родитель
Сommit
d10c04b36f
4 измененных файлов с 42 добавлено и 21 удалено
  1. 4 2
      connection.go
  2. 2 2
      packets.go
  3. 34 16
      rows.go
  4. 2 1
      statement.go

+ 4 - 2
connection.go

@@ -211,7 +211,8 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 			var resLen int
 			resLen, err = mc.readResultSetHeaderPacket()
 			if err == nil {
-				rows := &mysqlRows{mc, nil, false}
+				rows := new(textRows)
+				rows.mc = mc
 
 				if resLen > 0 {
 					// Columns
@@ -238,7 +239,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 	// Read Result
 	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
-		rows := &mysqlRows{mc, nil, false}
+		rows := new(textRows)
+		rows.mc = mc
 
 		if resLen > 0 {
 			// Columns

+ 2 - 2
packets.go

@@ -620,7 +620,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
-func (rows *mysqlRows) readRow(dest []driver.Value) error {
+func (rows *textRows) readRow(dest []driver.Value) error {
 	mc := rows.mc
 
 	data, err := mc.readPacket()
@@ -1002,7 +1002,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 }
 
 // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
-func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error {
+func (rows *binaryRows) readRow(dest []driver.Value) error {
 	data, err := rows.mc.readPacket()
 	if err != nil {
 		return err

+ 34 - 16
rows.go

@@ -22,7 +22,14 @@ type mysqlField struct {
 type mysqlRows struct {
 	mc      *mysqlConn
 	columns []mysqlField
-	binary  bool
+}
+
+type binaryRows struct {
+	mysqlRows
+}
+
+type textRows struct {
+	mysqlRows
 }
 
 func (rows *mysqlRows) Columns() []string {
@@ -41,28 +48,39 @@ func (rows *mysqlRows) Close() error {
 	if mc.netConn == nil {
 		return errInvalidConn
 	}
+
 	// Remove unread packets from stream
 	err := mc.readUntilEOF()
 	rows.mc = nil
 	return err
 }
 
-func (rows *mysqlRows) Next(dest []driver.Value) (err error) {
-	mc := rows.mc
-	if mc == nil {
-		return io.EOF
-	}
-	if mc.netConn == nil {
-		return errInvalidConn
-	}
-	// Fetch next row from stream
-	if rows.binary {
-		err = rows.readBinaryRow(dest)
-	} else {
-		err = rows.readRow(dest)
+func (rows *binaryRows) Next(dest []driver.Value) error {
+	if mc := rows.mc; mc != nil {
+		if mc.netConn == nil {
+			return errInvalidConn
+		}
+
+		// Fetch next row from stream
+		if err := rows.readRow(dest); err != io.EOF {
+			return err
+		}
+		rows.mc = nil
 	}
-	if err == io.EOF {
+	return io.EOF
+}
+
+func (rows *textRows) Next(dest []driver.Value) error {
+	if mc := rows.mc; mc != nil {
+		if mc.netConn == nil {
+			return errInvalidConn
+		}
+
+		// Fetch next row from stream
+		if err := rows.readRow(dest); err != io.EOF {
+			return err
+		}
 		rows.mc = nil
 	}
-	return
+	return io.EOF
 }

+ 2 - 1
statement.go

@@ -90,7 +90,8 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 		return nil, err
 	}
 
-	rows := &mysqlRows{mc, nil, true}
+	rows := new(binaryRows)
+	rows.mc = mc
 
 	if resLen > 0 {
 		// Columns