Browse Source

Add strict mode

Closes #40
Julien Schmidt 12 years ago
parent
commit
3f3502934a
5 changed files with 138 additions and 15 deletions
  1. 2 1
      README.md
  2. 7 0
      connection.go
  3. 63 5
      driver_test.go
  4. 43 1
      errors.go
  5. 23 8
      packets.go

+ 2 - 1
README.md

@@ -110,6 +110,7 @@ Possible Parameters are:
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
   * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
+  * `strict`: Enable strict mode. MySQL warnings are treated as errors.
 
 All other parameters are interpreted as system variables:
   * `autocommit`: *"SET autocommit=`value`"*
@@ -154,7 +155,7 @@ See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-drive
 ### `time.Time` support
 The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
 
-However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.  
+However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
 **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
 
 

+ 7 - 0
connection.go

@@ -31,6 +31,7 @@ type mysqlConn struct {
 	maxPacketAllowed int
 	maxWriteSize     int
 	parseTime        bool
+	strict           bool
 }
 
 type config struct {
@@ -71,6 +72,12 @@ func (mc *mysqlConn) handleParams() (err error) {
 				mc.parseTime = true
 			}
 
+		// Strict mode
+		case "strict":
+			if val == "true" {
+				mc.strict = true
+			}
+
 		// TLS-Encryption
 		case "tls":
 			err = errors.New("TLS-Encryption not implemented yet")

+ 63 - 5
driver_test.go

@@ -34,7 +34,7 @@ func init() {
 	dbname := env("MYSQL_TEST_DBNAME", "gotest")
 	charset = "charset=utf8"
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
-	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname)
+	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true&"+charset, user, pass, netAddr, dbname)
 	c, err := net.Dial(prot, addr)
 	if err == nil {
 		available = true
@@ -119,11 +119,11 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
 	defer db.Close()
 
 	dbt := &DBTest{t, db}
-	dbt.mustExec("DROP TABLE IF EXISTS test")
+	dbt.db.Exec("DROP TABLE IF EXISTS test")
 	for _, test := range tests {
 		test(dbt)
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
 	}
-	dbt.mustExec("DROP TABLE IF EXISTS test")
 }
 
 func (dbt *DBTest) fail(method, query string, err error) {
@@ -446,7 +446,6 @@ func TestDateTime(t *testing.T) {
 	testTime := func(dbt *DBTest) {
 		var rows *sql.Rows
 		for sqltype, tests := range timetests {
-			dbt.mustExec("DROP TABLE IF EXISTS test")
 			dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
 			for _, test := range tests {
 				for mode, q := range modes {
@@ -466,6 +465,7 @@ func TestDateTime(t *testing.T) {
 					}
 				}
 			}
+			dbt.mustExec("DROP TABLE IF EXISTS test")
 		}
 	}
 
@@ -701,7 +701,7 @@ func TestLoadData(t *testing.T) {
 		file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
 		file.Close()
 
-		dbt.mustExec("DROP TABLE IF EXISTS test")
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
 
 		// Local File
@@ -739,6 +739,64 @@ func TestLoadData(t *testing.T) {
 	})
 }
 
+func TestStrict(t *testing.T) {
+	runTests(t, "TestStrict", func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
+
+		queries := [...][2]string{
+			{"DROP TABLE IF EXISTS no_such_table", "Note 1051: Unknown table 'no_such_table'"},
+			{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')",
+				"Warning 1265: Data truncated for column 'b' at row 1\r\n" +
+					"Warning 1048: Column 'a' cannot be null\r\n" +
+					"Warning 1264: Out of range value for column 'a' at row 3\r\n" +
+					"Warning 1265: Data truncated for column 'b' at row 3",
+			},
+		}
+		var rows *sql.Rows
+		var err error
+
+		// text protocol
+		for i := range queries {
+			rows, err = dbt.db.Query(queries[i][0])
+			if rows != nil {
+				rows.Close()
+			}
+
+			if err == nil {
+				dbt.Errorf("Expecteded strict error on query [text] %s", queries[i][0])
+			} else if err.Error() != queries[i][1] {
+				dbt.Errorf("Unexpected error on query [text] %s: %s != %s", queries[i][0], err.Error(), queries[i][1])
+			}
+		}
+
+		var stmt *sql.Stmt
+
+		// binary protocol
+		for i := range queries {
+			stmt, err = dbt.db.Prepare(queries[i][0])
+			if err != nil {
+				dbt.Error("Error on preparing query %: ", queries[i][0], err.Error())
+			}
+
+			rows, err = stmt.Query()
+			if rows != nil {
+				rows.Close()
+			}
+
+			if err == nil {
+				dbt.Errorf("Expecteded strict error on query [binary] %s", queries[i][0])
+			} else if err.Error() != queries[i][1] {
+				dbt.Errorf("Unexpected error on query [binary] %s: %s != %s", queries[i][0], err.Error(), queries[i][1])
+			}
+
+			err = stmt.Close()
+			if err != nil {
+				dbt.Error("Error on closing stmt for query %: ", queries[i][0], err.Error())
+			}
+		}
+	})
+}
+
 // Special cases
 
 func TestRowsClose(t *testing.T) {

+ 43 - 1
errors.go

@@ -9,7 +9,12 @@
 
 package mysql
 
-import "errors"
+import (
+	"database/sql/driver"
+	"errors"
+	"fmt"
+	"io"
+)
 
 var (
 	errMalformPkt  = errors.New("Malformed Packet")
@@ -18,3 +23,40 @@ var (
 	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 )
+
+// error type which represents one or more MySQL warnings
+type MySQLWarnings []string
+
+func (mw MySQLWarnings) Error() string {
+	var msg string
+	for i := range mw {
+		if i > 0 {
+			msg += "\r\n"
+		}
+		msg += mw[i]
+	}
+	return msg
+}
+
+func (mc *mysqlConn) getWarnings() (err error) {
+	rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
+	if err != nil {
+		return
+	}
+
+	var warnings = MySQLWarnings{}
+	var values = make([]driver.Value, 3)
+
+	for {
+		if err = rows.Next(values); err == nil {
+			warnings = append(warnings,
+				fmt.Sprintf("%s %s: %s", values[0], values[1], values[2]),
+			)
+		} else if err == io.EOF {
+			return warnings
+		} else {
+			return
+		}
+	}
+	return
+}

+ 23 - 8
packets.go

@@ -352,8 +352,7 @@ func (mc *mysqlConn) readResultOK() error {
 		switch data[0] {
 
 		case iOK:
-			mc.handleOkPacket(data)
-			return nil
+			return mc.handleOkPacket(data)
 
 		case iEOF: // someone is using old_passwords
 			return errOldPassword
@@ -373,8 +372,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 		switch data[0] {
 
 		case iOK:
-			mc.handleOkPacket(data)
-			return 0, nil
+			return 0, mc.handleOkPacket(data)
 
 		case iERR:
 			return 0, mc.handleErrorPacket(data)
@@ -420,8 +418,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 
 // Ok Packet
 // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
-func (mc *mysqlConn) handleOkPacket(data []byte) {
-	var n int
+func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
+	var n, m int
 
 	// 0x00 [1 byte]
 
@@ -429,11 +427,22 @@ func (mc *mysqlConn) handleOkPacket(data []byte) {
 	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
 
 	// Insert id [Length Coded Binary]
-	mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
+	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 	// server_status [2 bytes]
+
 	// warning count [2 bytes]
+	if !mc.strict {
+		return
+	} else {
+		pos := 1 + n + m + 2
+		if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+			err = mc.getWarnings()
+		}
+	}
+
 	// message [until end of packet]
+	return
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
@@ -625,7 +634,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
 		pos += 2
 
 		// Warning count [16 bit uint]
-		// bytesToUint16(data[pos : pos+2])
+		if !stmt.mc.strict {
+			return
+		} else {
+			if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+				err = stmt.mc.getWarnings()
+			}
+		}
 	}
 	return
 }