Parcourir la source

add Scan features to Get method

xormplus il y a 8 ans
Parent
commit
52c6a34eff
2 fichiers modifiés avec 25 ajouts et 18 suppressions
  1. 9 14
      session_get.go
  2. 16 4
      statement.go

+ 9 - 14
session_get.go

@@ -22,12 +22,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
 
 	beanValue := reflect.ValueOf(bean)
 	if beanValue.Kind() != reflect.Ptr {
-		return false, errors.New("needs a pointer to a struct")
-	}
-
-	// FIXME: remove this after support non-struct Get
-	if beanValue.Elem().Kind() != reflect.Struct {
-		return false, errors.New("needs a pointer to a struct")
+		return false, errors.New("needs a pointer")
 	}
 
 	if beanValue.Elem().Kind() == reflect.Struct {
@@ -48,7 +43,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
 		args = session.Statement.RawParams
 	}
 
-	if session.canCache() {
+	if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
 		if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil &&
 			!session.Statement.unscoped {
 			has, err := session.cacheGet(bean, sqlStr, args...)
@@ -62,9 +57,10 @@ func (session *Session) Get(bean interface{}) (bool, error) {
 }
 
 func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
+	session.queryPreprocess(&sqlStr, args...)
+
 	var rawRows *core.Rows
 	var err error
-	session.queryPreprocess(&sqlStr, args...)
 	if session.IsAutoCommit {
 		_, rawRows, err = session.innerQuery(sqlStr, args...)
 	} else {
@@ -77,14 +73,13 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlS
 	defer rawRows.Close()
 
 	if rawRows.Next() {
-		fields, err := rawRows.Columns()
-		if err != nil {
-			// WARN: Alougth rawRows return true, but get fields failed
-			return true, err
-		}
-
 		switch beanKind {
 		case reflect.Struct:
+			fields, err := rawRows.Columns()
+			if err != nil {
+				// WARN: Alougth rawRows return true, but get fields failed
+				return true, err
+			}
 			dataStruct := rValue(bean)
 			session.Statement.setRefValue(dataStruct)
 			_, err = session.row2Bean(rawRows, fields, len(fields), bean, &dataStruct, session.Statement.RefTable)

+ 16 - 4
statement.go

@@ -1124,7 +1124,11 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
 }
 
 func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
-	statement.setRefValue(rValue(bean))
+	v := rValue(bean)
+	isStruct := v.Kind() == reflect.Struct
+	if isStruct {
+		statement.setRefValue(v)
+	}
 
 	var columnStr = statement.ColumnStr
 	if len(statement.selectStr) > 0 {
@@ -1143,14 +1147,22 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
 			if len(columnStr) == 0 {
 				if len(statement.GroupByStr) > 0 {
 					columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
-				} else {
-					columnStr = "*"
 				}
 			}
 		}
 	}
 
-	condSQL, condArgs, _ := statement.genConds(bean)
+	if len(columnStr) == 0 {
+		columnStr = "*"
+	}
+
+	var condSQL string
+	var condArgs []interface{}
+	if isStruct {
+		condSQL, condArgs, _ = statement.genConds(bean)
+	} else {
+		condSQL, condArgs, _ = builder.ToSQL(statement.cond)
+	}
 
 	return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
 }