statement.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "fmt"
  12. "io"
  13. "reflect"
  14. )
  15. type mysqlStmt struct {
  16. mc *mysqlConn
  17. id uint32
  18. paramCount int
  19. }
  20. func (stmt *mysqlStmt) Close() error {
  21. if stmt.mc == nil || stmt.mc.closed.IsSet() {
  22. // driver.Stmt.Close can be called more than once, thus this function
  23. // has to be idempotent.
  24. // See also Issue #450 and golang/go#16019.
  25. //errLog.Print(ErrInvalidConn)
  26. return driver.ErrBadConn
  27. }
  28. err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
  29. stmt.mc = nil
  30. return err
  31. }
  32. func (stmt *mysqlStmt) NumInput() int {
  33. return stmt.paramCount
  34. }
  35. func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
  36. return converter{}
  37. }
  38. func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
  39. if stmt.mc.closed.IsSet() {
  40. errLog.Print(ErrInvalidConn)
  41. return nil, driver.ErrBadConn
  42. }
  43. // Send command
  44. err := stmt.writeExecutePacket(args)
  45. if err != nil {
  46. return nil, stmt.mc.markBadConn(err)
  47. }
  48. mc := stmt.mc
  49. mc.affectedRows = 0
  50. mc.insertId = 0
  51. // Read Result
  52. resLen, err := mc.readResultSetHeaderPacket()
  53. if err != nil {
  54. return nil, err
  55. }
  56. if resLen > 0 {
  57. // Columns
  58. if err = mc.readUntilEOF(); err != nil {
  59. return nil, err
  60. }
  61. // Rows
  62. if err := mc.readUntilEOF(); err != nil {
  63. return nil, err
  64. }
  65. }
  66. if err := mc.discardResults(); err != nil {
  67. return nil, err
  68. }
  69. return &mysqlResult{
  70. affectedRows: int64(mc.affectedRows),
  71. insertId: int64(mc.insertId),
  72. }, nil
  73. }
  74. func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
  75. return stmt.query(args)
  76. }
  77. func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
  78. if stmt.mc.closed.IsSet() {
  79. errLog.Print(ErrInvalidConn)
  80. return nil, driver.ErrBadConn
  81. }
  82. // Send command
  83. err := stmt.writeExecutePacket(args)
  84. if err != nil {
  85. return nil, stmt.mc.markBadConn(err)
  86. }
  87. mc := stmt.mc
  88. // Read Result
  89. resLen, err := mc.readResultSetHeaderPacket()
  90. if err != nil {
  91. return nil, err
  92. }
  93. rows := new(binaryRows)
  94. if resLen > 0 {
  95. rows.mc = mc
  96. rows.rs.columns, err = mc.readColumns(resLen)
  97. } else {
  98. rows.rs.done = true
  99. switch err := rows.NextResultSet(); err {
  100. case nil, io.EOF:
  101. return rows, nil
  102. default:
  103. return nil, err
  104. }
  105. }
  106. return rows, err
  107. }
  108. type converter struct{}
  109. // ConvertValue mirrors the reference/default converter in database/sql/driver
  110. // with _one_ exception. We support uint64 with their high bit and the default
  111. // implementation does not. This function should be kept in sync with
  112. // database/sql/driver defaultConverter.ConvertValue() except for that
  113. // deliberate difference.
  114. func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
  115. if driver.IsValue(v) {
  116. return v, nil
  117. }
  118. if vr, ok := v.(driver.Valuer); ok {
  119. sv, err := callValuerValue(vr)
  120. if err != nil {
  121. return nil, err
  122. }
  123. if !driver.IsValue(sv) {
  124. return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
  125. }
  126. return sv, nil
  127. }
  128. rv := reflect.ValueOf(v)
  129. switch rv.Kind() {
  130. case reflect.Ptr:
  131. // indirect pointers
  132. if rv.IsNil() {
  133. return nil, nil
  134. } else {
  135. return c.ConvertValue(rv.Elem().Interface())
  136. }
  137. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  138. return rv.Int(), nil
  139. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  140. return rv.Uint(), nil
  141. case reflect.Float32, reflect.Float64:
  142. return rv.Float(), nil
  143. case reflect.Bool:
  144. return rv.Bool(), nil
  145. case reflect.Slice:
  146. ek := rv.Type().Elem().Kind()
  147. if ek == reflect.Uint8 {
  148. return rv.Bytes(), nil
  149. }
  150. return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
  151. case reflect.String:
  152. return rv.String(), nil
  153. }
  154. return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
  155. }
  156. var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
  157. // callValuerValue returns vr.Value(), with one exception:
  158. // If vr.Value is an auto-generated method on a pointer type and the
  159. // pointer is nil, it would panic at runtime in the panicwrap
  160. // method. Treat it like nil instead.
  161. //
  162. // This is so people can implement driver.Value on value types and
  163. // still use nil pointers to those types to mean nil/NULL, just like
  164. // string/*string.
  165. //
  166. // This is an exact copy of the same-named unexported function from the
  167. // database/sql package.
  168. func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
  169. if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
  170. rv.IsNil() &&
  171. rv.Type().Elem().Implements(valuerReflectType) {
  172. return nil, nil
  173. }
  174. return vr.Value()
  175. }