prepared.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package native
  2. import (
  3. "github.com/ziutek/mymysql/mysql"
  4. "log"
  5. )
  6. type Stmt struct {
  7. my *Conn
  8. id uint32
  9. sql string // For reprepare during reconnect
  10. params []paramValue // Parameters binding
  11. rebind bool
  12. binded bool
  13. fields []*mysql.Field
  14. field_count int
  15. param_count int
  16. warning_count int
  17. status mysql.ConnStatus
  18. null_bitmap []byte
  19. }
  20. func (stmt *Stmt) Fields() []*mysql.Field {
  21. return stmt.fields
  22. }
  23. func (stmt *Stmt) NumParam() int {
  24. return stmt.param_count
  25. }
  26. func (stmt *Stmt) WarnCount() int {
  27. return stmt.warning_count
  28. }
  29. func (stmt *Stmt) sendCmdExec() {
  30. // Calculate packet length and NULL bitmap
  31. pkt_len := 1 + 4 + 1 + 4 + 1 + len(stmt.null_bitmap)
  32. for ii := range stmt.null_bitmap {
  33. stmt.null_bitmap[ii] = 0
  34. }
  35. for ii, param := range stmt.params {
  36. par_len := param.Len()
  37. pkt_len += par_len
  38. if par_len == 0 {
  39. null_byte := ii >> 3
  40. null_mask := byte(1) << uint(ii-(null_byte<<3))
  41. stmt.null_bitmap[null_byte] |= null_mask
  42. }
  43. }
  44. if stmt.rebind {
  45. pkt_len += stmt.param_count * 2
  46. }
  47. // Reset sequence number
  48. stmt.my.seq = 0
  49. // Packet sending
  50. pw := stmt.my.newPktWriter(pkt_len)
  51. pw.writeByte(_COM_STMT_EXECUTE)
  52. pw.writeU32(stmt.id)
  53. pw.writeByte(0) // flags = CURSOR_TYPE_NO_CURSOR
  54. pw.writeU32(1) // iteration_count
  55. pw.write(stmt.null_bitmap)
  56. if stmt.rebind {
  57. pw.writeByte(1)
  58. // Types
  59. for _, param := range stmt.params {
  60. pw.writeU16(param.typ)
  61. }
  62. } else {
  63. pw.writeByte(0)
  64. }
  65. // Values
  66. for i := range stmt.params {
  67. pw.writeValue(&stmt.params[i])
  68. }
  69. if stmt.my.Debug {
  70. log.Printf("[%2d <-] Exec command packet: len=%d, null_bitmap=%v, rebind=%t",
  71. stmt.my.seq-1, pkt_len, stmt.null_bitmap, stmt.rebind)
  72. }
  73. // Mark that we sended information about binded types
  74. stmt.rebind = false
  75. }
  76. func (my *Conn) getPrepareResult(stmt *Stmt) interface{} {
  77. loop:
  78. pr := my.newPktReader() // New reader for next packet
  79. pkt0 := pr.readByte()
  80. //log.Println("pkt0:", pkt0, "stmt:", stmt)
  81. if pkt0 == 255 {
  82. // Error packet
  83. my.getErrorPacket(pr)
  84. }
  85. if stmt == nil {
  86. if pkt0 == 0 {
  87. // OK packet
  88. return my.getPrepareOkPacket(pr)
  89. }
  90. } else {
  91. unreaded_params := (stmt.param_count < len(stmt.params))
  92. switch {
  93. case pkt0 == 254:
  94. // EOF packet
  95. stmt.warning_count, stmt.status = my.getEofPacket(pr)
  96. stmt.my.status = stmt.status
  97. return stmt
  98. case pkt0 > 0 && pkt0 < 251 && (stmt.field_count < len(stmt.fields) ||
  99. unreaded_params):
  100. // Field packet
  101. if unreaded_params {
  102. // Read and ignore parameter field. Sentence from MySQL source:
  103. /* skip parameters data: we don't support it yet */
  104. pr.skipAll()
  105. // Increment param_count count
  106. stmt.param_count++
  107. } else {
  108. field := my.getFieldPacket(pr)
  109. stmt.fields[stmt.field_count] = field
  110. // Increment field count
  111. stmt.field_count++
  112. }
  113. // Read next packet
  114. goto loop
  115. }
  116. }
  117. panic(mysql.ErrUnkResultPkt)
  118. }
  119. func (my *Conn) getPrepareOkPacket(pr *pktReader) (stmt *Stmt) {
  120. if my.Debug {
  121. log.Printf("[%2d ->] Perpared OK packet:", my.seq-1)
  122. }
  123. stmt = new(Stmt)
  124. stmt.my = my
  125. // First byte was readed by getPrepRes
  126. stmt.id = pr.readU32()
  127. stmt.fields = make([]*mysql.Field, int(pr.readU16())) // FieldCount
  128. pl := int(pr.readU16()) // ParamCount
  129. if pl > 0 {
  130. stmt.params = make([]paramValue, pl)
  131. stmt.null_bitmap = make([]byte, (pl+7)>>3)
  132. }
  133. pr.skipN(1)
  134. stmt.warning_count = int(pr.readU16())
  135. pr.checkEof()
  136. if my.Debug {
  137. log.Printf(tab8s+"ID=0x%x ParamCount=%d FieldsCount=%d WarnCount=%d",
  138. stmt.id, len(stmt.params), len(stmt.fields), stmt.warning_count,
  139. )
  140. }
  141. return
  142. }