conn.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "bufio"
  7. "crypto/tls"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "log"
  13. "net"
  14. "strconv"
  15. "strings"
  16. "sync"
  17. "sync/atomic"
  18. "time"
  19. "github.com/gocql/gocql/internal/streams"
  20. )
  21. var (
  22. approvedAuthenticators = [...]string{
  23. "org.apache.cassandra.auth.PasswordAuthenticator",
  24. "com.instaclustr.cassandra.auth.SharedSecretAuthenticator",
  25. }
  26. )
  27. func approve(authenticator string) bool {
  28. for _, s := range approvedAuthenticators {
  29. if authenticator == s {
  30. return true
  31. }
  32. }
  33. return false
  34. }
  35. //JoinHostPort is a utility to return a address string that can be used
  36. //gocql.Conn to form a connection with a host.
  37. func JoinHostPort(addr string, port int) string {
  38. addr = strings.TrimSpace(addr)
  39. if _, _, err := net.SplitHostPort(addr); err != nil {
  40. addr = net.JoinHostPort(addr, strconv.Itoa(port))
  41. }
  42. return addr
  43. }
  44. type Authenticator interface {
  45. Challenge(req []byte) (resp []byte, auth Authenticator, err error)
  46. Success(data []byte) error
  47. }
  48. type PasswordAuthenticator struct {
  49. Username string
  50. Password string
  51. }
  52. func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
  53. if !approve(string(req)) {
  54. return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
  55. }
  56. resp := make([]byte, 2+len(p.Username)+len(p.Password))
  57. resp[0] = 0
  58. copy(resp[1:], p.Username)
  59. resp[len(p.Username)+1] = 0
  60. copy(resp[2+len(p.Username):], p.Password)
  61. return resp, nil, nil
  62. }
  63. func (p PasswordAuthenticator) Success(data []byte) error {
  64. return nil
  65. }
  66. type SslOptions struct {
  67. tls.Config
  68. // CertPath and KeyPath are optional depending on server
  69. // config, but both fields must be omitted to avoid using a
  70. // client certificate
  71. CertPath string
  72. KeyPath string
  73. CaPath string //optional depending on server config
  74. // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
  75. // This option is basically the inverse of InSecureSkipVerify
  76. // See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
  77. EnableHostVerification bool
  78. }
  79. type ConnConfig struct {
  80. ProtoVersion int
  81. CQLVersion string
  82. Timeout time.Duration
  83. NumStreams int
  84. Compressor Compressor
  85. Authenticator Authenticator
  86. Keepalive time.Duration
  87. tlsConfig *tls.Config
  88. }
  89. type ConnErrorHandler interface {
  90. HandleError(conn *Conn, err error, closed bool)
  91. }
  92. // How many timeouts we will allow to occur before the connection is closed
  93. // and restarted. This is to prevent a single query timeout from killing a connection
  94. // which may be serving more queries just fine.
  95. // Default is 10, should not be changed concurrently with queries.
  96. var TimeoutLimit int64 = 10
  97. // Conn is a single connection to a Cassandra node. It can be used to execute
  98. // queries, but users are usually advised to use a more reliable, higher
  99. // level API.
  100. type Conn struct {
  101. conn net.Conn
  102. r *bufio.Reader
  103. timeout time.Duration
  104. cfg *ConnConfig
  105. numStreams int
  106. headerBuf []byte
  107. streams *streams.IDGenerator
  108. mu sync.RWMutex
  109. calls map[int]*callReq
  110. errorHandler ConnErrorHandler
  111. compressor Compressor
  112. auth Authenticator
  113. addr string
  114. version uint8
  115. currentKeyspace string
  116. started bool
  117. session *Session
  118. closed int32
  119. quit chan struct{}
  120. timeouts int64
  121. }
  122. // Connect establishes a connection to a Cassandra node.
  123. // You must also call the Serve method before you can execute any queries.
  124. func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
  125. var (
  126. err error
  127. conn net.Conn
  128. )
  129. dialer := &net.Dialer{
  130. Timeout: cfg.Timeout,
  131. }
  132. if cfg.tlsConfig != nil {
  133. // the TLS config is safe to be reused by connections but it must not
  134. // be modified after being used.
  135. conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
  136. } else {
  137. conn, err = dialer.Dial("tcp", addr)
  138. }
  139. if err != nil {
  140. return nil, err
  141. }
  142. // going to default to proto 2
  143. if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion4 {
  144. log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
  145. cfg.ProtoVersion = 2
  146. }
  147. headerSize := 8
  148. if cfg.ProtoVersion > protoVersion2 {
  149. headerSize = 9
  150. }
  151. c := &Conn{
  152. conn: conn,
  153. r: bufio.NewReader(conn),
  154. cfg: cfg,
  155. calls: make(map[int]*callReq),
  156. timeout: cfg.Timeout,
  157. version: uint8(cfg.ProtoVersion),
  158. addr: conn.RemoteAddr().String(),
  159. errorHandler: errorHandler,
  160. compressor: cfg.Compressor,
  161. auth: cfg.Authenticator,
  162. headerBuf: make([]byte, headerSize),
  163. quit: make(chan struct{}),
  164. session: session,
  165. streams: streams.New(cfg.ProtoVersion),
  166. }
  167. if cfg.Keepalive > 0 {
  168. c.setKeepalive(cfg.Keepalive)
  169. }
  170. go c.serve()
  171. if err := c.startup(); err != nil {
  172. conn.Close()
  173. return nil, err
  174. }
  175. c.started = true
  176. return c, nil
  177. }
  178. func (c *Conn) Write(p []byte) (int, error) {
  179. if c.timeout > 0 {
  180. c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
  181. }
  182. return c.conn.Write(p)
  183. }
  184. func (c *Conn) Read(p []byte) (n int, err error) {
  185. const maxAttempts = 5
  186. for i := 0; i < maxAttempts; i++ {
  187. var nn int
  188. if c.timeout > 0 {
  189. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  190. }
  191. nn, err = io.ReadFull(c.r, p[n:])
  192. n += nn
  193. if err == nil {
  194. break
  195. }
  196. if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
  197. break
  198. }
  199. }
  200. return
  201. }
  202. func (c *Conn) startup() error {
  203. m := map[string]string{
  204. "CQL_VERSION": c.cfg.CQLVersion,
  205. }
  206. if c.compressor != nil {
  207. m["COMPRESSION"] = c.compressor.Name()
  208. }
  209. framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
  210. if err != nil {
  211. return err
  212. }
  213. frame, err := framer.parseFrame()
  214. if err != nil {
  215. return err
  216. }
  217. switch v := frame.(type) {
  218. case error:
  219. return v
  220. case *readyFrame:
  221. return nil
  222. case *authenticateFrame:
  223. return c.authenticateHandshake(v)
  224. default:
  225. return NewErrProtocol("Unknown type of response to startup frame: %s", v)
  226. }
  227. }
  228. func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
  229. if c.auth == nil {
  230. return fmt.Errorf("authentication required (using %q)", authFrame.class)
  231. }
  232. resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
  233. if err != nil {
  234. return err
  235. }
  236. req := &writeAuthResponseFrame{data: resp}
  237. for {
  238. framer, err := c.exec(req, nil)
  239. if err != nil {
  240. return err
  241. }
  242. frame, err := framer.parseFrame()
  243. if err != nil {
  244. return err
  245. }
  246. switch v := frame.(type) {
  247. case error:
  248. return v
  249. case *authSuccessFrame:
  250. if challenger != nil {
  251. return challenger.Success(v.data)
  252. }
  253. return nil
  254. case *authChallengeFrame:
  255. resp, challenger, err = challenger.Challenge(v.data)
  256. if err != nil {
  257. return err
  258. }
  259. req = &writeAuthResponseFrame{
  260. data: resp,
  261. }
  262. default:
  263. return fmt.Errorf("unknown frame response during authentication: %v", v)
  264. }
  265. framerPool.Put(framer)
  266. }
  267. }
  268. func (c *Conn) closeWithError(err error) {
  269. if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
  270. return
  271. }
  272. if err != nil {
  273. // we should attempt to deliver the error back to the caller if it
  274. // exists
  275. c.mu.RLock()
  276. for _, req := range c.calls {
  277. // we need to send the error to all waiting queries, put the state
  278. // of this conn into not active so that it can not execute any queries.
  279. if err != nil {
  280. select {
  281. case req.resp <- err:
  282. default:
  283. }
  284. }
  285. }
  286. c.mu.RUnlock()
  287. }
  288. // if error was nil then unblock the quit channel
  289. close(c.quit)
  290. c.conn.Close()
  291. if c.started && err != nil {
  292. c.errorHandler.HandleError(c, err, true)
  293. }
  294. }
  295. func (c *Conn) Close() {
  296. c.closeWithError(nil)
  297. }
  298. // Serve starts the stream multiplexer for this connection, which is required
  299. // to execute any queries. This method runs as long as the connection is
  300. // open and is therefore usually called in a separate goroutine.
  301. func (c *Conn) serve() {
  302. var (
  303. err error
  304. )
  305. for {
  306. err = c.recv()
  307. if err != nil {
  308. break
  309. }
  310. }
  311. c.closeWithError(err)
  312. }
  313. func (c *Conn) discardFrame(head frameHeader) error {
  314. _, err := io.CopyN(ioutil.Discard, c, int64(head.length))
  315. if err != nil {
  316. return err
  317. }
  318. return nil
  319. }
  320. func (c *Conn) recv() error {
  321. // not safe for concurrent reads
  322. // read a full header, ignore timeouts, as this is being ran in a loop
  323. // TODO: TCP level deadlines? or just query level deadlines?
  324. if c.timeout > 0 {
  325. c.conn.SetReadDeadline(time.Time{})
  326. }
  327. // were just reading headers over and over and copy bodies
  328. head, err := readHeader(c.r, c.headerBuf)
  329. if err != nil {
  330. return err
  331. }
  332. if head.stream > c.streams.NumStreams {
  333. return fmt.Errorf("gocql: frame header stream is beyond call exepected bounds: %d", head.stream)
  334. } else if head.stream == -1 {
  335. // TODO: handle cassandra event frames, we shouldnt get any currently
  336. return c.discardFrame(head)
  337. } else if head.stream <= 0 {
  338. // reserved stream that we dont use, probably due to a protocol error
  339. // or a bug in Cassandra, this should be an error, parse it and return.
  340. framer := newFramer(c, c, c.compressor, c.version)
  341. if err := framer.readFrame(&head); err != nil {
  342. return err
  343. }
  344. defer framerPool.Put(framer)
  345. frame, err := framer.parseFrame()
  346. if err != nil {
  347. return err
  348. }
  349. switch v := frame.(type) {
  350. case error:
  351. return fmt.Errorf("gocql: error on stream %d: %v", head.stream, v)
  352. default:
  353. return fmt.Errorf("gocql: received frame on stream %d: %v", head.stream, frame)
  354. }
  355. }
  356. c.mu.RLock()
  357. call, ok := c.calls[head.stream]
  358. c.mu.RUnlock()
  359. if call == nil || call.framer == nil || !ok {
  360. log.Printf("gocql: received response for stream which has no handler: header=%v\n", head)
  361. return c.discardFrame(head)
  362. }
  363. err = call.framer.readFrame(&head)
  364. if err != nil {
  365. // only net errors should cause the connection to be closed. Though
  366. // cassandra returning corrupt frames will be returned here as well.
  367. if _, ok := err.(net.Error); ok {
  368. return err
  369. }
  370. }
  371. // we either, return a response to the caller, the caller timedout, or the
  372. // connection has closed. Either way we should never block indefinatly here
  373. select {
  374. case call.resp <- err:
  375. case <-call.timeout:
  376. c.releaseStream(head.stream)
  377. case <-c.quit:
  378. }
  379. return nil
  380. }
  381. type callReq struct {
  382. // could use a waitgroup but this allows us to do timeouts on the read/send
  383. resp chan error
  384. framer *framer
  385. timeout chan struct{} // indicates to recv() that a call has timedout
  386. streamID int // current stream in use
  387. }
  388. func (c *Conn) releaseStream(stream int) {
  389. c.mu.Lock()
  390. call := c.calls[stream]
  391. if call != nil && stream != call.streamID {
  392. panic(fmt.Sprintf("attempt to release streamID with ivalid stream: %d -> %+v\n", stream, call))
  393. } else if call == nil {
  394. panic(fmt.Sprintf("releasing a stream not in use: %d", stream))
  395. }
  396. delete(c.calls, stream)
  397. c.mu.Unlock()
  398. streamPool.Put(call)
  399. c.streams.Clear(stream)
  400. }
  401. func (c *Conn) handleTimeout() {
  402. if atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit {
  403. c.closeWithError(ErrTooManyTimeouts)
  404. }
  405. }
  406. var (
  407. streamPool = sync.Pool{
  408. New: func() interface{} {
  409. return &callReq{
  410. resp: make(chan error),
  411. }
  412. },
  413. }
  414. )
  415. func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
  416. // TODO: move tracer onto conn
  417. stream, ok := c.streams.GetStream()
  418. if !ok {
  419. fmt.Println(c.streams)
  420. return nil, ErrNoStreams
  421. }
  422. // resp is basically a waiting semaphore protecting the framer
  423. framer := newFramer(c, c, c.compressor, c.version)
  424. c.mu.Lock()
  425. call := c.calls[stream]
  426. if call != nil {
  427. c.mu.Unlock()
  428. return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, call.streamID)
  429. } else {
  430. call = streamPool.Get().(*callReq)
  431. }
  432. c.calls[stream] = call
  433. c.mu.Unlock()
  434. call.framer = framer
  435. call.timeout = make(chan struct{})
  436. call.streamID = stream
  437. if tracer != nil {
  438. framer.trace()
  439. }
  440. err := req.writeFrame(framer, stream)
  441. if err != nil {
  442. // I think this is the correct thing to do, im not entirely sure. It is not
  443. // ideal as readers might still get some data, but they probably wont.
  444. // Here we need to be careful as the stream is not available and if all
  445. // writes just timeout or fail then the pool might use this connection to
  446. // send a frame on, with all the streams used up and not returned.
  447. c.closeWithError(err)
  448. return nil, err
  449. }
  450. select {
  451. case err := <-call.resp:
  452. if err != nil {
  453. if !c.Closed() {
  454. // if the connection is closed then we cant release the stream,
  455. // this is because the request is still outstanding and we have
  456. // been handed another error from another stream which caused the
  457. // connection to close.
  458. c.releaseStream(stream)
  459. }
  460. return nil, err
  461. }
  462. case <-time.After(c.timeout):
  463. close(call.timeout)
  464. c.handleTimeout()
  465. return nil, ErrTimeoutNoResponse
  466. case <-c.quit:
  467. return nil, ErrConnectionClosed
  468. }
  469. // dont release the stream if detect a timeout as another request can reuse
  470. // that stream and get a response for the old request, which we have no
  471. // easy way of detecting.
  472. //
  473. // Ensure that the stream is not released if there are potentially outstanding
  474. // requests on the stream to prevent nil pointer dereferences in recv().
  475. defer c.releaseStream(stream)
  476. if v := framer.header.version.version(); v != c.version {
  477. return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
  478. }
  479. return framer, nil
  480. }
  481. func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error) {
  482. stmtsLRU.Lock()
  483. if stmtsLRU.lru == nil {
  484. initStmtsLRU(defaultMaxPreparedStmts)
  485. }
  486. stmtCacheKey := c.addr + c.currentKeyspace + stmt
  487. if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  488. stmtsLRU.Unlock()
  489. flight := val.(*inflightPrepare)
  490. flight.wg.Wait()
  491. return &flight.info, flight.err
  492. }
  493. flight := new(inflightPrepare)
  494. flight.wg.Add(1)
  495. stmtsLRU.lru.Add(stmtCacheKey, flight)
  496. stmtsLRU.Unlock()
  497. prep := &writePrepareFrame{
  498. statement: stmt,
  499. }
  500. framer, err := c.exec(prep, tracer)
  501. if err != nil {
  502. flight.err = err
  503. flight.wg.Done()
  504. return nil, err
  505. }
  506. frame, err := framer.parseFrame()
  507. if err != nil {
  508. flight.err = err
  509. flight.wg.Done()
  510. return nil, err
  511. }
  512. // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
  513. // everytime we need to parse a frame.
  514. if len(framer.traceID) > 0 {
  515. tracer.Trace(framer.traceID)
  516. }
  517. switch x := frame.(type) {
  518. case *resultPreparedFrame:
  519. // defensivly copy as we will recycle the underlying buffer after we
  520. // return.
  521. flight.info.Id = copyBytes(x.preparedID)
  522. // the type info's should _not_ have a reference to the framers read buffer,
  523. // therefore we can just copy them directly.
  524. flight.info.Args = x.reqMeta.columns
  525. flight.info.PKeyColumns = x.reqMeta.pkeyColumns
  526. flight.info.Rval = x.respMeta.columns
  527. case error:
  528. flight.err = x
  529. default:
  530. flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
  531. }
  532. flight.wg.Done()
  533. if flight.err != nil {
  534. stmtsLRU.Lock()
  535. stmtsLRU.lru.Remove(stmtCacheKey)
  536. stmtsLRU.Unlock()
  537. }
  538. framerPool.Put(framer)
  539. return &flight.info, flight.err
  540. }
  541. func (c *Conn) executeQuery(qry *Query) *Iter {
  542. params := queryParams{
  543. consistency: qry.cons,
  544. }
  545. // frame checks that it is not 0
  546. params.serialConsistency = qry.serialCons
  547. params.defaultTimestamp = qry.defaultTimestamp
  548. if len(qry.pageState) > 0 {
  549. params.pagingState = qry.pageState
  550. }
  551. if qry.pageSize > 0 {
  552. params.pageSize = qry.pageSize
  553. }
  554. var frame frameWriter
  555. if qry.shouldPrepare() {
  556. // Prepare all DML queries. Other queries can not be prepared.
  557. info, err := c.prepareStatement(qry.stmt, qry.trace)
  558. if err != nil {
  559. return &Iter{err: err}
  560. }
  561. var values []interface{}
  562. if qry.binding == nil {
  563. values = qry.values
  564. } else {
  565. values, err = qry.binding(info)
  566. if err != nil {
  567. return &Iter{err: err}
  568. }
  569. }
  570. if len(values) != len(info.Args) {
  571. return &Iter{err: ErrQueryArgLength}
  572. }
  573. params.values = make([]queryValues, len(values))
  574. for i := 0; i < len(values); i++ {
  575. val, err := Marshal(info.Args[i].TypeInfo, values[i])
  576. if err != nil {
  577. return &Iter{err: err}
  578. }
  579. v := &params.values[i]
  580. v.value = val
  581. // TODO: handle query binding names
  582. }
  583. frame = &writeExecuteFrame{
  584. preparedID: info.Id,
  585. params: params,
  586. }
  587. } else {
  588. frame = &writeQueryFrame{
  589. statement: qry.stmt,
  590. params: params,
  591. }
  592. }
  593. framer, err := c.exec(frame, qry.trace)
  594. if err != nil {
  595. return &Iter{err: err}
  596. }
  597. resp, err := framer.parseFrame()
  598. if err != nil {
  599. return &Iter{err: err}
  600. }
  601. if len(framer.traceID) > 0 {
  602. qry.trace.Trace(framer.traceID)
  603. }
  604. switch x := resp.(type) {
  605. case *resultVoidFrame:
  606. return &Iter{framer: framer}
  607. case *resultRowsFrame:
  608. iter := &Iter{
  609. meta: x.meta,
  610. rows: x.rows,
  611. framer: framer,
  612. }
  613. if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
  614. iter.next = &nextIter{
  615. qry: *qry,
  616. pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
  617. }
  618. iter.next.qry.pageState = x.meta.pagingState
  619. if iter.next.pos < 1 {
  620. iter.next.pos = 1
  621. }
  622. }
  623. return iter
  624. case *resultKeyspaceFrame:
  625. return &Iter{framer: framer}
  626. case *resultSchemaChangeFrame, *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction:
  627. iter := &Iter{framer: framer}
  628. c.awaitSchemaAgreement()
  629. // dont return an error from this, might be a good idea to give a warning
  630. // though. The impact of this returning an error would be that the cluster
  631. // is not consistent with regards to its schema.
  632. return iter
  633. case *RequestErrUnprepared:
  634. stmtsLRU.Lock()
  635. stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
  636. if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  637. stmtsLRU.lru.Remove(stmtCacheKey)
  638. stmtsLRU.Unlock()
  639. return c.executeQuery(qry)
  640. }
  641. stmtsLRU.Unlock()
  642. return &Iter{err: x, framer: framer}
  643. case error:
  644. return &Iter{err: x, framer: framer}
  645. default:
  646. return &Iter{
  647. err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x),
  648. framer: framer,
  649. }
  650. }
  651. }
  652. func (c *Conn) Pick(qry *Query) *Conn {
  653. if c.Closed() {
  654. return nil
  655. }
  656. return c
  657. }
  658. func (c *Conn) Closed() bool {
  659. return atomic.LoadInt32(&c.closed) == 1
  660. }
  661. func (c *Conn) Address() string {
  662. return c.addr
  663. }
  664. func (c *Conn) AvailableStreams() int {
  665. return c.streams.Available()
  666. }
  667. func (c *Conn) UseKeyspace(keyspace string) error {
  668. q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
  669. q.params.consistency = Any
  670. framer, err := c.exec(q, nil)
  671. if err != nil {
  672. return err
  673. }
  674. resp, err := framer.parseFrame()
  675. if err != nil {
  676. return err
  677. }
  678. switch x := resp.(type) {
  679. case *resultKeyspaceFrame:
  680. case error:
  681. return x
  682. default:
  683. return NewErrProtocol("unknown frame in response to USE: %v", x)
  684. }
  685. c.currentKeyspace = keyspace
  686. return nil
  687. }
  688. func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
  689. if c.version == protoVersion1 {
  690. return nil, ErrUnsupported
  691. }
  692. n := len(batch.Entries)
  693. req := &writeBatchFrame{
  694. typ: batch.Type,
  695. statements: make([]batchStatment, n),
  696. consistency: batch.Cons,
  697. serialConsistency: batch.serialCons,
  698. defaultTimestamp: batch.defaultTimestamp,
  699. }
  700. stmts := make(map[string]string)
  701. for i := 0; i < n; i++ {
  702. entry := &batch.Entries[i]
  703. b := &req.statements[i]
  704. if len(entry.Args) > 0 || entry.binding != nil {
  705. info, err := c.prepareStatement(entry.Stmt, nil)
  706. if err != nil {
  707. return nil, err
  708. }
  709. var args []interface{}
  710. if entry.binding == nil {
  711. args = entry.Args
  712. } else {
  713. args, err = entry.binding(info)
  714. if err != nil {
  715. return nil, err
  716. }
  717. }
  718. if len(args) != len(info.Args) {
  719. return nil, ErrQueryArgLength
  720. }
  721. b.preparedID = info.Id
  722. stmts[string(info.Id)] = entry.Stmt
  723. b.values = make([]queryValues, len(info.Args))
  724. for j := 0; j < len(info.Args); j++ {
  725. val, err := Marshal(info.Args[j].TypeInfo, args[j])
  726. if err != nil {
  727. return nil, err
  728. }
  729. b.values[j].value = val
  730. // TODO: add names
  731. }
  732. } else {
  733. b.statement = entry.Stmt
  734. }
  735. }
  736. // TODO: should batch support tracing?
  737. framer, err := c.exec(req, nil)
  738. if err != nil {
  739. return nil, err
  740. }
  741. resp, err := framer.parseFrame()
  742. if err != nil {
  743. return nil, err
  744. }
  745. switch x := resp.(type) {
  746. case *resultVoidFrame:
  747. framerPool.Put(framer)
  748. return nil, nil
  749. case *RequestErrUnprepared:
  750. stmt, found := stmts[string(x.StatementId)]
  751. if found {
  752. stmtsLRU.Lock()
  753. stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
  754. stmtsLRU.Unlock()
  755. }
  756. framerPool.Put(framer)
  757. if found {
  758. return c.executeBatch(batch)
  759. } else {
  760. return nil, x
  761. }
  762. case *resultRowsFrame:
  763. iter := &Iter{
  764. meta: x.meta,
  765. rows: x.rows,
  766. framer: framer,
  767. }
  768. return iter, nil
  769. case error:
  770. framerPool.Put(framer)
  771. return nil, x
  772. default:
  773. framerPool.Put(framer)
  774. return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
  775. }
  776. }
  777. func (c *Conn) setKeepalive(d time.Duration) error {
  778. if tc, ok := c.conn.(*net.TCPConn); ok {
  779. err := tc.SetKeepAlivePeriod(d)
  780. if err != nil {
  781. return err
  782. }
  783. return tc.SetKeepAlive(true)
  784. }
  785. return nil
  786. }
  787. func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) {
  788. q := c.session.Query(statement, values...).Consistency(One)
  789. return c.executeQuery(q)
  790. }
  791. func (c *Conn) awaitSchemaAgreement() (err error) {
  792. const (
  793. peerSchemas = "SELECT schema_version FROM system.peers"
  794. localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
  795. )
  796. endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
  797. for time.Now().Before(endDeadline) {
  798. iter := c.query(peerSchemas)
  799. versions := make(map[string]struct{})
  800. var schemaVersion string
  801. for iter.Scan(&schemaVersion) {
  802. versions[schemaVersion] = struct{}{}
  803. schemaVersion = ""
  804. }
  805. if err = iter.Close(); err != nil {
  806. goto cont
  807. }
  808. iter = c.query(localSchemas)
  809. for iter.Scan(&schemaVersion) {
  810. versions[schemaVersion] = struct{}{}
  811. schemaVersion = ""
  812. }
  813. if err = iter.Close(); err != nil {
  814. goto cont
  815. }
  816. if len(versions) <= 1 {
  817. return nil
  818. }
  819. cont:
  820. time.Sleep(200 * time.Millisecond)
  821. }
  822. if err != nil {
  823. return
  824. }
  825. // not exported
  826. return errors.New("gocql: cluster schema versions not consistent")
  827. }
  828. type inflightPrepare struct {
  829. info QueryInfo
  830. err error
  831. wg sync.WaitGroup
  832. }
  833. var (
  834. ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
  835. ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
  836. ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
  837. ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
  838. ErrNoStreams = errors.New("gocql: no streams available on connection")
  839. )