conn.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. type connection struct {
  12. conn net.Conn
  13. uniq chan uint8
  14. requests []frameRequest
  15. nwait int32
  16. prepMu sync.Mutex
  17. prep map[string]*queryInfo
  18. timeout time.Duration
  19. }
  20. func connect(addr string, cfg *Config) (*connection, error) {
  21. conn, err := net.Dial("tcp", addr)
  22. if err != nil {
  23. return nil, err
  24. }
  25. c := &connection{
  26. conn: conn,
  27. uniq: make(chan uint8, 64),
  28. requests: make([]frameRequest, 64),
  29. prep: make(map[string]*queryInfo),
  30. timeout: cfg.Timeout,
  31. }
  32. for i := 0; i < cap(c.uniq); i++ {
  33. c.uniq <- uint8(i)
  34. }
  35. go c.run()
  36. frame := make(buffer, headerSize)
  37. frame.setHeader(protoRequest, 0, 0, opStartup)
  38. frame.writeStringMap(map[string]string{
  39. "CQL_VERSION": cfg.CQLVersion,
  40. })
  41. frame.setLength(len(frame) - headerSize)
  42. frame, err = c.request(frame)
  43. if err != nil {
  44. return nil, err
  45. }
  46. if cfg.Keyspace != "" {
  47. qry := &Query{stmt: "USE " + cfg.Keyspace}
  48. frame, err = c.executeQuery(qry)
  49. }
  50. return c, nil
  51. }
  52. func (c *connection) run() {
  53. var err error
  54. for {
  55. var frame buffer
  56. frame, err = c.recv()
  57. if err != nil {
  58. break
  59. }
  60. c.dispatch(frame)
  61. }
  62. c.conn.Close()
  63. for id := 0; id < len(c.requests); id++ {
  64. req := &c.requests[id]
  65. if atomic.LoadInt32(&req.active) == 1 {
  66. req.reply <- frameReply{nil, err}
  67. }
  68. }
  69. }
  70. func (c *connection) recv() (buffer, error) {
  71. frame := make(buffer, headerSize, headerSize+512)
  72. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  73. n, last, pinged := 0, 0, false
  74. for n < len(frame) {
  75. nn, err := c.conn.Read(frame[n:])
  76. n += nn
  77. if err != nil {
  78. if err, ok := err.(net.Error); ok && err.Timeout() {
  79. if n > last {
  80. // we hit the deadline but we made progress.
  81. // simply extend the deadline
  82. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  83. last = n
  84. } else if n == 0 && !pinged {
  85. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  86. if atomic.LoadInt32(&c.nwait) > 0 {
  87. go c.ping()
  88. pinged = true
  89. }
  90. } else {
  91. return nil, err
  92. }
  93. } else {
  94. return nil, err
  95. }
  96. }
  97. if n == headerSize && len(frame) == headerSize {
  98. if frame[0] != protoResponse {
  99. return nil, ErrInvalid
  100. }
  101. frame.grow(frame.Length())
  102. }
  103. }
  104. return frame, nil
  105. }
  106. func (c *connection) ping() error {
  107. frame := make(buffer, headerSize, headerSize)
  108. frame.setHeader(protoRequest, 0, 0, opOptions)
  109. frame.setLength(0)
  110. _, err := c.request(frame)
  111. return err
  112. }
  113. func (c *connection) request(frame buffer) (buffer, error) {
  114. id := <-c.uniq
  115. frame[2] = id
  116. req := &c.requests[id]
  117. req.reply = make(chan frameReply, 1)
  118. atomic.AddInt32(&c.nwait, 1)
  119. atomic.StoreInt32(&req.active, 1)
  120. if _, err := c.conn.Write(frame); err != nil {
  121. return nil, err
  122. }
  123. reply := <-req.reply
  124. req.reply = nil
  125. c.uniq <- id
  126. return reply.buf, reply.err
  127. }
  128. func (c *connection) dispatch(frame buffer) {
  129. id := int(frame[2])
  130. if id >= len(c.requests) {
  131. return
  132. }
  133. req := &c.requests[id]
  134. if !atomic.CompareAndSwapInt32(&req.active, 1, 0) {
  135. return
  136. }
  137. atomic.AddInt32(&c.nwait, -1)
  138. req.reply <- frameReply{frame, nil}
  139. }
  140. func (c *connection) prepareQuery(stmt string) *queryInfo {
  141. c.prepMu.Lock()
  142. info := c.prep[stmt]
  143. if info != nil {
  144. c.prepMu.Unlock()
  145. info.wg.Wait()
  146. return info
  147. }
  148. info = new(queryInfo)
  149. info.wg.Add(1)
  150. c.prep[stmt] = info
  151. c.prepMu.Unlock()
  152. frame := make(buffer, headerSize, headerSize+512)
  153. frame.setHeader(protoRequest, 0, 0, opPrepare)
  154. frame.writeLongString(stmt)
  155. frame.setLength(len(frame) - headerSize)
  156. frame, err := c.request(frame)
  157. if err != nil {
  158. return nil
  159. }
  160. frame.skipHeader()
  161. frame.readInt() // kind
  162. info.id = frame.readShortBytes()
  163. info.args = frame.readMetaData()
  164. info.rval = frame.readMetaData()
  165. info.wg.Done()
  166. return info
  167. }
  168. func (c *connection) executeQuery(query *Query) (buffer, error) {
  169. var info *queryInfo
  170. if len(query.args) > 0 {
  171. info = c.prepareQuery(query.stmt)
  172. }
  173. frame := make(buffer, headerSize, headerSize+512)
  174. if info == nil {
  175. frame.setHeader(protoRequest, 0, 0, opQuery)
  176. frame.writeLongString(query.stmt)
  177. } else {
  178. frame.setHeader(protoRequest, 0, 0, opExecute)
  179. frame.writeShortBytes(info.id)
  180. }
  181. frame.writeShort(uint16(query.cons))
  182. flags := uint8(0)
  183. if len(query.args) > 0 {
  184. flags |= flagQueryValues
  185. }
  186. frame.writeByte(flags)
  187. if len(query.args) > 0 {
  188. frame.writeShort(uint16(len(query.args)))
  189. for i := 0; i < len(query.args); i++ {
  190. val, err := Marshal(info.args[i].TypeInfo, query.args[i])
  191. if err != nil {
  192. return nil, err
  193. }
  194. frame.writeBytes(val)
  195. }
  196. }
  197. frame.setLength(len(frame) - headerSize)
  198. frame, err := c.request(frame)
  199. if err != nil {
  200. return nil, err
  201. }
  202. if frame[3] == opError {
  203. frame.skipHeader()
  204. code := frame.readInt()
  205. desc := frame.readString()
  206. return nil, Error{code, desc}
  207. }
  208. return frame, nil
  209. }
  210. type queryInfo struct {
  211. id []byte
  212. args []columnInfo
  213. rval []columnInfo
  214. wg sync.WaitGroup
  215. }
  216. type frameRequest struct {
  217. active int32
  218. reply chan frameReply
  219. }
  220. type frameReply struct {
  221. buf buffer
  222. err error
  223. }