Browse Source

add QueryExpr function

xormplus 7 năm trước cách đây
mục cha
commit
dd83e8e121
17 tập tin đã thay đổi với 293 bổ sung18 xóa
  1. 8 0
      README.md
  2. 1 1
      engine.go
  3. 1 1
      engine_cond.go
  4. 2 0
      error.go
  5. 1 1
      session_cond.go
  6. 1 1
      session_cond_test.go
  7. 1 1
      session_exist.go
  8. 1 1
      session_find.go
  9. 22 1
      session_query.go
  10. 2 2
      session_query_test.go
  11. 1 1
      session_raw.go
  12. 1 1
      session_stats_test.go
  13. 1 1
      session_sum_test.go
  14. 1 1
      session_update.go
  15. 96 0
      sql_expr.go
  16. 34 5
      statement.go
  17. 119 0
      string_builder.go

+ 8 - 0
README.md

@@ -1110,6 +1110,14 @@ err := engine.Table("user").Select("user.*, detail.*")
 // SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10
 ```
 
+* 子查询
+
+```Go
+var student []Student
+err = db.Table("student").Select("id ,name").Where("id in (?)", db.Table("studentinfo").Select("id").Where("status = ?", 2).QueryExpr()).Find(&student)
+//SELECT id ,name FROM `student` WHERE (id in (SELECT id FROM `studentinfo` WHERE (status = 2)))
+```
+
 * 根据条件遍历数据库,可以有两种方式: Iterate and Rows
 
 ```Go

+ 1 - 1
engine.go

@@ -20,7 +20,7 @@ import (
 	"time"
 
 	"github.com/fsnotify/fsnotify"
-	"github.com/go-xorm/builder"
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 )
 

+ 1 - 1
engine_cond.go

@@ -12,7 +12,7 @@ import (
 	"strings"
 	"time"
 
-	"github.com/go-xorm/builder"
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 )
 

+ 2 - 0
error.go

@@ -30,6 +30,8 @@ var (
 	ErrNotImplemented = errors.New("Not implemented")
 	// ErrConditionType condition type unsupported
 	ErrConditionType = errors.New("Unsupported condition type")
+	// ErrNeedMoreArguments need more arguments
+	ErrNeedMoreArguments = errors.New("Need more sql arguments")
 )
 
 // ErrFieldIsNotExist columns does not exist

+ 1 - 1
session_cond.go

@@ -4,7 +4,7 @@
 
 package xorm
 
-import "github.com/go-xorm/builder"
+import "github.com/xormplus/builder"
 
 // Sql provides raw sql input parameter. When you have a complex SQL statement
 // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL.

+ 1 - 1
session_cond_test.go

