sql_expr.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package xorm
  2. import (
  3. sql2 "database/sql"
  4. "fmt"
  5. "reflect"
  6. "time"
  7. )
  8. type sqlExpr struct {
  9. sqlExpr string
  10. }
  11. func noSQLQuoteNeeded(a interface{}) bool {
  12. switch a.(type) {
  13. case int, int8, int16, int32, int64:
  14. return true
  15. case uint, uint8, uint16, uint32, uint64:
  16. return true
  17. case float32, float64:
  18. return true
  19. case bool:
  20. return true
  21. case string:
  22. return false
  23. case time.Time, *time.Time:
  24. return false
  25. case sqlExpr, *sqlExpr:
  26. return true
  27. }
  28. t := reflect.TypeOf(a)
  29. switch t.Kind() {
  30. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  31. return true
  32. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  33. return true
  34. case reflect.Float32, reflect.Float64:
  35. return true
  36. case reflect.Bool:
  37. return true
  38. case reflect.String:
  39. return false
  40. }
  41. return false
  42. }
  43. // ConvertToBoundSQL will convert SQL and args to a bound SQL
  44. func ConvertToBoundSQL(sql string, args []interface{}) (string, error) {
  45. buf := StringBuilder{}
  46. var i, j, start int
  47. for ; i < len(sql); i++ {
  48. if sql[i] == '?' {
  49. _, err := buf.WriteString(sql[start:i])
  50. if err != nil {
  51. return "", err
  52. }
  53. start = i + 1
  54. if len(args) == j {
  55. return "", ErrNeedMoreArguments
  56. }
  57. arg := args[j]
  58. if exprArg, ok := arg.(sqlExpr); ok {
  59. _, err = fmt.Fprint(&buf, exprArg.sqlExpr)
  60. if err != nil {
  61. return "", err
  62. }
  63. } else {
  64. if namedArg, ok := arg.(sql2.NamedArg); ok {
  65. arg = namedArg.Value
  66. }
  67. if noSQLQuoteNeeded(arg) {
  68. _, err = fmt.Fprint(&buf, arg)
  69. } else {
  70. _, err = fmt.Fprintf(&buf, "'%v'", arg)
  71. }
  72. if err != nil {
  73. return "", err
  74. }
  75. }
  76. j = j + 1
  77. }
  78. }
  79. _, err := buf.WriteString(sql[start:])
  80. if err != nil {
  81. return "", err
  82. }
  83. return buf.String(), nil
  84. }