conn.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "bufio"
  7. "crypto/tls"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "time"
  18. )
  19. //JoinHostPort is a utility to return a address string that can be used
  20. //gocql.Conn to form a connection with a host.
  21. func JoinHostPort(addr string, port int) string {
  22. addr = strings.TrimSpace(addr)
  23. if _, _, err := net.SplitHostPort(addr); err != nil {
  24. addr = net.JoinHostPort(addr, strconv.Itoa(port))
  25. }
  26. return addr
  27. }
  28. type Authenticator interface {
  29. Challenge(req []byte) (resp []byte, auth Authenticator, err error)
  30. Success(data []byte) error
  31. }
  32. type PasswordAuthenticator struct {
  33. Username string
  34. Password string
  35. }
  36. func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
  37. if string(req) != "org.apache.cassandra.auth.PasswordAuthenticator" {
  38. return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
  39. }
  40. resp := make([]byte, 2+len(p.Username)+len(p.Password))
  41. resp[0] = 0
  42. copy(resp[1:], p.Username)
  43. resp[len(p.Username)+1] = 0
  44. copy(resp[2+len(p.Username):], p.Password)
  45. return resp, nil, nil
  46. }
  47. func (p PasswordAuthenticator) Success(data []byte) error {
  48. return nil
  49. }
  50. type SslOptions struct {
  51. // CertPath and KeyPath are optional depending on server
  52. // config, but both fields must be omitted to avoid using a
  53. // client certificate
  54. CertPath string
  55. KeyPath string
  56. CaPath string //optional depending on server config
  57. // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
  58. // This option is basically the inverse of InSecureSkipVerify
  59. // See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
  60. EnableHostVerification bool
  61. }
  62. type ConnConfig struct {
  63. ProtoVersion int
  64. CQLVersion string
  65. Timeout time.Duration
  66. NumStreams int
  67. Compressor Compressor
  68. Authenticator Authenticator
  69. Keepalive time.Duration
  70. tlsConfig *tls.Config
  71. }
  72. type ConnErrorHandler interface {
  73. HandleError(conn *Conn, err error, closed bool)
  74. }
  75. // How many timeouts we will allow to occur before the connection is closed
  76. // and restarted. This is to prevent a single query timeout from killing a connection
  77. // which may be serving more queries just fine.
  78. // Default is 10, should not be changed concurrently with queries.
  79. var TimeoutLimit int64 = 10
  80. // Conn is a single connection to a Cassandra node. It can be used to execute
  81. // queries, but users are usually advised to use a more reliable, higher
  82. // level API.
  83. type Conn struct {
  84. conn net.Conn
  85. r *bufio.Reader
  86. timeout time.Duration
  87. headerBuf []byte
  88. uniq chan int
  89. calls []callReq
  90. errorHandler ConnErrorHandler
  91. compressor Compressor
  92. auth Authenticator
  93. addr string
  94. version uint8
  95. currentKeyspace string
  96. started bool
  97. closedMu sync.RWMutex
  98. isClosed bool
  99. timeouts int64
  100. }
  101. // Connect establishes a connection to a Cassandra node.
  102. // You must also call the Serve method before you can execute any queries.
  103. func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
  104. var (
  105. err error
  106. conn net.Conn
  107. )
  108. dialer := &net.Dialer{
  109. Timeout: cfg.Timeout,
  110. }
  111. if cfg.tlsConfig != nil {
  112. // the TLS config is safe to be reused by connections but it must not
  113. // be modified after being used.
  114. conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
  115. } else {
  116. conn, err = dialer.Dial("tcp", addr)
  117. }
  118. if err != nil {
  119. return nil, err
  120. }
  121. // going to default to proto 2
  122. if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
  123. log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
  124. cfg.ProtoVersion = 2
  125. }
  126. headerSize := 8
  127. maxStreams := 128
  128. if cfg.ProtoVersion > protoVersion2 {
  129. maxStreams = 32768
  130. headerSize = 9
  131. }
  132. if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
  133. cfg.NumStreams = maxStreams
  134. }
  135. c := &Conn{
  136. conn: conn,
  137. r: bufio.NewReader(conn),
  138. uniq: make(chan int, cfg.NumStreams),
  139. calls: make([]callReq, cfg.NumStreams),
  140. timeout: cfg.Timeout,
  141. version: uint8(cfg.ProtoVersion),
  142. addr: conn.RemoteAddr().String(),
  143. errorHandler: errorHandler,
  144. compressor: cfg.Compressor,
  145. auth: cfg.Authenticator,
  146. headerBuf: make([]byte, headerSize),
  147. }
  148. if cfg.Keepalive > 0 {
  149. c.setKeepalive(cfg.Keepalive)
  150. }
  151. for i := 0; i < cfg.NumStreams; i++ {
  152. c.calls[i].resp = make(chan error, 1)
  153. c.uniq <- i
  154. }
  155. go c.serve()
  156. if err := c.startup(&cfg); err != nil {
  157. conn.Close()
  158. return nil, err
  159. }
  160. c.started = true
  161. return c, nil
  162. }
  163. func (c *Conn) Write(p []byte) (int, error) {
  164. if c.timeout > 0 {
  165. c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
  166. }
  167. return c.conn.Write(p)
  168. }
  169. func (c *Conn) Read(p []byte) (n int, err error) {
  170. const maxAttempts = 5
  171. for i := 0; i < maxAttempts; i++ {
  172. var nn int
  173. if c.timeout > 0 {
  174. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  175. }
  176. nn, err = io.ReadFull(c.r, p[n:])
  177. n += nn
  178. if err == nil {
  179. break
  180. }
  181. if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
  182. break
  183. }
  184. }
  185. return
  186. }
  187. func (c *Conn) startup(cfg *ConnConfig) error {
  188. m := map[string]string{
  189. "CQL_VERSION": cfg.CQLVersion,
  190. }
  191. if c.compressor != nil {
  192. m["COMPRESSION"] = c.compressor.Name()
  193. }
  194. frame, err := c.exec(&writeStartupFrame{opts: m}, nil)
  195. if err != nil {
  196. return err
  197. }
  198. switch v := frame.(type) {
  199. case error:
  200. return v
  201. case *readyFrame:
  202. return nil
  203. case *authenticateFrame:
  204. return c.authenticateHandshake(v)
  205. default:
  206. return NewErrProtocol("Unknown type of response to startup frame: %s", v)
  207. }
  208. }
  209. func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
  210. if c.auth == nil {
  211. return fmt.Errorf("authentication required (using %q)", authFrame.class)
  212. }
  213. resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
  214. if err != nil {
  215. return err
  216. }
  217. req := &writeAuthResponseFrame{data: resp}
  218. for {
  219. frame, err := c.exec(req, nil)
  220. if err != nil {
  221. return err
  222. }
  223. switch v := frame.(type) {
  224. case error:
  225. return v
  226. case *authSuccessFrame:
  227. if challenger != nil {
  228. return challenger.Success(v.data)
  229. }
  230. return nil
  231. case *authChallengeFrame:
  232. resp, challenger, err = challenger.Challenge(v.data)
  233. if err != nil {
  234. return err
  235. }
  236. req = &writeAuthResponseFrame{
  237. data: resp,
  238. }
  239. default:
  240. return fmt.Errorf("unknown frame response during authentication: %v", v)
  241. }
  242. }
  243. }
  244. // Serve starts the stream multiplexer for this connection, which is required
  245. // to execute any queries. This method runs as long as the connection is
  246. // open and is therefore usually called in a separate goroutine.
  247. func (c *Conn) serve() {
  248. var (
  249. err error
  250. )
  251. for {
  252. err = c.recv()
  253. if err != nil {
  254. break
  255. }
  256. }
  257. c.closeWithError(err)
  258. }
  259. func (c *Conn) closeWithError(err error) {
  260. if c.Closed() {
  261. return
  262. }
  263. c.Close()
  264. for id := 0; id < len(c.calls); id++ {
  265. req := &c.calls[id]
  266. // we need to send the error to all waiting queries, put the state
  267. // of this conn into not active so that it can not execute any queries.
  268. select {
  269. case req.resp <- err:
  270. default:
  271. }
  272. close(req.resp)
  273. }
  274. if c.started {
  275. c.errorHandler.HandleError(c, err, true)
  276. }
  277. }
  278. func (c *Conn) recv() error {
  279. // not safe for concurrent reads
  280. // read a full header, ignore timeouts, as this is being ran in a loop
  281. // TODO: TCP level deadlines? or just query level deadlines?
  282. if c.timeout > 0 {
  283. c.conn.SetReadDeadline(time.Time{})
  284. }
  285. // were just reading headers over and over and copy bodies
  286. head, err := readHeader(c.r, c.headerBuf)
  287. if err != nil {
  288. return err
  289. }
  290. call := &c.calls[head.stream]
  291. err = call.framer.readFrame(&head)
  292. if err != nil {
  293. return err
  294. }
  295. // once we get to here we know that the caller must be waiting and that there
  296. // is no error.
  297. select {
  298. case call.resp <- nil:
  299. default:
  300. // in case the caller timedout
  301. }
  302. return nil
  303. }
  304. type callReq struct {
  305. // could use a waitgroup but this allows us to do timeouts on the read/send
  306. resp chan error
  307. framer *framer
  308. }
  309. func (c *Conn) releaseStream(stream int) {
  310. select {
  311. case c.uniq <- stream:
  312. default:
  313. }
  314. }
  315. func (c *Conn) handleTimeout() {
  316. if atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit {
  317. c.closeWithError(ErrTooManyTimeouts)
  318. }
  319. }
  320. func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
  321. // TODO: move tracer onto conn
  322. stream := <-c.uniq
  323. defer c.releaseStream(stream)
  324. call := &c.calls[stream]
  325. // resp is basically a waiting semaphore protecting the framer
  326. framer := newFramer(c, c, c.compressor, c.version)
  327. call.framer = framer
  328. if tracer != nil {
  329. framer.trace()
  330. }
  331. err := req.writeFrame(framer, stream)
  332. if err != nil {
  333. return nil, err
  334. }
  335. select {
  336. case err = <-call.resp:
  337. case <-time.After(c.timeout):
  338. c.handleTimeout()
  339. return nil, ErrTimeoutNoResponse
  340. }
  341. if err != nil {
  342. return nil, err
  343. }
  344. if v := framer.header.version.version(); v != c.version {
  345. return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
  346. }
  347. frame, err := framer.parseFrame()
  348. if err != nil {
  349. return nil, err
  350. }
  351. if len(framer.traceID) > 0 {
  352. tracer.Trace(framer.traceID)
  353. }
  354. framerPool.Put(framer)
  355. call.framer = nil
  356. return frame, nil
  357. }
  358. func (c *Conn) prepareStatement(stmt string, trace Tracer) (*resultPreparedFrame, error) {
  359. stmtsLRU.Lock()
  360. if stmtsLRU.lru == nil {
  361. initStmtsLRU(defaultMaxPreparedStmts)
  362. }
  363. stmtCacheKey := c.addr + c.currentKeyspace + stmt
  364. if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  365. stmtsLRU.Unlock()
  366. flight := val.(*inflightPrepare)
  367. flight.wg.Wait()
  368. return flight.info, flight.err
  369. }
  370. flight := new(inflightPrepare)
  371. flight.wg.Add(1)
  372. stmtsLRU.lru.Add(stmtCacheKey, flight)
  373. stmtsLRU.Unlock()
  374. prep := &writePrepareFrame{
  375. statement: stmt,
  376. }
  377. resp, err := c.exec(prep, trace)
  378. if err != nil {
  379. flight.err = err
  380. flight.wg.Done()
  381. return nil, err
  382. }
  383. switch x := resp.(type) {
  384. case *resultPreparedFrame:
  385. flight.info = x
  386. case error:
  387. flight.err = x
  388. default:
  389. flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
  390. }
  391. flight.wg.Done()
  392. if flight.err != nil {
  393. stmtsLRU.Lock()
  394. stmtsLRU.lru.Remove(stmtCacheKey)
  395. stmtsLRU.Unlock()
  396. }
  397. return flight.info, flight.err
  398. }
  399. func (c *Conn) executeQuery(qry *Query) *Iter {
  400. params := queryParams{
  401. consistency: qry.cons,
  402. }
  403. // frame checks that it is not 0
  404. params.serialConsistency = qry.serialCons
  405. params.defaultTimestamp = qry.defaultTimestamp
  406. if len(qry.pageState) > 0 {
  407. params.pagingState = qry.pageState
  408. }
  409. if qry.pageSize > 0 {
  410. params.pageSize = qry.pageSize
  411. }
  412. var frame frameWriter
  413. if qry.shouldPrepare() {
  414. // Prepare all DML queries. Other queries can not be prepared.
  415. info, err := c.prepareStatement(qry.stmt, qry.trace)
  416. if err != nil {
  417. return &Iter{err: err}
  418. }
  419. var values []interface{}
  420. if qry.binding == nil {
  421. values = qry.values
  422. } else {
  423. binding := &QueryInfo{
  424. Id: info.preparedID,
  425. Args: info.reqMeta.columns,
  426. Rval: info.respMeta.columns,
  427. }
  428. values, err = qry.binding(binding)
  429. if err != nil {
  430. return &Iter{err: err}
  431. }
  432. }
  433. if len(values) != len(info.reqMeta.columns) {
  434. return &Iter{err: ErrQueryArgLength}
  435. }
  436. params.values = make([]queryValues, len(values))
  437. for i := 0; i < len(values); i++ {
  438. val, err := Marshal(info.reqMeta.columns[i].TypeInfo, values[i])
  439. if err != nil {
  440. return &Iter{err: err}
  441. }
  442. v := &params.values[i]
  443. v.value = val
  444. // TODO: handle query binding names
  445. }
  446. frame = &writeExecuteFrame{
  447. preparedID: info.preparedID,
  448. params: params,
  449. }
  450. } else {
  451. frame = &writeQueryFrame{
  452. statement: qry.stmt,
  453. params: params,
  454. }
  455. }
  456. resp, err := c.exec(frame, qry.trace)
  457. if err != nil {
  458. return &Iter{err: err}
  459. }
  460. switch x := resp.(type) {
  461. case *resultVoidFrame:
  462. return &Iter{}
  463. case *resultRowsFrame:
  464. iter := &Iter{
  465. meta: x.meta,
  466. rows: x.rows,
  467. }
  468. if len(x.meta.pagingState) > 0 {
  469. iter.next = &nextIter{
  470. qry: *qry,
  471. pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
  472. }
  473. iter.next.qry.pageState = x.meta.pagingState
  474. if iter.next.pos < 1 {
  475. iter.next.pos = 1
  476. }
  477. }
  478. return iter
  479. case *resultKeyspaceFrame, *resultSchemaChangeFrame:
  480. return &Iter{}
  481. case *RequestErrUnprepared:
  482. stmtsLRU.Lock()
  483. stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
  484. if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  485. stmtsLRU.lru.Remove(stmtCacheKey)
  486. stmtsLRU.Unlock()
  487. return c.executeQuery(qry)
  488. }
  489. stmtsLRU.Unlock()
  490. return &Iter{err: x}
  491. case error:
  492. return &Iter{err: x}
  493. default:
  494. return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
  495. }
  496. }
  497. func (c *Conn) Pick(qry *Query) *Conn {
  498. if c.Closed() {
  499. return nil
  500. }
  501. return c
  502. }
  503. func (c *Conn) Closed() bool {
  504. c.closedMu.RLock()
  505. closed := c.isClosed
  506. c.closedMu.RUnlock()
  507. return closed
  508. }
  509. func (c *Conn) Close() {
  510. c.closedMu.Lock()
  511. if c.isClosed {
  512. c.closedMu.Unlock()
  513. return
  514. }
  515. c.isClosed = true
  516. c.closedMu.Unlock()
  517. c.conn.Close()
  518. }
  519. func (c *Conn) Address() string {
  520. return c.addr
  521. }
  522. func (c *Conn) AvailableStreams() int {
  523. return len(c.uniq)
  524. }
  525. func (c *Conn) UseKeyspace(keyspace string) error {
  526. q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
  527. q.params.consistency = Any
  528. resp, err := c.exec(q, nil)
  529. if err != nil {
  530. return err
  531. }
  532. switch x := resp.(type) {
  533. case *resultKeyspaceFrame:
  534. case error:
  535. return x
  536. default:
  537. return NewErrProtocol("unknown frame in response to USE: %v", x)
  538. }
  539. c.currentKeyspace = keyspace
  540. return nil
  541. }
  542. func (c *Conn) executeBatch(batch *Batch) error {
  543. if c.version == protoVersion1 {
  544. return ErrUnsupported
  545. }
  546. n := len(batch.Entries)
  547. req := &writeBatchFrame{
  548. typ: batch.Type,
  549. statements: make([]batchStatment, n),
  550. consistency: batch.Cons,
  551. serialConsistency: batch.serialCons,
  552. defaultTimestamp: batch.defaultTimestamp,
  553. }
  554. stmts := make(map[string]string)
  555. for i := 0; i < n; i++ {
  556. entry := &batch.Entries[i]
  557. b := &req.statements[i]
  558. if len(entry.Args) > 0 || entry.binding != nil {
  559. info, err := c.prepareStatement(entry.Stmt, nil)
  560. if err != nil {
  561. return err
  562. }
  563. var args []interface{}
  564. if entry.binding == nil {
  565. args = entry.Args
  566. } else {
  567. binding := &QueryInfo{
  568. Id: info.preparedID,
  569. Args: info.reqMeta.columns,
  570. Rval: info.respMeta.columns,
  571. }
  572. args, err = entry.binding(binding)
  573. if err != nil {
  574. return err
  575. }
  576. }
  577. if len(args) != len(info.reqMeta.columns) {
  578. return ErrQueryArgLength
  579. }
  580. b.preparedID = info.preparedID
  581. stmts[string(info.preparedID)] = entry.Stmt
  582. b.values = make([]queryValues, len(info.reqMeta.columns))
  583. for j := 0; j < len(info.reqMeta.columns); j++ {
  584. val, err := Marshal(info.reqMeta.columns[j].TypeInfo, args[j])
  585. if err != nil {
  586. return err
  587. }
  588. b.values[j].value = val
  589. // TODO: add names
  590. }
  591. } else {
  592. b.statement = entry.Stmt
  593. }
  594. }
  595. // TODO: should batch support tracing?
  596. resp, err := c.exec(req, nil)
  597. if err != nil {
  598. return err
  599. }
  600. switch x := resp.(type) {
  601. case *resultVoidFrame:
  602. return nil
  603. case *RequestErrUnprepared:
  604. stmt, found := stmts[string(x.StatementId)]
  605. if found {
  606. stmtsLRU.Lock()
  607. stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
  608. stmtsLRU.Unlock()
  609. }
  610. if found {
  611. return c.executeBatch(batch)
  612. } else {
  613. return x
  614. }
  615. case error:
  616. return x
  617. default:
  618. return NewErrProtocol("Unknown type in response to batch statement: %s", x)
  619. }
  620. }
  621. func (c *Conn) setKeepalive(d time.Duration) error {
  622. if tc, ok := c.conn.(*net.TCPConn); ok {
  623. err := tc.SetKeepAlivePeriod(d)
  624. if err != nil {
  625. return err
  626. }
  627. return tc.SetKeepAlive(true)
  628. }
  629. return nil
  630. }
  631. type inflightPrepare struct {
  632. info *resultPreparedFrame
  633. err error
  634. wg sync.WaitGroup
  635. }
  636. var (
  637. ErrQueryArgLength = errors.New("query argument length mismatch")
  638. ErrTimeoutNoResponse = errors.New("gocql: no response recieved from cassandra within timeout period")
  639. ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
  640. )