瀏覽代碼

bug fixed

xormplus 9 年之前
父節點
當前提交
b893234e54
共有 5 個文件被更改,包括 409 次插入333 次删除
  1. 23 2
      engine.go
  2. 123 198
      session.go
  3. 62 50
      sessionplus.go
  4. 200 82
      statement.go
  5. 1 1
      xorm.go

+ 23 - 2
engine.go

@@ -930,8 +930,16 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
 		fieldType := fieldValue.Type()
 
 		if ormTagStr != "" {
-			col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false,
-				IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]bool)}
+			col = &core.Column{
+				FieldName:       t.Field(i).Name,
+				TableName:       table.Name,
+				Nullable:        true,
+				IsPrimaryKey:    false,
+				IsAutoIncrement: false,
+				MapType:         core.TWOSIDES,
+				Indexes:         make(map[string]bool),
+			}
+
 			tags := splitTag(ormTagStr)
 
 			if len(tags) > 0 {
@@ -953,6 +961,18 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
 					case reflect.Struct:
 						parentTable := engine.mapType(fieldValue)
 						for _, col := range parentTable.Columns() {
+							/*if t.Field(i).Anonymous {
+								col.TableName = parentTable.Name
+							} else {
+								col.TableName = engine.TableMapper.Obj2Table(t.Field(i).Name)
+							}*/
+							if len(col.TableName) <= 0 {
+								if _, ok := fieldValue.Interface().(TableName); ok {
+									col.TableName = fieldValue.Interface().(TableName).TableName()
+								} else {
+									col.TableName = engine.TableMapper.Obj2Table(fieldType.Name())
+								}
+							}
 							col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName)
 							table.AddColumn(col)
 						}
@@ -1134,6 +1154,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
 			col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name),
 				t.Field(i).Name, sqlType, sqlType.DefaultLength,
 				sqlType.DefaultLength2, true)
+			col.TableName = table.Name
 		}
 		if col.IsAutoIncrement {
 			col.Nullable = false

+ 123 - 198
session.go

@@ -649,39 +649,6 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
 	return nil
 }
 
