Przeglądaj źródła

Merge pull request #58 from go-sql-driver/strict

Add strict mode
Julien Schmidt 12 lat temu
rodzic
commit
26af0ebaba
6 zmienionych plików z 203 dodań i 20 usunięć
  1. 3 1
      README.md
  2. 6 3
      connection.go
  3. 72 6
      driver_test.go
  4. 85 1
      errors.go
  5. 27 9
      packets.go
  6. 10 0
      utils.go

+ 3 - 1
README.md

@@ -110,6 +110,7 @@ Possible Parameters are:
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
   * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
+  * `strict`: Enable strict mode. MySQL warnings are treated as errors.
 
 All other parameters are interpreted as system variables:
   * `autocommit`: *"SET autocommit=`value`"*
@@ -154,7 +155,8 @@ See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-drive
 ### `time.Time` support
 The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
 
-However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.  
+However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
+
 **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
 
 

+ 6 - 3
connection.go

@@ -31,6 +31,7 @@ type mysqlConn struct {
 	maxPacketAllowed int
 	maxWriteSize     int
 	parseTime        bool
+	strict           bool
 }
 
 type config struct {
@@ -67,9 +68,11 @@ func (mc *mysqlConn) handleParams() (err error) {
 
 		// time.Time parsing
 		case "parseTime":
-			if val == "true" {
-				mc.parseTime = true
-			}
+			mc.parseTime = readBool(val)
+
+		// Strict mode
+		case "strict":
+			mc.strict = readBool(val)
 
 		// TLS-Encryption
 		case "tls":

+ 72 - 6
driver_test.go

@@ -34,7 +34,7 @@ func init() {
 	dbname := env("MYSQL_TEST_DBNAME", "gotest")
 	charset = "charset=utf8"
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
-	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname)
+	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true&"+charset, user, pass, netAddr, dbname)
 	c, err := net.Dial(prot, addr)
 	if err == nil {
 		available = true
@@ -118,12 +118,13 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
 	}
 	defer db.Close()
 
+	db.Exec("DROP TABLE IF EXISTS test")
+
 	dbt := &DBTest{t, db}
-	dbt.mustExec("DROP TABLE IF EXISTS test")
 	for _, test := range tests {
 		test(dbt)
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
 	}
-	dbt.mustExec("DROP TABLE IF EXISTS test")
 }
 
 func (dbt *DBTest) fail(method, query string, err error) {
@@ -446,7 +447,6 @@ func TestDateTime(t *testing.T) {
 	testTime := func(dbt *DBTest) {
 		var rows *sql.Rows
 		for sqltype, tests := range timetests {
-			dbt.mustExec("DROP TABLE IF EXISTS test")
 			dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
 			for _, test := range tests {
 				for mode, q := range modes {
@@ -466,6 +466,7 @@ func TestDateTime(t *testing.T) {
 					}
 				}
 			}
+			dbt.mustExec("DROP TABLE IF EXISTS test")
 		}
 	}
 
@@ -701,7 +702,7 @@ func TestLoadData(t *testing.T) {
 		file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
 		file.Close()
 
-		dbt.mustExec("DROP TABLE IF EXISTS test")
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
 
 		// Local File
@@ -739,6 +740,71 @@ func TestLoadData(t *testing.T) {
 	})
 }
 
+func TestStrict(t *testing.T) {
+	runTests(t, "TestStrict", func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
+
+		var queries = [...]struct {
+			in    string
+			codes []string
+		}{
+			{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
+			{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
+		}
+		var err error
+
+		var checkWarnings = func(err error, mode string, idx int) {
+			if err == nil {
+				dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in)
+			}
+
+			if warnings, ok := err.(MySQLWarnings); ok {
+				var codes = make([]string, len(warnings))
+				for i := range warnings {
+					codes[i] = warnings[i].Code
+				}
+				if len(codes) != len(queries[idx].codes) {
+					dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
+				}
+
+				for i := range warnings {
+					if codes[i] != queries[idx].codes[i] {
+						dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
+						return
+					}
+				}
+
+			} else {
+				dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
+			}
+		}
+
+		// text protocol
+		for i := range queries {
+			_, err = dbt.db.Exec(queries[i].in)
+			checkWarnings(err, "text", i)
+		}
+
+		var stmt *sql.Stmt
+
+		// binary protocol
+		for i := range queries {
+			stmt, err = dbt.db.Prepare(queries[i].in)
+			if err != nil {
+				dbt.Error("Error on preparing query %: ", queries[i].in, err.Error())
+			}
+
+			_, err = stmt.Exec()
+			checkWarnings(err, "binary", i)
+
+			err = stmt.Close()
+			if err != nil {
+				dbt.Error("Error on closing stmt for query %: ", queries[i].in, err.Error())
+			}
+		}
+	})
+}
+
 // Special cases
 
 func TestRowsClose(t *testing.T) {
@@ -919,7 +985,7 @@ func TestStmtMultiRows(t *testing.T) {
 }
 
 func TestConcurrent(t *testing.T) {
-	if os.Getenv("MYSQL_TEST_CONCURRENT") != "1" {
+	if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
 		t.Log("CONCURRENT env var not set. Skipping TestConcurrent")
 		return
 	}

+ 85 - 1
errors.go

@@ -9,7 +9,12 @@
 
 package mysql
 
-import "errors"
+import (
+	"database/sql/driver"
+	"errors"
+	"fmt"
+	"io"
+)
 
 var (
 	errMalformPkt  = errors.New("Malformed Packet")
@@ -18,3 +23,82 @@ var (
 	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 )
+
+// error type which represents a single MySQL error
+type MySQLError struct {
+	Number  uint16
+	Message string
+}
+
+func (me *MySQLError) Error() string {
+	return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
+}
+
+// error type which represents a group (one ore more) MySQL warnings
+type MySQLWarnings []mysqlWarning
+
+func (mws MySQLWarnings) Error() string {
+	var msg string
+	for i, warning := range mws {
+		if i > 0 {
+			msg += "\r\n"
+		}
+		msg += fmt.Sprintf("%s %s: %s", warning.Level, warning.Code, warning.Message)
+	}
+	return msg
+}
+
+// error type which represents a single MySQL warning
+type mysqlWarning struct {
+	Level   string
+	Code    string
+	Message string
+}
+
+func (mc *mysqlConn) getWarnings() (err error) {
+	rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
+	if err != nil {
+		return
+	}
+
+	var warnings = MySQLWarnings{}
+	var values = make([]driver.Value, 3)
+
+	var warning mysqlWarning
+	var raw []byte
+	var ok bool
+
+	for {
+		err = rows.Next(values)
+		switch err {
+		case nil:
+			warning = mysqlWarning{}
+
+			if raw, ok = values[0].([]byte); ok {
+				warning.Level = string(raw)
+			} else {
+				warning.Level = fmt.Sprintf("%s", values[0])
+			}
+			if raw, ok = values[1].([]byte); ok {
+				warning.Code = string(raw)
+			} else {
+				warning.Code = fmt.Sprintf("%s", values[1])
+			}
+			if raw, ok = values[2].([]byte); ok {
+				warning.Message = string(raw)
+			} else {
+				warning.Message = fmt.Sprintf("%s", values[0])
+			}
+
+			warnings = append(warnings, warning)
+
+		case io.EOF:
+			return warnings
+
+		default:
+			rows.Close()
+			return
+		}
+	}
+	return
+}

+ 27 - 9
packets.go

@@ -352,8 +352,7 @@ func (mc *mysqlConn) readResultOK() error {
 		switch data[0] {
 
 		case iOK:
-			mc.handleOkPacket(data)
-			return nil
+			return mc.handleOkPacket(data)
 
 		case iEOF: // someone is using old_passwords
 			return errOldPassword
@@ -373,8 +372,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 		switch data[0] {
 
 		case iOK:
-			mc.handleOkPacket(data)
-			return 0, nil
+			return 0, mc.handleOkPacket(data)
 
 		case iERR:
 			return 0, mc.handleErrorPacket(data)
@@ -415,13 +413,16 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 	}
 
 	// Error Message [string]
-	return fmt.Errorf("Error %d: %s", errno, string(data[pos:]))
+	return &MySQLError{
+		Number:  errno,
+		Message: string(data[pos:]),
+	}
 }
 
 // Ok Packet
 // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
-func (mc *mysqlConn) handleOkPacket(data []byte) {
-	var n int
+func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
+	var n, m int
 
 	// 0x00 [1 byte]
 
@@ -429,11 +430,22 @@ func (mc *mysqlConn) handleOkPacket(data []byte) {
 	mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
 
 	// Insert id [Length Coded Binary]
-	mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
+	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 	// server_status [2 bytes]
+
 	// warning count [2 bytes]
+	if !mc.strict {
+		return
+	} else {
+		pos := 1 + n + m + 2
+		if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+			err = mc.getWarnings()
+		}
+	}
+
 	// message [until end of packet]
+	return
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
@@ -625,7 +637,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
 		pos += 2
 
 		// Warning count [16 bit uint]
-		// bytesToUint16(data[pos : pos+2])
+		if !stmt.mc.strict {
+			return
+		} else {
+			if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
+				err = stmt.mc.getWarnings()
+			}
+		}
 	}
 	return
 }

+ 10 - 0
utils.go

@@ -234,6 +234,16 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
 	return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
 }
 
+func readBool(value string) bool {
+	switch strings.ToLower(value) {
+	case "true":
+		return true
+	case "1":
+		return true
+	}
+	return false
+}
+
 /******************************************************************************
 *                       Convert from and to bytes                             *
 ******************************************************************************/