stmt.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. // copy from core/stores/sqlx/stmt.go
  2. package mocksql
  3. import (
  4. "database/sql"
  5. "fmt"
  6. "time"
  7. "github.com/tal-tech/go-zero/core/logx"
  8. "github.com/tal-tech/go-zero/core/timex"
  9. )
  10. const slowThreshold = time.Millisecond * 500
  11. func exec(db *sql.DB, q string, args ...interface{}) (sql.Result, error) {
  12. tx, err := db.Begin()
  13. if err != nil {
  14. return nil, err
  15. }
  16. defer func() {
  17. switch err {
  18. case nil:
  19. err = tx.Commit()
  20. default:
  21. tx.Rollback()
  22. }
  23. }()
  24. stmt, err := format(q, args...)
  25. if err != nil {
  26. return nil, err
  27. }
  28. startTime := timex.Now()
  29. result, err := tx.Exec(q, args...)
  30. duration := timex.Since(startTime)
  31. if duration > slowThreshold {
  32. logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
  33. } else {
  34. logx.WithDuration(duration).Infof("sql exec: %s", stmt)
  35. }
  36. if err != nil {
  37. logSqlError(stmt, err)
  38. }
  39. return result, err
  40. }
  41. func execStmt(conn *sql.Stmt, args ...interface{}) (sql.Result, error) {
  42. stmt := fmt.Sprint(args...)
  43. startTime := timex.Now()
  44. result, err := conn.Exec(args...)
  45. duration := timex.Since(startTime)
  46. if duration > slowThreshold {
  47. logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
  48. } else {
  49. logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
  50. }
  51. if err != nil {
  52. logSqlError(stmt, err)
  53. }
  54. return result, err
  55. }
  56. func query(db *sql.DB, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
  57. tx, err := db.Begin()
  58. if err != nil {
  59. return err
  60. }
  61. defer func() {
  62. switch err {
  63. case nil:
  64. err = tx.Commit()
  65. default:
  66. tx.Rollback()
  67. }
  68. }()
  69. stmt, err := format(q, args...)
  70. if err != nil {
  71. return err
  72. }
  73. startTime := timex.Now()
  74. rows, err := tx.Query(q, args...)
  75. duration := timex.Since(startTime)
  76. if duration > slowThreshold {
  77. logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
  78. } else {
  79. logx.WithDuration(duration).Infof("sql query: %s", stmt)
  80. }
  81. if err != nil {
  82. logSqlError(stmt, err)
  83. return err
  84. }
  85. defer rows.Close()
  86. return scanner(rows)
  87. }
  88. func queryStmt(conn *sql.Stmt, scanner func(*sql.Rows) error, args ...interface{}) error {
  89. stmt := fmt.Sprint(args...)
  90. startTime := timex.Now()
  91. rows, err := conn.Query(args...)
  92. duration := timex.Since(startTime)
  93. if duration > slowThreshold {
  94. logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
  95. } else {
  96. logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
  97. }
  98. if err != nil {
  99. logSqlError(stmt, err)
  100. return err
  101. }
  102. defer rows.Close()
  103. return scanner(rows)
  104. }