Просмотр исходного кода

Test warning codes + fmt

Test wwarning codes instead of the full message.
Julien Schmidt 12 лет назад
Родитель
Сommit
26d3f1848a
2 измененных файлов с 43 добавлено и 26 удалено
  1. 40 23
      driver_test.go
  2. 3 3
      errors.go

+ 40 - 23
driver_test.go

@@ -118,8 +118,9 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
 	}
 	defer db.Close()
 
+	db.Exec("DROP TABLE IF EXISTS test")
+
 	dbt := &DBTest{t, db}
-	dbt.db.Exec("DROP TABLE IF EXISTS test")
 	for _, test := range tests {
 		test(dbt)
 		dbt.db.Exec("DROP TABLE IF EXISTS test")
@@ -743,46 +744,62 @@ func TestStrict(t *testing.T) {
 	runTests(t, "TestStrict", func(dbt *DBTest) {
 		dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
 
-		queries := [...][2]string{
-			{"DROP TABLE IF EXISTS no_such_table", "Note 1051: Unknown table 'no_such_table'"},
-			{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')",
-				"Warning 1265: Data truncated for column 'b' at row 1\r\n" +
-					"Warning 1048: Column 'a' cannot be null\r\n" +
-					"Warning 1264: Out of range value for column 'a' at row 3\r\n" +
-					"Warning 1265: Data truncated for column 'b' at row 3",
-			},
+		var queries = [...]struct {
+			in    string
+			codes []string
+		}{
+			{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
+			{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
 		}
 		var err error
 
-		// text protocol
-		for i := range queries {
-			_, err = dbt.db.Exec(queries[i][0])
+		var checkWarnings = func(err error, mode string, idx int) {
 			if err == nil {
-				dbt.Errorf("Expecteded strict error on query [text] %s", queries[i][0])
-			} else if err.Error() != queries[i][1] {
-				dbt.Errorf("Unexpected error on query [text] %s: %s != %s", queries[i][0], err.Error(), queries[i][1])
+				dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in)
+			}
+
+			if warnings, ok := err.(MySQLWarnings); ok {
+				var codes = make([]string, len(warnings))
+				for i := range warnings {
+					codes[i] = warnings[i].Code
+				}
+				if len(codes) != len(queries[idx].codes) {
+					dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
+				}
+
+				for i := range warnings {
+					if codes[i] != queries[idx].codes[i] {
+						dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
+						return
+					}
+				}
+
+			} else {
+				dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
 			}
 		}
 
+		// text protocol
+		for i := range queries {
+			_, err = dbt.db.Exec(queries[i].in)
+			checkWarnings(err, "text", i)
+		}
+
 		var stmt *sql.Stmt
 
 		// binary protocol
 		for i := range queries {
-			stmt, err = dbt.db.Prepare(queries[i][0])
+			stmt, err = dbt.db.Prepare(queries[i].in)
 			if err != nil {
-				dbt.Error("Error on preparing query %: ", queries[i][0], err.Error())
+				dbt.Error("Error on preparing query %: ", queries[i].in, err.Error())
 			}
 
 			_, err = stmt.Exec()
-			if err == nil {
-				dbt.Errorf("Expecteded strict error on query [binary] %s", queries[i][0])
-			} else if err.Error() != queries[i][1] {
-				dbt.Errorf("Unexpected error on query [binary] %s: %s != %s", queries[i][0], err.Error(), queries[i][1])
-			}
+			checkWarnings(err, "binary", i)
 
 			err = stmt.Close()
 			if err != nil {
-				dbt.Error("Error on closing stmt for query %: ", queries[i][0], err.Error())
+				dbt.Error("Error on closing stmt for query %: ", queries[i].in, err.Error())
 			}
 		}
 	})

+ 3 - 3
errors.go

@@ -37,9 +37,9 @@ func (me *MySQLError) Error() string {
 // error type which represents a group (one ore more) MySQL warnings
 type MySQLWarnings []mysqlWarning
 
-func (mws *MySQLWarnings) Error() string {
+func (mws MySQLWarnings) Error() string {
 	var msg string
-	for i, warning := range *mws {
+	for i, warning := range mws {
 		if i > 0 {
 			msg += "\r\n"
 		}
@@ -93,7 +93,7 @@ func (mc *mysqlConn) getWarnings() (err error) {
 			warnings = append(warnings, warning)
 
 		case io.EOF:
-			return &warnings
+			return warnings
 
 		default:
 			rows.Close()