conn.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. const defaultFrameSize = 4096
  12. const flagResponse = 0x80
  13. const maskVersion = 0x7F
  14. type Cluster interface {
  15. //HandleAuth(addr, method string) ([]byte, Challenger, error)
  16. HandleError(conn *Conn, err error, closed bool)
  17. HandleKeyspace(conn *Conn, keyspace string)
  18. // Authenticate(addr string)
  19. }
  20. /* type Challenger interface {
  21. Challenge(data []byte) ([]byte, error)
  22. } */
  23. type ConnConfig struct {
  24. ProtoVersion int
  25. CQLVersion string
  26. Timeout time.Duration
  27. NumStreams int
  28. }
  29. // Conn is a single connection to a Cassandra node. It can be used to execute
  30. // queries, but users are usually advised to use a more reliable, higher
  31. // level API.
  32. type Conn struct {
  33. conn net.Conn
  34. timeout time.Duration
  35. uniq chan uint8
  36. calls []callReq
  37. nwait int32
  38. prepMu sync.Mutex
  39. prep map[string]*queryInfo
  40. cluster Cluster
  41. addr string
  42. version uint8
  43. }
  44. // Connect establishes a connection to a Cassandra node.
  45. // You must also call the Serve method before you can execute any queries.
  46. func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
  47. conn, err := net.DialTimeout("tcp", addr, cfg.Timeout)
  48. if err != nil {
  49. return nil, err
  50. }
  51. if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
  52. cfg.NumStreams = 128
  53. }
  54. if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
  55. cfg.ProtoVersion = 2
  56. }
  57. c := &Conn{
  58. conn: conn,
  59. uniq: make(chan uint8, cfg.NumStreams),
  60. calls: make([]callReq, cfg.NumStreams),
  61. prep: make(map[string]*queryInfo),
  62. timeout: cfg.Timeout,
  63. version: uint8(cfg.ProtoVersion),
  64. addr: conn.RemoteAddr().String(),
  65. cluster: cluster,
  66. }
  67. for i := 0; i < cap(c.uniq); i++ {
  68. c.uniq <- uint8(i)
  69. }
  70. if err := c.startup(&cfg); err != nil {
  71. return nil, err
  72. }
  73. go c.serve()
  74. return c, nil
  75. }
  76. func (c *Conn) startup(cfg *ConnConfig) error {
  77. req := make(frame, headerSize, defaultFrameSize)
  78. req.setHeader(c.version, 0, 0, opStartup)
  79. req.writeStringMap(map[string]string{
  80. "CQL_VERSION": cfg.CQLVersion,
  81. })
  82. resp, err := c.callSimple(req)
  83. if err != nil {
  84. return err
  85. }
  86. switch x := resp.(type) {
  87. case readyFrame:
  88. case error:
  89. return x
  90. default:
  91. return ErrProtocol
  92. }
  93. return nil
  94. }
  95. // Serve starts the stream multiplexer for this connection, which is required
  96. // to execute any queries. This method runs as long as the connection is
  97. // open and is therefore usually called in a separate goroutine.
  98. func (c *Conn) serve() {
  99. for {
  100. resp, err := c.recv()
  101. if err != nil {
  102. break
  103. }
  104. c.dispatch(resp)
  105. }
  106. c.conn.Close()
  107. for id := 0; id < len(c.calls); id++ {
  108. req := &c.calls[id]
  109. if atomic.LoadInt32(&req.active) == 1 {
  110. req.resp <- callResp{nil, ErrProtocol}
  111. }
  112. }
  113. c.cluster.HandleError(c, ErrProtocol, true)
  114. }
  115. func (c *Conn) recv() (frame, error) {
  116. resp := make(frame, headerSize, headerSize+512)
  117. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  118. n, last, pinged := 0, 0, false
  119. for n < len(resp) {
  120. nn, err := c.conn.Read(resp[n:])
  121. n += nn
  122. if err != nil {
  123. if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
  124. if n > last {
  125. // we hit the deadline but we made progress.
  126. // simply extend the deadline
  127. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  128. last = n
  129. } else if n == 0 && !pinged {
  130. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  131. if atomic.LoadInt32(&c.nwait) > 0 {
  132. go c.ping()
  133. pinged = true
  134. }
  135. } else {
  136. return nil, err
  137. }
  138. } else {
  139. return nil, err
  140. }
  141. }
  142. if n == headerSize && len(resp) == headerSize {
  143. if resp[0] != c.version|flagResponse {
  144. return nil, ErrProtocol
  145. }
  146. resp.grow(resp.Length())
  147. }
  148. }
  149. return resp, nil
  150. }
  151. func (c *Conn) callSimple(req frame) (interface{}, error) {
  152. req.setLength(len(req) - headerSize)
  153. if _, err := c.conn.Write(req); err != nil {
  154. c.conn.Close()
  155. return nil, err
  156. }
  157. buf, err := c.recv()
  158. if err != nil {
  159. return nil, err
  160. }
  161. return decodeFrame(buf)
  162. }
  163. func (c *Conn) call(req frame) (interface{}, error) {
  164. id := <-c.uniq
  165. req[2] = id
  166. call := &c.calls[id]
  167. call.resp = make(chan callResp, 1)
  168. atomic.AddInt32(&c.nwait, 1)
  169. atomic.StoreInt32(&call.active, 1)
  170. req.setLength(len(req) - headerSize)
  171. if n, err := c.conn.Write(req); err != nil {
  172. c.conn.Close()
  173. if n > 0 {
  174. return nil, ErrProtocol
  175. }
  176. return nil, ErrUnavailable
  177. }
  178. reply := <-call.resp
  179. call.resp = nil
  180. c.uniq <- id
  181. if reply.err != nil {
  182. return nil, reply.err
  183. }
  184. return decodeFrame(reply.buf)
  185. }
  186. func (c *Conn) dispatch(resp frame) {
  187. id := int(resp[2])
  188. if id >= len(c.calls) {
  189. return
  190. }
  191. call := &c.calls[id]
  192. if !atomic.CompareAndSwapInt32(&call.active, 1, 0) {
  193. return
  194. }
  195. atomic.AddInt32(&c.nwait, -1)
  196. call.resp <- callResp{resp, nil}
  197. }
  198. func (c *Conn) ping() error {
  199. req := make(frame, headerSize)
  200. req.setHeader(c.version, 0, 0, opOptions)
  201. _, err := c.call(req)
  202. return err
  203. }
  204. func (c *Conn) prepareStatement(stmt string) (*queryInfo, error) {
  205. c.prepMu.Lock()
  206. info := c.prep[stmt]
  207. if info != nil {
  208. c.prepMu.Unlock()
  209. info.wg.Wait()
  210. return info, nil
  211. }
  212. info = new(queryInfo)
  213. info.wg.Add(1)
  214. c.prep[stmt] = info
  215. c.prepMu.Unlock()
  216. frame := make(frame, headerSize, defaultFrameSize)
  217. frame.setHeader(c.version, 0, 0, opPrepare)
  218. frame.writeLongString(stmt)
  219. frame.setLength(len(frame) - headerSize)
  220. resp, err := c.call(frame)
  221. if err != nil {
  222. return nil, err
  223. }
  224. switch x := resp.(type) {
  225. case resultPreparedFrame:
  226. info.id = x.PreparedId
  227. info.args = x.Values
  228. info.wg.Done()
  229. case error:
  230. return nil, x
  231. default:
  232. return nil, ErrProtocol
  233. }
  234. return info, nil
  235. }
  236. func (c *Conn) ExecuteQuery(qry *Query) (*Iter, error) {
  237. var info *queryInfo
  238. if len(qry.Args) > 0 {
  239. var err error
  240. info, err = c.prepareStatement(qry.Stmt)
  241. if err != nil {
  242. return nil, err
  243. }
  244. }
  245. req := make(frame, headerSize, defaultFrameSize)
  246. if info == nil {
  247. req.setHeader(c.version, 0, 0, opQuery)
  248. req.writeLongString(qry.Stmt)
  249. req.writeConsistency(qry.Cons)
  250. if c.version > 1 {
  251. req.writeByte(0)
  252. }
  253. } else {
  254. req.setHeader(c.version, 0, 0, opExecute)
  255. req.writeShortBytes(info.id)
  256. if c.version == 1 {
  257. req.writeShort(uint16(len(qry.Args)))
  258. } else {
  259. req.writeConsistency(qry.Cons)
  260. flags := uint8(0)
  261. if len(qry.Args) > 0 {
  262. flags |= flagQueryValues
  263. }
  264. req.writeByte(flags)
  265. if flags&flagQueryValues != 0 {
  266. req.writeShort(uint16(len(qry.Args)))
  267. }
  268. }
  269. for i := 0; i < len(qry.Args); i++ {
  270. val, err := Marshal(info.args[i].TypeInfo, qry.Args[i])
  271. if err != nil {
  272. return nil, err
  273. }
  274. req.writeBytes(val)
  275. }
  276. if c.version == 1 {
  277. req.writeConsistency(qry.Cons)
  278. }
  279. }
  280. resp, err := c.call(req)
  281. if err != nil {
  282. return nil, err
  283. }
  284. switch x := resp.(type) {
  285. case resultVoidFrame:
  286. return &Iter{}, nil
  287. case resultRowsFrame:
  288. return &Iter{columns: x.Columns, rows: x.Rows}, nil
  289. case resultKeyspaceFrame:
  290. c.cluster.HandleKeyspace(c, x.Keyspace)
  291. return &Iter{}, nil
  292. case error:
  293. return &Iter{err: x}, nil
  294. }
  295. return nil, ErrProtocol
  296. }
  297. func (c *Conn) ExecuteBatch(batch *Batch) error {
  298. if c.version == 1 {
  299. return ErrProtocol
  300. }
  301. frame := make(frame, headerSize, defaultFrameSize)
  302. frame.setHeader(c.version, 0, 0, opBatch)
  303. frame.writeByte(byte(batch.Type))
  304. frame.writeShort(uint16(len(batch.Entries)))
  305. for i := 0; i < len(batch.Entries); i++ {
  306. entry := &batch.Entries[i]
  307. var info *queryInfo
  308. if len(entry.Args) > 0 {
  309. var err error
  310. info, err = c.prepareStatement(entry.Stmt)
  311. if err != nil {
  312. return err
  313. }
  314. frame.writeByte(1)
  315. frame.writeShortBytes(info.id)
  316. } else {
  317. frame.writeByte(0)
  318. frame.writeLongString(entry.Stmt)
  319. }
  320. frame.writeShort(uint16(len(entry.Args)))
  321. for j := 0; j < len(entry.Args); j++ {
  322. val, err := Marshal(info.args[j].TypeInfo, entry.Args[j])
  323. if err != nil {
  324. return err
  325. }
  326. frame.writeBytes(val)
  327. }
  328. }
  329. frame.writeConsistency(batch.Cons)
  330. resp, err := c.call(frame)
  331. if err != nil {
  332. return err
  333. }
  334. switch x := resp.(type) {
  335. case resultVoidFrame:
  336. case error:
  337. return x
  338. default:
  339. return ErrProtocol
  340. }
  341. return nil
  342. }
  343. func (c *Conn) Close() {
  344. c.conn.Close()
  345. }
  346. func (c *Conn) Address() string {
  347. return c.addr
  348. }
  349. func (c *Conn) UseKeyspace(keyspace string) error {
  350. frame := make(frame, headerSize, defaultFrameSize)
  351. frame.setHeader(c.version, 0, 0, opQuery)
  352. frame.writeLongString("USE " + keyspace)
  353. frame.writeConsistency(1)
  354. frame.writeByte(0)
  355. resp, err := c.call(frame)
  356. if err != nil {
  357. return err
  358. }
  359. switch x := resp.(type) {
  360. case resultKeyspaceFrame:
  361. case error:
  362. return x
  363. default:
  364. return ErrProtocol
  365. }
  366. return nil
  367. }
  368. type queryInfo struct {
  369. id []byte
  370. args []ColumnInfo
  371. rval []ColumnInfo
  372. wg sync.WaitGroup
  373. }
  374. type callReq struct {
  375. active int32
  376. resp chan callResp
  377. }
  378. type callResp struct {
  379. buf frame
  380. err error
  381. }