Browse Source

Merge pull request #80 from MJKWoolnough/master

Added clientFoundRows support.
Julien Schmidt 12 years ago
parent
commit
649219cb87
4 changed files with 50 additions and 2 deletions
  1. 1 0
      README.md
  2. 1 1
      connection.go
  3. 45 0
      driver_test.go
  4. 3 1
      packets.go

+ 1 - 0
README.md

@@ -111,6 +111,7 @@ Possible Parameters are:
   * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
   * `strict`: Enable strict mode. MySQL warnings are treated as errors.
+  * `clientFoundRows`: `clientFoundRows=true` causes causes an UPDATE to return the number of matching rows instead of the number of rows changed.
 
 All other parameters are interpreted as system variables:
   * `autocommit`: *"SET autocommit=`value`"*

+ 1 - 1
connection.go

@@ -63,7 +63,7 @@ func (mc *mysqlConn) handleParams() (err error) {
 			}
 
 		// handled elsewhere
-		case "timeout", "allowAllFiles", "loc":
+		case "timeout", "allowAllFiles", "loc", "clientFoundRows":
 			continue
 
 		// time.Time parsing

+ 45 - 0
driver_test.go

@@ -1036,6 +1036,51 @@ func TestConcurrent(t *testing.T) {
 	})
 }
 
+func TestFoundRows(t *testing.T) {
+	runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
+		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
+		
+		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
+		count, err := res.RowsAffected()
+		if err != nil {
+				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 affected rows, got %d", count)
+		}
+		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
+		count, err = res.RowsAffected()
+		if err != nil {
+				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 affected rows, got %d", count)
+		}
+	})
+	runTests(t, "TestFoundRows2", dsn + "&clientFoundRows=true", func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
+		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
+		
+		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
+		count, err := res.RowsAffected()
+		if err != nil {
+				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 matched rows, got %d", count)
+		}
+		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
+		count, err = res.RowsAffected()
+		if err != nil {
+				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			}
+		if count != 3 {
+			dbt.Fatalf("Expected 3 matched rows, got %d", count)
+		}
+	})
+}
+
 // BENCHMARKS
 var sample []byte
 

+ 3 - 1
packets.go

@@ -215,7 +215,9 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	if mc.flags&clientLongFlag > 0 {
 		clientFlags |= uint32(clientLongFlag)
 	}
-
+	if _, ok := mc.cfg.params["clientFoundRows"]; ok {
+		clientFlags |= uint32(clientFoundRows)
+	}
 	// User Password
 	scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
 	mc.cipher = nil