conn.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. package gocql
  2. import (
  3. "io"
  4. "net"
  5. "sync"
  6. "sync/atomic"
  7. )
  8. type queryInfo struct {
  9. id []byte
  10. args []columnInfo
  11. rval []columnInfo
  12. avail chan bool
  13. }
  14. type connection struct {
  15. conn net.Conn
  16. uniq chan uint8
  17. reply []chan buffer
  18. waiting uint64
  19. prepMu sync.Mutex
  20. prep map[string]*queryInfo
  21. }
  22. func connect(addr string, cfg *Config) (*connection, error) {
  23. conn, err := net.Dial("tcp", addr)
  24. if err != nil {
  25. return nil, err
  26. }
  27. c := &connection{
  28. conn: conn,
  29. uniq: make(chan uint8, 64),
  30. reply: make([]chan buffer, 64),
  31. prep: make(map[string]*queryInfo),
  32. }
  33. for i := 0; i < cap(c.uniq); i++ {
  34. c.uniq <- uint8(i)
  35. }
  36. go c.recv()
  37. frame := make(buffer, headerSize)
  38. frame.setHeader(protoRequest, 0, 0, opStartup)
  39. frame.writeStringMap(map[string]string{
  40. "CQL_VERSION": cfg.CQLVersion,
  41. })
  42. frame.setLength(len(frame) - headerSize)
  43. frame = c.request(frame)
  44. if cfg.Keyspace != "" {
  45. qry := &Query{stmt: "USE " + cfg.Keyspace}
  46. frame, err = c.executeQuery(qry)
  47. }
  48. return c, nil
  49. }
  50. func (c *connection) recv() {
  51. for {
  52. frame := make(buffer, headerSize, headerSize+512)
  53. if _, err := io.ReadFull(c.conn, frame); err != nil {
  54. return
  55. }
  56. if frame[0] != protoResponse {
  57. continue
  58. }
  59. if length := frame.Length(); length > 0 {
  60. frame.grow(frame.Length())
  61. io.ReadFull(c.conn, frame[headerSize:])
  62. }
  63. c.dispatch(frame)
  64. }
  65. panic("not possible")
  66. }
  67. func (c *connection) request(frame buffer) buffer {
  68. id := <-c.uniq
  69. frame[2] = id
  70. c.reply[id] = make(chan buffer, 1)
  71. for {
  72. w := atomic.LoadUint64(&c.waiting)
  73. if atomic.CompareAndSwapUint64(&c.waiting, w, w|(1<<id)) {
  74. break
  75. }
  76. }
  77. c.conn.Write(frame)
  78. resp := <-c.reply[id]
  79. c.uniq <- id
  80. return resp
  81. }
  82. func (c *connection) dispatch(frame buffer) {
  83. id := frame[2]
  84. if id >= 128 {
  85. return
  86. }
  87. for {
  88. w := atomic.LoadUint64(&c.waiting)
  89. if w&(1<<id) == 0 {
  90. return
  91. }
  92. if atomic.CompareAndSwapUint64(&c.waiting, w, w&^(1<<id)) {
  93. break
  94. }
  95. }
  96. c.reply[id] <- frame
  97. }
  98. func (c *connection) prepareQuery(stmt string) *queryInfo {
  99. c.prepMu.Lock()
  100. info := c.prep[stmt]
  101. if info != nil {
  102. c.prepMu.Unlock()
  103. <-info.avail
  104. return info
  105. }
  106. info = &queryInfo{avail: make(chan bool)}
  107. c.prep[stmt] = info
  108. c.prepMu.Unlock()
  109. frame := make(buffer, headerSize, headerSize+512)
  110. frame.setHeader(protoRequest, 0, 0, opPrepare)
  111. frame.writeLongString(stmt)
  112. frame.setLength(len(frame) - headerSize)
  113. frame = c.request(frame)
  114. frame.skipHeader()
  115. frame.readInt() // kind
  116. info.id = frame.readShortBytes()
  117. info.args = frame.readMetaData()
  118. info.rval = frame.readMetaData()
  119. close(info.avail)
  120. return info
  121. }
  122. func (c *connection) executeQuery(query *Query) (buffer, error) {
  123. var info *queryInfo
  124. if len(query.args) > 0 {
  125. info = c.prepareQuery(query.stmt)
  126. }
  127. frame := make(buffer, headerSize, headerSize+512)
  128. frame.setHeader(protoRequest, 0, 0, opQuery)
  129. frame.writeLongString(query.stmt)
  130. frame.writeShort(uint16(query.cons))
  131. flags := uint8(0)
  132. if len(query.args) > 0 {
  133. flags |= flagQueryValues
  134. }
  135. frame.writeByte(flags)
  136. if len(query.args) > 0 {
  137. frame.writeShort(uint16(len(query.args)))
  138. for i := 0; i < len(query.args); i++ {
  139. val, err := Marshal(info.args[i].TypeInfo, query.args[i])
  140. if err != nil {
  141. return nil, err
  142. }
  143. frame.writeBytes(val)
  144. }
  145. }
  146. frame.setLength(len(frame) - headerSize)
  147. frame = c.request(frame)
  148. if frame[3] == opError {
  149. frame.skipHeader()
  150. code := frame.readInt()
  151. desc := frame.readString()
  152. return nil, Error{code, desc}
  153. }
  154. return frame, nil
  155. }