utils.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. // copy from core/stores/sqlx/utils.go
  2. package mocksql
  3. import (
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "github.com/tal-tech/go-zero/core/logx"
  8. "github.com/tal-tech/go-zero/core/mapping"
  9. )
  10. var ErrNotFound = sql.ErrNoRows
  11. func desensitize(datasource string) string {
  12. // remove account
  13. pos := strings.LastIndex(datasource, "@")
  14. if 0 <= pos && pos+1 < len(datasource) {
  15. datasource = datasource[pos+1:]
  16. }
  17. return datasource
  18. }
  19. func escape(input string) string {
  20. var b strings.Builder
  21. for _, ch := range input {
  22. switch ch {
  23. case '\x00':
  24. b.WriteString(`\x00`)
  25. case '\r':
  26. b.WriteString(`\r`)
  27. case '\n':
  28. b.WriteString(`\n`)
  29. case '\\':
  30. b.WriteString(`\\`)
  31. case '\'':
  32. b.WriteString(`\'`)
  33. case '"':
  34. b.WriteString(`\"`)
  35. case '\x1a':
  36. b.WriteString(`\x1a`)
  37. default:
  38. b.WriteRune(ch)
  39. }
  40. }
  41. return b.String()
  42. }
  43. func format(query string, args ...interface{}) (string, error) {
  44. numArgs := len(args)
  45. if numArgs == 0 {
  46. return query, nil
  47. }
  48. var b strings.Builder
  49. argIndex := 0
  50. for _, ch := range query {
  51. if ch == '?' {
  52. if argIndex >= numArgs {
  53. return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
  54. }
  55. arg := args[argIndex]
  56. argIndex++
  57. switch v := arg.(type) {
  58. case bool:
  59. if v {
  60. b.WriteByte('1')
  61. } else {
  62. b.WriteByte('0')
  63. }
  64. case string:
  65. b.WriteByte('\'')
  66. b.WriteString(escape(v))
  67. b.WriteByte('\'')
  68. default:
  69. b.WriteString(mapping.Repr(v))
  70. }
  71. } else {
  72. b.WriteRune(ch)
  73. }
  74. }
  75. if argIndex < numArgs {
  76. return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
  77. }
  78. return b.String(), nil
  79. }
  80. func logSqlError(stmt string, err error) {
  81. if err != nil && err != ErrNotFound {
  82. logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
  83. }
  84. }