bulkinserter.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "github.com/tal-tech/go-zero/core/executors"
  8. "github.com/tal-tech/go-zero/core/logx"
  9. "github.com/tal-tech/go-zero/core/stringx"
  10. )
  11. const (
  12. flushInterval = time.Second
  13. maxBulkRows = 1000
  14. valuesKeyword = "values"
  15. )
  16. var emptyBulkStmt bulkStmt
  17. type (
  18. ResultHandler func(sql.Result, error)
  19. BulkInserter struct {
  20. executor *executors.PeriodicalExecutor
  21. inserter *dbInserter
  22. stmt bulkStmt
  23. }
  24. bulkStmt struct {
  25. prefix string
  26. valueFormat string
  27. suffix string
  28. }
  29. )
  30. func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
  31. bkStmt, err := parseInsertStmt(stmt)
  32. if err != nil {
  33. return nil, err
  34. }
  35. inserter := &dbInserter{
  36. sqlConn: sqlConn,
  37. stmt: bkStmt,
  38. }
  39. return &BulkInserter{
  40. executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
  41. inserter: inserter,
  42. stmt: bkStmt,
  43. }, nil
  44. }
  45. func (bi *BulkInserter) Flush() {
  46. bi.executor.Flush()
  47. }
  48. func (bi *BulkInserter) Insert(args ...interface{}) error {
  49. value, err := format(bi.stmt.valueFormat, args...)
  50. if err != nil {
  51. return err
  52. }
  53. bi.executor.Add(value)
  54. return nil
  55. }
  56. func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
  57. bi.executor.Sync(func() {
  58. bi.inserter.resultHandler = handler
  59. })
  60. }
  61. func (bi *BulkInserter) UpdateOrDelete(fn func()) {
  62. bi.executor.Flush()
  63. fn()
  64. }
  65. func (bi *BulkInserter) UpdateStmt(stmt string) error {
  66. bkStmt, err := parseInsertStmt(stmt)
  67. if err != nil {
  68. return err
  69. }
  70. bi.executor.Flush()
  71. bi.executor.Sync(func() {
  72. bi.inserter.stmt = bkStmt
  73. })
  74. return nil
  75. }
  76. type dbInserter struct {
  77. sqlConn SqlConn
  78. stmt bulkStmt
  79. values []string
  80. resultHandler ResultHandler
  81. }
  82. func (in *dbInserter) AddTask(task interface{}) bool {
  83. in.values = append(in.values, task.(string))
  84. return len(in.values) >= maxBulkRows
  85. }
  86. func (in *dbInserter) Execute(bulk interface{}) {
  87. values := bulk.([]string)
  88. if len(values) == 0 {
  89. return
  90. }
  91. stmtWithoutValues := in.stmt.prefix
  92. valuesStr := strings.Join(values, ", ")
  93. stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
  94. if len(in.stmt.suffix) > 0 {
  95. stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
  96. }
  97. result, err := in.sqlConn.Exec(stmt)
  98. if in.resultHandler != nil {
  99. in.resultHandler(result, err)
  100. } else if err != nil {
  101. logx.Errorf("sql: %s, error: %s", stmt, err)
  102. }
  103. }
  104. func (in *dbInserter) RemoveAll() interface{} {
  105. values := in.values
  106. in.values = nil
  107. return values
  108. }
  109. func parseInsertStmt(stmt string) (bulkStmt, error) {
  110. lower := strings.ToLower(stmt)
  111. pos := strings.Index(lower, valuesKeyword)
  112. if pos <= 0 {
  113. return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
  114. }
  115. var columns int
  116. right := strings.LastIndexByte(lower[:pos], ')')
  117. if right > 0 {
  118. left := strings.LastIndexByte(lower[:right], '(')
  119. if left > 0 {
  120. values := lower[left+1 : right]
  121. values = stringx.Filter(values, func(r rune) bool {
  122. return r == ' ' || r == '\t' || r == '\r' || r == '\n'
  123. })
  124. fields := strings.FieldsFunc(values, func(r rune) bool {
  125. return r == ','
  126. })
  127. columns = len(fields)
  128. }
  129. }
  130. var variables int
  131. var valueFormat string
  132. var suffix string
  133. left := strings.IndexByte(lower[pos:], '(')
  134. if left > 0 {
  135. right = strings.IndexByte(lower[pos+left:], ')')
  136. if right > 0 {
  137. values := lower[pos+left : pos+left+right]
  138. for _, x := range values {
  139. if x == '?' {
  140. variables++
  141. }
  142. }
  143. valueFormat = stmt[pos+left : pos+left+right+1]
  144. suffix = strings.TrimSpace(stmt[pos+left+right+1:])
  145. }
  146. }
  147. if variables == 0 {
  148. return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
  149. }
  150. if columns > 0 && columns != variables {
  151. return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
  152. }
  153. return bulkStmt{
  154. prefix: stmt[:pos+len(valuesKeyword)],
  155. valueFormat: valueFormat,
  156. suffix: suffix,
  157. }, nil
  158. }