conn.go 14 KB


  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "bufio"
  7. "bytes"
  8. "errors"
  9. "fmt"
  10. "net"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. )
  15. const defaultFrameSize = 4096
  16. const flagResponse = 0x80
  17. const maskVersion = 0x7F
  18. type Authenticator interface {
  19. Challenge(req []byte) (resp []byte, auth Authenticator, err error)
  20. Success(data []byte) error
  21. }
  22. type PasswordAuthenticator struct {
  23. Username string
  24. Password string
  25. }
  26. func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
  27. if string(req) != "org.apache.cassandra.auth.PasswordAuthenticator" {
  28. return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
  29. }
  30. resp := make([]byte, 2+len(p.Username)+len(p.Password))
  31. resp[0] = 0
  32. copy(resp[1:], p.Username)
  33. resp[len(p.Username)+1] = 0
  34. copy(resp[2+len(p.Username):], p.Password)
  35. return resp, nil, nil
  36. }
  37. func (p PasswordAuthenticator) Success(data []byte) error {
  38. return nil
  39. }
  40. type ConnConfig struct {
  41. ProtoVersion int
  42. CQLVersion string
  43. Timeout time.Duration
  44. NumStreams int
  45. Compressor Compressor
  46. Authenticator Authenticator
  47. Keepalive time.Duration
  48. }
  49. // Conn is a single connection to a Cassandra node. It can be used to execute
  50. // queries, but users are usually advised to use a more reliable, higher
  51. // level API.
  52. type Conn struct {
  53. conn net.Conn
  54. r *bufio.Reader
  55. timeout time.Duration
  56. uniq chan uint8
  57. calls []callReq
  58. nwait int32
  59. prepMu sync.Mutex
  60. prep map[string]*inflightPrepare
  61. pool ConnectionPool
  62. compressor Compressor
  63. auth Authenticator
  64. addr string
  65. version uint8
  66. closedMu sync.RWMutex
  67. isClosed bool
  68. }
  69. // Connect establishes a connection to a Cassandra node.
  70. // You must also call the Serve method before you can execute any queries.
  71. func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
  72. conn, err := net.DialTimeout("tcp", addr, cfg.Timeout)
  73. if err != nil {
  74. return nil, err
  75. }
  76. if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
  77. cfg.NumStreams = 128
  78. }
  79. if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
  80. cfg.ProtoVersion = 2
  81. }
  82. c := &Conn{
  83. conn: conn,
  84. r: bufio.NewReader(conn),
  85. uniq: make(chan uint8, cfg.NumStreams),
  86. calls: make([]callReq, cfg.NumStreams),
  87. prep: make(map[string]*inflightPrepare),
  88. timeout: cfg.Timeout,
  89. version: uint8(cfg.ProtoVersion),
  90. addr: conn.RemoteAddr().String(),
  91. pool: pool,
  92. compressor: cfg.Compressor,
  93. auth: cfg.Authenticator,
  94. }
  95. if cfg.Keepalive > 0 {
  96. c.setKeepalive(cfg.Keepalive)
  97. }
  98. for i := 0; i < cap(c.uniq); i++ {
  99. c.uniq <- uint8(i)
  100. }
  101. if err := c.startup(&cfg); err != nil {
  102. conn.Close()
  103. return nil, err
  104. }
  105. go c.serve()
  106. return c, nil
  107. }
  108. func (c *Conn) startup(cfg *ConnConfig) error {
  109. compression := ""
  110. if c.compressor != nil {
  111. compression = c.compressor.Name()
  112. }
  113. var req operation = &startupFrame{
  114. CQLVersion: cfg.CQLVersion,
  115. Compression: compression,
  116. }
  117. var challenger Authenticator
  118. for {
  119. resp, err := c.execSimple(req)
  120. if err != nil {
  121. return err
  122. }
  123. switch x := resp.(type) {
  124. case readyFrame:
  125. return nil
  126. case error:
  127. return x
  128. case authenticateFrame:
  129. if c.auth == nil {
  130. return fmt.Errorf("authentication required (using %q)", x.Authenticator)
  131. }
  132. var resp []byte
  133. resp, challenger, err = c.auth.Challenge([]byte(x.Authenticator))
  134. if err != nil {
  135. return err
  136. }
  137. req = &authResponseFrame{resp}
  138. case authChallengeFrame:
  139. if challenger == nil {
  140. return fmt.Errorf("authentication error (invalid challenge)")
  141. }
  142. var resp []byte
  143. resp, challenger, err = challenger.Challenge(x.Data)
  144. if err != nil {
  145. return err
  146. }
  147. req = &authResponseFrame{resp}
  148. case authSuccessFrame:
  149. if challenger != nil {
  150. return challenger.Success(x.Data)
  151. }
  152. return nil
  153. default:
  154. return NewErrProtocol("Unknown type of response to startup frame: %s", x)
  155. }
  156. }
  157. }
  158. // Serve starts the stream multiplexer for this connection, which is required
  159. // to execute any queries. This method runs as long as the connection is
  160. // open and is therefore usually called in a separate goroutine.
  161. func (c *Conn) serve() {
  162. var (
  163. err error
  164. resp frame
  165. )
  166. for {
  167. resp, err = c.recv()
  168. if err != nil {
  169. break
  170. }
  171. c.dispatch(resp)
  172. }
  173. c.Close()
  174. for id := 0; id < len(c.calls); id++ {
  175. req := &c.calls[id]
  176. if atomic.LoadInt32(&req.active) == 1 {
  177. req.resp <- callResp{nil, err}
  178. }
  179. }
  180. c.pool.HandleError(c, err, true)
  181. }
  182. func (c *Conn) recv() (frame, error) {
  183. resp := make(frame, headerSize, headerSize+512)
  184. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  185. n, last, pinged := 0, 0, false
  186. for n < len(resp) {
  187. nn, err := c.r.Read(resp[n:])
  188. n += nn
  189. if err != nil {
  190. if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
  191. if n > last {
  192. // we hit the deadline but we made progress.
  193. // simply extend the deadline
  194. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  195. last = n
  196. } else if n == 0 && !pinged {
  197. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  198. if atomic.LoadInt32(&c.nwait) > 0 {
  199. go c.ping()
  200. pinged = true
  201. }
  202. } else {
  203. return nil, err
  204. }
  205. } else {
  206. return nil, err
  207. }
  208. }
  209. if n == headerSize && len(resp) == headerSize {
  210. if resp[0] != c.version|flagResponse {
  211. return nil, NewErrProtocol("recv: Response protocol version does not match connection protocol version (%d != %d)", resp[0], c.version|flagResponse)
  212. }
  213. resp.grow(resp.Length())
  214. }
  215. }
  216. return resp, nil
  217. }
  218. func (c *Conn) execSimple(op operation) (interface{}, error) {
  219. f, err := op.encodeFrame(c.version, nil)
  220. f.setLength(len(f) - headerSize)
  221. if _, err := c.conn.Write([]byte(f)); err != nil {
  222. c.Close()
  223. return nil, err
  224. }
  225. if f, err = c.recv(); err != nil {
  226. return nil, err
  227. }
  228. return c.decodeFrame(f, nil)
  229. }
  230. func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
  231. req, err := op.encodeFrame(c.version, nil)
  232. if err != nil {
  233. return nil, err
  234. }
  235. if trace != nil {
  236. req[1] |= flagTrace
  237. }
  238. if len(req) > headerSize && c.compressor != nil {
  239. body, err := c.compressor.Encode([]byte(req[headerSize:]))
  240. if err != nil {
  241. return nil, err
  242. }
  243. req = append(req[:headerSize], frame(body)...)
  244. req[1] |= flagCompress
  245. }
  246. req.setLength(len(req) - headerSize)
  247. id := <-c.uniq
  248. req[2] = id
  249. call := &c.calls[id]
  250. call.resp = make(chan callResp, 1)
  251. atomic.AddInt32(&c.nwait, 1)
  252. atomic.StoreInt32(&call.active, 1)
  253. if _, err := c.conn.Write(req); err != nil {
  254. c.uniq <- id
  255. c.Close()
  256. return nil, err
  257. }
  258. reply := <-call.resp
  259. call.resp = nil
  260. c.uniq <- id
  261. if reply.err != nil {
  262. return nil, reply.err
  263. }
  264. return c.decodeFrame(reply.buf, trace)
  265. }
  266. func (c *Conn) dispatch(resp frame) {
  267. id := int(resp[2])
  268. if id >= len(c.calls) {
  269. return
  270. }
  271. call := &c.calls[id]
  272. if !atomic.CompareAndSwapInt32(&call.active, 1, 0) {
  273. return
  274. }
  275. atomic.AddInt32(&c.nwait, -1)
  276. call.resp <- callResp{resp, nil}
  277. }
  278. func (c *Conn) ping() error {
  279. _, err := c.exec(&optionsFrame{}, nil)
  280. return err
  281. }
  282. func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
  283. c.prepMu.Lock()
  284. flight := c.prep[stmt]
  285. if flight != nil {
  286. c.prepMu.Unlock()
  287. flight.wg.Wait()
  288. return flight.info, flight.err
  289. }
  290. flight = new(inflightPrepare)
  291. flight.wg.Add(1)
  292. c.prep[stmt] = flight
  293. c.prepMu.Unlock()
  294. resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
  295. if err != nil {
  296. flight.err = err
  297. } else {
  298. switch x := resp.(type) {
  299. case resultPreparedFrame:
  300. flight.info = &queryInfo{
  301. id: x.PreparedId,
  302. args: x.Values,
  303. }
  304. case error:
  305. flight.err = x
  306. default:
  307. flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
  308. }
  309. }
  310. flight.wg.Done()
  311. if err != nil {
  312. c.prepMu.Lock()
  313. delete(c.prep, stmt)
  314. c.prepMu.Unlock()
  315. }
  316. return flight.info, flight.err
  317. }
  318. func (c *Conn) executeQuery(qry *Query) *Iter {
  319. op := &queryFrame{
  320. Stmt: qry.stmt,
  321. Cons: qry.cons,
  322. PageSize: qry.pageSize,
  323. PageState: qry.pageState,
  324. }
  325. if qry.shouldPrepare() {
  326. // Prepare all DML queries. Other queries can not be prepared.
  327. info, err := c.prepareStatement(qry.stmt, qry.trace)
  328. if err != nil {
  329. return &Iter{err: err}
  330. }
  331. if len(qry.values) != len(info.args) {
  332. return &Iter{err: ErrQueryArgLength}
  333. }
  334. op.Prepared = info.id
  335. op.Values = make([][]byte, len(qry.values))
  336. for i := 0; i < len(qry.values); i++ {
  337. val, err := Marshal(info.args[i].TypeInfo, qry.values[i])
  338. if err != nil {
  339. return &Iter{err: err}
  340. }
  341. op.Values[i] = val
  342. }
  343. }
  344. resp, err := c.exec(op, qry.trace)
  345. if err != nil {
  346. return &Iter{err: err}
  347. }
  348. switch x := resp.(type) {
  349. case resultVoidFrame:
  350. return &Iter{}
  351. case resultRowsFrame:
  352. iter := &Iter{columns: x.Columns, rows: x.Rows}
  353. if len(x.PagingState) > 0 {
  354. iter.next = &nextIter{
  355. qry: *qry,
  356. pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
  357. }
  358. iter.next.qry.pageState = x.PagingState
  359. if iter.next.pos < 1 {
  360. iter.next.pos = 1
  361. }
  362. }
  363. return iter
  364. case resultKeyspaceFrame:
  365. return &Iter{}
  366. case RequestErrUnprepared:
  367. c.prepMu.Lock()
  368. if val, ok := c.prep[qry.stmt]; ok && val != nil {
  369. delete(c.prep, qry.stmt)
  370. c.prepMu.Unlock()
  371. return c.executeQuery(qry)
  372. }
  373. c.prepMu.Unlock()
  374. return &Iter{err: x}
  375. case error:
  376. return &Iter{err: x}
  377. default:
  378. return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
  379. }
  380. }
  381. func (c *Conn) Pick(qry *Query) *Conn {
  382. if c.Closed() {
  383. return nil
  384. }
  385. return c
  386. }
  387. func (c *Conn) Closed() bool {
  388. c.closedMu.RLock()
  389. closed := c.isClosed
  390. c.closedMu.RUnlock()
  391. return closed
  392. }
  393. func (c *Conn) Close() {
  394. c.closedMu.Lock()
  395. if c.isClosed {
  396. c.closedMu.Unlock()
  397. return
  398. }
  399. c.isClosed = true
  400. c.closedMu.Unlock()
  401. c.conn.Close()
  402. }
  403. func (c *Conn) Address() string {
  404. return c.addr
  405. }
  406. func (c *Conn) UseKeyspace(keyspace string) error {
  407. resp, err := c.exec(&queryFrame{Stmt: `USE "` + keyspace + `"`, Cons: Any}, nil)
  408. if err != nil {
  409. return err
  410. }
  411. switch x := resp.(type) {
  412. case resultKeyspaceFrame:
  413. case error:
  414. return x
  415. default:
  416. return NewErrProtocol("Unknown type in response to USE: %s", x)
  417. }
  418. return nil
  419. }
  420. func (c *Conn) executeBatch(batch *Batch) error {
  421. if c.version == 1 {
  422. return ErrUnsupported
  423. }
  424. f := make(frame, headerSize, defaultFrameSize)
  425. f.setHeader(c.version, 0, 0, opBatch)
  426. f.writeByte(byte(batch.Type))
  427. f.writeShort(uint16(len(batch.Entries)))
  428. for i := 0; i < len(batch.Entries); i++ {
  429. entry := &batch.Entries[i]
  430. var info *queryInfo
  431. if len(entry.Args) > 0 {
  432. var err error
  433. info, err = c.prepareStatement(entry.Stmt, nil)
  434. if err != nil {
  435. return err
  436. }
  437. f.writeByte(1)
  438. f.writeShortBytes(info.id)
  439. } else {
  440. f.writeByte(0)
  441. f.writeLongString(entry.Stmt)
  442. }
  443. f.writeShort(uint16(len(entry.Args)))
  444. for j := 0; j < len(entry.Args); j++ {
  445. val, err := Marshal(info.args[j].TypeInfo, entry.Args[j])
  446. if err != nil {
  447. return err
  448. }
  449. f.writeBytes(val)
  450. }
  451. }
  452. f.writeConsistency(batch.Cons)
  453. resp, err := c.exec(f, nil)
  454. if err != nil {
  455. return err
  456. }
  457. switch x := resp.(type) {
  458. case resultVoidFrame:
  459. return nil
  460. case RequestErrUnprepared:
  461. c.prepMu.Lock()
  462. found := false
  463. for stmt, flight := range c.prep {
  464. if flight == nil || flight.info == nil {
  465. continue
  466. }
  467. if bytes.Equal(flight.info.id, x.StatementId) {
  468. found = true
  469. delete(c.prep, stmt)
  470. break
  471. }
  472. }
  473. c.prepMu.Unlock()
  474. if found {
  475. return c.executeBatch(batch)
  476. } else {
  477. return x
  478. }
  479. case error:
  480. return x
  481. default:
  482. return NewErrProtocol("Unknown type in response to batch statement: %s", x)
  483. }
  484. }
  485. func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error) {
  486. defer func() {
  487. if r := recover(); r != nil {
  488. if e, ok := r.(ErrProtocol); ok {
  489. err = e
  490. return
  491. }
  492. panic(r)
  493. }
  494. }()
  495. if len(f) < headerSize {
  496. return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
  497. } else if f[0] != c.version|flagResponse {
  498. return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
  499. }
  500. flags, op, f := f[1], f[3], f[headerSize:]
  501. if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
  502. if buf, err := c.compressor.Decode([]byte(f)); err != nil {
  503. return nil, err
  504. } else {
  505. f = frame(buf)
  506. }
  507. }
  508. if flags&flagTrace != 0 {
  509. if len(f) < 16 {
  510. return nil, NewErrProtocol("Decoding frame: length of frame less than 16 while tracing is enabled")
  511. }
  512. traceId := []byte(f[:16])
  513. f = f[16:]
  514. trace.Trace(traceId)
  515. }
  516. switch op {
  517. case opReady:
  518. return readyFrame{}, nil
  519. case opResult:
  520. switch kind := f.readInt(); kind {
  521. case resultKindVoid:
  522. return resultVoidFrame{}, nil
  523. case resultKindRows:
  524. columns, pageState := f.readMetaData()
  525. numRows := f.readInt()
  526. values := make([][]byte, numRows*len(columns))
  527. for i := 0; i < len(values); i++ {
  528. values[i] = f.readBytes()
  529. }
  530. rows := make([][][]byte, numRows)
  531. for i := 0; i < numRows; i++ {
  532. rows[i], values = values[:len(columns)], values[len(columns):]
  533. }
  534. return resultRowsFrame{columns, rows, pageState}, nil
  535. case resultKindKeyspace:
  536. keyspace := f.readString()
  537. return resultKeyspaceFrame{keyspace}, nil
  538. case resultKindPrepared:
  539. id := f.readShortBytes()
  540. values, _ := f.readMetaData()
  541. return resultPreparedFrame{id, values}, nil
  542. case resultKindSchemaChanged:
  543. return resultVoidFrame{}, nil
  544. default:
  545. return nil, NewErrProtocol("Decoding frame: unknown result kind %s", kind)
  546. }
  547. case opAuthenticate:
  548. return authenticateFrame{f.readString()}, nil
  549. case opAuthChallenge:
  550. return authChallengeFrame{f.readBytes()}, nil
  551. case opAuthSuccess:
  552. return authSuccessFrame{f.readBytes()}, nil
  553. case opSupported:
  554. return supportedFrame{}, nil
  555. case opError:
  556. return f.readError(), nil
  557. default:
  558. return nil, NewErrProtocol("Decoding frame: unknown op", op)
  559. }
  560. }
  561. type queryInfo struct {
  562. id []byte
  563. args []ColumnInfo
  564. rval []ColumnInfo
  565. }
  566. type callReq struct {
  567. active int32
  568. resp chan callResp
  569. }
  570. type callResp struct {
  571. buf frame
  572. err error
  573. }
  574. type inflightPrepare struct {
  575. info *queryInfo
  576. err error
  577. wg sync.WaitGroup
  578. }
  579. var (
  580. ErrQueryArgLength = errors.New("query argument length mismatch")
  581. )