Преглед на файлове

remove QuoteStr() usage

Unknown преди 6 години
родител
ревизия
7d5610bf89
променени са 7 файла, в които са добавени 102 реда и са изтрити 60 реда
  1. 24 16
      engine.go
  2. 22 1
      helpers.go
  3. 21 1
      helpers_test.go
  4. 8 16
      session_insert.go
  5. 8 7
      session_update.go
  6. 10 19
      statement.go
  7. 9 0
      statement_test.go

+ 24 - 16
engine.go

@@ -180,6 +180,7 @@ func (engine *Engine) SupportInsertMany() bool {
 
 // QuoteStr Engine's database use which character as quote.
 // mysql, sqlite use ` and postgres use "
+// Deprecated, use Quote() instead
 func (engine *Engine) QuoteStr() string {
 	return engine.dialect.QuoteStr()
 }
@@ -199,13 +200,10 @@ func (engine *Engine) Quote(value string) string {
 		return value
 	}
 
-	if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
-		return value
-	}
-
-	value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
+	buf := builder.StringBuilder{}
+	engine.QuoteTo(&buf, value)
 
-	return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr()
+	return buf.String()
 }
 
 // QuoteTo quotes string and writes into the buffer
@@ -219,20 +217,30 @@ func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
 		return
 	}
 
-	if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
-		buf.WriteString(value)
+	quotePair := engine.dialect.Quote("")
+
+	if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote
+		_, _ = buf.WriteString(value)
 		return
+	} else {
+		prefix, suffix := quotePair[0], quotePair[1]
+
+		_ = buf.WriteByte(prefix)
+		for i := 0; i < len(value); i++ {
+			if value[i] == '.' {
+				_ = buf.WriteByte(suffix)
+				_ = buf.WriteByte('.')
+				_ = buf.WriteByte(prefix)
+			} else {
+				_ = buf.WriteByte(value[i])
+			}
+		}
+		_ = buf.WriteByte(suffix)
 	}
-
-	value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
-
-	buf.WriteString(engine.dialect.QuoteStr())
-	buf.WriteString(value)
-	buf.WriteString(engine.dialect.QuoteStr())
 }
 
 func (engine *Engine) quote(sql string) string {
-	return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
+	return engine.dialect.Quote(sql)
 }
 
 // SqlType will be deprecated, please use SQLType instead
@@ -1605,7 +1613,7 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{
 func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
 	switch sqlTypeName {
 	case core.Time:
-		s := t.Format("2006-01-02 15:04:05") //time.RFC3339
+		s := t.Format("2006-01-02 15:04:05") // time.RFC3339
 		v = s[11:19]
 	case core.Date:
 		v = t.Format("2006-01-02")

+ 22 - 1
helpers.go

@@ -281,7 +281,7 @@ func rValue(bean interface{}) reflect.Value {
 
 func rType(bean interface{}) reflect.Type {
 	sliceValue := reflect.Indirect(reflect.ValueOf(bean))
-	//return reflect.TypeOf(sliceValue.Interface())
+	// return reflect.TypeOf(sliceValue.Interface())
 	return sliceValue.Type()
 }
 
@@ -309,3 +309,24 @@ func sliceEq(left, right []string) bool {
 func indexName(tableName, idxName string) string {
 	return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
 }
+
+func eraseAny(value string, strToErase ...string) string {
+	if len(strToErase) == 0 {
+		return value
+	}
+	var replaceSeq []string
+	for _, s := range strToErase {
+		replaceSeq = append(replaceSeq, s, "")
+	}
+
+	replacer := strings.NewReplacer(replaceSeq...)
+
+	return replacer.Replace(value)
+}
+
+func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string {
+	for i := range cols {
+		cols[i] = quoteFunc(cols[i])
+	}
+	return strings.Join(cols, sep+" ")
+}

+ 21 - 1
helpers_test.go

@@ -4,7 +4,11 @@
 
 package xorm
 
-import "testing"
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
 
 func TestSplitTag(t *testing.T) {
 	var cases = []struct {
@@ -24,3 +28,19 @@ func TestSplitTag(t *testing.T) {
 		}
 	}
 }
+
+func TestEraseAny(t *testing.T) {
+	raw := "SELECT * FROM `table`.[table_name]"
+	assert.EqualValues(t, raw, eraseAny(raw))
+	assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
+	assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
+}
+
+func TestQuoteColumns(t *testing.T) {
+	cols := []string{"f1", "f2", "f3"}
+	quoteFunc := func(value string) string {
+		return "[" + value + "]"
+	}
+
+	assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ","))
+}

+ 8 - 16
session_insert.go

@@ -242,23 +242,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 
 	var sql string
 	if session.engine.dialect.DBType() == core.ORACLE {
-		temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
+		temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
 			session.engine.Quote(tableName),
-			session.engine.QuoteStr(),
-			strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
-			session.engine.QuoteStr())
-		sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
+			quoteColumns(colNames, session.engine.Quote, ","))
+		sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
 			session.engine.Quote(tableName),
-			session.engine.QuoteStr(),
-			strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
-			session.engine.QuoteStr(),
+			quoteColumns(colNames, session.engine.Quote, ","),
 			strings.Join(colMultiPlaces, temp))
 	} else {
-		sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
+		sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
 			session.engine.Quote(tableName),
-			session.engine.QuoteStr(),
-			strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
-			session.engine.QuoteStr(),
+			quoteColumns(colNames, session.engine.Quote, ","),
 			strings.Join(colMultiPlaces, "),("))
 	}
 	res, err := session.exec(sql, args...)
