Pārlūkot izejas kodu

修改原session.Qeury()API

xormplus 10 gadi atpakaļ
vecāks
revīzija
79c68794e6
5 mainītis faili ar 125 papildinājumiem un 195 dzēšanām
  1. 6 0
      engine.go
  2. 2 2
      helpers.go
  3. 42 22
      session.go
  4. 46 138
      statement.go
  5. 29 33
      xorm.go

+ 6 - 0
engine.go

@@ -511,6 +511,12 @@ func (engine *Engine) Distinct(columns ...string) *Session {
 	return session.Distinct(columns...)
 }
 
+func (engine *Engine) Select(str string) *Session {
+	session := engine.NewSession()
+	session.IsAutoClose = true
+	return session.Select(str)
+}
+
 // only use the paramters as select or update columns
 func (engine *Engine) Cols(columns ...string) *Session {
 	session := engine.NewSession()

+ 2 - 2
helpers.go

@@ -133,8 +133,8 @@ func reflect2value(rawValue *reflect.Value) (str string, err error) {
 		}
 	//时间类型
 	case reflect.Struct:
-		if aa == core.TimeType {
-			str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano)
+		if aa.ConvertibleTo(core.TimeType) {
+			str = rawValue.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
 		} else {
 			err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
 		}

+ 42 - 22
session.go

@@ -173,6 +173,12 @@ func (session *Session) SetExpr(column string, expression string) *Session {
 	return session
 }
 
+// Method Cols provides some columns to special
+func (session *Session) Select(str string) *Session {
+	session.Statement.Select(str)
+	return session
+}
+
 // Method Cols provides some columns to special
 func (session *Session) Cols(columns ...string) *Session {
 	session.Statement.Cols(columns...)
@@ -622,12 +628,20 @@ func (statement *Statement) convertIdSql(sqlStr string) string {
 	return ""
 }
 
-func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
-	// if has no reftable, then don't use cache currently
+func (session *Session) canCache() bool {
 	if session.Statement.RefTable == nil ||
 		session.Statement.JoinStr != "" ||
 		session.Statement.RawSQL != "" ||
-		session.Tx != nil {
+		session.Tx != nil || 
+		len(session.Statement.selectStr) > 0 {
+		return false
+	}
+	return true
+}
+
+func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
+	// if has no reftable, then don't use cache currently
+	if !session.canCache() {
 		return false, ErrCacheFailed
 	}
 
@@ -725,10 +739,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
 }
 
 func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) {
-	if session.Statement.RefTable == nil ||
+	if !session.canCache() || 
 		indexNoCase(sqlStr, "having") != -1 ||
-		indexNoCase(sqlStr, "group by") != -1 ||
-		session.Tx != nil {
+		indexNoCase(sqlStr, "group by") != -1 {
 		return ErrCacheFailed
 	}
 
@@ -1275,9 +1288,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 	}
 
 	if len(condiBean) > 0 {
+		var addedTableName = (len(session.Statement.JoinStr) > 0)
 		colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true,
 			false, true, session.Statement.allUseBool, session.Statement.useAllCols,
-			session.Statement.unscoped, session.Statement.mustColumnMap)
+			session.Statement.unscoped, session.Statement.mustColumnMap, 
+			session.Statement.TableName(), addedTableName)
 		session.Statement.ConditionStr = strings.Join(colNames, " AND ")
 		session.Statement.BeanArgs = args
 	} else {
@@ -1293,20 +1308,24 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 	var args []interface{}
 	if session.Statement.RawSQL == "" {
 		var columnStr string = session.Statement.ColumnStr
-		if session.Statement.JoinStr == "" {
-			if columnStr == "" {
-				if session.Statement.GroupByStr != "" {
-					columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
-				} else {
-					columnStr = session.Statement.genColumnStr()
-				}
-			}
+		if len(session.Statement.selectStr) > 0 {
+			columnStr = session.Statement.selectStr
 		} else {
-			if columnStr == "" {
-				if session.Statement.GroupByStr != "" {
-					columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
-				} else {
-					columnStr = "*"
+			if session.Statement.JoinStr == "" {
+				if columnStr == "" {
+					if session.Statement.GroupByStr != "" {
+						columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
+					} else {
+						columnStr = session.Statement.genColumnStr()
+					}
+				}
+			} else {
+				if columnStr == "" {
+					if session.Statement.GroupByStr != "" {
+						columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
+					} else {
+						columnStr = "*"
+					}
 				}
 			}
 		}
@@ -3560,7 +3579,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 	if len(condiBean) > 0 {
 		condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true,
 			false, true, session.Statement.allUseBool, session.Statement.useAllCols,
-			session.Statement.unscoped, session.Statement.mustColumnMap)
+			session.Statement.unscoped, session.Statement.mustColumnMap, session.Statement.TableName(), false)
 	}
 
 	var condition = ""
@@ -3780,7 +3799,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 	session.Statement.RefTable = table
 	colNames, args := buildConditions(session.Engine, table, bean, true, true,
 		false, true, session.Statement.allUseBool, session.Statement.useAllCols,
-		session.Statement.unscoped, session.Statement.mustColumnMap)
+		session.Statement.unscoped, session.Statement.mustColumnMap, 
+		session.Statement.TableName(), false)
 
 	var condition = ""
 	var andStr = session.Engine.dialect.AndStr()

+ 46 - 138
statement.go

@@ -49,6 +49,7 @@ type Statement struct {
 	GroupByStr    string
 	HavingStr     string
 	ColumnStr     string
+	selectStr string
 	columnMap     map[string]bool
 	useAllCols    bool
 	OmitStr       string
@@ -100,6 +101,7 @@ func (statement *Statement) Init() {
 	statement.UseAutoTime = true
 	statement.IsDistinct = false
 	statement.TableAlias = ""
+	statement.selectStr = ""
 	statement.allUseBool = false
 	statement.useAllCols = false
 	statement.mustColumnMap = make(map[string]bool)
@@ -170,122 +172,6 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
 	return statement
 }
 
-/*func (statement *Statement) genFields(bean interface{}) map[string]interface{} {
-    results := make(map[string]interface{})
-    table := statement.Engine.TableInfo(bean)
-    for _, col := range table.Columns {
-        fieldValue := col.ValueOf(bean)
-        fieldType := reflect.TypeOf(fieldValue.Interface())
-        var val interface{}
-        switch fieldType.Kind() {
-        case reflect.Bool:
-            if allUseBool {
-                val = fieldValue.Interface()
-            } else if _, ok := boolColumnMap[col.Name]; ok {
-                val = fieldValue.Interface()
-            } else {
-                // if a bool in a struct, it will not be as a condition because it default is false,
-                // please use Where() instead
-                continue
-            }
-        case reflect.String:
-            if fieldValue.String() == "" {
-                continue
-            }
-            // for MyString, should convert to string or panic
-            if fieldType.String() != reflect.String.String() {
-                val = fieldValue.String()
-            } else {
-                val = fieldValue.Interface()
-            }
-        case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
-            if fieldValue.Int() == 0 {
-                continue
-            }
-            val = fieldValue.Interface()
-        case reflect.Float32, reflect.Float64:
-            if fieldValue.Float() == 0.0 {
-                continue
-            }
-            val = fieldValue.Interface()
-        case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
-            if fieldValue.Uint() == 0 {
-                continue
-            }
-            val = fieldValue.Interface()
-        case reflect.Struct:
-            if fieldType == reflect.TypeOf(time.Now()) {
-                t := fieldValue.Interface().(time.Time)
-                if t.IsZero() || !fieldValue.IsValid() {
-                    continue
-                }
-                var str string
-                if col.SQLType.Name == Time {
-                    s := t.UTC().Format("2006-01-02 15:04:05")
-                    val = s[11:19]
-                } else if col.SQLType.Name == Date {
-                    str = t.Format("2006-01-02")
-                    val = str
-                } else {
-                    val = t
-                }
-            } else {
-                engine.autoMapType(fieldValue.Type())
-                if table, ok := engine.Tables[fieldValue.Type()]; ok {
-                    pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName)
-                    if pkField.Int() != 0 {
-                        val = pkField.Interface()
-                    } else {
-                        continue
-                    }
-                } else {
-                    val = fieldValue.Interface()
-                }
-            }
-        case reflect.Array, reflect.Slice, reflect.Map:
-            if fieldValue == reflect.Zero(fieldType) {
-                continue
-            }
-            if fieldValue.IsNil() || !fieldValue.IsValid() {
-                continue
-            }
-
-            if col.SQLType.IsText() {
-                bytes, err := json.Marshal(fieldValue.Interface())
-                if err != nil {
-                    engine.LogError(err)
-                    continue
-                }
-                val = string(bytes)
-            } else if col.SQLType.IsBlob() {
-                var bytes []byte
-                var err error
-                if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
-                    fieldType.Elem().Kind() == reflect.Uint8 {
-                    if fieldValue.Len() > 0 {
-                        val = fieldValue.Bytes()
-                    } else {
-                        continue
-                    }
-                } else {
-                    bytes, err = json.Marshal(fieldValue.Interface())
-                    if err != nil {
-                        engine.LogError(err)
-                        continue
-                    }
-                    val = bytes
-                }
-            } else {
-                continue
-            }
-        default:
-            val = fieldValue.Interface()
-        }
-        results[col.Name] = val
-    }
-    return results
-}*/
-
 // Auto generating conditions according a struct
 func buildUpdates(engine *Engine, table *core.Table, bean interface{},
 	includeVersion bool, includeUpdated bool, includeNil bool,
@@ -414,8 +300,8 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
 			t := int64(fieldValue.Uint())
 			val = reflect.ValueOf(&t).Interface()
 		case reflect.Struct:
-			if fieldType == reflect.TypeOf(time.Now()) {
-				t := fieldValue.Interface().(time.Time)
+			if fieldType.ConvertibleTo(core.TimeType) {
+				t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
 				if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
 					continue
 				}
@@ -496,8 +382,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
 func buildConditions(engine *Engine, table *core.Table, bean interface{},
 	includeVersion bool, includeUpdated bool, includeNil bool,
 	includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
-	mustColumnMap map[string]bool) ([]string, []interface{}) {
-
+	mustColumnMap map[string]bool, tableName string, addedTableName bool) ([]string, []interface{}) {
 	colNames := make([]string, 0)
 	var args = make([]interface{}, 0)
 	for _, col := range table.Columns() {
@@ -514,6 +399,14 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 		if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
 			continue
 		}
+
+		var colName string
+		if addedTableName {
+			colName = engine.Quote(tableName)+"."+engine.Quote(col.Name)
+		} else {
+			colName = engine.Quote(col.Name)
+		}
+
 		fieldValuePtr, err := col.ValueOf(bean)
 		if err != nil {
 			engine.LogError(err)
@@ -521,7 +414,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 		}
 
 		if col.IsDeleted && !unscoped { // tag "deleted" is enabled
-			colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", engine.Quote(col.Name), engine.Quote(col.Name)))
+			colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", 
+				colName, colName))
 		}
 
 		fieldValue := *fieldValuePtr
@@ -543,7 +437,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 			if fieldValue.IsNil() {
 				if includeNil {
 					args = append(args, nil)
-					colNames = append(colNames, fmt.Sprintf("%v %s ?", engine.Quote(col.Name), engine.dialect.EqStr()))
+					colNames = append(colNames, fmt.Sprintf("%v %s ?", colName, engine.dialect.EqStr()))
 				}
 				continue
 			} else if !fieldValue.IsValid() {
@@ -666,7 +560,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 		if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
 			condi = "id() == ?"
 		} else {
-			condi = fmt.Sprintf("%v %s ?", engine.Quote(col.Name), engine.dialect.EqStr())
+			condi = fmt.Sprintf("%v %s ?", colName, engine.dialect.EqStr())
 		}
 		colNames = append(colNames, condi)
 	}
@@ -862,6 +756,12 @@ func (statement *Statement) Distinct(columns ...string) *Statement {
 	return statement
 }
 
+// replace select
+func (s *Statement) Select(str string) *Statement {
+	s.selectStr = str
+	return s
+}
+
 // Generate "col1, col2" statement
 func (statement *Statement) Cols(columns ...string) *Statement {
 	newColumns := col2NewCols(columns...)
@@ -1138,28 +1038,34 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
 		table = statement.RefTable
 	}
 
+	var addedTableName = (len(statement.JoinStr) > 0)
+
 	colNames, args := buildConditions(statement.Engine, table, bean, true, true,
 		false, true, statement.allUseBool, statement.useAllCols,
-		statement.unscoped, statement.mustColumnMap)
+		statement.unscoped, statement.mustColumnMap, statement.TableName(), addedTableName)
 
 	statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ")
 	statement.BeanArgs = args
 
 	var columnStr string = statement.ColumnStr
-	if len(statement.JoinStr) == 0 {
-		if len(columnStr) == 0 {
-			if statement.GroupByStr != "" {
-				columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
-			} else {
-				columnStr = statement.genColumnStr()
-			}
-		}
+	if len(statement.selectStr) > 0 {
+		columnStr = statement.selectStr
 	} else {
-		if len(columnStr) == 0 {
-			if statement.GroupByStr != "" {
-				columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
-			} else {
-				columnStr = "*"
+		if len(statement.JoinStr) == 0 {
+			if len(columnStr) == 0 {
+				if statement.GroupByStr != "" {
+					columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
+				} else {
+					columnStr = statement.genColumnStr()
+				}
+			}
+		} else {
+			if len(columnStr) == 0 {
+				if statement.GroupByStr != "" {
+					columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
+				} else {
+					columnStr = "*"
+				}
 			}
 		}
 	}
@@ -1193,9 +1099,11 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}
 	table := statement.Engine.TableInfo(bean)
 	statement.RefTable = table
 
+	var addedTableName = (len(statement.JoinStr) > 0)
+
 	colNames, args := buildConditions(statement.Engine, table, bean, true, true, false,
 		true, statement.allUseBool, statement.useAllCols,
-		statement.unscoped, statement.mustColumnMap)
+		statement.unscoped, statement.mustColumnMap, statement.TableName(), addedTableName)
 
 	statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ")
 	statement.BeanArgs = args

+ 29 - 33
xorm.go

@@ -20,6 +20,35 @@ const (
 	Version string = "1.0.0"
 )
 
+func regDrvsNDialects() bool {
+	providedDrvsNDialects := map[string]struct {
+		dbType     core.DbType
+		getDriver  func() core.Driver
+		getDialect func() core.Dialect
+	}{
+		"mssql":    {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }},
+		"odbc":     {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
+		"mysql":    {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }},
+		"mymysql":  {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }},
+		"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
+		"sqlite3":  {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }},
+		"oci8":     {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }},
+		"goracle":  {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},
+	}
+
+	for driverName, v := range providedDrvsNDialects {
+		if driver := core.QueryDriver(driverName); driver == nil {
+			core.RegisterDriver(driverName, v.getDriver())
+			core.RegisterDialect(v.dbType, v.getDialect())
+		}
+	}
+	return true
+}
+
+func close(engine *Engine) {
+	engine.Close()
+}
+
 // new a db manager according to the parameter. Currently support four
 // drivers
 func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
