瀏覽代碼

TestMultiQuery
discard additional OK response after Multi Statement Exec Calls

Badoet Endoet 10 年之前
父節點
當前提交
4aa920ddb8
共有 3 個文件被更改,包括 83 次插入0 次删除
  1. 1 0
      AUTHORS
  2. 81 0
      driver_test.go
  3. 1 0
      packets.go

+ 1 - 0
AUTHORS

@@ -39,6 +39,7 @@ Nicola Peduzzi <thenikso at gmail.com>
 Runrioter Wung <runrioter at gmail.com>
 Soroush Pour <me at soroushjp.com>
 Stan Putrya <root.vagner at gmail.com>
+Stanley Gunawan <gunawan.stanley at gmail.com>
 Xiaobing Jiang <s7v7nislands at gmail.com>
 Xiuming Chen <cc at cxm.cc>
 

+ 81 - 0
driver_test.go

@@ -76,6 +76,28 @@ type DBTest struct {
 	db *sql.DB
 }
 
+func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
+
+	dsn3 := dsn + "&multiStatements=true"
+	var db3 *sql.DB
+	if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
+		db3, err = sql.Open("mysql", dsn3)
+		if err != nil {
+			t.Fatalf("Error connecting: %s", err.Error())
+		}
+		defer db3.Close()
+	}
+
+	dbt3 := &DBTest{t, db3}
+	for _, test := range tests {
+		test(dbt3)
+		dbt3.db.Exec("DROP TABLE IF EXISTS test")
+	}
+}
+
 func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 	if !available {
 		t.Skipf("MySQL server not running on %s", netAddr)
@@ -99,8 +121,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 		defer db2.Close()
 	}
 
+	dsn3 := dsn + "&multiStatements=true"
+	var db3 *sql.DB
+	if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
+		db3, err = sql.Open("mysql", dsn3)
+		if err != nil {
+			t.Fatalf("Error connecting: %s", err.Error())
+		}
+		defer db3.Close()
+	}
+
 	dbt := &DBTest{t, db}
 	dbt2 := &DBTest{t, db2}
+	dbt3 := &DBTest{t, db3}
 	for _, test := range tests {
 		test(dbt)
 		dbt.db.Exec("DROP TABLE IF EXISTS test")
@@ -108,6 +141,10 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 			test(dbt2)
 			dbt2.db.Exec("DROP TABLE IF EXISTS test")
 		}
+		if db3 != nil {
+			test(dbt3)
+			dbt3.db.Exec("DROP TABLE IF EXISTS test")
+		}
 	}
 }
 
@@ -237,6 +274,50 @@ func TestCRUD(t *testing.T) {
 	})
 }
 
+func TestMultiQuery(t *testing.T) {
+	runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+		// Create Table
+		dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
+
+		// Create Data
+		res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
+		count, err := res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 1 {
+			dbt.Fatalf("Expected 1 affected row, got %d", count)
+		}
+
+		// Update
+		res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
+		count, err = res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 1 {
+			dbt.Fatalf("Expected 1 affected row, got %d", count)
+		}
+
+		// Read
+		var out int
+		rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
+		if rows.Next() {
+			rows.Scan(&out)
+			if 5 != out {
+				dbt.Errorf("5 != %t", out)
+			}
+
+			if rows.Next() {
+				dbt.Error("unexpected data")
+			}
+		} else {
+			dbt.Error("no data")
+		}
+
+	})
+}
+
 func TestInt(t *testing.T) {
 	runTests(t, dsn, func(dbt *DBTest) {
 		types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}

+ 1 - 0
packets.go

@@ -543,6 +543,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 
 	// server_status [2 bytes]
 	mc.status = readStatus(data[1+n+m : 1+n+m+2])
+	mc.discardMoreResultsIfExists()
 
 	// warning count [2 bytes]
 	if !mc.strict {