statement.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "fmt"
  12. "io"
  13. "reflect"
  14. "strconv"
  15. )
  16. type mysqlStmt struct {
  17. mc *mysqlConn
  18. id uint32
  19. paramCount int
  20. columns [][]mysqlField // cached from the first query
  21. }
  22. func (stmt *mysqlStmt) Close() error {
  23. if stmt.mc == nil || stmt.mc.netConn == nil {
  24. // driver.Stmt.Close can be called more than once, thus this function
  25. // has to be idempotent.
  26. // See also Issue #450 and golang/go#16019.
  27. //errLog.Print(ErrInvalidConn)
  28. return driver.ErrBadConn
  29. }
  30. err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
  31. stmt.mc = nil
  32. return err
  33. }
  34. func (stmt *mysqlStmt) NumInput() int {
  35. return stmt.paramCount
  36. }
  37. func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
  38. return converter{}
  39. }
  40. func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
  41. if stmt.mc.netConn == nil {
  42. errLog.Print(ErrInvalidConn)
  43. return nil, driver.ErrBadConn
  44. }
  45. // Send command
  46. err := stmt.writeExecutePacket(args)
  47. if err != nil {
  48. return nil, err
  49. }
  50. mc := stmt.mc
  51. mc.affectedRows = 0
  52. mc.insertId = 0
  53. // Read Result
  54. resLen, err := mc.readResultSetHeaderPacket()
  55. if err != nil {
  56. return nil, err
  57. }
  58. if resLen > 0 {
  59. // Columns
  60. if err = mc.readUntilEOF(); err != nil {
  61. return nil, err
  62. }
  63. // Rows
  64. if err := mc.readUntilEOF(); err != nil {
  65. return nil, err
  66. }
  67. }
  68. if err := mc.discardResults(); err != nil {
  69. return nil, err
  70. }
  71. return &mysqlResult{
  72. affectedRows: int64(mc.affectedRows),
  73. insertId: int64(mc.insertId),
  74. }, nil
  75. }
  76. func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
  77. if stmt.mc.netConn == nil {
  78. errLog.Print(ErrInvalidConn)
  79. return nil, driver.ErrBadConn
  80. }
  81. // Send command
  82. err := stmt.writeExecutePacket(args)
  83. if err != nil {
  84. return nil, err
  85. }
  86. mc := stmt.mc
  87. // Read Result
  88. resLen, err := mc.readResultSetHeaderPacket()
  89. if err != nil {
  90. return nil, err
  91. }
  92. rows := new(binaryRows)
  93. rows.stmtCols = &stmt.columns
  94. if resLen > 0 {
  95. rows.mc = mc
  96. rows.i++
  97. // Columns
  98. // If not cached, read them and cache them
  99. if len(stmt.columns) == 0 {
  100. rows.rs.columns, err = mc.readColumns(resLen)
  101. stmt.columns = append(stmt.columns, rows.rs.columns)
  102. } else {
  103. rows.rs.columns = stmt.columns[0]
  104. err = mc.readUntilEOF()
  105. }
  106. } else {
  107. rows.rs.done = true
  108. switch err := rows.NextResultSet(); err {
  109. case nil, io.EOF:
  110. return rows, nil
  111. default:
  112. return nil, err
  113. }
  114. }
  115. return rows, err
  116. }
  117. type converter struct{}
  118. func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
  119. if driver.IsValue(v) {
  120. return v, nil
  121. }
  122. rv := reflect.ValueOf(v)
  123. switch rv.Kind() {
  124. case reflect.Ptr:
  125. // indirect pointers
  126. if rv.IsNil() {
  127. return nil, nil
  128. }
  129. return c.ConvertValue(rv.Elem().Interface())
  130. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  131. return rv.Int(), nil
  132. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
  133. return int64(rv.Uint()), nil
  134. case reflect.Uint64:
  135. u64 := rv.Uint()
  136. if u64 >= 1<<63 {
  137. return strconv.FormatUint(u64, 10), nil
  138. }
  139. return int64(u64), nil
  140. case reflect.Float32, reflect.Float64:
  141. return rv.Float(), nil
  142. }
  143. return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
  144. }