bulkinserter.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "github.com/tal-tech/go-zero/core/executors"
  8. "github.com/tal-tech/go-zero/core/logx"
  9. "github.com/tal-tech/go-zero/core/stringx"
  10. )
  11. const (
  12. flushInterval = time.Second
  13. maxBulkRows = 1000
  14. valuesKeyword = "values"
  15. )
  16. var emptyBulkStmt bulkStmt
  17. type (
  18. // ResultHandler defines the method of result handlers.
  19. ResultHandler func(sql.Result, error)
  20. // A BulkInserter is used to batch insert records.
  21. // Postgresql is not supported yet, because of the sql is formated with symbol `$`.
  22. BulkInserter struct {
  23. executor *executors.PeriodicalExecutor
  24. inserter *dbInserter
  25. stmt bulkStmt
  26. }
  27. bulkStmt struct {
  28. prefix string
  29. valueFormat string
  30. suffix string
  31. }
  32. )
  33. // NewBulkInserter returns a BulkInserter.
  34. func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
  35. bkStmt, err := parseInsertStmt(stmt)
  36. if err != nil {
  37. return nil, err
  38. }
  39. inserter := &dbInserter{
  40. sqlConn: sqlConn,
  41. stmt: bkStmt,
  42. }
  43. return &BulkInserter{
  44. executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
  45. inserter: inserter,
  46. stmt: bkStmt,
  47. }, nil
  48. }
  49. // Flush flushes all the pending records.
  50. func (bi *BulkInserter) Flush() {
  51. bi.executor.Flush()
  52. }
  53. // Insert inserts given args.
  54. func (bi *BulkInserter) Insert(args ...interface{}) error {
  55. value, err := format(bi.stmt.valueFormat, args...)
  56. if err != nil {
  57. return err
  58. }
  59. bi.executor.Add(value)
  60. return nil
  61. }
  62. // SetResultHandler sets the given handler.
  63. func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
  64. bi.executor.Sync(func() {
  65. bi.inserter.resultHandler = handler
  66. })
  67. }
  68. // UpdateOrDelete runs update or delete queries, which flushes pending records first.
  69. func (bi *BulkInserter) UpdateOrDelete(fn func()) {
  70. bi.executor.Flush()
  71. fn()
  72. }
  73. // UpdateStmt updates the insert statement.
  74. func (bi *BulkInserter) UpdateStmt(stmt string) error {
  75. bkStmt, err := parseInsertStmt(stmt)
  76. if err != nil {
  77. return err
  78. }
  79. bi.executor.Flush()
  80. bi.executor.Sync(func() {
  81. bi.inserter.stmt = bkStmt
  82. })
  83. return nil
  84. }
  85. type dbInserter struct {
  86. sqlConn SqlConn
  87. stmt bulkStmt
  88. values []string
  89. resultHandler ResultHandler
  90. }
  91. func (in *dbInserter) AddTask(task interface{}) bool {
  92. in.values = append(in.values, task.(string))
  93. return len(in.values) >= maxBulkRows
  94. }
  95. func (in *dbInserter) Execute(bulk interface{}) {
  96. values := bulk.([]string)
  97. if len(values) == 0 {
  98. return
  99. }
  100. stmtWithoutValues := in.stmt.prefix
  101. valuesStr := strings.Join(values, ", ")
  102. stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
  103. if len(in.stmt.suffix) > 0 {
  104. stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
  105. }
  106. result, err := in.sqlConn.Exec(stmt)
  107. if in.resultHandler != nil {
  108. in.resultHandler(result, err)
  109. } else if err != nil {
  110. logx.Errorf("sql: %s, error: %s", stmt, err)
  111. }
  112. }
  113. func (in *dbInserter) RemoveAll() interface{} {
  114. values := in.values
  115. in.values = nil
  116. return values
  117. }
  118. func parseInsertStmt(stmt string) (bulkStmt, error) {
  119. lower := strings.ToLower(stmt)
  120. pos := strings.Index(lower, valuesKeyword)
  121. if pos <= 0 {
  122. return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
  123. }
  124. var columns int
  125. right := strings.LastIndexByte(lower[:pos], ')')
  126. if right > 0 {
  127. left := strings.LastIndexByte(lower[:right], '(')
  128. if left > 0 {
  129. values := lower[left+1 : right]
  130. values = stringx.Filter(values, func(r rune) bool {
  131. return r == ' ' || r == '\t' || r == '\r' || r == '\n'
  132. })
  133. fields := strings.FieldsFunc(values, func(r rune) bool {
  134. return r == ','
  135. })
  136. columns = len(fields)
  137. }
  138. }
  139. var variables int
  140. var valueFormat string
  141. var suffix string
  142. left := strings.IndexByte(lower[pos:], '(')
  143. if left > 0 {
  144. right = strings.IndexByte(lower[pos+left:], ')')
  145. if right > 0 {
  146. values := lower[pos+left : pos+left+right]
  147. for _, x := range values {
  148. if x == '?' {
  149. variables++
  150. }
  151. }
  152. valueFormat = stmt[pos+left : pos+left+right+1]
  153. suffix = strings.TrimSpace(stmt[pos+left+right+1:])
  154. }
  155. }
  156. if variables == 0 {
  157. return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
  158. }
  159. if columns > 0 && columns != variables {
  160. return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
  161. }
  162. return bulkStmt{
  163. prefix: stmt[:pos+len(valuesKeyword)],
  164. valueFormat: valueFormat,
  165. suffix: suffix,
  166. }, nil
  167. }