Browse Source

Add support for map[string]interface{} as condition on Update and Where

xormplus 8 years ago
parent
commit
89d83cae6d
4 changed files with 74 additions and 15 deletions
  1. 20 11
      error.go
  2. 17 4
      session_update.go
  3. 31 0
      session_update_test.go
  4. 6 0
      statement.go

+ 20 - 11
error.go

@@ -9,15 +9,24 @@ import (
 )
 
 var (
-	ErrParamsType            error = errors.New("Params type error")
-	ErrParamsFormat          error = errors.New("Params format error")
-	ErrTableNotFound         error = errors.New("Not found table")
-	ErrUnSupportedType       error = errors.New("Unsupported type error")
-	ErrNotExist              error = errors.New("Not exist error")
-	ErrCacheFailed           error = errors.New("Cache failed")
-	ErrNeedDeletedCond       error = errors.New("Delete need at least one condition")
-	ErrNotImplemented        error = errors.New("Not implemented.")
-	ErrNotInTransaction      error = errors.New("Not in transaction.")
-	ErrNestedTransaction     error = errors.New("Nested transaction error.")
-	ErrTransactionDefinition error = errors.New("Transaction definition error.")
+	// ErrParamsType params error
+	ErrParamsType   = errors.New("Params type error")
+	ErrParamsFormat = errors.New("Params format error")
+	// ErrTableNotFound table not found error
+	ErrTableNotFound = errors.New("Not found table")
+	// ErrUnSupportedType unsupported error
+	ErrUnSupportedType = errors.New("Unsupported type error")
+	// ErrNotExist record is not exist error
+	ErrNotExist              = errors.New("Not exist error")
+	ErrNotInTransaction      = errors.New("Not in transaction.")
+	ErrNestedTransaction     = errors.New("Nested transaction error.")
+	ErrTransactionDefinition = errors.New("Transaction definition error.")
+	// ErrCacheFailed cache failed error
+	ErrCacheFailed = errors.New("Cache failed")
+	// ErrNeedDeletedCond delete needs less one condition error
+	ErrNeedDeletedCond = errors.New("Delete need at least one condition")
+	// ErrNotImplemented not implemented
+	ErrNotImplemented = errors.New("Not implemented")
+	// ErrConditionType condition type unsupported
+	ErrConditionType = errors.New("Unsupported conditon type")
 )

+ 17 - 4
session_update.go

@@ -242,10 +242,23 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 
 	var autoCond builder.Cond
 	if !session.statement.noAutoCondition && len(condiBean) > 0 {
-		var err error
-		autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
-		if err != nil {
-			return 0, err
+		if c, ok := condiBean[0].(map[string]interface{}); ok {
+			autoCond = builder.Eq(c)
+		} else {
+			ct := reflect.TypeOf(condiBean[0])
+			k := ct.Kind()
+			if k == reflect.Ptr {
+				k = ct.Elem().Kind()
+			}
+			if k == reflect.Struct {
+				var err error
+				autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
+				if err != nil {
+					return 0, err
+				}
+			} else {
+				return 0, ErrConditionType
+			}
 		}
 	}
 

+ 31 - 0
session_update_test.go

@@ -1215,3 +1215,34 @@ func TestCreatedUpdated2(t *testing.T) {
 	assert.True(t, s2.UpdateAt.Unix() > s.UpdateAt.Unix())
 	assert.True(t, s2.UpdateAt.Unix() > s2.CreateAt.Unix())
 }
+
+func TestUpdateMapCondition(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	type UpdateMapCondition struct {
+		Id     int64
+		String string
+	}
+
+	assertSync(t, new(UpdateMapCondition))
+
+	var c = UpdateMapCondition{
+		String: "string",
+	}
+	_, err := testEngine.Insert(&c)
+	assert.NoError(t, err)
+
+	cnt, err := testEngine.Update(&UpdateMapCondition{
+		String: "string1",
+	}, map[string]interface{}{
+		"id": c.Id,
+	})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+
+	var c2 UpdateMapCondition
+	has, err := testEngine.ID(c.Id).Get(&c2)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, "string1", c2.String)
+}

+ 6 - 0
statement.go

@@ -160,6 +160,9 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
 	case string:
 		cond := builder.Expr(query.(string), args...)
 		statement.cond = statement.cond.And(cond)
+	case map[string]interface{}:
+		cond := builder.Eq(query.(map[string]interface{}))
+		statement.cond = statement.cond.And(cond)
 	case builder.Cond:
 		cond := query.(builder.Cond)
 		statement.cond = statement.cond.And(cond)
@@ -181,6 +184,9 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
 	case string:
 		cond := builder.Expr(query.(string), args...)
 		statement.cond = statement.cond.Or(cond)
+	case map[string]interface{}:
+		cond := builder.Eq(query.(map[string]interface{}))
+		statement.cond = statement.cond.Or(cond)
 	case builder.Cond:
 		cond := query.(builder.Cond)
 		statement.cond = statement.cond.Or(cond)