@@ -379,11 +373,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 		output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
 	}
 	if len(colPlaces) > 0 {
-		sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v)%s VALUES (%v)",
+		sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
 			session.engine.Quote(tableName),
-			session.engine.QuoteStr(),
-			strings.Join(colNames, session.engine.Quote(", ")),
-			session.engine.QuoteStr(),
+			quoteColumns(colNames, session.engine.Quote, ","),
 			output,
 			colPlaces)
 	} else {

+ 8 - 7
session_update.go

@@ -96,14 +96,15 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
 				return ErrCacheFailed
 			}
 			kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
+
 			for idx, kv := range kvs {
 				sps := strings.SplitN(kv, "=", 2)
 				sps2 := strings.Split(sps[0], ".")
 				colName := sps2[len(sps2)-1]
-				if strings.Contains(colName, "`") {
-					colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
-				} else if strings.Contains(colName, session.engine.QuoteStr()) {
-					colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1))
+				// treat quote prefix, suffix and '`' as quotes
+				quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
+				if strings.ContainsAny(colName, strings.Join(quotes, "")) {
+					colName = strings.TrimSpace(eraseAny(colName, quotes...))
 				} else {
 					session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
 					return ErrCacheFailed
@@ -221,19 +222,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 		}
 	}
 
-	//for update action to like "column = column + ?"
+	// for update action to like "column = column + ?"
 	incColumns := session.statement.getInc()
 	for _, v := range incColumns {
 		colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
 		args = append(args, v.arg)
 	}
-	//for update action to like "column = column - ?"
+	// for update action to like "column = column - ?"
 	decColumns := session.statement.getDec()
 	for _, v := range decColumns {
 		colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
 		args = append(args, v.arg)
 	}
-	//for update action to like "column = expression"
+	// for update action to like "column = expression"
 	exprColumns := session.statement.getExpr()
 	for _, v := range exprColumns {
 		colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)

+ 10 - 19
statement.go

@@ -6,7 +6,6 @@ package xorm
 
 import (
 	"database/sql/driver"
-	"errors"
 	"fmt"
 	"reflect"
 	"strings"
@@ -426,7 +425,7 @@ func (statement *Statement) buildUpdates(bean interface{},
 								continue
 							}
 						} else {
-							//TODO: how to handler?
+							// TODO: how to handler?
 							panic("not supported")
 						}
 					} else {
@@ -607,21 +606,9 @@ func (statement *Statement) getExpr() map[string]exprParam {
 
 func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
 	newColumns := make([]string, 0)
+	quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
 	for _, col := range columns {
-		col = strings.Replace(col, "`", "", -1)
-		col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
-		ccols := strings.Split(col, ",")
-		for _, c := range ccols {
-			fields := strings.Split(strings.TrimSpace(c), ".")
-			if len(fields) == 1 {
-				newColumns = append(newColumns, statement.Engine.quote(fields[0]))
-			} else if len(fields) == 2 {
-				newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
-					statement.Engine.quote(fields[1]))
-			} else {
-				panic(errors.New("unwanted colnames"))
-			}
-		}
+		newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
 	}
 	return newColumns
 }
@@ -792,7 +779,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 			return statement
 		}
 		tbs := strings.Split(tp.TableName(), ".")
-		var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
+		quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
+
+		var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
 		fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
 		statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
 	case *builder.Builder:
@@ -802,7 +791,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 			return statement
 		}
 		tbs := strings.Split(tp.TableName(), ".")
-		var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
+		quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
+
+		var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
 		fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
 		statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
 	default:
@@ -1272,7 +1263,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
 
 	var whereStr = sqls[1]
 
-	//TODO: for postgres only, if any other database?
+	// TODO: for postgres only, if any other database?
 	var paraStr string
 	if statement.Engine.dialect.DBType() == core.POSTGRES {
 		paraStr = "$"

+ 9 - 0
statement_test.go

@@ -237,3 +237,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
 	testEngine.Update(record)
 	assertGetRecord()
 }
+
+func TestCol2NewColsWithQuote(t *testing.T) {
+	cols := []string{"f1", "f2", "t3.f3"}
+
+	statement := createTestStatement()
+
+	quotedCols := statement.col2NewColsWithQuote(cols...)
+	assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
+}