@@ -9,8 +9,8 @@ import (
 	"fmt"
 	"testing"
 
-	"github.com/go-xorm/builder"
 	"github.com/stretchr/testify/assert"
+	"github.com/xormplus/builder"
 )
 
 func TestBuilder(t *testing.T) {

+ 1 - 1
session_exist.go

@@ -9,7 +9,7 @@ import (
 	"fmt"
 	"reflect"
 
-	"github.com/go-xorm/builder"
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 )
 

+ 1 - 1
session_find.go

@@ -10,7 +10,7 @@ import (
 	"reflect"
 	"strings"
 
-	"github.com/go-xorm/builder"
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 )
 

+ 22 - 1
session_query.go

@@ -11,7 +11,7 @@ import (
 	"strings"
 	"time"
 
-	"github.com/go-xorm/builder"
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 )
 
@@ -384,3 +384,24 @@ func (session *Session) QueryInterface(sqlorArgs ...interface{}) ([]map[string]i
 
 	return rows2Interfaces(rows)
 }
+
+// QueryExpr returns the query as bound SQL
+func (session *Session) QueryExpr(sqlorArgs ...interface{}) sqlExpr {
+	if session.isAutoClose {
+		defer session.Close()
+	}
+
+	sqlStr, args, err := session.genQuerySQL()
+	if err != nil {
+		session.engine.logger.Error(err)
+		return sqlExpr{sqlExpr: ""}
+	}
+
+	sqlStr, err = ConvertToBoundSQL(sqlStr, args)
+	if err != nil {
+		session.engine.logger.Error(err)
+		return sqlExpr{sqlExpr: ""}
+	}
+
+	return sqlExpr{sqlExpr: sqlStr}
+}

+ 2 - 2
session_query_test.go

@@ -10,8 +10,8 @@ import (
 	"testing"
 	"time"
 
-	"github.com/go-xorm/builder"
-	"github.com/go-xorm/core"
+	"github.com/xormplus/builder"
+	"github.com/xormplus/core"
 
 	"github.com/stretchr/testify/assert"
 )

+ 1 - 1
session_raw.go

@@ -9,7 +9,7 @@ import (
 	"reflect"
 	"time"
 
-	"github.com/go-xorm/builder"
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 )
 

+ 1 - 1
session_stats_test.go

@@ -9,8 +9,8 @@ import (
 	"strconv"
 	"testing"
 
-	"github.com/go-xorm/builder"
 	"github.com/stretchr/testify/assert"
+	"github.com/xormplus/builder"
 )
 
 func isFloatEq(i, j float64, precision int) bool {

+ 1 - 1
session_sum_test.go

@@ -9,8 +9,8 @@ import (
 	"strconv"
 	"testing"
 
-	"github.com/go-xorm/builder"
 	"github.com/stretchr/testify/assert"
+	"github.com/xormplus/builder"
 )
 
 func isFloatEq(i, j float64, precision int) bool {

+ 1 - 1
session_update.go

@@ -11,7 +11,7 @@ import (
 	"strconv"
 	"strings"
 
-	"github.com/go-xorm/builder"
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 )
 

+ 96 - 0
sql_expr.go

@@ -0,0 +1,96 @@
+package xorm
+
+import (
+	sql2 "database/sql"
+	"fmt"
+	"reflect"
+	"time"
+)
+
+type sqlExpr struct {
+	sqlExpr string
+}
+
+func noSQLQuoteNeeded(a interface{}) bool {
+	switch a.(type) {
+	case int, int8, int16, int32, int64:
+		return true
+	case uint, uint8, uint16, uint32, uint64:
+		return true
+	case float32, float64:
+		return true
+	case bool:
+		return true
+	case string:
+		return false
+	case time.Time, *time.Time:
+		return false
+	case sqlExpr, *sqlExpr:
+		return true
+	}
+
+	t := reflect.TypeOf(a)
+	switch t.Kind() {
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		return true
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		return true
+	case reflect.Float32, reflect.Float64:
+		return true
+	case reflect.Bool:
+		return true
+	case reflect.String:
+		return false
+	}
+
+	return false
+}
+
+// ConvertToBoundSQL will convert SQL and args to a bound SQL
+func ConvertToBoundSQL(sql string, args []interface{}) (string, error) {
+	buf := StringBuilder{}
+	var i, j, start int
+	for ; i < len(sql); i++ {
+		if sql[i] == '?' {
+			_, err := buf.WriteString(sql[start:i])
+			if err != nil {
+				return "", err
+			}
+			start = i + 1
+
+			if len(args) == j {
+				return "", ErrNeedMoreArguments
+			}
+
+			arg := args[j]
+
+			if exprArg, ok := arg.(sqlExpr); ok {
+				_, err = fmt.Fprint(&buf, exprArg.sqlExpr)
+				if err != nil {
+					return "", err
+				}
+
+			} else {
+				if namedArg, ok := arg.(sql2.NamedArg); ok {
+					arg = namedArg.Value
+				}
+
+				if noSQLQuoteNeeded(arg) {
+					_, err = fmt.Fprint(&buf, arg)
+				} else {
+					_, err = fmt.Fprintf(&buf, "'%v'", arg)
+				}
+				if err != nil {
+					return "", err
+				}
+			}
+
+			j = j + 1
+		}
+	}
+	_, err := buf.WriteString(sql[start:])
+	if err != nil {
+		return "", err
+	}
+	return buf.String(), nil
+}

+ 34 - 5
statement.go

@@ -13,7 +13,7 @@ import (
 	"strings"
 	"time"
 
-	"github.com/go-xorm/builder"
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 )
 
@@ -146,8 +146,22 @@ func (statement *Statement) Where(query interface{}, args ...interface{}) *State
 func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
 	switch query.(type) {
 	case string:
-		cond := builder.Expr(query.(string), args...)
-		statement.cond = statement.cond.And(cond)
+		isExpr := false
+		var cargs []interface{}
+		for i, _ := range args {
+			if _, ok := args[i].(sqlExpr); ok {
+				isExpr = true
+			}
+			cargs = append(cargs, args[i])
+		}
+		if isExpr {
+			sqlStr, _ := ConvertToBoundSQL(query.(string), cargs)
+			cond := builder.Expr(sqlStr)
+			statement.cond = statement.cond.And(cond)
+		} else {
+			cond := builder.Expr(query.(string), args...)
+			statement.cond = statement.cond.And(cond)
+		}
 	case map[string]interface{}:
 		cond := builder.Eq(query.(map[string]interface{}))
 		statement.cond = statement.cond.And(cond)
@@ -170,8 +184,23 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
 func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
 	switch query.(type) {
 	case string:
-		cond := builder.Expr(query.(string), args...)
-		statement.cond = statement.cond.Or(cond)
+		isExpr := false
+		var cargs []interface{}
+		for i, _ := range args {
+			if _, ok := args[i].(sqlExpr); ok {
+				isExpr = true
+			}
+			cargs = append(cargs, args[i])
+		}
+		if isExpr {
+			sqlStr, _ := ConvertToBoundSQL(query.(string), cargs)
+			cond := builder.Expr(sqlStr)
+			statement.cond = statement.cond.Or(cond)
+		} else {
+			cond := builder.Expr(query.(string), args...)
+			statement.cond = statement.cond.Or(cond)
+		}
+
 	case map[string]interface{}:
 		cond := builder.Eq(query.(map[string]interface{}))
 		statement.cond = statement.cond.Or(cond)

+ 119 - 0
string_builder.go

@@ -0,0 +1,119 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xorm
+
+import (
+	"unicode/utf8"
+	"unsafe"
+)
+
+// A StringBuilder is used to efficiently build a string using Write methods.
+// It minimizes memory copying. The zero value is ready to use.
+// Do not copy a non-zero Builder.
+type StringBuilder struct {
+	addr *StringBuilder // of receiver, to detect copies by value
+	buf  []byte
+}
+
+// noescape hides a pointer from escape analysis.  noescape is
+// the identity function but escape analysis doesn't think the
+// output depends on the input. noescape is inlined and currently
+// compiles down to zero instructions.
+// USE CAREFULLY!
+// This was copied from the runtime; see issues 23382 and 7921.
+//go:nosplit
+func noescape(p unsafe.Pointer) unsafe.Pointer {
+	x := uintptr(p)
+	return unsafe.Pointer(x ^ 0)
+}
+
+func (b *StringBuilder) copyCheck() {
+	if b.addr == nil {
+		// This hack works around a failing of Go's escape analysis
+		// that was causing b to escape and be heap allocated.
+		// See issue 23382.
+		// TODO: once issue 7921 is fixed, this should be reverted to
+		// just "b.addr = b".
+		b.addr = (*StringBuilder)(noescape(unsafe.Pointer(b)))
+	} else if b.addr != b {
+		panic("strings: illegal use of non-zero Builder copied by value")
+	}
+}
+
+// String returns the accumulated string.
+func (b *StringBuilder) String() string {
+	return *(*string)(unsafe.Pointer(&b.buf))
+}
+
+// Len returns the number of accumulated bytes; b.Len() == len(b.String()).
+func (b *StringBuilder) Len() int { return len(b.buf) }
+
+// Reset resets the Builder to be empty.
+func (b *StringBuilder) Reset() {
+	b.addr = nil
+	b.buf = nil
+}
+
+// grow copies the buffer to a new, larger buffer so that there are at least n
+// bytes of capacity beyond len(b.buf).
+func (b *StringBuilder) grow(n int) {
+	buf := make([]byte, len(b.buf), 2*cap(b.buf)+n)
+	copy(buf, b.buf)
+	b.buf = buf
+}
+
+// Grow grows b's capacity, if necessary, to guarantee space for
+// another n bytes. After Grow(n), at least n bytes can be written to b
+// without another allocation. If n is negative, Grow panics.
+func (b *StringBuilder) Grow(n int) {
+	b.copyCheck()
+	if n < 0 {
+		panic("strings.Builder.Grow: negative count")
+	}
+	if cap(b.buf)-len(b.buf) < n {
+		b.grow(n)
+	}
+}
+
+// Write appends the contents of p to b's buffer.
+// Write always returns len(p), nil.
+func (b *StringBuilder) Write(p []byte) (int, error) {
+	b.copyCheck()
+	b.buf = append(b.buf, p...)
+	return len(p), nil
+}
+
+// WriteByte appends the byte c to b's buffer.
+// The returned error is always nil.
+func (b *StringBuilder) WriteByte(c byte) error {
+	b.copyCheck()
+	b.buf = append(b.buf, c)
+	return nil
+}
+
+// WriteRune appends the UTF-8 encoding of Unicode code point r to b's buffer.
+// It returns the length of r and a nil error.
+func (b *StringBuilder) WriteRune(r rune) (int, error) {
+	b.copyCheck()
+	if r < utf8.RuneSelf {
+		b.buf = append(b.buf, byte(r))
+		return 1, nil
+	}
+	l := len(b.buf)
+	if cap(b.buf)-l < utf8.UTFMax {
+		b.grow(utf8.UTFMax)
+	}
+	n := utf8.EncodeRune(b.buf[l:l+utf8.UTFMax], r)
+	b.buf = b.buf[:l+n]
+	return n, nil
+}
+
+// WriteString appends the contents of s to b's buffer.
+// It returns the length of s and a nil error.
+func (b *StringBuilder) WriteString(s string) (int, error) {
+	b.copyCheck()
+	b.buf = append(b.buf, s...)
+	return len(s), nil
+}