Ver código fonte

remove eof field from mysqlRows

Arne Hormann 12 anos atrás
pai
commit
63dc97f1f3
3 arquivos alterados com 19 adições e 26 exclusões
  1. 2 2
      connection.go
  2. 16 23
      rows.go
  3. 1 1
      statement.go

+ 2 - 2
connection.go

@@ -211,7 +211,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 			var resLen int
 			resLen, err = mc.readResultSetHeaderPacket()
 			if err == nil {
-				rows := &mysqlRows{mc, nil, false, false}
+				rows := &mysqlRows{mc, nil, false}
 
 				if resLen > 0 {
 					// Columns
@@ -238,7 +238,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 	// Read Result
 	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
-		rows := &mysqlRows{mc, nil, false, false}
+		rows := &mysqlRows{mc, nil, false}
 
 		if resLen > 0 {
 			// Columns

+ 16 - 23
rows.go

@@ -23,7 +23,6 @@ type mysqlRows struct {
 	mc      *mysqlConn
 	columns []mysqlField
 	binary  bool
-	eof     bool
 }
 
 func (rows *mysqlRows) Columns() []string {
@@ -34,43 +33,37 @@ func (rows *mysqlRows) Columns() []string {
 	return columns
 }
 
-func (rows *mysqlRows) Close() (err error) {
-	// Remove unread packets from stream
-	if !rows.eof {
-		if rows.mc == nil || rows.mc.netConn == nil {
-			return errInvalidConn
-		}
-
-		err = rows.mc.readUntilEOF()
-
-		// explicitly set because readUntilEOF might return early in case of an
-		// error
-		rows.eof = true
+func (rows *mysqlRows) Close() error {
+	mc := rows.mc
+	if mc == nil {
+		return nil
 	}
-
+	if mc.netConn == nil {
+		return errInvalidConn
+	}
+	// Remove unread packets from stream
+	err := mc.readUntilEOF()
 	rows.mc = nil
-
-	return
+	return err
 }
 
-func (rows *mysqlRows) Next(dest []driver.Value) (err error) {
-	if rows.eof {
+func (rows *mysqlRows) Next(dest []driver.Value) error {
+	mc := rows.mc
+	if mc == nil {
 		return io.EOF
 	}
-
-	if rows.mc == nil || rows.mc.netConn == nil {
+	if mc.netConn == nil {
 		return errInvalidConn
 	}
-
+	var err error
 	// Fetch next row from stream
 	if rows.binary {
 		err = rows.readBinaryRow(dest)
 	} else {
 		err = rows.readRow(dest)
 	}
-
 	if err == io.EOF {
-		rows.eof = true
+		rows.mc = nil
 	}
 	return err
 }

+ 1 - 1
statement.go

@@ -90,7 +90,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 		return nil, err
 	}
 
-	rows := &mysqlRows{mc, nil, true, false}
+	rows := &mysqlRows{mc, nil, true}
 
 	if resLen > 0 {
 		// Columns