فهرست منبع

add CondDeleted method

xormplus 8 سال پیش
والد
کامیت
e71d50cb8e
6فایلهای تغییر یافته به همراه250 افزوده شده و 231 حذف شده
  1. 9 0
      engine.go
  2. 230 0
      engine_cond.go
  3. 2 2
      session.go
  4. 5 4
      session_convert.go
  5. 3 6
      session_find.go
  6. 1 219
      statement.go

+ 9 - 0
engine.go

@@ -20,6 +20,7 @@ import (
 	"time"
 
 	"github.com/fsnotify/fsnotify"
+	"github.com/go-xorm/builder"
 	"github.com/xormplus/core"
 )
 
@@ -1588,3 +1589,11 @@ func (engine *Engine) Unscoped() *Session {
 	session.IsAutoClose = true
 	return session.Unscoped()
 }
+
+// CondDeleted returns the conditions whether a record is soft deleted.
+func (engine *Engine) CondDeleted(colName string) builder.Cond {
+	if engine.dialect.DBType() == core.MSSQL {
+		return builder.IsNull{colName}
+	}
+	return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1})
+}

+ 230 - 0
engine_cond.go

@@ -0,0 +1,230 @@
+// Copyright 2017 The Xorm Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xorm
+
+import (
+	"database/sql/driver"
+	"encoding/json"
+	"fmt"
+	"reflect"
+	"time"
+
+	"github.com/go-xorm/builder"
+	"github.com/xormplus/core"
+)
+
+func (engine *Engine) buildConds(table *core.Table, bean interface{},
+	includeVersion bool, includeUpdated bool, includeNil bool,
+	includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
+	mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
+	var conds []builder.Cond
+	for _, col := range table.Columns() {
+		if !includeVersion && col.IsVersion {
+			continue
+		}
+		if !includeUpdated && col.IsUpdated {
+			continue
+		}
+		if !includeAutoIncr && col.IsAutoIncrement {
+			continue
+		}
+
+		if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) {
+			continue
+		}
+		if col.SQLType.IsJson() {
+			continue
+		}
+
+		var colName string
+		if addedTableName {
+			var nm = tableName
+			if len(aliasName) > 0 {
+				nm = aliasName
+			}
+			colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
+		} else {
+			colName = engine.Quote(col.Name)
+		}
+
+		fieldValuePtr, err := col.ValueOf(bean)
+		if err != nil {
+			engine.logger.Error(err)
+			continue
+		}
+
+		if col.IsDeleted && !unscoped { // tag "deleted" is enabled
+			conds = append(conds, engine.CondDeleted(colName))
+		}
+
+		fieldValue := *fieldValuePtr
+		if fieldValue.Interface() == nil {
+			continue
+		}
+
+		fieldType := reflect.TypeOf(fieldValue.Interface())
+		requiredField := useAllCols
+
+		if b, ok := getFlagForColumn(mustColumnMap, col); ok {
+			if b {
+				requiredField = true
+			} else {
+				continue
+			}
+		}
+
+		if fieldType.Kind() == reflect.Ptr {
+			if fieldValue.IsNil() {
+				if includeNil {
+					conds = append(conds, builder.Eq{colName: nil})
+				}
+				continue
+			} else if !fieldValue.IsValid() {
+				continue
+			} else {
+				// dereference ptr type to instance type
+				fieldValue = fieldValue.Elem()
+				fieldType = reflect.TypeOf(fieldValue.Interface())
+				requiredField = true
+			}
+		}
+
+		var val interface{}
+		switch fieldType.Kind() {
+		case reflect.Bool:
+			if allUseBool || requiredField {
+				val = fieldValue.Interface()
+			} else {
+				// if a bool in a struct, it will not be as a condition because it default is false,
+				// please use Where() instead
+				continue
+			}
+		case reflect.String:
+			if !requiredField && fieldValue.String() == "" {
+				continue
+			}
+			// for MyString, should convert to string or panic
+			if fieldType.String() != reflect.String.String() {
+				val = fieldValue.String()
+			} else {
+				val = fieldValue.Interface()
+			}
+		case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
+			if !requiredField && fieldValue.Int() == 0 {
+				continue
+			}
+			val = fieldValue.Interface()
+		case reflect.Float32, reflect.Float64:
+			if !requiredField && fieldValue.Float() == 0.0 {
+				continue
+			}
+			val = fieldValue.Interface()
+		case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
+			if !requiredField && fieldValue.Uint() == 0 {
+				continue
+			}
+			t := int64(fieldValue.Uint())
+			val = reflect.ValueOf(&t).Interface()
+		case reflect.Struct:
+			if fieldType.ConvertibleTo(core.TimeType) {
+				t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
+				if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
+					continue
+				}
+				val = engine.formatColTime(col, t)
+			} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
+				continue
+			} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
+				val, _ = valNul.Value()
+				if val == nil {
+					continue
+				}
+			} else {
+				if col.SQLType.IsJson() {
+					if col.SQLType.IsText() {
+						bytes, err := json.Marshal(fieldValue.Interface())
+						if err != nil {
+							engine.logger.Error(err)
+							continue
+						}
+						val = string(bytes)
+					} else if col.SQLType.IsBlob() {
+						var bytes []byte
+						var err error
+						bytes, err = json.Marshal(fieldValue.Interface())
+						if err != nil {
+							engine.logger.Error(err)
+							continue
+						}
+						val = bytes
+					}
+				} else {
+					engine.autoMapType(fieldValue)
+					if table, ok := engine.Tables[fieldValue.Type()]; ok {
+						if len(table.PrimaryKeys) == 1 {
+							pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
+							// fix non-int pk issues
+							//if pkField.Int() != 0 {
+							if pkField.IsValid() && !isZero(pkField.Interface()) {
+								val = pkField.Interface()
+							} else {
+								continue
+							}
+						} else {
+							//TODO: how to handler?
+							return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
+						}
+					} else {
+						val = fieldValue.Interface()
+					}
+				}
+			}
+		case reflect.Array:
+			continue
+		case reflect.Slice, reflect.Map:
+			if fieldValue == reflect.Zero(fieldType) {
+				continue
+			}
+			if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
+				continue
+			}
+
+			if col.SQLType.IsText() {
+				bytes, err := json.Marshal(fieldValue.Interface())
+				if err != nil {
+					engine.logger.Error(err)
+					continue
+				}
+				val = string(bytes)
+			} else if col.SQLType.IsBlob() {
+				var bytes []byte
+				var err error
+				if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
+					fieldType.Elem().Kind() == reflect.Uint8 {
+					if fieldValue.Len() > 0 {
+						val = fieldValue.Bytes()
+					} else {
+						continue
+					}
+				} else {
+					bytes, err = json.Marshal(fieldValue.Interface())
+					if err != nil {
+						engine.logger.Error(err)
+						continue
+					}
+					val = bytes
+				}
+			} else {
+				continue
+			}
+		default:
+			val = fieldValue.Interface()
+		}
+
+		conds = append(conds, builder.Eq{colName: val})
+	}
+
+	return builder.And(conds...), nil
+}

