Преглед изворни кода

rows: Invalidate connection on error in discardResults() (#513)

Fixes #422
Julien Schmidt пре 9 година
родитељ
комит
8f6c67f022
2 измењених фајлова са 61 додато и 6 уклоњено
  1. 47 0
      driver_test.go
  2. 14 6
      packets.go

+ 47 - 0
driver_test.go

@@ -1855,3 +1855,50 @@ func TestUnixSocketAuthFail(t *testing.T) {
 		}
 	})
 }
+
+// See Issue #422
+func TestInterruptBySignal(t *testing.T) {
+	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec(`
+			DROP PROCEDURE IF EXISTS test_signal;
+			CREATE PROCEDURE test_signal(ret INT)
+			BEGIN
+				SELECT ret;
+				SIGNAL SQLSTATE
+					'45001'
+				SET
+					MESSAGE_TEXT = "an error",
+					MYSQL_ERRNO = 45001;
+			END
+		`)
+		defer dbt.mustExec("DROP PROCEDURE test_signal")
+
+		var val int
+
+		// text protocol
+		rows, err := dbt.db.Query("CALL test_signal(42)")
+		if err != nil {
+			dbt.Fatalf("error on text query: %s", err.Error())
+		}
+		for rows.Next() {
+			if err := rows.Scan(&val); err != nil {
+				dbt.Error(err)
+			} else if val != 42 {
+				dbt.Errorf("expected val to be 42")
+			}
+		}
+
+		// binary protocol
+		rows, err = dbt.db.Query("CALL test_signal(?)", 42)
+		if err != nil {
+			dbt.Fatalf("error on binary query: %s", err.Error())
+		}
+		for rows.Next() {
+			if err := rows.Scan(&val); err != nil {
+				dbt.Error(err)
+			} else if val != 42 {
+				dbt.Errorf("expected val to be 42")
+			}
+		}
+	})
+}

+ 14 - 6
packets.go

@@ -700,11 +700,15 @@ func (rows *textRows) readRow(dest []driver.Value) error {
 	if data[0] == iEOF && len(data) == 5 {
 		// server_status [2 bytes]
 		rows.mc.status = readStatus(data[3:])
-		if err := rows.mc.discardResults(); err != nil {
-			return err
+		err = rows.mc.discardResults()
+		if err == nil {
+			err = io.EOF
+		} else {
+			// connection unusable
+			rows.mc.Close()
 		}
 		rows.mc = nil
-		return io.EOF
+		return err
 	}
 	if data[0] == iERR {
 		rows.mc = nil
@@ -1105,11 +1109,15 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 		// EOF Packet
 		if data[0] == iEOF && len(data) == 5 {
 			rows.mc.status = readStatus(data[3:])
-			if err := rows.mc.discardResults(); err != nil {
-				return err
+			err = rows.mc.discardResults()
+			if err == nil {
+				err = io.EOF
+			} else {
+				// connection unusable
+				rows.mc.Close()
 			}
 			rows.mc = nil
-			return io.EOF
+			return err
 		}
 		rows.mc = nil