conn.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "bufio"
  7. "crypto/tls"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "time"
  18. )
  19. //JoinHostPort is a utility to return a address string that can be used
  20. //gocql.Conn to form a connection with a host.
  21. func JoinHostPort(addr string, port int) string {
  22. addr = strings.TrimSpace(addr)
  23. if _, _, err := net.SplitHostPort(addr); err != nil {
  24. addr = net.JoinHostPort(addr, strconv.Itoa(port))
  25. }
  26. return addr
  27. }
  28. type Authenticator interface {
  29. Challenge(req []byte) (resp []byte, auth Authenticator, err error)
  30. Success(data []byte) error
  31. }
  32. type PasswordAuthenticator struct {
  33. Username string
  34. Password string
  35. }
  36. func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
  37. if string(req) != "org.apache.cassandra.auth.PasswordAuthenticator" {
  38. return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
  39. }
  40. resp := make([]byte, 2+len(p.Username)+len(p.Password))
  41. resp[0] = 0
  42. copy(resp[1:], p.Username)
  43. resp[len(p.Username)+1] = 0
  44. copy(resp[2+len(p.Username):], p.Password)
  45. return resp, nil, nil
  46. }
  47. func (p PasswordAuthenticator) Success(data []byte) error {
  48. return nil
  49. }
  50. type SslOptions struct {
  51. tls.Config
  52. // CertPath and KeyPath are optional depending on server
  53. // config, but both fields must be omitted to avoid using a
  54. // client certificate
  55. CertPath string
  56. KeyPath string
  57. CaPath string //optional depending on server config
  58. // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
  59. // This option is basically the inverse of InSecureSkipVerify
  60. // See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
  61. EnableHostVerification bool
  62. }
  63. type ConnConfig struct {
  64. ProtoVersion int
  65. CQLVersion string
  66. Timeout time.Duration
  67. NumStreams int
  68. Compressor Compressor
  69. Authenticator Authenticator
  70. Keepalive time.Duration
  71. tlsConfig *tls.Config
  72. }
  73. type ConnErrorHandler interface {
  74. HandleError(conn *Conn, err error, closed bool)
  75. }
  76. // How many timeouts we will allow to occur before the connection is closed
  77. // and restarted. This is to prevent a single query timeout from killing a connection
  78. // which may be serving more queries just fine.
  79. // Default is 10, should not be changed concurrently with queries.
  80. var TimeoutLimit int64 = 10
  81. // Conn is a single connection to a Cassandra node. It can be used to execute
  82. // queries, but users are usually advised to use a more reliable, higher
  83. // level API.
  84. type Conn struct {
  85. conn net.Conn
  86. r *bufio.Reader
  87. timeout time.Duration
  88. headerBuf []byte
  89. uniq chan int
  90. calls []callReq
  91. errorHandler ConnErrorHandler
  92. compressor Compressor
  93. auth Authenticator
  94. addr string
  95. version uint8
  96. currentKeyspace string
  97. started bool
  98. closedMu sync.RWMutex
  99. isClosed bool
  100. timeouts int64
  101. }
  102. // Connect establishes a connection to a Cassandra node.
  103. // You must also call the Serve method before you can execute any queries.
  104. func Connect(addr string, cfg ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
  105. var (
  106. err error
  107. conn net.Conn
  108. )
  109. dialer := &net.Dialer{
  110. Timeout: cfg.Timeout,
  111. }
  112. if cfg.tlsConfig != nil {
  113. // the TLS config is safe to be reused by connections but it must not
  114. // be modified after being used.
  115. conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
  116. } else {
  117. conn, err = dialer.Dial("tcp", addr)
  118. }
  119. if err != nil {
  120. return nil, err
  121. }
  122. // going to default to proto 2
  123. if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
  124. log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
  125. cfg.ProtoVersion = 2
  126. }
  127. headerSize := 8
  128. maxStreams := 128
  129. if cfg.ProtoVersion > protoVersion2 {
  130. maxStreams = 32768
  131. headerSize = 9
  132. }
  133. if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
  134. cfg.NumStreams = maxStreams
  135. }
  136. c := &Conn{
  137. conn: conn,
  138. r: bufio.NewReader(conn),
  139. uniq: make(chan int, cfg.NumStreams),
  140. calls: make([]callReq, cfg.NumStreams),
  141. timeout: cfg.Timeout,
  142. version: uint8(cfg.ProtoVersion),
  143. addr: conn.RemoteAddr().String(),
  144. errorHandler: errorHandler,
  145. compressor: cfg.Compressor,
  146. auth: cfg.Authenticator,
  147. headerBuf: make([]byte, headerSize),
  148. }
  149. if cfg.Keepalive > 0 {
  150. c.setKeepalive(cfg.Keepalive)
  151. }
  152. for i := 0; i < cfg.NumStreams; i++ {
  153. c.calls[i].resp = make(chan error, 1)
  154. c.uniq <- i
  155. }
  156. go c.serve()
  157. if err := c.startup(&cfg); err != nil {
  158. conn.Close()
  159. return nil, err
  160. }
  161. c.started = true
  162. return c, nil
  163. }
  164. func (c *Conn) Write(p []byte) (int, error) {
  165. if c.timeout > 0 {
  166. c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
  167. }
  168. return c.conn.Write(p)
  169. }
  170. func (c *Conn) Read(p []byte) (n int, err error) {
  171. const maxAttempts = 5
  172. for i := 0; i < maxAttempts; i++ {
  173. var nn int
  174. if c.timeout > 0 {
  175. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  176. }
  177. nn, err = io.ReadFull(c.r, p[n:])
  178. n += nn
  179. if err == nil {
  180. break
  181. }
  182. if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
  183. break
  184. }
  185. }
  186. return
  187. }
  188. func (c *Conn) startup(cfg *ConnConfig) error {
  189. m := map[string]string{
  190. "CQL_VERSION": cfg.CQLVersion,
  191. }
  192. if c.compressor != nil {
  193. m["COMPRESSION"] = c.compressor.Name()
  194. }
  195. frame, err := c.exec(&writeStartupFrame{opts: m}, nil)
  196. if err != nil {
  197. return err
  198. }
  199. switch v := frame.(type) {
  200. case error:
  201. return v
  202. case *readyFrame:
  203. return nil
  204. case *authenticateFrame:
  205. return c.authenticateHandshake(v)
  206. default:
  207. return NewErrProtocol("Unknown type of response to startup frame: %s", v)
  208. }
  209. }
  210. func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
  211. if c.auth == nil {
  212. return fmt.Errorf("authentication required (using %q)", authFrame.class)
  213. }
  214. resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
  215. if err != nil {
  216. return err
  217. }
  218. req := &writeAuthResponseFrame{data: resp}
  219. for {
  220. frame, err := c.exec(req, nil)
  221. if err != nil {
  222. return err
  223. }
  224. switch v := frame.(type) {
  225. case error:
  226. return v
  227. case *authSuccessFrame:
  228. if challenger != nil {
  229. return challenger.Success(v.data)
  230. }
  231. return nil
  232. case *authChallengeFrame:
  233. resp, challenger, err = challenger.Challenge(v.data)
  234. if err != nil {
  235. return err
  236. }
  237. req = &writeAuthResponseFrame{
  238. data: resp,
  239. }
  240. default:
  241. return fmt.Errorf("unknown frame response during authentication: %v", v)
  242. }
  243. }
  244. }
  245. // Serve starts the stream multiplexer for this connection, which is required
  246. // to execute any queries. This method runs as long as the connection is
  247. // open and is therefore usually called in a separate goroutine.
  248. func (c *Conn) serve() {
  249. var (
  250. err error
  251. )
  252. for {
  253. err = c.recv()
  254. if err != nil {
  255. break
  256. }
  257. }
  258. c.closeWithError(err)
  259. }
  260. func (c *Conn) closeWithError(err error) {
  261. if c.Closed() {
  262. return
  263. }
  264. c.Close()
  265. for id := 0; id < len(c.calls); id++ {
  266. req := &c.calls[id]
  267. // we need to send the error to all waiting queries, put the state
  268. // of this conn into not active so that it can not execute any queries.
  269. select {
  270. case req.resp <- err:
  271. default:
  272. }
  273. close(req.resp)
  274. }
  275. if c.started {
  276. c.errorHandler.HandleError(c, err, true)
  277. }
  278. }
  279. func (c *Conn) recv() error {
  280. // not safe for concurrent reads
  281. // read a full header, ignore timeouts, as this is being ran in a loop
  282. // TODO: TCP level deadlines? or just query level deadlines?
  283. if c.timeout > 0 {
  284. c.conn.SetReadDeadline(time.Time{})
  285. }
  286. // were just reading headers over and over and copy bodies
  287. head, err := readHeader(c.r, c.headerBuf)
  288. if err != nil {
  289. return err
  290. }
  291. call := &c.calls[head.stream]
  292. err = call.framer.readFrame(&head)
  293. if err != nil {
  294. return err
  295. }
  296. // once we get to here we know that the caller must be waiting and that there
  297. // is no error.
  298. select {
  299. case call.resp <- nil:
  300. default:
  301. // in case the caller timedout
  302. }
  303. return nil
  304. }
  305. type callReq struct {
  306. // could use a waitgroup but this allows us to do timeouts on the read/send
  307. resp chan error
  308. framer *framer
  309. }
  310. func (c *Conn) releaseStream(stream int) {
  311. select {
  312. case c.uniq <- stream:
  313. default:
  314. }
  315. }
  316. func (c *Conn) handleTimeout() {
  317. if atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit {
  318. c.closeWithError(ErrTooManyTimeouts)
  319. }
  320. }
  321. func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
  322. // TODO: move tracer onto conn
  323. stream := <-c.uniq
  324. defer c.releaseStream(stream)
  325. call := &c.calls[stream]
  326. // resp is basically a waiting semaphore protecting the framer
  327. framer := newFramer(c, c, c.compressor, c.version)
  328. call.framer = framer
  329. if tracer != nil {
  330. framer.trace()
  331. }
  332. err := req.writeFrame(framer, stream)
  333. if err != nil {
  334. return nil, err
  335. }
  336. select {
  337. case err = <-call.resp:
  338. case <-time.After(c.timeout):
  339. c.handleTimeout()
  340. return nil, ErrTimeoutNoResponse
  341. }
  342. if err != nil {
  343. return nil, err
  344. }
  345. if v := framer.header.version.version(); v != c.version {
  346. return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
  347. }
  348. frame, err := framer.parseFrame()
  349. if err != nil {
  350. return nil, err
  351. }
  352. if len(framer.traceID) > 0 {
  353. tracer.Trace(framer.traceID)
  354. }
  355. framerPool.Put(framer)
  356. call.framer = nil
  357. return frame, nil
  358. }
  359. func (c *Conn) prepareStatement(stmt string, trace Tracer) (*resultPreparedFrame, error) {
  360. stmtsLRU.Lock()
  361. if stmtsLRU.lru == nil {
  362. initStmtsLRU(defaultMaxPreparedStmts)
  363. }
  364. stmtCacheKey := c.addr + c.currentKeyspace + stmt
  365. if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  366. stmtsLRU.Unlock()
  367. flight := val.(*inflightPrepare)
  368. flight.wg.Wait()
  369. return flight.info, flight.err
  370. }
  371. flight := new(inflightPrepare)
  372. flight.wg.Add(1)
  373. stmtsLRU.lru.Add(stmtCacheKey, flight)
  374. stmtsLRU.Unlock()
  375. prep := &writePrepareFrame{
  376. statement: stmt,
  377. }
  378. resp, err := c.exec(prep, trace)
  379. if err != nil {
  380. flight.err = err
  381. flight.wg.Done()
  382. return nil, err
  383. }
  384. switch x := resp.(type) {
  385. case *resultPreparedFrame:
  386. flight.info = x
  387. case error:
  388. flight.err = x
  389. default:
  390. flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
  391. }
  392. flight.wg.Done()
  393. if flight.err != nil {
  394. stmtsLRU.Lock()
  395. stmtsLRU.lru.Remove(stmtCacheKey)
  396. stmtsLRU.Unlock()
  397. }
  398. return flight.info, flight.err
  399. }
  400. func (c *Conn) executeQuery(qry *Query) *Iter {
  401. params := queryParams{
  402. consistency: qry.cons,
  403. }
  404. // frame checks that it is not 0
  405. params.serialConsistency = qry.serialCons
  406. params.defaultTimestamp = qry.defaultTimestamp
  407. if len(qry.pageState) > 0 {
  408. params.pagingState = qry.pageState
  409. }
  410. if qry.pageSize > 0 {
  411. params.pageSize = qry.pageSize
  412. }
  413. var frame frameWriter
  414. if qry.shouldPrepare() {
  415. // Prepare all DML queries. Other queries can not be prepared.
  416. info, err := c.prepareStatement(qry.stmt, qry.trace)
  417. if err != nil {
  418. return &Iter{err: err}
  419. }
  420. var values []interface{}
  421. if qry.binding == nil {
  422. values = qry.values
  423. } else {
  424. binding := &QueryInfo{
  425. Id: info.preparedID,
  426. Args: info.reqMeta.columns,
  427. Rval: info.respMeta.columns,
  428. }
  429. values, err = qry.binding(binding)
  430. if err != nil {
  431. return &Iter{err: err}
  432. }
  433. }
  434. if len(values) != len(info.reqMeta.columns) {
  435. return &Iter{err: ErrQueryArgLength}
  436. }
  437. params.values = make([]queryValues, len(values))
  438. for i := 0; i < len(values); i++ {
  439. val, err := Marshal(info.reqMeta.columns[i].TypeInfo, values[i])
  440. if err != nil {
  441. return &Iter{err: err}
  442. }
  443. v := &params.values[i]
  444. v.value = val
  445. // TODO: handle query binding names
  446. }
  447. frame = &writeExecuteFrame{
  448. preparedID: info.preparedID,
  449. params: params,
  450. }
  451. } else {
  452. frame = &writeQueryFrame{
  453. statement: qry.stmt,
  454. params: params,
  455. }
  456. }
  457. resp, err := c.exec(frame, qry.trace)
  458. if err != nil {
  459. return &Iter{err: err}
  460. }
  461. switch x := resp.(type) {
  462. case *resultVoidFrame:
  463. return &Iter{}
  464. case *resultRowsFrame:
  465. iter := &Iter{
  466. meta: x.meta,
  467. rows: x.rows,
  468. }
  469. if len(x.meta.pagingState) > 0 {
  470. iter.next = &nextIter{
  471. qry: *qry,
  472. pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
  473. }
  474. iter.next.qry.pageState = x.meta.pagingState
  475. if iter.next.pos < 1 {
  476. iter.next.pos = 1
  477. }
  478. }
  479. return iter
  480. case *resultKeyspaceFrame, *resultSchemaChangeFrame:
  481. return &Iter{}
  482. case *RequestErrUnprepared:
  483. stmtsLRU.Lock()
  484. stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
  485. if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  486. stmtsLRU.lru.Remove(stmtCacheKey)
  487. stmtsLRU.Unlock()
  488. return c.executeQuery(qry)
  489. }
  490. stmtsLRU.Unlock()
  491. return &Iter{err: x}
  492. case error:
  493. return &Iter{err: x}
  494. default:
  495. return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
  496. }
  497. }
  498. func (c *Conn) Pick(qry *Query) *Conn {
  499. if c.Closed() {
  500. return nil
  501. }
  502. return c
  503. }
  504. func (c *Conn) Closed() bool {
  505. c.closedMu.RLock()
  506. closed := c.isClosed
  507. c.closedMu.RUnlock()
  508. return closed
  509. }
  510. func (c *Conn) Close() {
  511. c.closedMu.Lock()
  512. if c.isClosed {
  513. c.closedMu.Unlock()
  514. return
  515. }
  516. c.isClosed = true
  517. c.closedMu.Unlock()
  518. c.conn.Close()
  519. }
  520. func (c *Conn) Address() string {
  521. return c.addr
  522. }
  523. func (c *Conn) AvailableStreams() int {
  524. return len(c.uniq)
  525. }
  526. func (c *Conn) UseKeyspace(keyspace string) error {
  527. q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
  528. q.params.consistency = Any
  529. resp, err := c.exec(q, nil)
  530. if err != nil {
  531. return err
  532. }
  533. switch x := resp.(type) {
  534. case *resultKeyspaceFrame:
  535. case error:
  536. return x
  537. default:
  538. return NewErrProtocol("unknown frame in response to USE: %v", x)
  539. }
  540. c.currentKeyspace = keyspace
  541. return nil
  542. }
  543. func (c *Conn) executeBatch(batch *Batch) error {
  544. if c.version == protoVersion1 {
  545. return ErrUnsupported
  546. }
  547. n := len(batch.Entries)
  548. req := &writeBatchFrame{
  549. typ: batch.Type,
  550. statements: make([]batchStatment, n),
  551. consistency: batch.Cons,
  552. serialConsistency: batch.serialCons,
  553. defaultTimestamp: batch.defaultTimestamp,
  554. }
  555. stmts := make(map[string]string)
  556. for i := 0; i < n; i++ {
  557. entry := &batch.Entries[i]
  558. b := &req.statements[i]
  559. if len(entry.Args) > 0 || entry.binding != nil {
  560. info, err := c.prepareStatement(entry.Stmt, nil)
  561. if err != nil {
  562. return err
  563. }
  564. var args []interface{}
  565. if entry.binding == nil {
  566. args = entry.Args
  567. } else {
  568. binding := &QueryInfo{
  569. Id: info.preparedID,
  570. Args: info.reqMeta.columns,
  571. Rval: info.respMeta.columns,
  572. }
  573. args, err = entry.binding(binding)
  574. if err != nil {
  575. return err
  576. }
  577. }
  578. if len(args) != len(info.reqMeta.columns) {
  579. return ErrQueryArgLength
  580. }
  581. b.preparedID = info.preparedID
  582. stmts[string(info.preparedID)] = entry.Stmt
  583. b.values = make([]queryValues, len(info.reqMeta.columns))
  584. for j := 0; j < len(info.reqMeta.columns); j++ {
  585. val, err := Marshal(info.reqMeta.columns[j].TypeInfo, args[j])
  586. if err != nil {
  587. return err
  588. }
  589. b.values[j].value = val
  590. // TODO: add names
  591. }
  592. } else {
  593. b.statement = entry.Stmt
  594. }
  595. }
  596. // TODO: should batch support tracing?
  597. resp, err := c.exec(req, nil)
  598. if err != nil {
  599. return err
  600. }
  601. switch x := resp.(type) {
  602. case *resultVoidFrame:
  603. return nil
  604. case *RequestErrUnprepared:
  605. stmt, found := stmts[string(x.StatementId)]
  606. if found {
  607. stmtsLRU.Lock()
  608. stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
  609. stmtsLRU.Unlock()
  610. }
  611. if found {
  612. return c.executeBatch(batch)
  613. } else {
  614. return x
  615. }
  616. case error:
  617. return x
  618. default:
  619. return NewErrProtocol("Unknown type in response to batch statement: %s", x)
  620. }
  621. }
  622. func (c *Conn) setKeepalive(d time.Duration) error {
  623. if tc, ok := c.conn.(*net.TCPConn); ok {
  624. err := tc.SetKeepAlivePeriod(d)
  625. if err != nil {
  626. return err
  627. }
  628. return tc.SetKeepAlive(true)
  629. }
  630. return nil
  631. }
  632. type inflightPrepare struct {
  633. info *resultPreparedFrame
  634. err error
  635. wg sync.WaitGroup
  636. }
  637. var (
  638. ErrQueryArgLength = errors.New("query argument length mismatch")
  639. ErrTimeoutNoResponse = errors.New("gocql: no response recieved from cassandra within timeout period")
  640. ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
  641. )