+ 2 - 2
session.go

@@ -573,7 +573,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
 								fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
 							}
 						} else {
-							panic(fmt.Sprintf("rawValueType is %v, value is %v", rawValueType, vv.Interface()))
+							return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
 						}
 					}
 				} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
@@ -613,7 +613,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
 
 					hasAssigned = true
 					if len(table.PrimaryKeys) != 1 {
-						panic("unsupported non or composited primary key cascade")
+						return nil, errors.New("unsupported non or composited primary key cascade")
 					}
 					var pk = make(core.PK, len(table.PrimaryKeys))
 					pk[0], err = asKind(vv, rawValueType)

+ 5 - 4
session_convert.go

@@ -28,8 +28,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
 		parseLoc = col.TimeZone
 	}
 
-	if sdata == "0000-00-00 00:00:00" ||
-		sdata == "0001-01-01 00:00:00" {
+	if sdata == zeroTime0 || sdata == zeroTime1 {
 	} else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column
 		// time stamp
 		sd, err := strconv.ParseInt(sdata, 10, 64)
@@ -213,8 +212,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
 
 				// TODO: current only support 1 primary key
 				if len(table.PrimaryKeys) > 1 {
-					panic("unsupported composited primary key cascade")
+					return errors.New("unsupported composited primary key cascade")
 				}
+
 				var pk = make(core.PK, len(table.PrimaryKeys))
 				rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
 				pk[0], err = str2PK(string(data), rawValueType)
@@ -496,8 +496,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
 					}
 
 					if len(table.PrimaryKeys) > 1 {
-						panic("unsupported composited primary key cascade")
+						return errors.New("unsupported composited primary key cascade")
 					}
