conn.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "bufio"
  7. "crypto/tls"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "time"
  18. )
  19. //JoinHostPort is a utility to return a address string that can be used
  20. //gocql.Conn to form a connection with a host.
  21. func JoinHostPort(addr string, port int) string {
  22. addr = strings.TrimSpace(addr)
  23. if _, _, err := net.SplitHostPort(addr); err != nil {
  24. addr = net.JoinHostPort(addr, strconv.Itoa(port))
  25. }
  26. return addr
  27. }
  28. type Authenticator interface {
  29. Challenge(req []byte) (resp []byte, auth Authenticator, err error)
  30. Success(data []byte) error
  31. }
  32. type PasswordAuthenticator struct {
  33. Username string
  34. Password string
  35. }
  36. func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
  37. if string(req) != "org.apache.cassandra.auth.PasswordAuthenticator" {
  38. return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
  39. }
  40. resp := make([]byte, 2+len(p.Username)+len(p.Password))
  41. resp[0] = 0
  42. copy(resp[1:], p.Username)
  43. resp[len(p.Username)+1] = 0
  44. copy(resp[2+len(p.Username):], p.Password)
  45. return resp, nil, nil
  46. }
  47. func (p PasswordAuthenticator) Success(data []byte) error {
  48. return nil
  49. }
  50. type SslOptions struct {
  51. // CertPath and KeyPath are optional depending on server
  52. // config, but both fields must be omitted to avoid using a
  53. // client certificate
  54. CertPath string
  55. KeyPath string
  56. CaPath string //optional depending on server config
  57. // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
  58. // This option is basically the inverse of InSecureSkipVerify
  59. // See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
  60. EnableHostVerification bool
  61. }
  62. type ConnConfig struct {
  63. ProtoVersion int
  64. CQLVersion string
  65. Timeout time.Duration
  66. NumStreams int
  67. Compressor Compressor
  68. Authenticator Authenticator
  69. Keepalive time.Duration
  70. tlsConfig *tls.Config
  71. }
  72. type ConnErrorHandler interface {
  73. HandleError(conn *Conn, err error, closed bool)
  74. }
  75. // How many timeouts we will allow to occur before the connection is closed
  76. // and restarted. This is to prevent a single query timeout from killing a connection
  77. // which may be serving more queries just fine.
  78. const timeoutLimit = 10
  79. // Conn is a single connection to a Cassandra node. It can be used to execute
  80. // queries, but users are usually advised to use a more reliable, higher
  81. // level API.
  82. type Conn struct {
  83. conn net.Conn
  84. r *bufio.Reader
  85. timeout time.Duration
  86. headerBuf []byte
  87. uniq chan int
  88. calls []callReq
  89. errorHandler ConnErrorHandler
  90. compressor Compressor
  91. auth Authenticator
  92. addr string
  93. version uint8
  94. currentKeyspace string
  95. started bool
  96. closedMu sync.RWMutex
  97. isClosed bool
  98. timeouts int64
  99. }
  100. // Connect establishes a connection to a Cassandra node.
  101. // You must also call the Serve method before you can execute any queries.
  102. func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
  103. var (
  104. err error
  105. conn net.Conn
  106. )
  107. dialer := &net.Dialer{
  108. Timeout: cfg.Timeout,
  109. }
  110. if cfg.tlsConfig != nil {
  111. // the TLS config is safe to be reused by connections but it must not
  112. // be modified after being used.
  113. conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
  114. } else {
  115. conn, err = dialer.Dial("tcp", addr)
  116. }
  117. if err != nil {
  118. return nil, err
  119. }
  120. // going to default to proto 2
  121. if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
  122. log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
  123. cfg.ProtoVersion = 2
  124. }
  125. headerSize := 8
  126. maxStreams := 128
  127. if cfg.ProtoVersion > protoVersion2 {
  128. maxStreams = 32768
  129. headerSize = 9
  130. }
  131. if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
  132. cfg.NumStreams = maxStreams
  133. }
  134. c := &Conn{
  135. conn: conn,
  136. r: bufio.NewReader(conn),
  137. uniq: make(chan int, cfg.NumStreams),
  138. calls: make([]callReq, cfg.NumStreams),
  139. timeout: cfg.Timeout,
  140. version: uint8(cfg.ProtoVersion),
  141. addr: conn.RemoteAddr().String(),
  142. errorHandler: errorHandler,
  143. compressor: cfg.Compressor,
  144. auth: cfg.Authenticator,
  145. headerBuf: make([]byte, headerSize),
  146. }
  147. if cfg.Keepalive > 0 {
  148. c.setKeepalive(cfg.Keepalive)
  149. }
  150. for i := 0; i < cfg.NumStreams; i++ {
  151. c.calls[i].resp = make(chan error, 1)
  152. c.uniq <- i
  153. }
  154. go c.serve()
  155. if err := c.startup(&cfg); err != nil {
  156. conn.Close()
  157. return nil, err
  158. }
  159. c.started = true
  160. return c, nil
  161. }
  162. func (c *Conn) Write(p []byte) (int, error) {
  163. if c.timeout > 0 {
  164. c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
  165. }
  166. return c.conn.Write(p)
  167. }
  168. func (c *Conn) Read(p []byte) (n int, err error) {
  169. const maxAttempts = 5
  170. for i := 0; i < maxAttempts; i++ {
  171. var nn int
  172. if c.timeout > 0 {
  173. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  174. }
  175. nn, err = io.ReadFull(c.r, p[n:])
  176. n += nn
  177. if err == nil {
  178. break
  179. }
  180. if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
  181. break
  182. }
  183. }
  184. return
  185. }
  186. func (c *Conn) startup(cfg *ConnConfig) error {
  187. m := map[string]string{
  188. "CQL_VERSION": cfg.CQLVersion,
  189. }
  190. if c.compressor != nil {
  191. m["COMPRESSION"] = c.compressor.Name()
  192. }
  193. frame, err := c.exec(&writeStartupFrame{opts: m}, nil)
  194. if err != nil {
  195. return err
  196. }
  197. switch v := frame.(type) {
  198. case error:
  199. return v
  200. case *readyFrame:
  201. return nil
  202. case *authenticateFrame:
  203. return c.authenticateHandshake(v)
  204. default:
  205. return NewErrProtocol("Unknown type of response to startup frame: %s", v)
  206. }
  207. }
  208. func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
  209. if c.auth == nil {
  210. return fmt.Errorf("authentication required (using %q)", authFrame.class)
  211. }
  212. resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
  213. if err != nil {
  214. return err
  215. }
  216. req := &writeAuthResponseFrame{data: resp}
  217. for {
  218. frame, err := c.exec(req, nil)
  219. if err != nil {
  220. return err
  221. }
  222. switch v := frame.(type) {
  223. case error:
  224. return v
  225. case *authSuccessFrame:
  226. if challenger != nil {
  227. return challenger.Success(v.data)
  228. }
  229. return nil
  230. case *authChallengeFrame:
  231. resp, challenger, err = challenger.Challenge(v.data)
  232. if err != nil {
  233. return err
  234. }
  235. req = &writeAuthResponseFrame{
  236. data: resp,
  237. }
  238. default:
  239. return fmt.Errorf("unknown frame response during authentication: %v", v)
  240. }
  241. }
  242. }
  243. // Serve starts the stream multiplexer for this connection, which is required
  244. // to execute any queries. This method runs as long as the connection is
  245. // open and is therefore usually called in a separate goroutine.
  246. func (c *Conn) serve() {
  247. var (
  248. err error
  249. )
  250. for {
  251. err = c.recv()
  252. if err != nil {
  253. break
  254. }
  255. }
  256. c.closeWithError(err)
  257. }
  258. func (c *Conn) closeWithError(err error) {
  259. if c.Closed() {
  260. return
  261. }
  262. c.Close()
  263. for id := 0; id < len(c.calls); id++ {
  264. req := &c.calls[id]
  265. // we need to send the error to all waiting queries, put the state
  266. // of this conn into not active so that it can not execute any queries.
  267. select {
  268. case req.resp <- err:
  269. default:
  270. }
  271. close(req.resp)
  272. }
  273. if c.started {
  274. c.errorHandler.HandleError(c, err, true)
  275. }
  276. }
  277. func (c *Conn) recv() error {
  278. // not safe for concurrent reads
  279. // read a full header, ignore timeouts, as this is being ran in a loop
  280. // TODO: TCP level deadlines? or just query level deadlines?
  281. if c.timeout > 0 {
  282. c.conn.SetReadDeadline(time.Time{})
  283. }
  284. // were just reading headers over and over and copy bodies
  285. head, err := readHeader(c.r, c.headerBuf)
  286. if err != nil {
  287. return err
  288. }
  289. call := &c.calls[head.stream]
  290. err = call.framer.readFrame(&head)
  291. if err != nil {
  292. return err
  293. }
  294. // once we get to here we know that the caller must be waiting and that there
  295. // is no error.
  296. select {
  297. case call.resp <- nil:
  298. default:
  299. // in case the caller timedout
  300. }
  301. return nil
  302. }
  303. type callReq struct {
  304. // could use a waitgroup but this allows us to do timeouts on the read/send
  305. resp chan error
  306. framer *framer
  307. }
  308. func (c *Conn) releaseStream(stream int) {
  309. select {
  310. case c.uniq <- stream:
  311. default:
  312. }
  313. }
  314. func (c *Conn) handleTimeout() {
  315. if atomic.AddInt64(&c.timeouts, 1) > timeoutLimit {
  316. c.closeWithError(ErrTooManyTimeouts)
  317. }
  318. }
  319. func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
  320. // TODO: move tracer onto conn
  321. stream := <-c.uniq
  322. defer c.releaseStream(stream)
  323. call := &c.calls[stream]
  324. // resp is basically a waiting semaphore protecting the framer
  325. framer := newFramer(c, c, c.compressor, c.version)
  326. call.framer = framer
  327. if tracer != nil {
  328. framer.trace()
  329. }
  330. err := req.writeFrame(framer, stream)
  331. if err != nil {
  332. return nil, err
  333. }
  334. select {
  335. case err = <-call.resp:
  336. case <-time.After(c.timeout):
  337. c.handleTimeout()
  338. return nil, ErrTimeoutNoResponse
  339. }
  340. if err != nil {
  341. return nil, err
  342. }
  343. if v := framer.header.version.version(); v != c.version {
  344. return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
  345. }
  346. frame, err := framer.parseFrame()
  347. if err != nil {
  348. return nil, err
  349. }
  350. if len(framer.traceID) > 0 {
  351. tracer.Trace(framer.traceID)
  352. }
  353. framerPool.Put(framer)
  354. call.framer = nil
  355. return frame, nil
  356. }
  357. func (c *Conn) prepareStatement(stmt string, trace Tracer) (*resultPreparedFrame, error) {
  358. stmtsLRU.Lock()
  359. if stmtsLRU.lru == nil {
  360. initStmtsLRU(defaultMaxPreparedStmts)
  361. }
  362. stmtCacheKey := c.addr + c.currentKeyspace + stmt
  363. if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  364. stmtsLRU.Unlock()
  365. flight := val.(*inflightPrepare)
  366. flight.wg.Wait()
  367. return flight.info, flight.err
  368. }
  369. flight := new(inflightPrepare)
  370. flight.wg.Add(1)
  371. stmtsLRU.lru.Add(stmtCacheKey, flight)
  372. stmtsLRU.Unlock()
  373. prep := &writePrepareFrame{
  374. statement: stmt,
  375. }
  376. resp, err := c.exec(prep, trace)
  377. if err != nil {
  378. flight.err = err
  379. flight.wg.Done()
  380. return nil, err
  381. }
  382. switch x := resp.(type) {
  383. case *resultPreparedFrame:
  384. flight.info = x
  385. case error:
  386. flight.err = x
  387. default:
  388. flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
  389. }
  390. flight.wg.Done()
  391. if flight.err != nil {
  392. stmtsLRU.Lock()
  393. stmtsLRU.lru.Remove(stmtCacheKey)
  394. stmtsLRU.Unlock()
  395. }
  396. return flight.info, flight.err
  397. }
  398. func (c *Conn) executeQuery(qry *Query) *Iter {
  399. params := queryParams{
  400. consistency: qry.cons,
  401. }
  402. // frame checks that it is not 0
  403. params.serialConsistency = qry.serialCons
  404. params.defaultTimestamp = qry.defaultTimestamp
  405. if len(qry.pageState) > 0 {
  406. params.pagingState = qry.pageState
  407. }
  408. if qry.pageSize > 0 {
  409. params.pageSize = qry.pageSize
  410. }
  411. var frame frameWriter
  412. if qry.shouldPrepare() {
  413. // Prepare all DML queries. Other queries can not be prepared.
  414. info, err := c.prepareStatement(qry.stmt, qry.trace)
  415. if err != nil {
  416. return &Iter{err: err}
  417. }
  418. var values []interface{}
  419. if qry.binding == nil {
  420. values = qry.values
  421. } else {
  422. binding := &QueryInfo{
  423. Id: info.preparedID,
  424. Args: info.reqMeta.columns,
  425. Rval: info.respMeta.columns,
  426. }
  427. values, err = qry.binding(binding)
  428. if err != nil {
  429. return &Iter{err: err}
  430. }
  431. }
  432. if len(values) != len(info.reqMeta.columns) {
  433. return &Iter{err: ErrQueryArgLength}
  434. }
  435. params.values = make([]queryValues, len(values))
  436. for i := 0; i < len(values); i++ {
  437. val, err := Marshal(info.reqMeta.columns[i].TypeInfo, values[i])
  438. if err != nil {
  439. return &Iter{err: err}
  440. }
  441. v := &params.values[i]
  442. v.value = val
  443. // TODO: handle query binding names
  444. }
  445. frame = &writeExecuteFrame{
  446. preparedID: info.preparedID,
  447. params: params,
  448. }
  449. } else {
  450. frame = &writeQueryFrame{
  451. statement: qry.stmt,
  452. params: params,
  453. }
  454. }
  455. resp, err := c.exec(frame, qry.trace)
  456. if err != nil {
  457. return &Iter{err: err}
  458. }
  459. switch x := resp.(type) {
  460. case *resultVoidFrame:
  461. return &Iter{}
  462. case *resultRowsFrame:
  463. iter := &Iter{
  464. meta: x.meta,
  465. rows: x.rows,
  466. }
  467. if len(x.meta.pagingState) > 0 {
  468. iter.next = &nextIter{
  469. qry: *qry,
  470. pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
  471. }
  472. iter.next.qry.pageState = x.meta.pagingState
  473. if iter.next.pos < 1 {
  474. iter.next.pos = 1
  475. }
  476. }
  477. return iter
  478. case *resultKeyspaceFrame, *resultSchemaChangeFrame:
  479. return &Iter{}
  480. case *RequestErrUnprepared:
  481. stmtsLRU.Lock()
  482. stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
  483. if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  484. stmtsLRU.lru.Remove(stmtCacheKey)
  485. stmtsLRU.Unlock()
  486. return c.executeQuery(qry)
  487. }
  488. stmtsLRU.Unlock()
  489. return &Iter{err: x}
  490. case error:
  491. return &Iter{err: x}
  492. default:
  493. return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
  494. }
  495. }
  496. func (c *Conn) Pick(qry *Query) *Conn {
  497. if c.Closed() {
  498. return nil
  499. }
  500. return c
  501. }
  502. func (c *Conn) Closed() bool {
  503. c.closedMu.RLock()
  504. closed := c.isClosed
  505. c.closedMu.RUnlock()
  506. return closed
  507. }
  508. func (c *Conn) Close() {
  509. c.closedMu.Lock()
  510. if c.isClosed {
  511. c.closedMu.Unlock()
  512. return
  513. }
  514. c.isClosed = true
  515. c.closedMu.Unlock()
  516. c.conn.Close()
  517. }
  518. func (c *Conn) Address() string {
  519. return c.addr
  520. }
  521. func (c *Conn) AvailableStreams() int {
  522. return len(c.uniq)
  523. }
  524. func (c *Conn) UseKeyspace(keyspace string) error {
  525. q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
  526. q.params.consistency = Any
  527. resp, err := c.exec(q, nil)
  528. if err != nil {
  529. return err
  530. }
  531. switch x := resp.(type) {
  532. case *resultKeyspaceFrame:
  533. case error:
  534. return x
  535. default:
  536. return NewErrProtocol("unknown frame in response to USE: %v", x)
  537. }
  538. c.currentKeyspace = keyspace
  539. return nil
  540. }
  541. func (c *Conn) executeBatch(batch *Batch) error {
  542. if c.version == protoVersion1 {
  543. return ErrUnsupported
  544. }
  545. n := len(batch.Entries)
  546. req := &writeBatchFrame{
  547. typ: batch.Type,
  548. statements: make([]batchStatment, n),
  549. consistency: batch.Cons,
  550. serialConsistency: batch.serialCons,
  551. defaultTimestamp: batch.defaultTimestamp,
  552. }
  553. stmts := make(map[string]string)
  554. for i := 0; i < n; i++ {
  555. entry := &batch.Entries[i]
  556. b := &req.statements[i]
  557. if len(entry.Args) > 0 || entry.binding != nil {
  558. info, err := c.prepareStatement(entry.Stmt, nil)
  559. if err != nil {
  560. return err
  561. }
  562. var args []interface{}
  563. if entry.binding == nil {
  564. args = entry.Args
  565. } else {
  566. binding := &QueryInfo{
  567. Id: info.preparedID,
  568. Args: info.reqMeta.columns,
  569. Rval: info.respMeta.columns,
  570. }
  571. args, err = entry.binding(binding)
  572. if err != nil {
  573. return err
  574. }
  575. }
  576. if len(args) != len(info.reqMeta.columns) {
  577. return ErrQueryArgLength
  578. }
  579. b.preparedID = info.preparedID
  580. stmts[string(info.preparedID)] = entry.Stmt
  581. b.values = make([]queryValues, len(info.reqMeta.columns))
  582. for j := 0; j < len(info.reqMeta.columns); j++ {
  583. val, err := Marshal(info.reqMeta.columns[j].TypeInfo, args[j])
  584. if err != nil {
  585. return err
  586. }
  587. b.values[j].value = val
  588. // TODO: add names
  589. }
  590. } else {
  591. b.statement = entry.Stmt
  592. }
  593. }
  594. // TODO: should batch support tracing?
  595. resp, err := c.exec(req, nil)
  596. if err != nil {
  597. return err
  598. }
  599. switch x := resp.(type) {
  600. case *resultVoidFrame:
  601. return nil
  602. case *RequestErrUnprepared:
  603. stmt, found := stmts[string(x.StatementId)]
  604. if found {
  605. stmtsLRU.Lock()
  606. stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
  607. stmtsLRU.Unlock()
  608. }
  609. if found {
  610. return c.executeBatch(batch)
  611. } else {
  612. return x
  613. }
  614. case error:
  615. return x
  616. default:
  617. return NewErrProtocol("Unknown type in response to batch statement: %s", x)
  618. }
  619. }
  620. func (c *Conn) setKeepalive(d time.Duration) error {
  621. if tc, ok := c.conn.(*net.TCPConn); ok {
  622. err := tc.SetKeepAlivePeriod(d)
  623. if err != nil {
  624. return err
  625. }
  626. return tc.SetKeepAlive(true)
  627. }
  628. return nil
  629. }
  630. type inflightPrepare struct {
  631. info *resultPreparedFrame
  632. err error
  633. wg sync.WaitGroup
  634. }
  635. var (
  636. ErrQueryArgLength = errors.New("query argument length mismatch")
  637. ErrTimeoutNoResponse = errors.New("gocql: no response recieved from cassandra within timeout period")
  638. ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
  639. )