Ver código fonte

Move escape funcs to utils.go, export them, add references to mysql surce code

arvenil 11 anos atrás
pai
commit
058ce87948
2 arquivos alterados com 75 adições e 51 exclusões
  1. 5 50
      connection.go
  2. 70 1
      utils.go

+ 5 - 50
connection.go

@@ -165,60 +165,15 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	return stmt, err
 }
 
+// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
 func (mc *mysqlConn) escapeBytes(v []byte) string {
-	buf := make([]byte, len(v)*2+2)
-	buf[0] = '\''
-	pos := 1
+	var escape func([]byte) []byte
 	if mc.status&statusNoBackslashEscapes == 0 {
-		for _, c := range v {
-			switch c {
-			case '\x00':
-				buf[pos] = '\\'
-				buf[pos+1] = '0'
-				pos += 2
-			case '\n':
-				buf[pos] = '\\'
-				buf[pos+1] = 'n'
-				pos += 2
-			case '\r':
-				buf[pos] = '\\'
-				buf[pos+1] = 'r'
-				pos += 2
-			case '\x1a':
-				buf[pos] = '\\'
-				buf[pos+1] = 'Z'
-				pos += 2
-			case '\'':
-				buf[pos] = '\\'
-				buf[pos+1] = '\''
-				pos += 2
-			case '"':
-				buf[pos] = '\\'
-				buf[pos+1] = '"'
-				pos += 2
-			case '\\':
-				buf[pos] = '\\'
-				buf[pos+1] = '\\'
-				pos += 2
-			default:
-				buf[pos] = c
-				pos += 1
-			}
-		}
+		escape = EscapeString
 	} else {
-		for _, c := range v {
-			if c == '\'' {
-				buf[pos] = '\''
-				buf[pos+1] = '\''
-				pos += 2
-			} else {
-				buf[pos] = c
-				pos++
-			}
-		}
+		escape = EscapeQuotes
 	}
-	buf[pos] = '\''
-	return string(buf[:pos+1])
+	return "'" + string(escape(v)) + "'"
 }
 
 func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) {

+ 70 - 1
utils.go

@@ -224,7 +224,7 @@ func parseDSNParams(cfg *config, params string) (err error) {
 			}
 			cfg.collation = collation
 			break
-		
+
 		case "columnsWithAlias":
 			var isBool bool
 			cfg.columnsWithAlias, isBool = readBool(value)
@@ -806,3 +806,72 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
 	return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
 		byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
 }
+
+// Escape string with backslashes (\)
+// This escapes the contents of a string (provided as []byte) by adding backslashes before special
+// characters, and turning others into specific escape sequences, such as
+// turning newlines into \n and null bytes into \0.
+// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
+func EscapeString(v []byte) []byte {
+	buf := make([]byte, len(v)*2)
+	pos := 0
+	for _, c := range v {
+		switch c {
+		case '\x00':
+			buf[pos] = '\\'
+			buf[pos+1] = '0'
+			pos += 2
+		case '\n':
+			buf[pos] = '\\'
+			buf[pos+1] = 'n'
+			pos += 2
+		case '\r':
+			buf[pos] = '\\'
+			buf[pos+1] = 'r'
+			pos += 2
+		case '\x1a':
+			buf[pos] = '\\'
+			buf[pos+1] = 'Z'
+			pos += 2
+		case '\'':
+			buf[pos] = '\\'
+			buf[pos+1] = '\''
+			pos += 2
+		case '"':
+			buf[pos] = '\\'
+			buf[pos+1] = '"'
+			pos += 2
+		case '\\':
+			buf[pos] = '\\'
+			buf[pos+1] = '\\'
+			pos += 2
+		default:
+			buf[pos] = c
+			pos += 1
+		}
+	}
+
+	return buf[:pos]
+}
+
+// Escape apostrophes by doubling them up
+// This escapes the contents of a string by doubling up any apostrophes that
+// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
+// effect on the server.
+// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
+func EscapeQuotes(v []byte) []byte {
+	buf := make([]byte, len(v)*2)
+	pos := 0
+	for _, c := range v {
+		if c == '\'' {
+			buf[pos] = '\''
+			buf[pos+1] = '\''
+			pos += 2
+		} else {
+			buf[pos] = c
+			pos++
+		}
+	}
+
+	return buf[:pos]
+}