stmt.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package core
  5. import (
  6. "context"
  7. "database/sql"
  8. "errors"
  9. "reflect"
  10. )
  11. // Stmt reprents a stmt objects
  12. type Stmt struct {
  13. *sql.Stmt
  14. db *DB
  15. names map[string]int
  16. }
  17. func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
  18. names := make(map[string]int)
  19. var i int
  20. query = re.ReplaceAllStringFunc(query, func(src string) string {
  21. names[src[1:]] = i
  22. i += 1
  23. return "?"
  24. })
  25. stmt, err := db.DB.PrepareContext(ctx, query)
  26. if err != nil {
  27. return nil, err
  28. }
  29. return &Stmt{stmt, db, names}, nil
  30. }
  31. func (db *DB) Prepare(query string) (*Stmt, error) {
  32. return db.PrepareContext(context.Background(), query)
  33. }
  34. func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) {
  35. vv := reflect.ValueOf(mp)
  36. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  37. return nil, errors.New("mp should be a map's pointer")
  38. }
  39. args := make([]interface{}, len(s.names))
  40. for k, i := range s.names {
  41. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  42. }
  43. return s.Stmt.ExecContext(ctx, args...)
  44. }
  45. func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
  46. return s.ExecMapContext(context.Background(), mp)
  47. }
  48. func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) {
  49. vv := reflect.ValueOf(st)
  50. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  51. return nil, errors.New("mp should be a map's pointer")
  52. }
  53. args := make([]interface{}, len(s.names))
  54. for k, i := range s.names {
  55. args[i] = vv.Elem().FieldByName(k).Interface()
  56. }
  57. return s.Stmt.ExecContext(ctx, args...)
  58. }
  59. func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
  60. return s.ExecStructContext(context.Background(), st)
  61. }
  62. func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
  63. rows, err := s.Stmt.QueryContext(ctx, args...)
  64. if err != nil {
  65. return nil, err
  66. }
  67. return &Rows{rows, s.db}, nil
  68. }
  69. func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
  70. return s.QueryContext(context.Background(), args...)
  71. }
  72. func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) {
  73. vv := reflect.ValueOf(mp)
  74. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  75. return nil, errors.New("mp should be a map's pointer")
  76. }
  77. args := make([]interface{}, len(s.names))
  78. for k, i := range s.names {
  79. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  80. }
  81. return s.QueryContext(ctx, args...)
  82. }
  83. func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
  84. return s.QueryMapContext(context.Background(), mp)
  85. }
  86. func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) {
  87. vv := reflect.ValueOf(st)
  88. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  89. return nil, errors.New("mp should be a map's pointer")
  90. }
  91. args := make([]interface{}, len(s.names))
  92. for k, i := range s.names {
  93. args[i] = vv.Elem().FieldByName(k).Interface()
  94. }
  95. return s.Query(args...)
  96. }
  97. func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
  98. return s.QueryStructContext(context.Background(), st)
  99. }
  100. func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
  101. rows, err := s.QueryContext(ctx, args...)
  102. return &Row{rows, err}
  103. }
  104. func (s *Stmt) QueryRow(args ...interface{}) *Row {
  105. return s.QueryRowContext(context.Background(), args...)
  106. }
  107. func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row {
  108. vv := reflect.ValueOf(mp)
  109. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  110. return &Row{nil, errors.New("mp should be a map's pointer")}
  111. }
  112. args := make([]interface{}, len(s.names))
  113. for k, i := range s.names {
  114. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  115. }
  116. return s.QueryRowContext(ctx, args...)
  117. }
  118. func (s *Stmt) QueryRowMap(mp interface{}) *Row {
  119. return s.QueryRowMapContext(context.Background(), mp)
  120. }
  121. func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row {
  122. vv := reflect.ValueOf(st)
  123. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  124. return &Row{nil, errors.New("st should be a struct's pointer")}
  125. }
  126. args := make([]interface{}, len(s.names))
  127. for k, i := range s.names {
  128. args[i] = vv.Elem().FieldByName(k).Interface()
  129. }
  130. return s.QueryRowContext(ctx, args...)
  131. }
  132. func (s *Stmt) QueryRowStruct(st interface{}) *Row {
  133. return s.QueryRowStructContext(context.Background(), st)
  134. }