Browse Source

Add strict mode

Closes #40
Julien Schmidt 12 years ago
parent
commit
3f3502934a
5 changed files with 138 additions and 15 deletions
  1. 2 1
      README.md
  2. 7 0
      connection.go
  3. 63 5
      driver_test.go
  4. 43 1
      errors.go
  5. 23 8
      packets.go

+ 2 - 1
README.md

@@ -110,6 +110,7 @@ Possible Parameters are:
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
   * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
+  * `strict`: Enable strict mode. MySQL warnings are treated as errors.
 
 
 All other parameters are interpreted as system variables:
 All other parameters are interpreted as system variables:
   * `autocommit`: *"SET autocommit=`value`"*
   * `autocommit`: *"SET autocommit=`value`"*
@@ -154,7 +155,7 @@ See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-drive
 ### `time.Time` support
 ### `time.Time` support
 The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
 The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
 
 
-However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.  
+However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
 **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
 **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
 
 
 
 

+ 7 - 0
connection.go

@@ -31,6 +31,7 @@ type mysqlConn struct {
 	maxPacketAllowed int
 	maxPacketAllowed int
 	maxWriteSize     int
 	maxWriteSize     int
 	parseTime        bool
 	parseTime        bool
+	strict           bool
 }
 }
 
 
 type config struct {
 type config struct {
@@ -71,6 +72,12 @@ func (mc *mysqlConn) handleParams() (err error) {
 				mc.parseTime = true
 				mc.parseTime = true
 			}
 			}
 
 
+		// Strict mode
+		case "strict":
+			if val == "true" {
+				mc.strict = true
+			}
+
 		// TLS-Encryption
 		// TLS-Encryption
 		case "tls":
 		case "tls":
 			err = errors.New("TLS-Encryption not implemented yet")
 			err = errors.New("TLS-Encryption not implemented yet")

+ 63 - 5
driver_test.go

@@ -34,7 +34,7 @@ func init() {
 	dbname := env("MYSQL_TEST_DBNAME", "gotest")
 	dbname := env("MYSQL_TEST_DBNAME", "gotest")
 	charset = "charset=utf8"
 	charset = "charset=utf8"
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
-	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname)
+	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true&"+charset, user, pass, netAddr, dbname)
 	c, err := net.Dial(prot, addr)
 	c, err := net.Dial(prot, addr)
 	if err == nil {
 	if err == nil {
 		available = true
 		available = true
@@ -119,11 +119,11 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
 	defer db.Close()
 	defer db.Close()
 
 
 	dbt := &DBTest{t, db}
 	dbt := &DBTest{t, db}
-	dbt.mustExec("DROP TABLE IF EXISTS test")
+	dbt.db.Exec("DROP TABLE IF EXISTS test")
 	for _, test := range tests {
 	for _, test := range tests {
 		test(dbt)
 		test(dbt)
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
 	}
 	}
-	dbt.mustExec("DROP TABLE IF EXISTS test")
 }
 }
 
 
 func (dbt *DBTest) fail(method, query string, err error) {
 func (dbt *DBTest) fail(method, query string, err error) {
@@ -446,7 +446,6 @@ func TestDateTime(t *testing.T) {
 	testTime := func(dbt *DBTest) {
 	testTime := func(dbt *DBTest) {
 		var rows *sql.Rows
 		var rows *sql.Rows
 		for sqltype, tests := range timetests {
 		for sqltype, tests := range timetests {
-			dbt.mustExec("DROP TABLE IF EXISTS test")
 			dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
 			dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
 			for _, test := range tests {
 			for _, test := range tests {
 				for mode, q := range modes {
 				for mode, q := range modes {
@@ -466,6 +465,7 @@ func TestDateTime(t *testing.T) {
 					}
 					}
 				}
 				}
 			}
 			}
+			dbt.mustExec("DROP TABLE IF EXISTS test")
 		}
 		}
 	}
 	}
 
 
@@ -701,7 +701,7 @@ func TestLoadData(t *testing.T) {
 		file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
 		file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
 		file.Close()
 		file.Close()
 
 
-		dbt.mustExec("DROP TABLE IF EXISTS test")
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
 
 
 		// Local File
 		// Local File
@@ -739,6 +739,64 @@ func TestLoadData(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestStrict(t *testing.T) {
+	runTests(t, "TestStrict", func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
+
+		queries := [...][2]string{
+			{"DROP TABLE IF EXISTS no_such_table", "Note 1051: Unknown table 'no_such_table'"},
+			{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')",
+				"Warning 1265: Data truncated for column 'b' at row 1\r\n" +
+					"Warning 1048: Column 'a' cannot be null\r\n" +
+					"Warning 1264: Out of range value for column 'a' at row 3\r\n" +
+					"Warning 1265: Data truncated for column 'b' at row 3",
+			},
+		}
+		var rows *sql.Rows
+		var err error
+
+		// text protocol
+		for i := range queries {
+			rows, err = dbt.db.Query(queries[i][0])
+			if rows != nil {
+				rows.Close()
+			}
+
+			if err == nil {
+				dbt.Errorf("Expecteded strict error on query [text] %s", queries[i][0])
+			} else if err.Error() != queries[i][1] {
+				dbt.Errorf("Unexpected error on query [text] %s: %s != %s", queries[i][0], err.Error(), queries[i][1])
+			}
+		}
+
+		var stmt *sql.Stmt
+
+		// binary protocol
+		for i := range queries {
+			stmt, err = dbt.db.Prepare(queries[i][0])
+			if err != nil {
+				dbt.Error("Error on preparing query %: ", queries[i][0], err.Error())
+			}
+
+			rows, err = stmt.Query()
+			if rows != nil {
+				rows.Close()
+			}
+
+			if err == nil {
+				dbt.Errorf("Expecteded strict error on query [binary] %s", queries[i][0])
+			} else if err.Error() != queries[i][1] {
+				dbt.Errorf("Unexpected error on query [binary] %s: %s != %s", queries[i][0], err.Error(), queries[i][1])
+			}
+
+			err = stmt.Close()
+			if err != nil {
+				dbt.Error("Error on closing stmt for query %: ", queries[i][0], err.Error())
+			}
+		}
+	})
+}
+
 // Special cases
 // Special cases
 
 
 func TestRowsClose(t *testing.T) {
 func TestRowsClose(t *testing.T) {

+ 43 - 1
errors.go

@@ -9,7 +9,12 @@
 
 
 package mysql
 package mysql
 
 
-import "errors"
+import (
+	"database/sql/driver"
+	"errors"
+	"fmt"
+	"io"
+)
 
 
 var (
 var (
 	errMalformPkt  = errors.New("Malformed Packet")
 	errMalformPkt  = errors.New("Malformed Packet")
@@ -18,3 +23,40 @@ var (
 	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 )
 )
+
+// error type which represents one or more MySQL warnings
+type MySQLWarnings []string
+
+func (mw MySQLWarnings) Error() string {
+	var msg string
+	for i := range mw {
+		if i > 0 {
+			msg += "\r\n"
+		}
+		msg += mw[i]
+	}
+	return msg
+}
+
+func (mc *mysqlConn) getWarnings() (err error) {
+	rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
+	if err != nil {
+		return
+	}
+
+	var warnings = MySQLWarnings{}
+	var values = make([]driver.Value, 3)
+
+	for {
+		if err = rows.Next(values); err == nil {
+			warnings = append(warnings,
+				fmt.Sprintf("%s %s: %s", values[0], values[1], values[2]),
+			)
+		} else if err == io.EOF {
+			return warnings
+		} else {
+			return
+		}
+	}
+	return
+}

+ 23 - 8
packets.go

@@ -352,8 +352,7 @@ func (mc *mysqlConn) readResultOK() error {
 		switch data[0] {
 		switch data[0] {
 
 
 		case iOK:
 		case iOK:
-			mc.handleOkPacket(data)
-			return nil
+			return mc.handleOkPacket(data)
 
 
 		case iEOF: // someone is using old_passwords
 		case iEOF: // someone is using old_passwords
 			return errOldPassword
 			return errOldPassword
@@ -373,8 +372,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 		switch data[0] {
 		switch data[0] {
 
 
 		case iOK:
 		case iOK:
-			mc.handleOkPacket(data)
-			return 0, nil
+			return 0, mc.handleOkPacket(data)
 
 
 		case iERR:
 		case iERR:
 			return 0, mc.handleErrorPacket(data)
 			return 0, mc.handleErrorPacket(data)
@@ -420,8 +418,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 
 
 // Ok Packet
 // Ok Packet
 // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
 // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
-func (mc *mysqlConn) handleOkPacket(data []byte) {
-	var n int
+func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
+	var n, m int
 
 
 	// 0x00 [1 byte]
 	// 0x00 [1 byte]
 
 
@@ -429,11 +427,22 @@ func (mc *mysqlConn) handleOkPacket(data []byte) {
 	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
 	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
 
 
 	// Insert id [Length Coded Binary]
 	// Insert id [Length Coded Binary]
-	mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
+	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 
 	// server_status [2 bytes]
 	// server_status [2 bytes]
+
 	// warning count [2 bytes]
 	// warning count [2 bytes]
+	if !mc.strict {
+		return
+	} else {
+		pos := 1 + n + m + 2
+		if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+			err = mc.getWarnings()
+		}
+	}
+
 	// message [until end of packet]
 	// message [until end of packet]
+	return
 }
 }
 
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
 // Read Packets as Field Packets until EOF-Packet or an Error appears
@@ -625,7 +634,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
 		pos += 2
 		pos += 2
 
 
 		// Warning count [16 bit uint]
 		// Warning count [16 bit uint]
-		// bytesToUint16(data[pos : pos+2])
+		if !stmt.mc.strict {
+			return
+		} else {
+			if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+				err = stmt.mc.getWarnings()
+			}
+		}
 	}
 	}
 	return
 	return
 }
 }