utils.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package sqlx
  2. import (
  3. "fmt"
  4. "strconv"
  5. "strings"
  6. "git.i2edu.net/i2/go-zero/core/logx"
  7. "git.i2edu.net/i2/go-zero/core/mapping"
  8. )
  9. func desensitize(datasource string) string {
  10. // remove account
  11. pos := strings.LastIndex(datasource, "@")
  12. if 0 <= pos && pos+1 < len(datasource) {
  13. datasource = datasource[pos+1:]
  14. }
  15. return datasource
  16. }
  17. func escape(input string) string {
  18. var b strings.Builder
  19. for _, ch := range input {
  20. switch ch {
  21. case '\x00':
  22. b.WriteString(`\x00`)
  23. case '\r':
  24. b.WriteString(`\r`)
  25. case '\n':
  26. b.WriteString(`\n`)
  27. case '\\':
  28. b.WriteString(`\\`)
  29. case '\'':
  30. b.WriteString(`\'`)
  31. case '"':
  32. b.WriteString(`\"`)
  33. case '\x1a':
  34. b.WriteString(`\x1a`)
  35. default:
  36. b.WriteRune(ch)
  37. }
  38. }
  39. return b.String()
  40. }
  41. func format(query string, args ...interface{}) (string, error) {
  42. numArgs := len(args)
  43. if numArgs == 0 {
  44. return query, nil
  45. }
  46. var b strings.Builder
  47. var argIndex int
  48. bytes := len(query)
  49. for i := 0; i < bytes; i++ {
  50. ch := query[i]
  51. switch ch {
  52. case '?':
  53. if argIndex >= numArgs {
  54. return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
  55. }
  56. writeValue(&b, args[argIndex])
  57. argIndex++
  58. case '$':
  59. var j int
  60. for j = i + 1; j < bytes; j++ {
  61. char := query[j]
  62. if char < '0' || '9' < char {
  63. break
  64. }
  65. }
  66. if j > i+1 {
  67. index, err := strconv.Atoi(query[i+1 : j])
  68. if err != nil {
  69. return "", err
  70. }
  71. // index starts from 1 for pg
  72. if index > argIndex {
  73. argIndex = index
  74. }
  75. index--
  76. if index < 0 || numArgs <= index {
  77. return "", fmt.Errorf("error: wrong index %d in sql", index)
  78. }
  79. writeValue(&b, args[index])
  80. i = j - 1
  81. }
  82. default:
  83. b.WriteByte(ch)
  84. }
  85. }
  86. if argIndex < numArgs {
  87. return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
  88. }
  89. return b.String(), nil
  90. }
  91. func logInstanceError(datasource string, err error) {
  92. datasource = desensitize(datasource)
  93. logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
  94. }
  95. func logSqlError(stmt string, err error) {
  96. if err != nil && err != ErrNotFound {
  97. logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
  98. }
  99. }
  100. func writeValue(buf *strings.Builder, arg interface{}) {
  101. switch v := arg.(type) {
  102. case bool:
  103. if v {
  104. buf.WriteByte('1')
  105. } else {
  106. buf.WriteByte('0')
  107. }
  108. case string:
  109. buf.WriteByte('\'')
  110. buf.WriteString(escape(v))
  111. buf.WriteByte('\'')
  112. default:
  113. buf.WriteString(mapping.Repr(v))
  114. }
  115. }