浏览代码

Allow to use protocol V5 (#1165)

* Allow to use protocol V5

* Protocol V5: add keyspace in query/execute/prepare message

* Protocol V5: add error map in read/write failures

* Protocol V5: duration type
Jaume Marhuenda 7 年之前
父节点
当前提交
78e324ac30
共有 10 个文件被更改,包括 357 次插入15 次删除
  1. 5 5
      .travis.yml
  2. 52 0
      cassandra_test.go
  3. 6 0
      conn.go
  4. 11 0
      cqltypes.go
  5. 4 0
      errors.go
  6. 91 9
      frame.go
  7. 4 0
      helpers.go
  8. 1 1
      integration.sh
  9. 117 0
      marshal.go
  10. 66 0
      marshal_test.go

+ 5 - 5
.travis.yml

@@ -19,13 +19,13 @@ env:
   global:
     - GOMAXPROCS=2
   matrix:
-    - CASS=2.1.12
-      AUTH=false
-    - CASS=2.2.5
+    - CASS=2.2.13
       AUTH=true
-    - CASS=2.2.5
+    - CASS=2.2.13
+      AUTH=false
+    - CASS=3.0.17
       AUTH=false
-    - CASS=3.0.8
+    - CASS=3.11.3
       AUTH=false
 
 go:

+ 52 - 0
cassandra_test.go

@@ -485,6 +485,58 @@ func TestCAS(t *testing.T) {
 	}
 }
 
+func TestDurationType(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if session.cfg.ProtoVersion < 5 {
+		t.Skip("Duration type is not supported. Please use protocol version >= 4 and cassandra version >= 3.11")
+	}
+
+	if err := createTable(session, `CREATE TABLE gocql_test.duration_table (
+		k int primary key, v duration
+	)`); err != nil {
+		t.Fatal("create:", err)
+	}
+
+	durations := []Duration{
+		Duration{
+			Months:      250,
+			Days:        500,
+			Nanoseconds: 300010001,
+		},
+		Duration{
+			Months:      -250,
+			Days:        -500,
+			Nanoseconds: -300010001,
+		},
+		Duration{
+			Months:      0,
+			Days:        128,
+			Nanoseconds: 127,
+		},
+		Duration{
+			Months:      0x7FFFFFFF,
+			Days:        0x7FFFFFFF,
+			Nanoseconds: 0x7FFFFFFFFFFFFFFF,
+		},
+	}
+	for _, durationSend := range durations {
+		if err := session.Query(`INSERT INTO gocql_test.duration_table (k, v) VALUES (1, ?)`, durationSend).Exec(); err != nil {
+			t.Fatal(err)
+		}
+
+		var id int
+		var duration Duration
+		if err := session.Query(`SELECT k, v FROM gocql_test.duration_table`).Scan(&id, &duration); err != nil {
+			t.Fatal(err)
+		}
+		if duration.Months != durationSend.Months || duration.Days != durationSend.Days || duration.Nanoseconds != durationSend.Nanoseconds {
+			t.Fatalf("Unexpeted value returned, expected=%v, received=%v", durationSend, duration)
+		}
+	}
+}
+
 func TestMapScanCAS(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()

+ 6 - 0
conn.go

@@ -721,6 +721,9 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer)
 	prep := &writePrepareFrame{
 		statement: stmt,
 	}
