conn.go 17 KB


  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "bufio"
  7. "crypto/tls"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "time"
  18. )
  19. const (
  20. defaultFrameSize = 4096
  21. flagResponse = 0x80
  22. maskVersion = 0x7F
  23. )
  24. //JoinHostPort is a utility to return a address string that can be used
  25. //gocql.Conn to form a connection with a host.
  26. func JoinHostPort(addr string, port int) string {
  27. addr = strings.TrimSpace(addr)
  28. if _, _, err := net.SplitHostPort(addr); err != nil {
  29. addr = net.JoinHostPort(addr, strconv.Itoa(port))
  30. }
  31. return addr
  32. }
  33. type Authenticator interface {
  34. Challenge(req []byte) (resp []byte, auth Authenticator, err error)
  35. Success(data []byte) error
  36. }
  37. type PasswordAuthenticator struct {
  38. Username string
  39. Password string
  40. }
  41. func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
  42. if string(req) != "org.apache.cassandra.auth.PasswordAuthenticator" {
  43. return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
  44. }
  45. resp := make([]byte, 2+len(p.Username)+len(p.Password))
  46. resp[0] = 0
  47. copy(resp[1:], p.Username)
  48. resp[len(p.Username)+1] = 0
  49. copy(resp[2+len(p.Username):], p.Password)
  50. return resp, nil, nil
  51. }
  52. func (p PasswordAuthenticator) Success(data []byte) error {
  53. return nil
  54. }
  55. type SslOptions struct {
  56. CertPath string
  57. KeyPath string
  58. CaPath string //optional depending on server config
  59. // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
  60. // This option is basically the inverse of InSecureSkipVerify
  61. // See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
  62. EnableHostVerification bool
  63. }
  64. type ConnConfig struct {
  65. ProtoVersion int
  66. CQLVersion string
  67. Timeout time.Duration
  68. NumStreams int
  69. Compressor Compressor
  70. Authenticator Authenticator
  71. Keepalive time.Duration
  72. TLSConfig *tls.Config
  73. }
  74. // Conn is a single connection to a Cassandra node. It can be used to execute
  75. // queries, but users are usually advised to use a more reliable, higher
  76. // level API.
  77. type Conn struct {
  78. conn net.Conn
  79. r *bufio.Reader
  80. timeout time.Duration
  81. uniq chan int
  82. calls []callReq
  83. nwait int32
  84. pool ConnectionPool
  85. compressor Compressor
  86. auth Authenticator
  87. addr string
  88. version uint8
  89. currentKeyspace string
  90. closedMu sync.RWMutex
  91. isClosed bool
  92. }
  93. // Connect establishes a connection to a Cassandra node.
  94. // You must also call the Serve method before you can execute any queries.
  95. func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
  96. var (
  97. err error
  98. conn net.Conn
  99. )
  100. if cfg.TLSConfig != nil {
  101. // the TLS config is safe to be reused by connections but it must not
  102. // be modified after being used.
  103. if conn, err = tls.Dial("tcp", addr, cfg.TLSConfig); err != nil {
  104. return nil, err
  105. }
  106. } else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {
  107. return nil, err
  108. }
  109. // going to default to proto 2
  110. if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
  111. log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
  112. cfg.ProtoVersion = 2
  113. }
  114. maxStreams := 128
  115. if cfg.ProtoVersion > protoVersion2 {
  116. maxStreams = 32768
  117. }
  118. if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
  119. cfg.NumStreams = maxStreams
  120. }
  121. c := &Conn{
  122. conn: conn,
  123. r: bufio.NewReader(conn),
  124. uniq: make(chan int, cfg.NumStreams),
  125. calls: make([]callReq, cfg.NumStreams),
  126. timeout: cfg.Timeout,
  127. version: uint8(cfg.ProtoVersion),
  128. addr: conn.RemoteAddr().String(),
  129. pool: pool,
  130. compressor: cfg.Compressor,
  131. auth: cfg.Authenticator,
  132. }
  133. if cfg.Keepalive > 0 {
  134. c.setKeepalive(cfg.Keepalive)
  135. }
  136. for i := 0; i < cfg.NumStreams; i++ {
  137. c.uniq <- i
  138. }
  139. if err := c.startup(&cfg); err != nil {
  140. conn.Close()
  141. return nil, err
  142. }
  143. go c.serve()
  144. return c, nil
  145. }
  146. func (c *Conn) startup(cfg *ConnConfig) error {
  147. compression := ""
  148. if c.compressor != nil {
  149. compression = c.compressor.Name()
  150. }
  151. var req operation = &startupFrame{
  152. CQLVersion: cfg.CQLVersion,
  153. Compression: compression,
  154. }
  155. var challenger Authenticator
  156. for {
  157. resp, err := c.execSimple(req)
  158. if err != nil {
  159. return err
  160. }
  161. switch x := resp.(type) {
  162. case readyFrame:
  163. return nil
  164. case error:
  165. return x
  166. case authenticateFrame:
  167. if c.auth == nil {
  168. return fmt.Errorf("authentication required (using %q)", x.Authenticator)
  169. }
  170. var resp []byte
  171. resp, challenger, err = c.auth.Challenge([]byte(x.Authenticator))
  172. if err != nil {
  173. return err
  174. }
  175. req = &authResponseFrame{resp}
  176. case authChallengeFrame:
  177. if challenger == nil {
  178. return fmt.Errorf("authentication error (invalid challenge)")
  179. }
  180. var resp []byte
  181. resp, challenger, err = challenger.Challenge(x.Data)
  182. if err != nil {
  183. return err
  184. }
  185. req = &authResponseFrame{resp}
  186. case authSuccessFrame:
  187. if challenger != nil {
  188. return challenger.Success(x.Data)
  189. }
  190. return nil
  191. default:
  192. return NewErrProtocol("Unknown type of response to startup frame: %s", x)
  193. }
  194. }
  195. }
  196. // Serve starts the stream multiplexer for this connection, which is required
  197. // to execute any queries. This method runs as long as the connection is
  198. // open and is therefore usually called in a separate goroutine.
  199. func (c *Conn) serve() {
  200. var (
  201. err error
  202. resp frame
  203. )
  204. for {
  205. resp, err = c.recv()
  206. if err != nil {
  207. break
  208. }
  209. c.dispatch(resp)
  210. }
  211. c.Close()
  212. for id := 0; id < len(c.calls); id++ {
  213. req := &c.calls[id]
  214. if atomic.LoadInt32(&req.active) == 1 {
  215. req.resp <- callResp{nil, err}
  216. }
  217. }
  218. c.pool.HandleError(c, err, true)
  219. }
  220. func (c *Conn) Write(p []byte) (int, error) {
  221. c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
  222. return c.conn.Write(p)
  223. }
  224. func (c *Conn) Read(p []byte) (int, error) {
  225. return c.r.Read(p)
  226. }
  227. func (c *Conn) recv() (frame, error) {
  228. size := headerProtoSize[c.version]
  229. resp := make(frame, size, size+512)
  230. // read a full header, ignore timeouts, as this is being ran in a loop
  231. c.conn.SetReadDeadline(time.Time{})
  232. _, err := io.ReadFull(c.r, resp[:size])
  233. if err != nil {
  234. return nil, err
  235. }
  236. if v := c.version | flagResponse; resp[0] != v {
  237. return nil, NewErrProtocol("recv: response protocol version does not match connection protocol version (%d != %d)", resp[0], v)
  238. }
  239. bodySize := resp.Length(c.version)
  240. if bodySize == 0 {
  241. return resp, nil
  242. }
  243. resp.grow(bodySize)
  244. const maxAttempts = 5
  245. n := size
  246. for i := 0; i < maxAttempts; i++ {
  247. var nn int
  248. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  249. nn, err = io.ReadFull(c.r, resp[n:size+bodySize])
  250. if err == nil {
  251. break
  252. }
  253. n += nn
  254. if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
  255. break
  256. }
  257. }
  258. if err != nil {
  259. return nil, err
  260. }
  261. return resp, nil
  262. }
  263. func (c *Conn) execSimple(op operation) (interface{}, error) {
  264. f, err := op.encodeFrame(c.version, nil)
  265. if err != nil {
  266. // this should be a noop err
  267. return nil, err
  268. }
  269. bodyLen := len(f) - headerProtoSize[c.version]
  270. f.setLength(bodyLen, c.version)
  271. if _, err := c.Write([]byte(f)); err != nil {
  272. c.Close()
  273. return nil, err
  274. }
  275. // here recv wont timeout waiting for a header, should it?
  276. if f, err = c.recv(); err != nil {
  277. return nil, err
  278. }
  279. return c.decodeFrame(f, nil)
  280. }
  281. func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
  282. req, err := op.encodeFrame(c.version, nil)
  283. if err != nil {
  284. return nil, err
  285. }
  286. if trace != nil {
  287. req[1] |= flagTrace
  288. }
  289. headerSize := headerProtoSize[c.version]
  290. if len(req) > headerSize && c.compressor != nil {
  291. body, err := c.compressor.Encode([]byte(req[headerSize:]))
  292. if err != nil {
  293. return nil, err
  294. }
  295. req = append(req[:headerSize], frame(body)...)
  296. req[1] |= flagCompress
  297. }
  298. bodyLen := len(req) - headerSize
  299. req.setLength(bodyLen, c.version)
  300. id := <-c.uniq
  301. req.setStream(id, c.version)
  302. call := &c.calls[id]
  303. call.resp = make(chan callResp, 1)
  304. atomic.AddInt32(&c.nwait, 1)
  305. atomic.StoreInt32(&call.active, 1)
  306. if _, err := c.Write(req); err != nil {
  307. c.uniq <- id
  308. c.Close()
  309. return nil, err
  310. }
  311. reply := <-call.resp
  312. call.resp = nil
  313. c.uniq <- id
  314. if reply.err != nil {
  315. return nil, reply.err
  316. }
  317. return c.decodeFrame(reply.buf, trace)
  318. }
  319. func (c *Conn) dispatch(resp frame) {
  320. id := resp.Stream(c.version)
  321. if id >= len(c.calls) {
  322. return
  323. }
  324. call := &c.calls[id]
  325. if !atomic.CompareAndSwapInt32(&call.active, 1, 0) {
  326. return
  327. }
  328. atomic.AddInt32(&c.nwait, -1)
  329. call.resp <- callResp{resp, nil}
  330. }
  331. func (c *Conn) ping() error {
  332. _, err := c.exec(&optionsFrame{}, nil)
  333. return err
  334. }
  335. func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
  336. stmtsLRU.Lock()
  337. if stmtsLRU.lru == nil {
  338. initStmtsLRU(defaultMaxPreparedStmts)
  339. }
  340. stmtCacheKey := c.addr + c.currentKeyspace + stmt
  341. if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  342. flight := val.(*inflightPrepare)
  343. stmtsLRU.Unlock()
  344. flight.wg.Wait()
  345. return flight.info, flight.err
  346. }
  347. flight := new(inflightPrepare)
  348. flight.wg.Add(1)
  349. stmtsLRU.lru.Add(stmtCacheKey, flight)
  350. stmtsLRU.Unlock()
  351. resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
  352. if err != nil {
  353. flight.err = err
  354. } else {
  355. switch x := resp.(type) {
  356. case resultPreparedFrame:
  357. flight.info = &QueryInfo{
  358. Id: x.PreparedId,
  359. Args: x.Arguments,
  360. Rval: x.ReturnValues,
  361. }
  362. case error:
  363. flight.err = x
  364. default:
  365. flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
  366. }
  367. err = flight.err
  368. }
  369. flight.wg.Done()
  370. if err != nil {
  371. stmtsLRU.Lock()
  372. stmtsLRU.lru.Remove(stmtCacheKey)
  373. stmtsLRU.Unlock()
  374. }
  375. return flight.info, flight.err
  376. }
  377. func (c *Conn) executeQuery(qry *Query) *Iter {
  378. op := &queryFrame{
  379. Stmt: qry.stmt,
  380. Cons: qry.cons,
  381. PageSize: qry.pageSize,
  382. PageState: qry.pageState,
  383. }
  384. if qry.shouldPrepare() {
  385. // Prepare all DML queries. Other queries can not be prepared.
  386. info, err := c.prepareStatement(qry.stmt, qry.trace)
  387. if err != nil {
  388. return &Iter{err: err}
  389. }
  390. var values []interface{}
  391. if qry.binding == nil {
  392. values = qry.values
  393. } else {
  394. values, err = qry.binding(info)
  395. if err != nil {
  396. return &Iter{err: err}
  397. }
  398. }
  399. if len(values) != len(info.Args) {
  400. return &Iter{err: ErrQueryArgLength}
  401. }
  402. op.Prepared = info.Id
  403. op.Values = make([][]byte, len(values))
  404. for i := 0; i < len(values); i++ {
  405. val, err := Marshal(info.Args[i].TypeInfo, values[i])
  406. if err != nil {
  407. return &Iter{err: err}
  408. }
  409. op.Values[i] = val
  410. }
  411. }
  412. resp, err := c.exec(op, qry.trace)
  413. if err != nil {
  414. return &Iter{err: err}
  415. }
  416. switch x := resp.(type) {
  417. case resultVoidFrame:
  418. return &Iter{}
  419. case resultRowsFrame:
  420. iter := &Iter{columns: x.Columns, rows: x.Rows}
  421. if len(x.PagingState) > 0 {
  422. iter.next = &nextIter{
  423. qry: *qry,
  424. pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
  425. }
  426. iter.next.qry.pageState = x.PagingState
  427. if iter.next.pos < 1 {
  428. iter.next.pos = 1
  429. }
  430. }
  431. return iter
  432. case resultKeyspaceFrame:
  433. return &Iter{}
  434. case RequestErrUnprepared:
  435. stmtsLRU.Lock()
  436. stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
  437. if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  438. stmtsLRU.lru.Remove(stmtCacheKey)
  439. stmtsLRU.Unlock()
  440. return c.executeQuery(qry)
  441. }
  442. stmtsLRU.Unlock()
  443. return &Iter{err: x}
  444. case error:
  445. return &Iter{err: x}
  446. default:
  447. return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
  448. }
  449. }
  450. func (c *Conn) Pick(qry *Query) *Conn {
  451. if c.Closed() {
  452. return nil
  453. }
  454. return c
  455. }
  456. func (c *Conn) Closed() bool {
  457. c.closedMu.RLock()
  458. closed := c.isClosed
  459. c.closedMu.RUnlock()
  460. return closed
  461. }
  462. func (c *Conn) Close() {
  463. c.closedMu.Lock()
  464. if c.isClosed {
  465. c.closedMu.Unlock()
  466. return
  467. }
  468. c.isClosed = true
  469. c.closedMu.Unlock()
  470. c.conn.Close()
  471. }
  472. func (c *Conn) Address() string {
  473. return c.addr
  474. }
  475. func (c *Conn) AvailableStreams() int {
  476. return len(c.uniq)
  477. }
  478. func (c *Conn) UseKeyspace(keyspace string) error {
  479. resp, err := c.exec(&queryFrame{Stmt: `USE "` + keyspace + `"`, Cons: Any}, nil)
  480. if err != nil {
  481. return err
  482. }
  483. switch x := resp.(type) {
  484. case resultKeyspaceFrame:
  485. case error:
  486. return x
  487. default:
  488. return NewErrProtocol("Unknown type in response to USE: %s", x)
  489. }
  490. c.currentKeyspace = keyspace
  491. return nil
  492. }
  493. func (c *Conn) executeBatch(batch *Batch) error {
  494. if c.version == protoVersion1 {
  495. return ErrUnsupported
  496. }
  497. f := newFrame(c.version)
  498. f.setHeader(c.version, 0, 0, opBatch)
  499. f.writeByte(byte(batch.Type))
  500. f.writeShort(uint16(len(batch.Entries)))
  501. stmts := make(map[string]string)
  502. for i := 0; i < len(batch.Entries); i++ {
  503. entry := &batch.Entries[i]
  504. var info *QueryInfo
  505. var args []interface{}
  506. if len(entry.Args) > 0 || entry.binding != nil {
  507. var err error
  508. info, err = c.prepareStatement(entry.Stmt, nil)
  509. if err != nil {
  510. return err
  511. }
  512. if entry.binding == nil {
  513. args = entry.Args
  514. } else {
  515. args, err = entry.binding(info)
  516. if err != nil {
  517. return err
  518. }
  519. }
  520. if len(args) != len(info.Args) {
  521. return ErrQueryArgLength
  522. }
  523. stmts[string(info.Id)] = entry.Stmt
  524. f.writeByte(1)
  525. f.writeShortBytes(info.Id)
  526. } else {
  527. f.writeByte(0)
  528. f.writeLongString(entry.Stmt)
  529. }
  530. f.writeShort(uint16(len(args)))
  531. for j := 0; j < len(args); j++ {
  532. val, err := Marshal(info.Args[j].TypeInfo, args[j])
  533. if err != nil {
  534. return err
  535. }
  536. f.writeBytes(val)
  537. }
  538. }
  539. f.writeConsistency(batch.Cons)
  540. if c.version >= protoVersion3 {
  541. // TODO: add support for flags here
  542. f.writeByte(0)
  543. }
  544. resp, err := c.exec(f, nil)
  545. if err != nil {
  546. return err
  547. }
  548. switch x := resp.(type) {
  549. case resultVoidFrame:
  550. return nil
  551. case RequestErrUnprepared:
  552. stmt, found := stmts[string(x.StatementId)]
  553. if found {
  554. stmtsLRU.Lock()
  555. stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
  556. stmtsLRU.Unlock()
  557. }
  558. if found {
  559. return c.executeBatch(batch)
  560. } else {
  561. return x
  562. }
  563. case error:
  564. return x
  565. default:
  566. return NewErrProtocol("Unknown type in response to batch statement: %s", x)
  567. }
  568. }
  569. func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error) {
  570. defer func() {
  571. if r := recover(); r != nil {
  572. if e, ok := r.(ErrProtocol); ok {
  573. err = e
  574. return
  575. }
  576. panic(r)
  577. }
  578. }()
  579. headerSize := headerProtoSize[c.version]
  580. if len(f) < headerSize {
  581. return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
  582. } else if f[0] != c.version|flagResponse {
  583. return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
  584. }
  585. flags, op, f := f[1], f.Op(c.version), f[headerSize:]
  586. if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
  587. if buf, err := c.compressor.Decode([]byte(f)); err != nil {
  588. return nil, err
  589. } else {
  590. f = frame(buf)
  591. }
  592. }
  593. if flags&flagTrace != 0 {
  594. if len(f) < 16 {
  595. return nil, NewErrProtocol("Decoding frame: length of frame less than 16 while tracing is enabled")
  596. }
  597. traceId := []byte(f[:16])
  598. f = f[16:]
  599. trace.Trace(traceId)
  600. }
  601. switch op {
  602. case opReady:
  603. return readyFrame{}, nil
  604. case opResult:
  605. switch kind := f.readInt(); kind {
  606. case resultKindVoid:
  607. return resultVoidFrame{}, nil
  608. case resultKindRows:
  609. columns, pageState := f.readMetaData(c.version)
  610. numRows := f.readInt()
  611. values := make([][]byte, numRows*len(columns))
  612. for i := 0; i < len(values); i++ {
  613. values[i] = f.readBytes()
  614. }
  615. rows := make([][][]byte, numRows)
  616. for i := 0; i < numRows; i++ {
  617. rows[i], values = values[:len(columns)], values[len(columns):]
  618. }
  619. return resultRowsFrame{columns, rows, pageState}, nil
  620. case resultKindKeyspace:
  621. keyspace := f.readString()
  622. return resultKeyspaceFrame{keyspace}, nil
  623. case resultKindPrepared:
  624. id := f.readShortBytes()
  625. args, _ := f.readMetaData(c.version)
  626. if c.version < 2 {
  627. return resultPreparedFrame{PreparedId: id, Arguments: args}, nil
  628. }
  629. rvals, _ := f.readMetaData(c.version)
  630. return resultPreparedFrame{PreparedId: id, Arguments: args, ReturnValues: rvals}, nil
  631. case resultKindSchemaChanged:
  632. return resultVoidFrame{}, nil
  633. default:
  634. return nil, NewErrProtocol("Decoding frame: unknown result kind %s", kind)
  635. }
  636. case opAuthenticate:
  637. return authenticateFrame{f.readString()}, nil
  638. case opAuthChallenge:
  639. return authChallengeFrame{f.readBytes()}, nil
  640. case opAuthSuccess:
  641. return authSuccessFrame{f.readBytes()}, nil
  642. case opSupported:
  643. return supportedFrame{}, nil
  644. case opError:
  645. return f.readError(), nil
  646. default:
  647. return nil, NewErrProtocol("Decoding frame: unknown op", op)
  648. }
  649. }
  650. func (c *Conn) setKeepalive(d time.Duration) error {
  651. if tc, ok := c.conn.(*net.TCPConn); ok {
  652. err := tc.SetKeepAlivePeriod(d)
  653. if err != nil {
  654. return err
  655. }
  656. return tc.SetKeepAlive(true)
  657. }
  658. return nil
  659. }
  660. // QueryInfo represents the meta data associated with a prepared CQL statement.
  661. type QueryInfo struct {
  662. Id []byte
  663. Args []ColumnInfo
  664. Rval []ColumnInfo
  665. }
  666. type callReq struct {
  667. active int32
  668. resp chan callResp
  669. }
  670. type callResp struct {
  671. buf frame
  672. err error
  673. }
  674. type inflightPrepare struct {
  675. info *QueryInfo
  676. err error
  677. wg sync.WaitGroup
  678. }
  679. var (
  680. ErrQueryArgLength = errors.New("query argument length mismatch")
  681. )