瀏覽代碼

added snappy support

Christoph Hack 12 年之前
父節點
當前提交
d7b351c455
共有 4 個文件被更改,包括 121 次插入70 次删除
  1. 4 0
      cluster.go
  2. 115 15
      conn.go
  3. 1 55
      frame.go
  4. 1 0
      gocql_test/main.go

+ 4 - 0
cluster.go

@@ -6,6 +6,7 @@ package gocql
 
 import (
 	"fmt"
+	"log"
 	"strings"
 	"sync"
 	"time"
@@ -28,6 +29,7 @@ type ClusterConfig struct {
 	DelayMax     time.Duration // maximum reconnection delay (default: 10min)
 	StartupMin   int           // wait for StartupMin hosts (default: len(Hosts)/2+1)
 	Consistency  Consistency   // default consistency level (default: Quorum)
+	Compressor   Compressor    // compression algorithm (default: nil)
 }
 
 // NewCluster generates a new config for the default cluster implementation.
@@ -94,11 +96,13 @@ func (c *clusterImpl) connect(addr string) {
 		CQLVersion:   c.cfg.CQLVersion,
 		Timeout:      c.cfg.Timeout,
 		NumStreams:   c.cfg.NumStreams,
+		Compressor:   c.cfg.Compressor,
 	}
 	delay := c.cfg.DelayMin
 	for {
 		conn, err := Connect(addr, cfg, c)
 		if err != nil {
+			log.Printf("failed to connect to %q: %v", addr, err)
 			select {
 			case <-time.After(delay):
 				if delay *= 2; delay > c.cfg.DelayMax {

+ 115 - 15
conn.go

@@ -9,6 +9,8 @@ import (
 	"sync"
 	"sync/atomic"
 	"time"
+
+	"code.google.com/p/snappy-go/snappy"
 )
 
 const defaultFrameSize = 4096
@@ -31,6 +33,7 @@ type ConnConfig struct {
 	CQLVersion   string
 	Timeout      time.Duration
 	NumStreams   int
+	Compressor   Compressor
 }
 
 // Conn is a single connection to a Cassandra node. It can be used to execute
@@ -47,9 +50,10 @@ type Conn struct {
 	prepMu sync.Mutex
 	prep   map[string]*queryInfo
 
-	cluster Cluster
-	addr    string
-	version uint8
+	cluster    Cluster
+	compressor Compressor
+	addr       string
+	version    uint8
 }
 
 // Connect establishes a connection to a Cassandra node.
@@ -66,14 +70,15 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
 		cfg.ProtoVersion = 2
 	}
 	c := &Conn{
-		conn:    conn,
-		uniq:    make(chan uint8, cfg.NumStreams),
-		calls:   make([]callReq, cfg.NumStreams),
-		prep:    make(map[string]*queryInfo),
-		timeout: cfg.Timeout,
-		version: uint8(cfg.ProtoVersion),
-		addr:    conn.RemoteAddr().String(),
-		cluster: cluster,
+		conn:       conn,
+		uniq:       make(chan uint8, cfg.NumStreams),
+		calls:      make([]callReq, cfg.NumStreams),
+		prep:       make(map[string]*queryInfo),
+		timeout:    cfg.Timeout,
+		version:    uint8(cfg.ProtoVersion),
+		addr:       conn.RemoteAddr().String(),
+		cluster:    cluster,
+		compressor: cfg.Compressor,
 	}
 	for i := 0; i < cap(c.uniq); i++ {
 		c.uniq <- uint8(i)
@@ -91,9 +96,13 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
 func (c *Conn) startup(cfg *ConnConfig) error {
 	req := make(frame, headerSize, defaultFrameSize)
 	req.setHeader(c.version, 0, 0, opStartup)
-	req.writeStringMap(map[string]string{
+	m := map[string]string{
 		"CQL_VERSION": cfg.CQLVersion,
-	})
+	}
+	if c.compressor != nil {
+		m["COMPRESSION"] = c.compressor.Name()
+	}
+	req.writeStringMap(m)
 	resp, err := c.callSimple(req)
 	if err != nil {
 		return err
@@ -177,7 +186,7 @@ func (c *Conn) callSimple(req frame) (interface{}, error) {
 	if err != nil {
 		return nil, err
 	}
-	return decodeFrame(buf)
+	return c.decodeFrame(buf)
 }
 
 func (c *Conn) call(req frame) (interface{}, error) {
@@ -190,6 +199,15 @@ func (c *Conn) call(req frame) (interface{}, error) {
 	atomic.StoreInt32(&call.active, 1)
 
 	req.setLength(len(req) - headerSize)
+	if len(req) > headerSize && c.compressor != nil {
+		body, err := c.compressor.Encode([]byte(req[headerSize:]))
+		if err != nil {
+			return nil, err
+		}
+		req = append(req[:headerSize], frame(body)...)
+		req[1] |= flagCompress
+		req.setLength(len(req) - headerSize)
+	}
 	if n, err := c.conn.Write(req); err != nil {
 		c.conn.Close()
 		if n > 0 {
@@ -205,7 +223,7 @@ func (c *Conn) call(req frame) (interface{}, error) {
 	if reply.err != nil {
 		return nil, reply.err
 	}
-	return decodeFrame(reply.buf)
+	return c.decodeFrame(reply.buf)
 }
 
 func (c *Conn) dispatch(resp frame) {
@@ -402,6 +420,68 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	return nil
 }
 
+func (c *Conn) decodeFrame(f frame) (rval interface{}, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			if e, ok := r.(error); ok && e == ErrProtocol {
+				err = e
+				return
+			}
+			panic(r)
+		}
+	}()
+	if len(f) < headerSize || (f[0] != c.version|flagResponse) {
+		return nil, ErrProtocol
+	}
+	flags, op, f := f[1], f[3], f[headerSize:]
+	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
+		if buf, err := c.compressor.Decode([]byte(f)); err != nil {
+			return nil, err
+		} else {
+			f = frame(buf)
+		}
+	}
+
+	switch op {
+	case opReady:
+		return readyFrame{}, nil
+	case opResult:
+		switch kind := f.readInt(); kind {
+		case resultKindVoid:
+			return resultVoidFrame{}, nil
+		case resultKindRows:
+			columns := f.readMetaData()
+			numRows := f.readInt()
+			values := make([][]byte, numRows*len(columns))
+			for i := 0; i < len(values); i++ {
+				values[i] = f.readBytes()
+			}
+			rows := make([][][]byte, numRows)
+			for i := 0; i < len(values); i += len(columns) {
+				rows[i] = values[i : i+len(columns)]
+			}
+			return resultRowsFrame{columns, rows, nil}, nil
+		case resultKindKeyspace:
+			keyspace := f.readString()
+			return resultKeyspaceFrame{keyspace}, nil
+		case resultKindPrepared:
+			id := f.readShortBytes()
+			values := f.readMetaData()
+			return resultPreparedFrame{id, values}, nil
+		case resultKindSchemaChanged:
+			return resultVoidFrame{}, nil
+		default:
+			return nil, ErrProtocol
+		}
+	case opError:
+		code := f.readInt()
+		msg := f.readString()
+		return errorFrame{code, msg}, nil
+	default:
+		return nil, ErrProtocol
+	}
+}
+
 type queryInfo struct {
 	id   []byte
 	args []ColumnInfo
@@ -418,3 +498,23 @@ type callResp struct {
 	buf frame
 	err error
 }
+
+type Compressor interface {
+	Name() string
+	Encode(data []byte) ([]byte, error)
+	Decode(data []byte) ([]byte, error)
+}
+
+type SnappyCompressor struct{}
+
+func (s SnappyCompressor) Name() string {
+	return "snappy"
+}
+
+func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
+	return snappy.Encode(nil, data)
+}
+
+func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
+	return snappy.Decode(nil, data)
+}

+ 1 - 55
frame.go

@@ -36,6 +36,7 @@ const (
 	resultKindSchemaChanged = 5
 
 	flagQueryValues uint8 = 1
+	flagCompress    uint8 = 1
 
 	headerSize = 8
 )
@@ -271,61 +272,6 @@ var consistencyCodes = []uint16{
 	LocalSerial: 0x0009,
 }
 
-func decodeFrame(f frame) (rval interface{}, err error) {
-	defer func() {
-		if r := recover(); r != nil {
-			if e, ok := r.(error); ok && e == ErrProtocol {
-				err = e
-				return
-			}
-			panic(r)
-		}
-	}()
-	if len(f) < headerSize || (f[0] != 1|flagResponse && f[0] != 2|flagResponse) {
-		return nil, ErrProtocol
-	}
-	switch f[3] {
-	case opReady:
-		return readyFrame{}, nil
-	case opResult:
-		f.skipHeader()
-		switch kind := f.readInt(); kind {
-		case resultKindVoid:
-			return resultVoidFrame{}, nil
-		case resultKindRows:
-			columns := f.readMetaData()
-			numRows := f.readInt()
-			values := make([][]byte, numRows*len(columns))
-			for i := 0; i < len(values); i++ {
-				values[i] = f.readBytes()
-			}
-			rows := make([][][]byte, numRows)
-			for i := 0; i < len(values); i += len(columns) {
-				rows[i] = values[i : i+len(columns)]
-			}
-			return resultRowsFrame{columns, rows, nil}, nil
-		case resultKindKeyspace:
-			keyspace := f.readString()
-			return resultKeyspaceFrame{keyspace}, nil
-		case resultKindPrepared:
-			id := f.readShortBytes()
-			values := f.readMetaData()
-			return resultPreparedFrame{id, values}, nil
-		case resultKindSchemaChanged:
-			return resultVoidFrame{}, nil
-		default:
-			return nil, ErrProtocol
-		}
-	case opError:
-		f.skipHeader()
-		code := f.readInt()
-		msg := f.readString()
-		return errorFrame{code, msg}, nil
-	default:
-		return nil, ErrProtocol
-	}
-}
-
 type readyFrame struct{}
 
 type resultVoidFrame struct{}

+ 1 - 0
gocql_test/main.go

@@ -18,6 +18,7 @@ var session *gocql.Session
 
 func init() {
 	cluster := gocql.NewCluster("127.0.0.1")
+	cluster.Compressor = gocql.SnappyCompressor{}
 	session = cluster.CreateSession()
 }