+
 					var pk = make(core.PK, len(table.PrimaryKeys))
 					rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
 					pk[0], err = str2PK(string(data), rawValueType)

+ 3 - 6
session_find.go

@@ -66,7 +66,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 			var err error
 			autoCond, err = session.Statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName)
 			if err != nil {
-				panic(err)
+				return err
 			}
 		} else {
 			// !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given.
@@ -80,11 +80,8 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 					}
 					colName = session.Engine.Quote(nm) + "." + colName
 				}
-				if session.Engine.dialect.DBType() == core.MSSQL {
-					autoCond = builder.IsNull{colName}
-				} else {
-					autoCond = builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"})
-				}
+
+				autoCond = session.Engine.CondDeleted(colName)
 			}
 		}
 	}

+ 1 - 219
statement.go

@@ -490,224 +490,6 @@ func (statement *Statement) colName(col *core.Column, tableName string) string {
 	return statement.Engine.Quote(col.Name)
 }
 
-func buildConds(engine *Engine, table *core.Table, bean interface{},
-	includeVersion bool, includeUpdated bool, includeNil bool,
-	includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
-	mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
-	var conds []builder.Cond
-	for _, col := range table.Columns() {
-		if !includeVersion && col.IsVersion {
-			continue
-		}
-		if !includeUpdated && col.IsUpdated {
-			continue
-		}
-		if !includeAutoIncr && col.IsAutoIncrement {
-			continue
-		}
-
-		if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) {
-			continue
-		}
-		if col.SQLType.IsJson() {
-			continue
-		}
-
-		var colName string
-		if addedTableName {
-			var nm = tableName
-			if len(aliasName) > 0 {
-				nm = aliasName
-			}
-			colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
-		} else {
-			colName = engine.Quote(col.Name)
-		}
-
-		fieldValuePtr, err := col.ValueOf(bean)
-		if err != nil {
-			engine.logger.Error(err)
-			continue
-		}
-
-		if col.IsDeleted && !unscoped { // tag "deleted" is enabled
-			if engine.dialect.DBType() == core.MSSQL {
-				conds = append(conds, builder.IsNull{colName})
-			} else {
-				conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}))
-			}
-		}
-
-		fieldValue := *fieldValuePtr
-		if fieldValue.Interface() == nil {
-			continue
-		}
-
-		fieldType := reflect.TypeOf(fieldValue.Interface())
-		requiredField := useAllCols
-
-		if b, ok := getFlagForColumn(mustColumnMap, col); ok {
-			if b {
-				requiredField = true
-			} else {
-				continue
-			}
-		}
-
-		if fieldType.Kind() == reflect.Ptr {
-			if fieldValue.IsNil() {
-				if includeNil {
-					conds = append(conds, builder.Eq{colName: nil})
-				}
-				continue
-			} else if !fieldValue.IsValid() {
-				continue
-			} else {
-				// dereference ptr type to instance type
-				fieldValue = fieldValue.Elem()
-				fieldType = reflect.TypeOf(fieldValue.Interface())
-				requiredField = true
-			}
-		}
-
-		var val interface{}
-		switch fieldType.Kind() {
-		case reflect.Bool:
-			if allUseBool || requiredField {
-				val = fieldValue.Interface()
-			} else {
-				// if a bool in a struct, it will not be as a condition because it default is false,
-				// please use Where() instead
-				continue
-			}
-		case reflect.String:
-			if !requiredField && fieldValue.String() == "" {
-				continue
-			}
-			// for MyString, should convert to string or panic
-			if fieldType.String() != reflect.String.String() {
-				val = fieldValue.String()
-			} else {
-				val = fieldValue.Interface()
-			}
-		case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
-			if !requiredField && fieldValue.Int() == 0 {
-				continue
-			}
-			val = fieldValue.Interface()
-		case reflect.Float32, reflect.Float64:
-			if !requiredField && fieldValue.Float() == 0.0 {
-				continue
-			}
-			val = fieldValue.Interface()
-		case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
-			if !requiredField && fieldValue.Uint() == 0 {
-				continue
-			}
-			t := int64(fieldValue.Uint())
-			val = reflect.ValueOf(&t).Interface()
-		case reflect.Struct:
-			if fieldType.ConvertibleTo(core.TimeType) {
-				t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
-				if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
-					continue
-				}
-				val = engine.formatColTime(col, t)
-			} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
-				continue
-			} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
-				val, _ = valNul.Value()
-				if val == nil {
-					continue
-				}
-			} else {
-				if col.SQLType.IsJson() {
-					if col.SQLType.IsText() {
-						bytes, err := json.Marshal(fieldValue.Interface())
-						if err != nil {
-							engine.logger.Error(err)
-							continue
-						}
-						val = string(bytes)
-					} else if col.SQLType.IsBlob() {
-						var bytes []byte
-						var err error
-						bytes, err = json.Marshal(fieldValue.Interface())
-						if err != nil {
-							engine.logger.Error(err)
-							continue
-						}
-						val = bytes
-					}
-				} else {
-					engine.autoMapType(fieldValue)
-					if table, ok := engine.Tables[fieldValue.Type()]; ok {
-						if len(table.PrimaryKeys) == 1 {
-							pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
-							// fix non-int pk issues
-							//if pkField.Int() != 0 {
-							if pkField.IsValid() && !isZero(pkField.Interface()) {
-								val = pkField.Interface()
-							} else {
-								continue
-							}
-						} else {
-							//TODO: how to handler?
-							panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys))
-						}
-					} else {
-						val = fieldValue.Interface()
-					}
-				}
-			}
-		case reflect.Array:
-			continue
-		case reflect.Slice, reflect.Map:
-			if fieldValue == reflect.Zero(fieldType) {
-				continue
-			}
-			if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
-				continue
-			}
-
-			if col.SQLType.IsText() {
-				bytes, err := json.Marshal(fieldValue.Interface())
-				if err != nil {
-					engine.logger.Error(err)
-					continue
-				}
-				val = string(bytes)
-			} else if col.SQLType.IsBlob() {
-				var bytes []byte
-				var err error
-				if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
-					fieldType.Elem().Kind() == reflect.Uint8 {
-					if fieldValue.Len() > 0 {
-						val = fieldValue.Bytes()
-					} else {
-						continue
-					}
-				} else {
-					bytes, err = json.Marshal(fieldValue.Interface())
-					if err != nil {
-						engine.logger.Error(err)
-						continue
-					}
-					val = bytes
-				}
-			} else {
-				continue
-			}
-		default:
-			val = fieldValue.Interface()
-		}
-
-		conds = append(conds, builder.Eq{colName: val})
-	}
-
-	return builder.And(conds...), nil
-}
-
 // TableName return current tableName
 func (statement *Statement) TableName() string {
 	if statement.AltTableName != "" {
@@ -1104,7 +886,7 @@ func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interfa
 }
 
 func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
-	return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
+	return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
 		statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
 }