utils.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. // copy from core/stores/sqlx/utils.go
  2. package mocksql
  3. import (
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "git.i2edu.net/i2/go-zero/core/logx"
  8. "git.i2edu.net/i2/go-zero/core/mapping"
  9. )
  10. // ErrNotFound is the alias of sql.ErrNoRows
  11. var ErrNotFound = sql.ErrNoRows
  12. func desensitize(datasource string) string {
  13. // remove account
  14. pos := strings.LastIndex(datasource, "@")
  15. if 0 <= pos && pos+1 < len(datasource) {
  16. datasource = datasource[pos+1:]
  17. }
  18. return datasource
  19. }
  20. func escape(input string) string {
  21. var b strings.Builder
  22. for _, ch := range input {
  23. switch ch {
  24. case '\x00':
  25. b.WriteString(`\x00`)
  26. case '\r':
  27. b.WriteString(`\r`)
  28. case '\n':
  29. b.WriteString(`\n`)
  30. case '\\':
  31. b.WriteString(`\\`)
  32. case '\'':
  33. b.WriteString(`\'`)
  34. case '"':
  35. b.WriteString(`\"`)
  36. case '\x1a':
  37. b.WriteString(`\x1a`)
  38. default:
  39. b.WriteRune(ch)
  40. }
  41. }
  42. return b.String()
  43. }
  44. func format(query string, args ...interface{}) (string, error) {
  45. numArgs := len(args)
  46. if numArgs == 0 {
  47. return query, nil
  48. }
  49. var b strings.Builder
  50. argIndex := 0
  51. for _, ch := range query {
  52. if ch == '?' {
  53. if argIndex >= numArgs {
  54. return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
  55. }
  56. arg := args[argIndex]
  57. argIndex++
  58. switch v := arg.(type) {
  59. case bool:
  60. if v {
  61. b.WriteByte('1')
  62. } else {
  63. b.WriteByte('0')
  64. }
  65. case string:
  66. b.WriteByte('\'')
  67. b.WriteString(escape(v))
  68. b.WriteByte('\'')
  69. default:
  70. b.WriteString(mapping.Repr(v))
  71. }
  72. } else {
  73. b.WriteRune(ch)
  74. }
  75. }
  76. if argIndex < numArgs {
  77. return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
  78. }
  79. return b.String(), nil
  80. }
  81. func logSqlError(stmt string, err error) {
  82. if err != nil && err != ErrNotFound {
  83. logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
  84. }
  85. }