瀏覽代碼

Join add TableName interface support

xormplus 7 年之前
父節點
當前提交
a6b3ff88db
共有 3 個文件被更改,包括 112 次插入20 次删除
  1. 64 13
      session_find_test.go
  2. 24 0
      session_get_test.go
  3. 24 7
      statement.go

+ 64 - 13
session_find_test.go

@@ -120,10 +120,8 @@ func TestFind2(t *testing.T) {
 	assertSync(t, new(Userinfo))
 
 	err := testEngine.Find(&users)
-	if err != nil {
-		t.Error(err)
-		panic(err)
-	}
+	assert.NoError(t, err)
+
 	for _, user := range users {
 		fmt.Println(user)
 	}
@@ -139,13 +137,15 @@ type TeamUser struct {
 	TeamId int64
 }
 
+func (TeamUser) TableName() string {
+	return "team_user"
+}
+
 func TestFind3(t *testing.T) {
+	var teamUser = new(TeamUser)
 	assert.NoError(t, prepareEngine())
-	err := testEngine.Sync2(new(Team), new(TeamUser))
-	if err != nil {
-		t.Error(err)
-		panic(err.Error())
-	}
+	err := testEngine.Sync2(new(Team), teamUser)
+	assert.NoError(t, err)
 
 	var teams []Team
 	err = testEngine.Cols("`team`.id").
@@ -153,10 +153,47 @@ func TestFind3(t *testing.T) {
 		And("`team_user`.uid=?", 2).
 		Join("INNER", "`team_user`", "`team_user`.team_id=`team`.id").
 		Find(&teams)
-	if err != nil {
-		t.Error(err)
-		panic(err.Error())
-	}
+	assert.NoError(t, err)
+
+	teams = make([]Team, 0)
+	err = testEngine.Cols("`team`.id").
+		Where("`team_user`.org_id=?", 1).
+		And("`team_user`.uid=?", 2).
+		Join("INNER", teamUser, "`team_user`.team_id=`team`.id").
+		Find(&teams)
+	assert.NoError(t, err)
+
+	teams = make([]Team, 0)
+	err = testEngine.Cols("`team`.id").
+		Where("`team_user`.org_id=?", 1).
+		And("`team_user`.uid=?", 2).
+		Join("INNER", []interface{}{teamUser}, "`team_user`.team_id=`team`.id").
+		Find(&teams)
+	assert.NoError(t, err)
+
+	teams = make([]Team, 0)
+	err = testEngine.Cols("`team`.id").
+		Where("`tu`.org_id=?", 1).
+		And("`tu`.uid=?", 2).
+		Join("INNER", []string{"team_user", "tu"}, "`tu`.team_id=`team`.id").
+		Find(&teams)
+	assert.NoError(t, err)
+
+	teams = make([]Team, 0)
+	err = testEngine.Cols("`team`.id").
+		Where("`tu`.org_id=?", 1).
+		And("`tu`.uid=?", 2).
+		Join("INNER", []interface{}{"team_user", "tu"}, "`tu`.team_id=`team`.id").
+		Find(&teams)
+	assert.NoError(t, err)
+
+	teams = make([]Team, 0)
+	err = testEngine.Cols("`team`.id").
+		Where("`tu`.org_id=?", 1).
+		And("`tu`.uid=?", 2).
+		Join("INNER", []interface{}{teamUser, "tu"}, "`tu`.team_id=`team`.id").
+		Find(&teams)
+	assert.NoError(t, err)
 }
 
 func TestFindMap(t *testing.T) {
@@ -561,4 +598,18 @@ func TestFindAndCountOneFunc(t *testing.T) {
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, len(results))
 	assert.EqualValues(t, 1, cnt)
+
+	results = make([]FindAndCountStruct, 0, 1)
+	cnt, err = testEngine.Where("msg = ?", true).Select("id, content, msg").
+		Limit(1).FindAndCount(&results)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, len(results))
+	assert.EqualValues(t, 1, cnt)
+
+	results = make([]FindAndCountStruct, 0, 1)
+	cnt, err = testEngine.Where("msg = ?", true).Desc("id").
+		Limit(1).FindAndCount(&results)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, len(results))
+	assert.EqualValues(t, 1, cnt)
 }

+ 24 - 0
session_get_test.go

@@ -255,3 +255,27 @@ func TestJSONString(t *testing.T) {
 	assert.EqualValues(t, 1, len(jss))
 	assert.EqualValues(t, `["1","2"]`, jss[0].Content)
 }
+
+func TestGetActionMapping(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	type ActionMapping struct {
+		ActionId    string `xorm:"pk"`
+		ActionName  string `xorm:"index"`
+		ScriptId    string `xorm:"unique"`
+		RollbackId  string `xorm:"unique"`
+		Env         string
+		Tags        string
+		Description string
+		UpdateTime  time.Time `xorm:"updated"`
+		DeleteTime  time.Time `xorm:"deleted"`
+	}
+
+	assertSync(t, new(ActionMapping))
+
+	var valuesSlice = make([]string, 2)
+	_, err := testEngine.Table(new(ActionMapping)).
+		Cols("script_id", "rollback_id").
+		ID(1).Get(&valuesSlice)
+	assert.NoError(t, err)
+}

+ 24 - 7
statement.go

@@ -766,12 +766,19 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 		var table string
 		if l > 0 {
 			f := t[0]
-			v := rValue(f)
-			t := v.Type()
-			if t.Kind() == reflect.String {
+			switch f.(type) {
+			case string:
 				table = f.(string)
-			} else if t.Kind() == reflect.Struct {
-				table = statement.Engine.tbName(v)
+			case TableName:
+				table = f.(TableName).TableName()
+			default:
+				v := rValue(f)
+				t := v.Type()
+				if t.Kind() == reflect.Struct {
+					fmt.Fprintf(&buf, statement.Engine.tbName(v))
+				} else {
+					fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f)))
+				}
 			}
 		}
 		if l > 1 {
@@ -780,8 +787,18 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
 		} else if l == 1 {
 			fmt.Fprintf(&buf, statement.Engine.Quote(table))
 		}
+	case TableName:
+		fmt.Fprintf(&buf, tablename.(TableName).TableName())
+	case string:
+		fmt.Fprintf(&buf, tablename.(string))
 	default:
-		fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
+		v := rValue(tablename)
+		t := v.Type()
+		if t.Kind() == reflect.Struct {
+			fmt.Fprintf(&buf, statement.Engine.tbName(v))
+		} else {
+			fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
+		}
 	}
 
 	fmt.Fprintf(&buf, " ON %v", condition)
@@ -907,7 +924,7 @@ func (statement *Statement) genDelIndexSQL() []string {
 
 func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
 	quote := statement.Engine.Quote
-	sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()),
+	sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
 		col.String(statement.Engine.dialect))
 	if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
 		sql += " COMMENT '" + col.Comment + "'"