Browse Source

Merge pull request #58 from go-sql-driver/strict

Add strict mode
Julien Schmidt 12 years ago
parent
commit
26af0ebaba
6 changed files with 203 additions and 20 deletions
  1. 3 1
      README.md
  2. 6 3
      connection.go
  3. 72 6
      driver_test.go
  4. 85 1
      errors.go
  5. 27 9
      packets.go
  6. 10 0
      utils.go

+ 3 - 1
README.md

@@ -110,6 +110,7 @@ Possible Parameters are:
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
   * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
+  * `strict`: Enable strict mode. MySQL warnings are treated as errors.
 
 
 All other parameters are interpreted as system variables:
 All other parameters are interpreted as system variables:
   * `autocommit`: *"SET autocommit=`value`"*
   * `autocommit`: *"SET autocommit=`value`"*
@@ -154,7 +155,8 @@ See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-drive
 ### `time.Time` support
 ### `time.Time` support
 The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
 The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
 
 
-However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.  
+However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
+
 **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
 **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
 
 
 
 

+ 6 - 3
connection.go

@@ -31,6 +31,7 @@ type mysqlConn struct {
 	maxPacketAllowed int
 	maxPacketAllowed int
 	maxWriteSize     int
 	maxWriteSize     int
 	parseTime        bool
 	parseTime        bool
+	strict           bool
 }
 }
 
 
 type config struct {
 type config struct {
@@ -67,9 +68,11 @@ func (mc *mysqlConn) handleParams() (err error) {
 
 
 		// time.Time parsing
 		// time.Time parsing
 		case "parseTime":
 		case "parseTime":
-			if val == "true" {
-				mc.parseTime = true
-			}
+			mc.parseTime = readBool(val)
+
+		// Strict mode
+		case "strict":
+			mc.strict = readBool(val)
 
 
 		// TLS-Encryption
 		// TLS-Encryption
 		case "tls":
 		case "tls":

+ 72 - 6
driver_test.go

@@ -34,7 +34,7 @@ func init() {
 	dbname := env("MYSQL_TEST_DBNAME", "gotest")
 	dbname := env("MYSQL_TEST_DBNAME", "gotest")
 	charset = "charset=utf8"
 	charset = "charset=utf8"
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
-	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname)
+	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true&"+charset, user, pass, netAddr, dbname)
 	c, err := net.Dial(prot, addr)
 	c, err := net.Dial(prot, addr)
 	if err == nil {
 	if err == nil {
 		available = true
 		available = true
@@ -118,12 +118,13 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
 	}
 	}
 	defer db.Close()
 	defer db.Close()
 
 
+	db.Exec("DROP TABLE IF EXISTS test")
+
 	dbt := &DBTest{t, db}
 	dbt := &DBTest{t, db}
-	dbt.mustExec("DROP TABLE IF EXISTS test")
 	for _, test := range tests {
 	for _, test := range tests {
 		test(dbt)
 		test(dbt)
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
 	}
 	}
-	dbt.mustExec("DROP TABLE IF EXISTS test")
 }
 }
 
 
 func (dbt *DBTest) fail(method, query string, err error) {
 func (dbt *DBTest) fail(method, query string, err error) {
@@ -446,7 +447,6 @@ func TestDateTime(t *testing.T) {
 	testTime := func(dbt *DBTest) {
 	testTime := func(dbt *DBTest) {
 		var rows *sql.Rows
 		var rows *sql.Rows
 		for sqltype, tests := range timetests {
 		for sqltype, tests := range timetests {
-			dbt.mustExec("DROP TABLE IF EXISTS test")
 			dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
 			dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
 			for _, test := range tests {
 			for _, test := range tests {
 				for mode, q := range modes {
 				for mode, q := range modes {
@@ -466,6 +466,7 @@ func TestDateTime(t *testing.T) {
 					}
 					}
 				}
 				}
 			}
 			}
+			dbt.mustExec("DROP TABLE IF EXISTS test")
 		}
 		}
 	}
 	}
 
 
@@ -701,7 +702,7 @@ func TestLoadData(t *testing.T) {
 		file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
 		file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
 		file.Close()
 		file.Close()
 
 
-		dbt.mustExec("DROP TABLE IF EXISTS test")
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
 
 
 		// Local File
 		// Local File
@@ -739,6 +740,71 @@ func TestLoadData(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestStrict(t *testing.T) {
+	runTests(t, "TestStrict", func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
+
+		var queries = [...]struct {
+			in    string
+			codes []string
+		}{
+			{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
+			{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
+		}
+		var err error
+
+		var checkWarnings = func(err error, mode string, idx int) {
+			if err == nil {
+				dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in)
+			}
+
+			if warnings, ok := err.(MySQLWarnings); ok {
+				var codes = make([]string, len(warnings))
+				for i := range warnings {
+					codes[i] = warnings[i].Code
+				}
+				if len(codes) != len(queries[idx].codes) {
+					dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
+				}
+
+				for i := range warnings {
+					if codes[i] != queries[idx].codes[i] {
+						dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
+						return
+					}
+				}
+
+			} else {
+				dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
+			}
+		}
+
+		// text protocol
+		for i := range queries {
+			_, err = dbt.db.Exec(queries[i].in)
+			checkWarnings(err, "text", i)
+		}
+
+		var stmt *sql.Stmt
+
+		// binary protocol
+		for i := range queries {
+			stmt, err = dbt.db.Prepare(queries[i].in)
+			if err != nil {
+				dbt.Error("Error on preparing query %: ", queries[i].in, err.Error())
+			}
+
+			_, err = stmt.Exec()
+			checkWarnings(err, "binary", i)
+
+			err = stmt.Close()
+			if err != nil {
+				dbt.Error("Error on closing stmt for query %: ", queries[i].in, err.Error())
+			}
+		}
+	})
+}
+
 // Special cases
 // Special cases
 
 
 func TestRowsClose(t *testing.T) {
 func TestRowsClose(t *testing.T) {
@@ -919,7 +985,7 @@ func TestStmtMultiRows(t *testing.T) {
 }
 }
 
 
 func TestConcurrent(t *testing.T) {
 func TestConcurrent(t *testing.T) {
-	if os.Getenv("MYSQL_TEST_CONCURRENT") != "1" {
+	if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
 		t.Log("CONCURRENT env var not set. Skipping TestConcurrent")
 		t.Log("CONCURRENT env var not set. Skipping TestConcurrent")
 		return
 		return
 	}
 	}

+ 85 - 1
errors.go

@@ -9,7 +9,12 @@
 
 
 package mysql
 package mysql
 
 
-import "errors"
+import (
+	"database/sql/driver"
+	"errors"
+	"fmt"
+	"io"
+)
 
 
 var (
 var (
 	errMalformPkt  = errors.New("Malformed Packet")
 	errMalformPkt  = errors.New("Malformed Packet")
@@ -18,3 +23,82 @@ var (
 	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 )
 )
+
+// error type which represents a single MySQL error
+type MySQLError struct {
+	Number  uint16
+	Message string
+}
+
+func (me *MySQLError) Error() string {
+	return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
+}
+
+// error type which represents a group (one ore more) MySQL warnings
+type MySQLWarnings []mysqlWarning
+
+func (mws MySQLWarnings) Error() string {
+	var msg string
+	for i, warning := range mws {
+		if i > 0 {
+			msg += "\r\n"
+		}
+		msg += fmt.Sprintf("%s %s: %s", warning.Level, warning.Code, warning.Message)
+	}
+	return msg
+}
+
+// error type which represents a single MySQL warning
+type mysqlWarning struct {
+	Level   string
+	Code    string
+	Message string
+}
+
+func (mc *mysqlConn) getWarnings() (err error) {
+	rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
+	if err != nil {
+		return
+	}
+
+	var warnings = MySQLWarnings{}
+	var values = make([]driver.Value, 3)
+
+	var warning mysqlWarning
+	var raw []byte
+	var ok bool
+
+	for {
+		err = rows.Next(values)
+		switch err {
+		case nil:
+			warning = mysqlWarning{}
+
+			if raw, ok = values[0].([]byte); ok {
+				warning.Level = string(raw)
+			} else {
+				warning.Level = fmt.Sprintf("%s", values[0])
+			}
+			if raw, ok = values[1].([]byte); ok {
+				warning.Code = string(raw)
+			} else {
+				warning.Code = fmt.Sprintf("%s", values[1])
+			}
+			if raw, ok = values[2].([]byte); ok {
+				warning.Message = string(raw)
+			} else {
+				warning.Message = fmt.Sprintf("%s", values[0])
+			}
+
+			warnings = append(warnings, warning)
+
+		case io.EOF:
+			return warnings
+
+		default:
+			rows.Close()
+			return
+		}
+	}
+	return
+}

+ 27 - 9
packets.go

@@ -352,8 +352,7 @@ func (mc *mysqlConn) readResultOK() error {
 		switch data[0] {
 		switch data[0] {
 
 
 		case iOK:
 		case iOK:
-			mc.handleOkPacket(data)
-			return nil
+			return mc.handleOkPacket(data)
 
 
 		case iEOF: // someone is using old_passwords
 		case iEOF: // someone is using old_passwords
 			return errOldPassword
 			return errOldPassword
@@ -373,8 +372,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 		switch data[0] {
 		switch data[0] {
 
 
 		case iOK:
 		case iOK:
-			mc.handleOkPacket(data)
-			return 0, nil
+			return 0, mc.handleOkPacket(data)
 
 
 		case iERR:
 		case iERR:
 			return 0, mc.handleErrorPacket(data)
 			return 0, mc.handleErrorPacket(data)
@@ -415,13 +413,16 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 	}
 	}
 
 
 	// Error Message [string]
 	// Error Message [string]
-	return fmt.Errorf("Error %d: %s", errno, string(data[pos:]))
+	return &MySQLError{
+		Number:  errno,
+		Message: string(data[pos:]),
+	}
 }
 }
 
 
 // Ok Packet
 // Ok Packet
 // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
 // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
-func (mc *mysqlConn) handleOkPacket(data []byte) {
-	var n int
+func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
+	var n, m int
 
 
 	// 0x00 [1 byte]
 	// 0x00 [1 byte]
 
 
@@ -429,11 +430,22 @@ func (mc *mysqlConn) handleOkPacket(data []byte) {
 	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
 	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
 
 
 	// Insert id [Length Coded Binary]
 	// Insert id [Length Coded Binary]
-	mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
+	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 
 	// server_status [2 bytes]
 	// server_status [2 bytes]
+
 	// warning count [2 bytes]
 	// warning count [2 bytes]
+	if !mc.strict {
+		return
+	} else {
+		pos := 1 + n + m + 2
+		if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+			err = mc.getWarnings()
+		}
+	}
+
 	// message [until end of packet]
 	// message [until end of packet]
+	return
 }
 }
 
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
 // Read Packets as Field Packets until EOF-Packet or an Error appears
@@ -625,7 +637,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
 		pos += 2
 		pos += 2
 
 
 		// Warning count [16 bit uint]
 		// Warning count [16 bit uint]
-		// bytesToUint16(data[pos : pos+2])
+		if !stmt.mc.strict {
+			return
+		} else {
+			if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+				err = stmt.mc.getWarnings()
+			}
+		}
 	}
 	}
 	return
 	return
 }
 }

+ 10 - 0
utils.go

@@ -234,6 +234,16 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
 	return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
 	return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
 }
 }
 
 
+func readBool(value string) bool {
+	switch strings.ToLower(value) {
+	case "true":
+		return true
+	case "1":
+		return true
+	}
+	return false
+}
+
 /******************************************************************************
 /******************************************************************************
 *                       Convert from and to bytes                             *
 *                       Convert from and to bytes                             *
 ******************************************************************************/
 ******************************************************************************/