@@ -59,10 +88,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
 		TZLocation:    time.Local,
 	}
 
-	if err != nil {
-		engine.Logger.Warning(err)
-	}
-
 	engine.dialect.SetLogger(engine.Logger)
 
 	engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper)))
@@ -72,35 +97,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
 	return engine, nil
 }
 
-func regDrvsNDialects() bool {
-	providedDrvsNDialects := map[string]struct {
-		dbType     core.DbType
-		getDriver  func() core.Driver
-		getDialect func() core.Dialect
-	}{
-		"mssql":    {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }},
-		"odbc":     {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
-		"mysql":    {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }},
-		"mymysql":  {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }},
-		"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
-		"sqlite3":  {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }},
-		"oci8":     {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }},
-		"goracle":  {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},
-	}
-
-	for driverName, v := range providedDrvsNDialects {
-		if driver := core.QueryDriver(driverName); driver == nil {
-			core.RegisterDriver(driverName, v.getDriver())
-			core.RegisterDialect(v.dbType, v.getDialect())
-		}
-	}
-	return true
-}
-
-func close(engine *Engine) {
-	engine.Close()
-}
-
 // clone an engine
 func (engine *Engine) Clone() (*Engine, error) {
 	return NewEngine(engine.DriverName(), engine.DataSourceName())