Pārlūkot izejas kodu

Fix Row-Fetching

Julien Schmidt 13 gadi atpakaļ
vecāks
revīzija
57fe6e4e22
3 mainītis faili ar 30 papildinājumiem un 28 dzēšanām
  1. 1 1
      connection.go
  2. 28 26
      rows.go
  3. 1 1
      statement.go

+ 1 - 1
connection.go

@@ -256,7 +256,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 		return nil, e
 	}
 
-	rows := mysqlRows{&rowsContent{mc, false, resLen, nil}}
+	rows := mysqlRows{&rowsContent{mc, false, nil, false}}
 
 	if resLen > 0 {
 		// Columns

+ 28 - 26
rows.go

@@ -24,8 +24,8 @@ type mysqlField struct {
 type rowsContent struct {
 	mc      *mysqlConn
 	binary  bool
-	unread  int
 	columns []mysqlField
+	eof     bool
 }
 
 type mysqlRows struct {
@@ -47,7 +47,7 @@ func (rows mysqlRows) Close() (e error) {
 	}()
 
 	// Remove unread packets from stream
-	if rows.content.unread > -1 {
+	if !rows.content.eof {
 		if rows.content.mc == nil {
 			return errors.New("Invalid Connection")
 		}
@@ -66,36 +66,38 @@ func (rows mysqlRows) Close() (e error) {
 // when the dest type is know, which makes type conversion easier and avoids 
 // unnecessary conversions.
 func (rows mysqlRows) Next(dest []driver.Value) error {
-	if rows.content.unread > 0 {
-		if rows.content.mc == nil {
-			return errors.New("Invalid Connection")
-		}
+	if rows.content.eof {
+		return io.EOF
+	}
 
-		columnsCount := cap(dest)
+	if rows.content.mc == nil {
+		return errors.New("Invalid Connection")
+	}
 
-		// Fetch next row from stream
-		var row *[]*[]byte
-		var e error
-		if rows.content.binary {
-			row, e = rows.content.mc.readBinaryRow(rows.content)
-		} else {
-			row, e = rows.content.mc.readRow(columnsCount)
-		}
-		rows.content.unread--
+	columnsCount := cap(dest)
 
-		if e != nil {
-			return e
+	// Fetch next row from stream
+	var row *[]*[]byte
+	var e error
+	if rows.content.binary {
+		row, e = rows.content.mc.readBinaryRow(rows.content)
+	} else {
+		row, e = rows.content.mc.readRow(columnsCount)
+	}
+
+	if e != nil {
+		if e == io.EOF {
+			rows.content.eof = true
 		}
+		return e
+	}
 
-		for i := 0; i < columnsCount; i++ {
-			if (*row)[i] == nil {
-				dest[i] = nil
-			} else {
-				dest[i] = *(*row)[i]
-			}
+	for i := 0; i < columnsCount; i++ {
+		if (*row)[i] == nil {
+			dest[i] = nil
+		} else {
+			dest[i] = *(*row)[i]
 		}
-	} else {
-		return io.EOF
 	}
 
 	return nil

+ 1 - 1
statement.go

@@ -100,7 +100,7 @@ func (stmt mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 		return nil, e
 	}
 
-	rows := mysqlRows{&rowsContent{stmt.mc, true, resLen, nil}}
+	rows := mysqlRows{&rowsContent{stmt.mc, true, nil, false}}
 
 	if resLen > 0 {
 		// Columns