statement.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "fmt"
  12. "reflect"
  13. )
  14. type mysqlStmt struct {
  15. mc *mysqlConn
  16. id uint32
  17. paramCount int
  18. columns []mysqlField // cached from the first query
  19. }
  20. func (stmt *mysqlStmt) Close() error {
  21. if stmt.mc == nil || stmt.mc.netConn == nil {
  22. errLog.Print(ErrInvalidConn)
  23. return driver.ErrBadConn
  24. }
  25. err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
  26. stmt.mc = nil
  27. return err
  28. }
  29. func (stmt *mysqlStmt) NumInput() int {
  30. return stmt.paramCount
  31. }
  32. func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
  33. return converter{}
  34. }
  35. func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
  36. if stmt.mc.netConn == nil {
  37. errLog.Print(ErrInvalidConn)
  38. return nil, driver.ErrBadConn
  39. }
  40. // Send command
  41. err := stmt.writeExecutePacket(args)
  42. if err != nil {
  43. return nil, err
  44. }
  45. mc := stmt.mc
  46. mc.affectedRows = 0
  47. mc.insertId = 0
  48. // Read Result
  49. resLen, err := mc.readResultSetHeaderPacket()
  50. if err == nil {
  51. if resLen > 0 {
  52. // Columns
  53. err = mc.readUntilEOF()
  54. if err != nil {
  55. return nil, err
  56. }
  57. // Rows
  58. err = mc.readUntilEOF()
  59. }
  60. if err == nil {
  61. return &mysqlResult{
  62. affectedRows: int64(mc.affectedRows),
  63. insertId: int64(mc.insertId),
  64. }, nil
  65. }
  66. }
  67. return nil, err
  68. }
  69. func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
  70. if stmt.mc.netConn == nil {
  71. errLog.Print(ErrInvalidConn)
  72. return nil, driver.ErrBadConn
  73. }
  74. // Send command
  75. err := stmt.writeExecutePacket(args)
  76. if err != nil {
  77. return nil, err
  78. }
  79. mc := stmt.mc
  80. // Read Result
  81. resLen, err := mc.readResultSetHeaderPacket()
  82. if err != nil {
  83. return nil, err
  84. }
  85. rows := new(binaryRows)
  86. rows.mc = mc
  87. if resLen > 0 {
  88. // Columns
  89. // If not cached, read them and cache them
  90. if stmt.columns == nil {
  91. rows.columns, err = mc.readColumns(resLen)
  92. stmt.columns = rows.columns
  93. } else {
  94. rows.columns = stmt.columns
  95. err = mc.readUntilEOF()
  96. }
  97. }
  98. return rows, err
  99. }
  100. type converter struct{}
  101. func (converter) ConvertValue(v interface{}) (driver.Value, error) {
  102. if driver.IsValue(v) {
  103. return v, nil
  104. }
  105. if svi, ok := v.(driver.Valuer); ok {
  106. sv, err := svi.Value()
  107. if err != nil {
  108. return nil, err
  109. }
  110. if !driver.IsValue(sv) {
  111. return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
  112. }
  113. return sv, nil
  114. }
  115. rv := reflect.ValueOf(v)
  116. switch rv.Kind() {
  117. case reflect.Ptr:
  118. // indirect pointers
  119. if rv.IsNil() {
  120. return nil, nil
  121. }
  122. return driver.DefaultParameterConverter.ConvertValue(rv.Elem().Interface())
  123. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  124. return rv.Int(), nil
  125. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
  126. return int64(rv.Uint()), nil
  127. case reflect.Uint64:
  128. u64 := rv.Uint()
  129. if u64 >= 1<<63 {
  130. return fmt.Sprintf("%d", rv.Uint()), nil
  131. }
  132. return int64(u64), nil
  133. case reflect.Float32, reflect.Float64:
  134. return rv.Float(), nil
  135. }
  136. return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
  137. }