buf.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. package mssql
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "io"
  6. )
  7. type packetType uint8
  8. type header struct {
  9. PacketType packetType
  10. Status uint8
  11. Size uint16
  12. Spid uint16
  13. PacketNo uint8
  14. Pad uint8
  15. }
  16. // tdsBuffer reads and writes TDS packets of data to the transport.
  17. // The write and read buffers are separate to make sending attn signals
  18. // possible without locks. Currently attn signals are only sent during
  19. // reads, not writes.
  20. type tdsBuffer struct {
  21. transport io.ReadWriteCloser
  22. packetSize int
  23. // Write fields.
  24. wbuf []byte
  25. wpos int
  26. wPacketSeq byte
  27. wPacketType packetType
  28. // Read fields.
  29. rbuf []byte
  30. rpos int
  31. rsize int
  32. final bool
  33. rPacketType packetType
  34. // afterFirst is assigned to right after tdsBuffer is created and
  35. // before the first use. It is executed after the first packet is
  36. // written and then removed.
  37. afterFirst func()
  38. }
  39. func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
  40. return &tdsBuffer{
  41. packetSize: int(bufsize),
  42. wbuf: make([]byte, 1<<16),
  43. rbuf: make([]byte, 1<<16),
  44. rpos: 8,
  45. transport: transport,
  46. }
  47. }
  48. func (rw *tdsBuffer) ResizeBuffer(packetSize int) {
  49. rw.packetSize = packetSize
  50. }
  51. func (w *tdsBuffer) PackageSize() int {
  52. return w.packetSize
  53. }
  54. func (w *tdsBuffer) flush() (err error) {
  55. // Write packet size.
  56. w.wbuf[0] = byte(w.wPacketType)
  57. binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos))
  58. w.wbuf[6] = w.wPacketSeq
  59. // Write packet into underlying transport.
  60. if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil {
  61. return err
  62. }
  63. // It is possible to create a whole new buffer after a flush.
  64. // Useful for debugging. Normally reuse the buffer.
  65. // w.wbuf = make([]byte, 1<<16)
  66. // Execute afterFirst hook if it is set.
  67. if w.afterFirst != nil {
  68. w.afterFirst()
  69. w.afterFirst = nil
  70. }
  71. w.wpos = 8
  72. w.wPacketSeq++
  73. return nil
  74. }
  75. func (w *tdsBuffer) Write(p []byte) (total int, err error) {
  76. for {
  77. copied := copy(w.wbuf[w.wpos:w.packetSize], p)
  78. w.wpos += copied
  79. total += copied
  80. if copied == len(p) {
  81. return
  82. }
  83. if err = w.flush(); err != nil {
  84. return
  85. }
  86. p = p[copied:]
  87. }
  88. }
  89. func (w *tdsBuffer) WriteByte(b byte) error {
  90. if int(w.wpos) == len(w.wbuf) || w.wpos == w.packetSize {
  91. if err := w.flush(); err != nil {
  92. return err
  93. }
  94. }
  95. w.wbuf[w.wpos] = b
  96. w.wpos += 1
  97. return nil
  98. }
  99. func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) {
  100. status := byte(0)
  101. if resetSession {
  102. switch packetType {
  103. // Reset session can only be set on the following packet types.
  104. case packSQLBatch, packRPCRequest, packTransMgrReq:
  105. status = 0x8
  106. }
  107. }
  108. w.wbuf[1] = status // Packet is incomplete. This byte is set again in FinishPacket.
  109. w.wpos = 8
  110. w.wPacketSeq = 1
  111. w.wPacketType = packetType
  112. }
  113. func (w *tdsBuffer) FinishPacket() error {
  114. w.wbuf[1] |= 1 // Mark this as the last packet in the message.
  115. return w.flush()
  116. }
  117. var headerSize = binary.Size(header{})
  118. func (r *tdsBuffer) readNextPacket() error {
  119. h := header{}
  120. var err error
  121. err = binary.Read(r.transport, binary.BigEndian, &h)
  122. if err != nil {
  123. return err
  124. }
  125. if int(h.Size) > r.packetSize {
  126. return errors.New("Invalid packet size, it is longer than buffer size")
  127. }
  128. if headerSize > int(h.Size) {
  129. return errors.New("Invalid packet size, it is shorter than header size")
  130. }
  131. _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size])
  132. if err != nil {
  133. return err
  134. }
  135. r.rpos = headerSize
  136. r.rsize = int(h.Size)
  137. r.final = h.Status != 0
  138. r.rPacketType = h.PacketType
  139. return nil
  140. }
  141. func (r *tdsBuffer) BeginRead() (packetType, error) {
  142. err := r.readNextPacket()
  143. if err != nil {
  144. return 0, err
  145. }
  146. return r.rPacketType, nil
  147. }
  148. func (r *tdsBuffer) ReadByte() (res byte, err error) {
  149. if r.rpos == r.rsize {
  150. if r.final {
  151. return 0, io.EOF
  152. }
  153. err = r.readNextPacket()
  154. if err != nil {
  155. return 0, err
  156. }
  157. }
  158. res = r.rbuf[r.rpos]
  159. r.rpos++
  160. return res, nil
  161. }
  162. func (r *tdsBuffer) byte() byte {
  163. b, err := r.ReadByte()
  164. if err != nil {
  165. badStreamPanic(err)
  166. }
  167. return b
  168. }
  169. func (r *tdsBuffer) ReadFull(buf []byte) {
  170. _, err := io.ReadFull(r, buf[:])
  171. if err != nil {
  172. badStreamPanic(err)
  173. }
  174. }
  175. func (r *tdsBuffer) uint64() uint64 {
  176. var buf [8]byte
  177. r.ReadFull(buf[:])
  178. return binary.LittleEndian.Uint64(buf[:])
  179. }
  180. func (r *tdsBuffer) int32() int32 {
  181. return int32(r.uint32())
  182. }
  183. func (r *tdsBuffer) uint32() uint32 {
  184. var buf [4]byte
  185. r.ReadFull(buf[:])
  186. return binary.LittleEndian.Uint32(buf[:])
  187. }
  188. func (r *tdsBuffer) uint16() uint16 {
  189. var buf [2]byte
  190. r.ReadFull(buf[:])
  191. return binary.LittleEndian.Uint16(buf[:])
  192. }
  193. func (r *tdsBuffer) BVarChar() string {
  194. l := int(r.byte())
  195. return r.readUcs2(l)
  196. }
  197. func (r *tdsBuffer) UsVarChar() string {
  198. l := int(r.uint16())
  199. return r.readUcs2(l)
  200. }
  201. func (r *tdsBuffer) readUcs2(numchars int) string {
  202. b := make([]byte, numchars*2)
  203. r.ReadFull(b)
  204. res, err := ucs22str(b)
  205. if err != nil {
  206. badStreamPanic(err)
  207. }
  208. return res
  209. }
  210. func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
  211. copied = 0
  212. err = nil
  213. if r.rpos == r.rsize {
  214. if r.final {
  215. return 0, io.EOF
  216. }
  217. err = r.readNextPacket()
  218. if err != nil {
  219. return
  220. }
  221. }
  222. copied = copy(buf, r.rbuf[r.rpos:r.rsize])
  223. r.rpos += copied
  224. return
  225. }