-func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string {
-	var colnames = make([]string, len(cols))
-	for i, col := range cols {
-		if includeTableName {
-			colnames[i] = statement.Engine.Quote(statement.TableName()) +
-				"." + statement.Engine.Quote(col.Name)
-		} else {
-			colnames[i] = statement.Engine.Quote(col.Name)
-		}
-	}
-	return strings.Join(colnames, ", ")
-}
-
-func (statement *Statement) convertIdSql(sqlStr string) string {
-	if statement.RefTable != nil {
-		cols := statement.RefTable.PKColumns()
-		if len(cols) == 0 {
-			return ""
-		}
-
-		colstrs := statement.JoinColumns(cols, false)
-		sqls := splitNNoCase(sqlStr, " from ", 2)
-		if len(sqls) != 2 {
-			return ""
-		}
-		if statement.Engine.dialect.DBType() == "ql" {
-			return fmt.Sprintf("SELECT id() FROM %v", sqls[1])
-		}
-		return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1])
-	}
-	return ""
-}
-
 func (session *Session) canCache() bool {
 	if session.Statement.RefTable == nil ||
 		session.Statement.JoinStr != "" ||
@@ -1044,8 +1011,10 @@ func (session *Session) Get(bean interface{}) (bool, error) {
 	var sqlStr string
 	var args []interface{}
 
+	session.Statement.OutTable = session.Engine.TableInfo(bean)
+
 	if session.Statement.RefTable == nil {
-		session.Statement.RefTable = session.Engine.TableInfo(bean)
+		session.Statement.RefTable = session.Statement.OutTable
 	}
 
 	if session.Statement.RawSQL == "" {
@@ -1139,72 +1108,48 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 
 	sliceElementType := sliceValue.Type().Elem()
 	var table *core.Table
-	if session.Statement.RefTable == nil {
-		if sliceElementType.Kind() == reflect.Ptr {
-			if sliceElementType.Elem().Kind() == reflect.Struct {
-				pv := reflect.New(sliceElementType.Elem())
-				table = session.Engine.autoMapType(pv.Elem())
-			} else {
-				return errors.New("slice type")
-			}
-		} else if sliceElementType.Kind() == reflect.Struct {
-			pv := reflect.New(sliceElementType)
+
+	if sliceElementType.Kind() == reflect.Ptr {
+		if sliceElementType.Elem().Kind() == reflect.Struct {
+			pv := reflect.New(sliceElementType.Elem())
 			table = session.Engine.autoMapType(pv.Elem())
 		} else {
 			return errors.New("slice type")
 		}
-		session.Statement.RefTable = table
+	} else if sliceElementType.Kind() == reflect.Struct {
+		pv := reflect.New(sliceElementType)
+		table = session.Engine.autoMapType(pv.Elem())
 	} else {
-		table = session.Statement.RefTable
+		return errors.New("slice type")
+	}
+
+	session.Statement.OutTable = table
+
+	if session.Statement.RefTable == nil {
+		session.Statement.RefTable = table
 	}
 
-	var addedTableName = (len(session.Statement.JoinStr) > 0)
 	if !session.Statement.noAutoCondition && len(condiBean) > 0 {
-		colNames, args := session.Statement.buildConditions(table, condiBean[0], true, true, false, true, addedTableName)
+		colNames, args := session.Statement.buildConditions(
+			table, condiBean[0], true, true, false, true, session.Statement.needTableName())
+
 		session.Statement.ConditionStr = strings.Join(colNames, " AND ")
 		session.Statement.BeanArgs = args
 	} else {
 		// !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given.
 		// See https://github.com/go-xorm/xorm/issues/179
-		if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled
-			var colName = session.Engine.Quote(col.Name)
-			if addedTableName {
-				var nm = session.Statement.TableName()
-				if len(session.Statement.TableAlias) > 0 {
-					nm = session.Statement.TableAlias
-				}
-				colName = session.Engine.Quote(nm) + "." + colName
-			}
-			session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL OR %v = '0001-01-01 00:00:00')",
-				colName, colName)
+		if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped {
+			// tag "deleted" is enabled
+			var colName = session.Statement.colName(col)
+			session.Statement.ConditionStr = fmt.Sprintf(
+				"(%v IS NULL OR %v = '0001-01-01 00:00:00')", colName, colName)
 		}
 	}
 
 	var sqlStr string
 	var args []interface{}
 	if session.Statement.RawSQL == "" {
-		var columnStr = session.Statement.ColumnStr
-		if len(session.Statement.selectStr) > 0 {
-			columnStr = session.Statement.selectStr
-		} else {
-			if session.Statement.JoinStr == "" {
-				if columnStr == "" {
-					if session.Statement.GroupByStr != "" {
-						columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
-					} else {
-						columnStr = session.Statement.genColumnStr()
-					}
-				}
-			} else {
-				if columnStr == "" {
-					if session.Statement.GroupByStr != "" {
-						columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
-					} else {
-						columnStr = "*"
-					}
-				}
-			}
-		}
+		columnStr := session.Statement.genColumnStr()
 
 		session.Statement.Params = append(session.Statement.joinArgs, append(session.Statement.Params, session.Statement.BeanArgs...)...)
 
@@ -1596,7 +1541,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 		if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil {
 			rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii]))
 
-			//if row is null then ignore
+			// if row is null then ignore
 			if rawValue.Interface() == nil {
 				continue
 			}
@@ -1635,28 +1580,30 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 				var bs []byte
 				if rawValueType.Kind() == reflect.String {
 					bs = []byte(vv.String())
-				} else if rawValueType.ConvertibleTo(reflect.SliceOf(reflect.TypeOf(uint8(1)))) {
+				} else if rawValueType.ConvertibleTo(core.BytesType) {
 					bs = vv.Bytes()
 				} else {
-					return errors.New("unsupported database data type")
+					return fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind())
 				}
 
 				hasAssigned = true
 
-				if fieldValue.CanAddr() {
-					err := json.Unmarshal(bs, fieldValue.Addr().Interface())
-					if err != nil {
-						session.Engine.logger.Error(err)
-						return err
-					}
-				} else {
-					x := reflect.New(fieldType)
-					err := json.Unmarshal(bs, x.Interface())
-					if err != nil {
-						session.Engine.logger.Error(err)
-						return err
+				if len(bs) > 0 {
+					if fieldValue.CanAddr() {
+						err := json.Unmarshal(bs, fieldValue.Addr().Interface())
+						if err != nil {
+							session.Engine.logger.Error(key, err)
+							return err
+						}
+					} else {
+						x := reflect.New(fieldType)
+						err := json.Unmarshal(bs, x.Interface())
+						if err != nil {
+							session.Engine.logger.Error(key, err)
+							return err
+						}
+						fieldValue.Set(x.Elem())
 					}
-					fieldValue.Set(x.Elem())
 				}
 
 				continue
@@ -1668,25 +1615,27 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 				var bs []byte
 				if rawValueType.Kind() == reflect.String {
 					bs = []byte(vv.String())
-				} else if rawValueType.Kind() == reflect.Slice {
+				} else if rawValueType.ConvertibleTo(core.BytesType) {
 					bs = vv.Bytes()
 				}
 
 				hasAssigned = true
-				if fieldValue.CanAddr() {
-					err := json.Unmarshal(bs, fieldValue.Addr().Interface())
-					if err != nil {
-						session.Engine.logger.Error(err)
-						return err
-					}
-				} else {
-					x := reflect.New(fieldType)
-					err := json.Unmarshal(bs, x.Interface())
-					if err != nil {
-						session.Engine.logger.Error(err)
-						return err
+				if len(bs) > 0 {
+					if fieldValue.CanAddr() {
+						err := json.Unmarshal(bs, fieldValue.Addr().Interface())
+						if err != nil {
+							session.Engine.logger.Error(err)
+							return err
+						}
+					} else {
+						x := reflect.New(fieldType)
+						err := json.Unmarshal(bs, x.Interface())
+						if err != nil {
+							session.Engine.logger.Error(err)
+							return err
+						}
+						fieldValue.Set(x.Elem())
 					}
-					fieldValue.Set(x.Elem())
 				}
 			case reflect.Slice, reflect.Array:
 				switch rawValueType.Kind() {
@@ -1800,21 +1749,25 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 					if rawValueType.Kind() == reflect.String {
 						hasAssigned = true
 						x := reflect.New(fieldType)
-						err := json.Unmarshal([]byte(vv.String()), x.Interface())
-						if err != nil {
-							session.Engine.logger.Error(err)
-							return err
+						if len([]byte(vv.String())) > 0 {
+							err := json.Unmarshal([]byte(vv.String()), x.Interface())
+							if err != nil {
+								session.Engine.logger.Error(err)
+								return err
+							}
+							fieldValue.Set(x.Elem())
 						}
-						fieldValue.Set(x.Elem())
 					} else if rawValueType.Kind() == reflect.Slice {
 						hasAssigned = true
 						x := reflect.New(fieldType)
-						err := json.Unmarshal(vv.Bytes(), x.Interface())
-						if err != nil {
-							session.Engine.logger.Error(err)
-							return err
+						if len(vv.Bytes()) > 0 {
+							err := json.Unmarshal(vv.Bytes(), x.Interface())
+							if err != nil {
+								session.Engine.logger.Error(err)
+								return err
+							}
+							fieldValue.Set(x.Elem())
 						}
-						fieldValue.Set(x.Elem())
 					}
 				} else if session.Statement.UseCascade {
 					table := session.Engine.autoMapType(*fieldValue)
@@ -1972,20 +1925,24 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 					}
 				case core.Complex64Type:
 					var x complex64
-					err := json.Unmarshal([]byte(vv.String()), &x)
-					if err != nil {
-						session.Engine.logger.Error(err)
-					} else {
-						fieldValue.Set(reflect.ValueOf(&x))
+					if len([]byte(vv.String())) > 0 {
+						err := json.Unmarshal([]byte(vv.String()), &x)
+						if err != nil {
+							session.Engine.logger.Error(err)
+						} else {
+							fieldValue.Set(reflect.ValueOf(&x))
+						}
 					}
 					hasAssigned = true
 				case core.Complex128Type:
 					var x complex128
-					err := json.Unmarshal([]byte(vv.String()), &x)
-					if err != nil {
-						session.Engine.logger.Error(err)
-					} else {
-						fieldValue.Set(reflect.ValueOf(&x))
+					if len([]byte(vv.String())) > 0 {
+						err := json.Unmarshal([]byte(vv.String()), &x)
+						if err != nil {
+							session.Engine.logger.Error(err)
+						} else {
+							fieldValue.Set(reflect.ValueOf(&x))
+						}
 					}
 					hasAssigned = true
 				} // switch fieldType
@@ -2430,36 +2387,41 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
 	switch fieldType.Kind() {
 	case reflect.Complex64, reflect.Complex128:
 		x := reflect.New(fieldType)
-
-		err := json.Unmarshal(data, x.Interface())
-		if err != nil {
-			session.Engine.logger.Error(err)
-			return err
+		if len(data) > 0 {
+			err := json.Unmarshal(data, x.Interface())
+			if err != nil {
+				session.Engine.logger.Error(err)
+				return err
+			}
+			fieldValue.Set(x.Elem())
 		}
-		fieldValue.Set(x.Elem())
 	case reflect.Slice, reflect.Array, reflect.Map:
 		v = data
 		t := fieldType.Elem()
 		k := t.Kind()
 		if col.SQLType.IsText() {
 			x := reflect.New(fieldType)
-			err := json.Unmarshal(data, x.Interface())
-			if err != nil {
-				session.Engine.logger.Error(err)
-				return err
+			if len(data) > 0 {
+				err := json.Unmarshal(data, x.Interface())
+				if err != nil {
+					session.Engine.logger.Error(err)
+					return err
+				}
+				fieldValue.Set(x.Elem())
 			}
-			fieldValue.Set(x.Elem())
 		} else if col.SQLType.IsBlob() {
 			if k == reflect.Uint8 {
 				fieldValue.Set(reflect.ValueOf(v))
 			} else {
 				x := reflect.New(fieldType)
-				err := json.Unmarshal(data, x.Interface())
-				if err != nil {
-					session.Engine.logger.Error(err)
-					return err
+				if len(data) > 0 {
+					err := json.Unmarshal(data, x.Interface())
+					if err != nil {
+						session.Engine.logger.Error(err)
+						return err
+					}
+					fieldValue.Set(x.Elem())
 				}
-				fieldValue.Set(x.Elem())
 			}
 		} else {
 			return ErrUnSupportedType
@@ -2584,21 +2546,25 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
 		// case "*complex64":
 		case core.Complex64Type.Kind():
 			var x complex64
-			err := json.Unmarshal(data, &x)
-			if err != nil {
-				session.Engine.logger.Error(err)
-				return err
+			if len(data) > 0 {
+				err := json.Unmarshal(data, &x)
+				if err != nil {
+					session.Engine.logger.Error(err)
+					return err
+				}
+				fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
 			}
-			fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
 		// case "*complex128":
 		case core.Complex128Type.Kind():
 			var x complex128
-			err := json.Unmarshal(data, &x)
-			if err != nil {
-				session.Engine.logger.Error(err)
-				return err
+			if len(data) > 0 {
+				err := json.Unmarshal(data, &x)
+				if err != nil {
+					session.Engine.logger.Error(err)
+					return err
+				}
+				fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
 			}
-			fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
 		// case "*float64":
 		case core.Float64Type.Kind():
 			x, err := strconv.ParseFloat(string(data), 64)
@@ -3207,47 +3173,6 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
 	return session.innerInsert(bean)
 }
 
-func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
-	if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
-		return "", ""
-	}
-
-	colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true)
-	sqls := splitNNoCase(sqlStr, "where", 2)
-	if len(sqls) != 2 {
-		if len(sqls) == 1 {
-			return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
-				colstrs, statement.Engine.Quote(statement.TableName()))
-		}
-		return "", ""
-	}
-
-	var whereStr = sqls[1]
-
-	//TODO: for postgres only, if any other database?
-	var paraStr string
-	if statement.Engine.dialect.DBType() == core.POSTGRES {
-		paraStr = "$"
-	} else if statement.Engine.dialect.DBType() == core.MSSQL {
-		paraStr = ":"
-	}
-
-	if paraStr != "" {
-		if strings.Contains(sqls[1], paraStr) {
-			dollers := strings.Split(sqls[1], paraStr)
-			whereStr = dollers[0]
-			for i, c := range dollers[1:] {
-				ccs := strings.SplitN(c, " ", 2)
-				whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
-			}
-		}
-	}
-
-	return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
-		colstrs, statement.Engine.Quote(statement.TableName()),
-		whereStr)
-}
-
 func (session *Session) cacheInsert(tables ...string) error {
 	if session.Statement.RefTable == nil {
 		return ErrCacheFailed

+ 62 - 50
sessionplus.go

@@ -457,7 +457,7 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 		if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil {
 			rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii]))
 
-			//if row is null then ignore
+			// if row is null then ignore
 			if rawValue.Interface() == nil {
 				continue
 			}
@@ -496,28 +496,30 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 				var bs []byte
 				if rawValueType.Kind() == reflect.String {
 					bs = []byte(vv.String())
-				} else if rawValueType.ConvertibleTo(reflect.SliceOf(reflect.TypeOf(uint8(1)))) {
+				} else if rawValueType.ConvertibleTo(core.BytesType) {
 					bs = vv.Bytes()
 				} else {
-					return errors.New("unsupported database data type")
+					return fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind())
 				}
 
 				hasAssigned = true
 
-				if fieldValue.CanAddr() {
-					err := json.Unmarshal(bs, fieldValue.Addr().Interface())
-					if err != nil {
-						session.Engine.logger.Error(err)
-						return err
-					}
-				} else {
-					x := reflect.New(fieldType)
-					err := json.Unmarshal(bs, x.Interface())
-					if err != nil {
-						session.Engine.logger.Error(err)
-						return err
+				if len(bs) > 0 {
+					if fieldValue.CanAddr() {
+						err := json.Unmarshal(bs, fieldValue.Addr().Interface())
+						if err != nil {
+							session.Engine.logger.Error(key, err)
+							return err
+						}
+					} else {
+						x := reflect.New(fieldType)
+						err := json.Unmarshal(bs, x.Interface())
+						if err != nil {
+							session.Engine.logger.Error(key, err)
+							return err
+						}
+						fieldValue.Set(x.Elem())
 					}
-					fieldValue.Set(x.Elem())
 				}
 
 				continue
@@ -529,25 +531,27 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 				var bs []byte
 				if rawValueType.Kind() == reflect.String {
 					bs = []byte(vv.String())
-				} else if rawValueType.Kind() == reflect.Slice {
+				} else if rawValueType.ConvertibleTo(core.BytesType) {
 					bs = vv.Bytes()
 				}
 
 				hasAssigned = true
-				if fieldValue.CanAddr() {
-					err := json.Unmarshal(bs, fieldValue.Addr().Interface())
-					if err != nil {
-						session.Engine.logger.Error(err)
-						return err
-					}
-				} else {
-					x := reflect.New(fieldType)
-					err := json.Unmarshal(bs, x.Interface())
-					if err != nil {
-						session.Engine.logger.Error(err)
-						return err
+				if len(bs) > 0 {
+					if fieldValue.CanAddr() {
+						err := json.Unmarshal(bs, fieldValue.Addr().Interface())
+						if err != nil {
+							session.Engine.logger.Error(err)
+							return err
+						}
+					} else {
+						x := reflect.New(fieldType)
+						err := json.Unmarshal(bs, x.Interface())
+						if err != nil {
+							session.Engine.logger.Error(err)
+							return err
+						}
+						fieldValue.Set(x.Elem())
 					}
-					fieldValue.Set(x.Elem())
 				}
 			case reflect.Slice, reflect.Array:
 				switch rawValueType.Kind() {
@@ -662,21 +666,25 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 					if rawValueType.Kind() == reflect.String {
 						hasAssigned = true
 						x := reflect.New(fieldType)
-						err := json.Unmarshal([]byte(vv.String()), x.Interface())
-						if err != nil {
-							session.Engine.logger.Error(err)
-							return err
+						if len([]byte(vv.String())) > 0 {
+							err := json.Unmarshal([]byte(vv.String()), x.Interface())
+							if err != nil {
+								session.Engine.logger.Error(err)
+								return err
+							}
+							fieldValue.Set(x.Elem())
 						}
-						fieldValue.Set(x.Elem())
 					} else if rawValueType.Kind() == reflect.Slice {
 						hasAssigned = true
 						x := reflect.New(fieldType)
-						err := json.Unmarshal(vv.Bytes(), x.Interface())
-						if err != nil {
-							session.Engine.logger.Error(err)
-							return err
+						if len(vv.Bytes()) > 0 {
+							err := json.Unmarshal(vv.Bytes(), x.Interface())
+							if err != nil {
+								session.Engine.logger.Error(err)
+								return err
+							}
+							fieldValue.Set(x.Elem())
 						}
-						fieldValue.Set(x.Elem())
 					}
 				} else if session.Statement.UseCascade {
 					table := session.Engine.autoMapType(*fieldValue)
@@ -834,20 +842,24 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 					}
 				case core.Complex64Type:
 					var x complex64
-					err := json.Unmarshal([]byte(vv.String()), &x)
-					if err != nil {
-						session.Engine.logger.Error(err)
-					} else {
-						fieldValue.Set(reflect.ValueOf(&x))
+					if len([]byte(vv.String())) > 0 {
+						err := json.Unmarshal([]byte(vv.String()), &x)
+						if err != nil {
+							session.Engine.logger.Error(err)
+						} else {
+							fieldValue.Set(reflect.ValueOf(&x))
+						}
 					}
 					hasAssigned = true
 				case core.Complex128Type:
 					var x complex128
-					err := json.Unmarshal([]byte(vv.String()), &x)
-					if err != nil {
-						session.Engine.logger.Error(err)
-					} else {
-						fieldValue.Set(reflect.ValueOf(&x))
+					if len([]byte(vv.String())) > 0 {
+						err := json.Unmarshal([]byte(vv.String()), &x)
+						if err != nil {
+							session.Engine.logger.Error(err)
+						} else {
+							fieldValue.Set(reflect.ValueOf(&x))
+						}
 					}
 					hasAssigned = true
 				} // switch fieldType

+ 200 - 82
statement.go

@@ -40,6 +40,7 @@ type exprParam struct {
 // Statement save all the sql info for executing SQL
 type Statement struct {
 	RefTable        *core.Table
+	OutTable        *core.Table
 	Engine          *Engine
 	Start           int
 	LimitN          int
@@ -54,6 +55,7 @@ type Statement struct {
 	ColumnStr       string
 	selectStr       string
 	columnMap       map[string]bool
+	tableMap        map[string]string
 	useAllCols      bool
 	OmitStr         string
 	ConditionStr    string
@@ -85,6 +87,7 @@ type Statement struct {
 // Init reset all the statment's fields
 func (statement *Statement) Init() {
 	statement.RefTable = nil
+	statement.OutTable = nil
 	statement.Start = 0
 	statement.LimitN = 0
 	statement.WhereStr = ""
@@ -98,6 +101,7 @@ func (statement *Statement) Init() {
 	statement.ColumnStr = ""
 	statement.OmitStr = ""
 	statement.columnMap = make(map[string]bool)
+	statement.tableMap = make(map[string]string)
 	statement.ConditionStr = ""
 	statement.AltTableName = ""
 	statement.IdParam = nil
@@ -141,7 +145,14 @@ func (statement *Statement) Sql(querystring string, args ...interface{}) *Statem
 
 // Alias set the table alias
 func (statement *Statement) Alias(alias string) *Statement {
+	if statement.TableName() != "" {
+		statement.tableMapDelete(statement.TableName())
+	}
+	if statement.TableAlias != "" {
+		statement.tableMapDelete(statement.TableAlias)
+	}
 	statement.TableAlias = alias
+	statement.tableMapAdd(alias)
 	return statement
 }
 
@@ -190,6 +201,9 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme
 
 // Table tempororily set table name, the parameter could be a string or a pointer of struct
 func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
+	if statement.TableAlias == "" && statement.TableName() != "" {
+		statement.tableMapDelete(statement.TableName())
+	}
 	v := rValue(tableNameOrBean)
 	t := v.Type()
 	if t.Kind() == reflect.String {
@@ -197,6 +211,9 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
 	} else if t.Kind() == reflect.Struct {
 		statement.RefTable = statement.Engine.autoMapType(v)
 	}
+	if statement.TableAlias == "" {
+		statement.tableMapAdd(statement.TableName())
+	}
 	return statement
 }
 
@@ -439,22 +456,70 @@ func (statement *Statement) needTableName() bool {
 	return len(statement.JoinStr) > 0
 }
 
-func (statement *Statement) colName(col *core.Column, tableName string) string {
-	if statement.needTableName() {
-		var nm = tableName
+func (statement *Statement) tableMapAdd(table string) {
+	tableName := statement.Engine.Quote(strings.ToLower(table))
+	statement.tableMap[tableName] = table
+}
+
+func (statement *Statement) tableMapDelete(table string) {
+	tableName := statement.Engine.Quote(strings.ToLower(table))
+	delete(statement.tableMap, tableName)
+}
+
+func (statement *Statement) isKnownTable(table string) (string, bool) {
+	if len(table) > 0 {
+		var mainTable string
+
 		if len(statement.TableAlias) > 0 {
-			nm = statement.TableAlias
+			mainTable = statement.TableAlias
+		} else {
+			mainTable = statement.TableName()
+		}
+
+		cm := statement.Engine.Quote(strings.ToLower(mainTable))
+		ct := statement.Engine.Quote(strings.ToLower(table))
+
+		if name, ok := statement.tableMap[ct]; ok {
+			return name, true
+		}
+
+		if ct == cm {
+			return mainTable, true
+		}
+	}
+	return "", false
+}
+
+func (statement *Statement) colName(col *core.Column) string {
+	var colTable string
+
+	if statement.needTableName() {
+		if name, ok := statement.isKnownTable(col.TableName); ok {
+			colTable = name
+		} else if name, ok := statement.isKnownTable(statement.outTableName()); ok {
+			colTable = name
+		} else {
+			if statement.TableAlias != "" {
+				colTable = statement.TableAlias
+			} else {
+				colTable = statement.TableName()
+			}
 		}
-		return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
 	}
-	return statement.Engine.Quote(col.Name)
+
+	if colTable != "" {
+		return statement.Engine.Quote(colTable) + "." + statement.Engine.Quote(col.Name)
+	} else {
+		return statement.Engine.Quote(col.Name)
+	}
 }
 
 // Auto generating conditions according a struct
-func buildConditions(engine *Engine, table *core.Table, bean interface{},
-	includeVersion bool, includeUpdated bool, includeNil bool,
-	includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
-	mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) ([]string, []interface{}) {
+func (statement *Statement) buildConditions(
+	table *core.Table, bean interface{},
+	includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool,
+	addedTableName bool) ([]string, []interface{}) {
+	engine := statement.Engine
 	var colNames []string
 	var args = make([]interface{}, 0)
 	for _, col := range table.Columns() {
@@ -475,16 +540,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 			continue
 		}
 
-		var colName string
-		if addedTableName {
-			var nm = tableName
-			if len(aliasName) > 0 {
-				nm = aliasName
-			}
-			colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
-		} else {
-			colName = engine.Quote(col.Name)
-		}
+		colName := statement.colName(col)
 
 		fieldValuePtr, err := col.ValueOf(bean)
 		if err != nil {
@@ -492,9 +548,9 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 			continue
 		}
 
-		if col.IsDeleted && !unscoped { // tag "deleted" is enabled
-			colNames = append(colNames, fmt.Sprintf("(%v IS NULL OR %v = '0001-01-01 00:00:00')",
-				colName, colName))
+		if col.IsDeleted && !statement.unscoped { // tag "deleted" is enabled
+			colNames = append(colNames, fmt.Sprintf(
+				"(%v IS NULL OR %v = '0001-01-01 00:00:00')", colName, colName))
 		}
 
 		fieldValue := *fieldValuePtr
@@ -503,8 +559,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 		}
 
 		fieldType := reflect.TypeOf(fieldValue.Interface())
-		requiredField := useAllCols
-		if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok {
+		requiredField := statement.useAllCols
+		if b, ok := statement.mustColumnMap[strings.ToLower(col.Name)]; ok {
 			if b {
 				requiredField = true
 			} else {
@@ -532,7 +588,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
 		var val interface{}
 		switch fieldType.Kind() {
 		case reflect.Bool:
-			if allUseBool || requiredField {
+			if statement.allUseBool || requiredField {
 				val = fieldValue.Interface()
 			} else {
 				// if a bool in a struct, it will not be as a condition because it default is false,
@@ -688,6 +744,13 @@ func (statement *Statement) TableName() string {
 	return ""
 }
 
+func (statement *Statement) outTableName() string {
+	if statement.OutTable != nil {
+		return statement.OutTable.Name
+	}
+	return ""
+}
+
 // Id generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?"
 func (statement *Statement) Id(id interface{}) *Statement {
 	idValue := reflect.ValueOf(id)
@@ -979,12 +1042,15 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 		fmt.Fprintf(&buf, "%v JOIN ", joinOP)
 	}
 
+	var refName string
 	switch tablename.(type) {
 	case []string:
 		t := tablename.([]string)
 		if len(t) > 1 {
+			refName = t[1]
 			fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
 		} else if len(t) == 1 {
+			refName = t[0]
 			fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
 		}
 	case []interface{}:
@@ -1003,18 +1069,22 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 			}
 		}
 		if l > 1 {
+			refName = fmt.Sprintf("%v", t[1])
 			fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
-				statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
+				statement.Engine.Quote(refName))
 		} else if l == 1 {
+			refName = table
 			fmt.Fprintf(&buf, statement.Engine.Quote(table))
 		}
 	default:
-		fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
+		refName = fmt.Sprintf("%v", tablename)
+		fmt.Fprintf(&buf, statement.Engine.Quote(refName))
 	}
 
 	fmt.Fprintf(&buf, " ON %v", condition)
 	statement.JoinStr = buf.String()
 	statement.joinArgs = append(statement.joinArgs, args...)
+	statement.tableMapAdd(refName)
 	return statement
 }
 
@@ -1037,7 +1107,23 @@ func (statement *Statement) Unscoped() *Statement {
 }
 
 func (statement *Statement) genColumnStr() string {
+	if len(statement.selectStr) > 0 {
+		return statement.selectStr
+	}
+
+	if len(statement.ColumnStr) > 0 {
+		return statement.ColumnStr
+	}
+
+	if len(statement.GroupByStr) > 0 {
+		return statement.Engine.Quote(
+			strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
+	}
+
 	table := statement.RefTable
+	if statement.OutTable != nil {
+		table = statement.OutTable
+	}
 	colNames := make([]string, 0)
 	for _, col := range table.Columns() {
 		if statement.OmitStr != "" {
@@ -1049,26 +1135,12 @@ func (statement *Statement) genColumnStr() string {
 			continue
 		}
 
-		if statement.JoinStr != "" {
-			var name string
-			if statement.TableAlias != "" {
-				name = statement.Engine.Quote(statement.TableAlias)
-			} else {
-				name = statement.Engine.Quote(statement.TableName())
-			}
-			name += "." + statement.Engine.Quote(col.Name)
-			if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
-				colNames = append(colNames, "id() AS "+name)
-			} else {
-				colNames = append(colNames, name)
-			}
+		name := statement.colName(col)
+
+		if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
+			colNames = append(colNames, "id() AS "+name)
 		} else {
-			name := statement.Engine.Quote(col.Name)
-			if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
-				colNames = append(colNames, "id() AS "+name)
-			} else {
-				colNames = append(colNames, name)
-			}
+			colNames = append(colNames, name)
 		}
 	}
 	return strings.Join(colNames, ", ")
@@ -1135,38 +1207,15 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
 		table = statement.RefTable
 	}
 
-	var addedTableName = (len(statement.JoinStr) > 0)
-
 	if !statement.noAutoCondition {
-		colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName)
+		colNames, args := statement.buildConditions(
+			table, bean, true, true, false, true, statement.needTableName())
 
 		statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ")
 		statement.BeanArgs = args
 	}
 
-	var columnStr string = statement.ColumnStr
-	if len(statement.selectStr) > 0 {
-		columnStr = statement.selectStr
-	} else {
-		// TODO: always generate column names, not use * even if join
-		if len(statement.JoinStr) == 0 {
-			if len(columnStr) == 0 {
-				if len(statement.GroupByStr) > 0 {
-					columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
-				} else {
-					columnStr = statement.genColumnStr()
-				}
-			}
-		} else {
-			if len(columnStr) == 0 {
-				if len(statement.GroupByStr) > 0 {
-					columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
-				} else {
-					columnStr = "*"
-				}
-			}
-		}
-	}
+	columnStr := statement.genColumnStr()
 
 	statement.attachInSql() // !admpub!  fix bug:Iterate func missing "... IN (...)"
 	return statement.genSelectSQL(columnStr), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...)
@@ -1193,21 +1242,16 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in
 	return sql, []interface{}{}
 }*/
 
-func (statement *Statement) buildConditions(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) ([]string, []interface{}) {
-	return buildConditions(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
-		statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
-}
-
 func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) {
 	table := statement.Engine.TableInfo(bean)
 	statement.RefTable = table
 
-	var addedTableName = (len(statement.JoinStr) > 0)
-
 	if !statement.noAutoCondition {
-		colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName)
+		colNames, args := statement.buildConditions(
+			table, bean, true, true, false, true, statement.needTableName())
 
 		statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ")
+
 		statement.BeanArgs = args
 	}
 
@@ -1331,7 +1375,7 @@ func (statement *Statement) processIdParam() {
 	if statement.IdParam != nil {
 		if statement.Engine.dialect.DBType() != "ql" {
 			for i, col := range statement.RefTable.PKColumns() {
-				var colName = statement.colName(col, statement.TableName())
+				var colName = statement.colName(col)
 				if i < len(*(statement.IdParam)) {
 					statement.And(fmt.Sprintf("%v %s ?", colName,
 						statement.Engine.dialect.EqStr()), (*(statement.IdParam))[i])
@@ -1347,3 +1391,77 @@ func (statement *Statement) processIdParam() {
 		}
 	}
 }
+
+func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string {
+	var colnames = make([]string, len(cols))
+	for i, col := range cols {
+		if includeTableName {
+			colnames[i] = statement.Engine.Quote(statement.TableName()) +
+				"." + statement.Engine.Quote(col.Name)
+		} else {
+			colnames[i] = statement.Engine.Quote(col.Name)
+		}
+	}
+	return strings.Join(colnames, ", ")
+}
+
+func (statement *Statement) convertIdSql(sqlStr string) string {
+	if statement.RefTable != nil {
+		cols := statement.RefTable.PKColumns()
+		if len(cols) == 0 {
+			return ""
+		}
+
+		colstrs := statement.JoinColumns(cols, false)
+		sqls := splitNNoCase(sqlStr, " from ", 2)
+		if len(sqls) != 2 {
+			return ""
+		}
+		if statement.Engine.dialect.DBType() == "ql" {
+			return fmt.Sprintf("SELECT id() FROM %v", sqls[1])
+		}
+		return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1])
+	}
+	return ""
+}
+
+func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
+	if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
+		return "", ""
+	}
+
+	colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true)
+	sqls := splitNNoCase(sqlStr, "where", 2)
+	if len(sqls) != 2 {
+		if len(sqls) == 1 {
+			return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
+				colstrs, statement.Engine.Quote(statement.TableName()))
+		}
+		return "", ""
+	}
+
+	var whereStr = sqls[1]
+
+	//TODO: for postgres only, if any other database?
+	var paraStr string
+	if statement.Engine.dialect.DBType() == core.POSTGRES {
+		paraStr = "$"
+	} else if statement.Engine.dialect.DBType() == core.MSSQL {
+		paraStr = ":"
+	}
+
+	if paraStr != "" {
+		if strings.Contains(sqls[1], paraStr) {
+			dollers := strings.Split(sqls[1], paraStr)
+			whereStr = dollers[0]
+			for i, c := range dollers[1:] {
+				ccs := strings.SplitN(c, " ", 2)
+				whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
+			}
+		}
+	}
+
+	return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
+		colstrs, statement.Engine.Quote(statement.TableName()),
+		whereStr)
+}

+ 1 - 1
xorm.go

@@ -17,7 +17,7 @@ import (
 
 const (
 	// Version show the xorm's version
-	Version string = "0.5.4.0508"
+	Version string = "0.5.4.0513"
 )
 
 func regDrvsNDialects() bool {