+	if c.version > protoVersion4 {
+		prep.keyspace = c.currentKeyspace
+	}
 
 	framer, err := c.exec(ctx, prep, tracer)
 	if err != nil {
@@ -805,6 +808,9 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	if qry.pageSize > 0 {
 		params.pageSize = qry.pageSize
 	}
+	if c.version > protoVersion4 {
+		params.keyspace = c.currentKeyspace
+	}
 
 	var (
 		frame frameWriter

+ 11 - 0
cqltypes.go

@@ -0,0 +1,11 @@
+// Copyright (c) 2012 The gocql Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gocql
+
+type Duration struct {
+	Months      int32
+	Days        int32
+	Nanoseconds int64
+}

+ 4 - 0
errors.go

@@ -64,6 +64,8 @@ func (e *RequestErrUnavailable) String() string {
 	return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive)
 }
 
+type ErrorMap map[string]uint16
+
 type RequestErrWriteTimeout struct {
 	errorFrame
 	Consistency Consistency
@@ -79,6 +81,7 @@ type RequestErrWriteFailure struct {
 	BlockFor    int
 	NumFailures int
 	WriteType   string
+	ErrorMap    ErrorMap
 }
 
 type RequestErrCDCWriteFailure struct {
@@ -111,6 +114,7 @@ type RequestErrReadFailure struct {
 	BlockFor    int
 	NumFailures int
 	DataPresent bool
+	ErrorMap    ErrorMap
 }
 
 type RequestErrFunctionFailure struct {

+ 91 - 9
frame.go

@@ -157,12 +157,17 @@ const (
 	flagWithSerialConsistency byte = 0x10
 	flagDefaultTimestamp      byte = 0x20
 	flagWithNameValues        byte = 0x40
+	flagWithKeyspace          byte = 0x80
+
+	// prepare flags
+	flagWithPreparedKeyspace uint32 = 0x01
 
 	// header flags
 	flagCompress      byte = 0x01
 	flagTracing       byte = 0x02
 	flagCustomPayload byte = 0x04
 	flagWarning       byte = 0x08
+	flagBetaProtocol  byte = 0x10
 )
 
 type Consistency uint16
@@ -404,6 +409,9 @@ func newFramer(r io.Reader, w io.Writer, compressor Compressor, version byte) *f
 	if compressor != nil {
 		flags |= flagCompress
 	}
+	if version == protoVersion5 {
+		flags |= flagBetaProtocol
+	}
 
 	version &= protoVersionMask
 
@@ -441,7 +449,7 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, err error) {
 
 	version := p[0] & protoVersionMask
 
-	if version < protoVersion1 || version > protoVersion4 {
+	if version < protoVersion1 || version > protoVersion5 {
 		return frameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version)
 	}
 
@@ -644,7 +652,14 @@ func (f *framer) parseErrorFrame() frame {
 		res.Consistency = f.readConsistency()
 		res.Received = f.readInt()
 		res.BlockFor = f.readInt()
+		if f.proto > protoVersion4 {
+			res.ErrorMap = f.readErrorMap()
+			res.NumFailures = len(res.ErrorMap)
+		} else {
+			res.NumFailures = f.readInt()
+		}
 		res.DataPresent = f.readByte() != 0
+
 		return res
 	case errWriteFailure:
 		res := &RequestErrWriteFailure{
@@ -653,7 +668,12 @@ func (f *framer) parseErrorFrame() frame {
 		res.Consistency = f.readConsistency()
 		res.Received = f.readInt()
 		res.BlockFor = f.readInt()
-		res.NumFailures = f.readInt()
+		if f.proto > protoVersion4 {
+			res.ErrorMap = f.readErrorMap()
+			res.NumFailures = len(res.ErrorMap)
+		} else {
+			res.NumFailures = f.readInt()
+		}
 		res.WriteType = f.readString()
 		return res
 	case errFunctionFailure:
@@ -680,6 +700,16 @@ func (f *framer) parseErrorFrame() frame {
 	}
 }
 
+func (f *framer) readErrorMap() (errMap ErrorMap) {
+	errMap = make(ErrorMap)
+	numErrs := f.readInt()
+	for i := 0; i < numErrs; i++ {
+		ip := f.readInetAdressOnly().String()
+		errMap[ip] = f.readShort()
+	}
+	return
+}
+
 func (f *framer) writeHeader(flags byte, op frameOp, stream int) {
 	f.wbuf = f.wbuf[:0]
 	f.wbuf = append(f.wbuf,
@@ -798,11 +828,28 @@ func (w *writeStartupFrame) writeFrame(f *framer, streamID int) error {
 
 type writePrepareFrame struct {
 	statement string
+	keyspace  string
 }
 
 func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error {
 	f.writeHeader(f.flags, opPrepare, streamID)
 	f.writeLongString(w.statement)
+
+	var flags uint32 = 0
+	if w.keyspace != "" {
+		if f.proto > protoVersion4 {
+			flags |= flagWithPreparedKeyspace
+		} else {
+			panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher"))
+		}
+	}
+	if f.proto > protoVersion4 {
+		f.writeUint(flags)
+	}
+	if w.keyspace != "" {
+		f.writeString(w.keyspace)
+	}
+
 	return f.finishWrite()
 }
 
@@ -1386,11 +1433,13 @@ type queryParams struct {
 	// v3+
 	defaultTimestamp      bool
 	defaultTimestampValue int64
+	// v5+
+	keyspace string
 }
 
 func (q queryParams) String() string {
-	return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v]",
-		q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values)
+	return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]",
+		q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace)
 }
 
 func (f *framer) writeQueryParams(opts *queryParams) {
@@ -1431,7 +1480,19 @@ func (f *framer) writeQueryParams(opts *queryParams) {
 		}
 	}
 
-	f.writeByte(flags)
+	if opts.keyspace != "" {
+		if f.proto > protoVersion4 {
+			flags |= flagWithKeyspace
+		} else {
+			panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher"))
+		}
+	}
+
+	if f.proto > protoVersion4 {
+		f.writeUint(uint32(flags))
+	} else {
+		f.writeByte(flags)
+	}
 
 	if n := len(opts.values); n > 0 {
 		f.writeShort(uint16(n))
@@ -1470,6 +1531,10 @@ func (f *framer) writeQueryParams(opts *queryParams) {
 		}
 		f.writeLong(ts)
 	}
+
+	if opts.keyspace != "" {
+		f.writeString(opts.keyspace)
+	}
 }
 
 type writeQueryFrame struct {
@@ -1609,7 +1674,11 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame) error {
 			flags |= flagDefaultTimestamp
 		}
 
-		f.writeByte(flags)
+		if f.proto > protoVersion4 {
+			f.writeUint(uint32(flags))
+		} else {
+			f.writeByte(flags)
+		}
 
 		if w.serialConsistency > 0 {
 			f.writeConsistency(Consistency(w.serialConsistency))
@@ -1777,7 +1846,7 @@ func (f *framer) readShortBytes() []byte {
 	return l
 }
 
-func (f *framer) readInet() (net.IP, int) {
+func (f *framer) readInetAdressOnly() net.IP {
 	if len(f.rbuf) < 1 {
 		panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.rbuf)))
 	}
@@ -1796,9 +1865,11 @@ func (f *framer) readInet() (net.IP, int) {
 	ip := make([]byte, size)
 	copy(ip, f.rbuf[:size])
 	f.rbuf = f.rbuf[size:]
+	return net.IP(ip)
+}
 
-	port := f.readInt()
-	return net.IP(ip), port
+func (f *framer) readInet() (net.IP, int) {
+	return f.readInetAdressOnly(), f.readInt()
 }
 
 func (f *framer) readConsistency() Consistency {
@@ -1871,6 +1942,13 @@ func appendInt(p []byte, n int32) []byte {
 		byte(n))
 }
 
+func appendUint(p []byte, n uint32) []byte {
+	return append(p, byte(n>>24),
+		byte(n>>16),
+		byte(n>>8),
+		byte(n))
+}
+
 func appendLong(p []byte, n int64) []byte {
 	return append(p,
 		byte(n>>56),
@@ -1889,6 +1967,10 @@ func (f *framer) writeInt(n int32) {
 	f.wbuf = appendInt(f.wbuf, n)
 }
 
+func (f *framer) writeUint(n uint32) {
+	f.wbuf = appendUint(f.wbuf, n)
+}
+
 func (f *framer) writeShort(n uint16) {
 	f.wbuf = appendShort(f.wbuf, n)
 }

+ 4 - 0
helpers.go

@@ -60,6 +60,8 @@ func goType(t TypeInfo) reflect.Type {
 		return reflect.TypeOf(make(map[string]interface{}))
 	case TypeDate:
 		return reflect.TypeOf(*new(time.Time))
+	case TypeDuration:
+		return reflect.TypeOf(*new(Duration))
 	default:
 		return nil
 	}
@@ -203,6 +205,8 @@ func getApacheCassandraType(class string) Type {
 		return TypeSet
 	case "TupleType":
 		return TypeTuple
+	case "DurationType":
+		return TypeDuration
 	default:
 		return TypeCustom
 	}

+ 1 - 1
integration.sh

@@ -51,7 +51,7 @@ function run_tests() {
 		proto=4
 		ccm updateconf 'enable_user_defined_functions: true'
 	elif [[ $version == 3.*.* ]]; then
-		proto=4
+		proto=5
 		ccm updateconf 'enable_user_defined_functions: true'
 	fi
 

+ 117 - 0
marshal.go

@@ -11,6 +11,7 @@ import (
 	"fmt"
 	"math"
 	"math/big"
+	"math/bits"
 	"net"
 	"reflect"
 	"strconv"
@@ -99,6 +100,8 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
 		return marshalUDT(info, value)
 	case TypeDate:
 		return marshalDate(info, value)
+	case TypeDuration:
+		return marshalDuration(info, value)
 	}
 
 	// detect protocol 2 UDT
@@ -161,6 +164,8 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
 		return unmarshalUDT(info, data, value)
 	case TypeDate:
 		return unmarshalDate(info, data, value)
+	case TypeDuration:
+		return unmarshalDuration(info, data, value)
 	}
 
 	// detect protocol 2 UDT
@@ -1211,6 +1216,115 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error {
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }
 
+func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case unsetColumn:
+		return nil, nil
+	case int64:
+		return encVints(0, 0, v), nil
+	case time.Duration:
+		return encVints(0, 0, v.Nanoseconds()), nil
+	case string:
+		d, err := time.ParseDuration(v)
+		if err != nil {
+			return nil, err
+		}
+		return encVints(0, 0, d.Nanoseconds()), nil
+	case Duration:
+		return encVints(v.Months, v.Days, v.Nanoseconds), nil
+	}
+
+	if value == nil {
+		return nil, nil
+	}
+
+	rv := reflect.ValueOf(value)
+	switch rv.Type().Kind() {
+	case reflect.Int64:
+		return encBigInt(rv.Int()), nil
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *Duration:
+		if len(data) == 0 {
+			*v = Duration{
+				Months:      0,
+				Days:        0,
+				Nanoseconds: 0,
+			}
+			return nil
+		}
+		months, days, nanos := decVints(data)
+		*v = Duration{
+			Months:      months,
+			Days:        days,
+			Nanoseconds: nanos,
+		}
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
+func decVints(data []byte) (int32, int32, int64) {
+	month, i := decVint(data)
+	days, j := decVint(data[i:])
+	nanos, _ := decVint(data[i+j:])
+	return int32(month), int32(days), nanos
+}
+
+func decVint(data []byte) (int64, int) {
+	firstByte := data[0]
+	if firstByte&0x80 == 0 {
+		return decIntZigZag(uint64(firstByte)), 1
+	}
+	numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24
+	ret := uint64(firstByte & (0xff >> uint(numBytes)))
+	for i := 0; i < numBytes; i++ {
+		ret <<= 8
+		ret |= uint64(data[i+1] & 0xff)
+	}
+	return decIntZigZag(ret), numBytes + 1
+}
+
+func decIntZigZag(n uint64) int64 {
+	return int64((n >> 1) ^ -(n & 1))
+}
+
+func encIntZigZag(n int64) uint64 {
+	return uint64((n >> 63) ^ (n << 1))
+}
+
+func encVints(months int32, seconds int32, nanos int64) []byte {
+	buf := append(encVint(int64(months)), encVint(int64(seconds))...)
+	return append(buf, encVint(nanos)...)
+}
+
+func encVint(v int64) []byte {
+	vEnc := encIntZigZag(v)
+	lead0 := bits.LeadingZeros64(vEnc)
+	numBytes := (639 - lead0*9) >> 6
+
+	// It can be 1 or 0 is v ==0
+	if numBytes <= 1 {
+		return []byte{byte(vEnc)}
+	}
+	extraBytes := numBytes - 1
+	var buf = make([]byte, numBytes)
+	for i := extraBytes; i >= 0; i-- {
+		buf[i] = byte(vEnc)
+		vEnc >>= 8
+	}
+	buf[0] |= byte(^(0xff >> uint(extraBytes)))
+	return buf
+}
+
 func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error {
 	if info.proto > protoVersion2 {
 		if n > math.MaxInt32 {
@@ -2131,6 +2245,7 @@ const (
 	TypeTime      Type = 0x0012
 	TypeSmallInt  Type = 0x0013
 	TypeTinyInt   Type = 0x0014
+	TypeDuration  Type = 0x0015
 	TypeList      Type = 0x0020
 	TypeMap       Type = 0x0021
 	TypeSet       Type = 0x0022
@@ -2175,6 +2290,8 @@ func (t Type) String() string {
 		return "inet"
 	case TypeDate:
 		return "date"
+	case TypeDuration:
+		return "duration"
 	case TypeTime:
 		return "time"
 	case TypeSmallInt:

+ 66 - 0
marshal_test.go

@@ -328,6 +328,27 @@ var marshalTests = []struct {
 		nil,
 		nil,
 	},
+	{
+		NativeType{proto: 5, typ: TypeDuration},
+		[]byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91\x06"),
+		Duration{Months: 1233, Days: 123213, Nanoseconds: 2312323},
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 5, typ: TypeDuration},
+		[]byte("\x89\xa1\xc3\xc2\x99\xe0F\x91\x05"),
+		Duration{Months: -1233, Days: -123213, Nanoseconds: -2312323},
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 5, typ: TypeDuration},
+		[]byte("\x02\x04\x80\xe6"),
+		Duration{Months: 1, Days: 2, Nanoseconds: 115},
+		nil,
+		nil,
+	},
 	{
 		CollectionType{
 			NativeType: NativeType{proto: 2, typ: TypeList},
@@ -1414,3 +1435,48 @@ func BenchmarkUnmarshalVarchar(b *testing.B) {
 		}
 	}
 }
+
+func TestMarshalDuration(t *testing.T) {
+	durationS := "1h10m10s"
+	duration, _ := time.ParseDuration(durationS)
+	expectedData := append([]byte{0, 0}, encVint(duration.Nanoseconds())...)
+	var marshalDurationTests = []struct {
+		Info  TypeInfo
+		Data  []byte
+		Value interface{}
+	}{
+		{
+			NativeType{proto: 5, typ: TypeDuration},
+			expectedData,
+			duration.Nanoseconds(),
+		},
+		{
+			NativeType{proto: 5, typ: TypeDuration},
+			expectedData,
+			duration,
+		},
+		{
+			NativeType{proto: 5, typ: TypeDuration},
+			expectedData,
+			durationS,
+		},
+		{
+			NativeType{proto: 5, typ: TypeDuration},
+			expectedData,
+			&duration,
+		},
+	}
+
+	for i, test := range marshalDurationTests {
+		t.Log(i, test)
+		data, err := Marshal(test.Info, test.Value)
+		if err != nil {
+			t.Errorf("marshalTest[%d]: %v", i, err)
+			continue
+		}
+		if !bytes.Equal(data, test.Data) {
+			t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i,
+				test.Data, decInt(test.Data), data, decInt(data), test.Value)
+		}
+	}
+}