Browse Source

Add frame level support for protocol v3

Make operations on frame headers dependant on the protocol version
Chris Bannister 10 years ago
parent
commit
175294b952
1 changed files with 138 additions and 25 deletions
  1. 138 25
      frame.go

+ 138 - 25
frame.go

@@ -5,12 +5,16 @@
 package gocql
 
 import (
+	"fmt"
 	"net"
 )
 
 const (
-	protoRequest  byte = 0x02
-	protoResponse byte = 0x82
+	protoDirectionMask = 0x80
+	protoVersionMask   = 0x7F
+	protoVersion1      = 0x01
+	protoVersion2      = 0x02
+	protoVersion3      = 0x03
 
 	opError         byte = 0x00
 	opStartup       byte = 0x01
@@ -42,13 +46,26 @@ const (
 	flagPageState   uint8 = 8
 	flagHasMore     uint8 = 2
 
-	headerSize = 8
-
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 )
 
+var headerProtoSize = [...]int{
+	protoVersion1: 8,
+	protoVersion2: 8,
+	protoVersion3: 9,
+}
+
+// TODO: replace with a struct which has a header and a body buffer,
+// header just has methods like, set/get the options in its backing array
+// then in a writeTo we write the header then the body.
 type frame []byte
 
+func newFrame(version uint8) frame {
+	// TODO: pool these at the session level incase anyone is using different
+	// clusters with different versions in the same application.
+	return make(frame, headerProtoSize[version], defaultFrameSize)
+}
+
 func (f *frame) writeInt(v int32) {
 	p := f.grow(4)
 	(*f)[p] = byte(v >> 24)
@@ -129,22 +146,67 @@ func (f *frame) writeStringMultimap(v map[string][]string) {
 	}
 }
 
-func (f *frame) setHeader(version, flags, stream, opcode uint8) {
+func (f *frame) setHeader(version, flags uint8, stream int, opcode uint8) {
 	(*f)[0] = version
 	(*f)[1] = flags
-	(*f)[2] = stream
-	(*f)[3] = opcode
+	p := 2
+	if version&maskVersion > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+		p += 2
+	} else {
+		(*f)[2] = byte(stream & 0xFF)
+		p++
+	}
+
+	(*f)[p] = opcode
 }
 
-func (f *frame) setLength(length int) {
-	(*f)[4] = byte(length >> 24)
-	(*f)[5] = byte(length >> 16)
-	(*f)[6] = byte(length >> 8)
-	(*f)[7] = byte(length)
+func (f *frame) setStream(stream int, version uint8) {
+	if version > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+	} else {
+		(*f)[2] = byte(stream)
+	}
 }
 
-func (f *frame) Length() int {
-	return int((*f)[4])<<24 | int((*f)[5])<<16 | int((*f)[6])<<8 | int((*f)[7])
+func (f *frame) Stream(version uint8) (n int) {
+	if version > protoVersion2 {
+		n = int((*f)[2])<<8 | int((*f)[3])
+	} else {
+		n = int((*f)[2])
+	}
+	return
+}
+
+func (f *frame) setLength(length int, version uint8) {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	(*f)[p] = byte(length >> 24)
+	(*f)[p+1] = byte(length >> 16)
+	(*f)[p+2] = byte(length >> 8)
+	(*f)[p+3] = byte(length)
+}
+
+func (f *frame) Op(version uint8) byte {
+	if version > protoVersion2 {
+		return (*f)[4]
+	} else {
+		return (*f)[3]
+	}
+}
+
+func (f *frame) Length(version uint8) int {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	return int((*f)[p])<<24 | int((*f)[p+1])<<16 | int((*f)[p+2])<<8 | int((*f)[p+3])
 }
 
 func (f *frame) grow(n int) int {
@@ -158,13 +220,13 @@ func (f *frame) grow(n int) int {
 	return p
 }
 
-func (f *frame) skipHeader() {
-	*f = (*f)[headerSize:]
+func (f *frame) skipHeader(version uint8) {
+	*f = (*f)[headerProtoSize[version]:]
 }
 
 func (f *frame) readInt() int {
 	if len(*f) < 4 {
-		panic(NewErrProtocol("Trying to read an int while >4 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read an int while <4 bytes in the buffer"))
 	}
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	*f = (*f)[4:]
@@ -173,7 +235,7 @@ func (f *frame) readInt() int {
 
 func (f *frame) readShort() uint16 {
 	if len(*f) < 2 {
-		panic(NewErrProtocol("Trying to read a short while >2 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read a short while <2 bytes in the buffer"))
 	}
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	*f = (*f)[2:]
@@ -251,16 +313,19 @@ func (f *frame) readTypeInfo() *TypeInfo {
 func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
 	flags := f.readInt()
 	numColumns := f.readInt()
+
 	var pageState []byte
 	if flags&2 != 0 {
 		pageState = f.readBytes()
 	}
+
 	globalKeyspace := ""
 	globalTable := ""
 	if flags&1 != 0 {
 		globalKeyspace = f.readString()
 		globalTable = f.readString()
 	}
+
 	columns := make([]ColumnInfo, numColumns)
 	for i := 0; i < numColumns; i++ {
 		columns[i].Keyspace = globalKeyspace
@@ -381,19 +446,32 @@ type startupFrame struct {
 	Compression string
 }
 
+func (op *startupFrame) String() string {
+	return fmt.Sprintf("[startup cqlversion=%q compression=%q]", op.CQLVersion, op.Compression)
+}
+
 func (op *startupFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
+
 	f.setHeader(version, 0, 0, opStartup)
-	f.writeShort(1)
+
+	// TODO: fix this, this is actually a StringMap
+	var size uint16 = 1
+	if op.Compression != "" {
+		size++
+	}
+
+	f.writeShort(size)
 	f.writeString("CQL_VERSION")
 	f.writeString(op.CQLVersion)
+
 	if op.Compression != "" {
-		f[headerSize+1] += 1
 		f.writeString("COMPRESSION")
 		f.writeString(op.Compression)
 	}
+
 	return f, nil
 }
 
@@ -406,14 +484,20 @@ type queryFrame struct {
 	PageState []byte
 }
 
+func (op *queryFrame) String() string {
+	return fmt.Sprintf("[query statement=%q prepared=%x cons=%v ...]", op.Stmt, op.Prepared, op.Cons)
+}
+
 func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if version == 1 && (op.PageSize != 0 || len(op.PageState) > 0 ||
 		(len(op.Values) > 0 && len(op.Prepared) == 0)) {
 		return nil, ErrUnsupported
 	}
+
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
+
 	if len(op.Prepared) > 0 {
 		f.setHeader(version, 0, 0, opExecute)
 		f.writeShortBytes(op.Prepared)
@@ -421,10 +505,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		f.setHeader(version, 0, 0, opQuery)
 		f.writeLongString(op.Stmt)
 	}
+
 	if version >= 2 {
 		f.writeConsistency(op.Cons)
 		flagPos := len(f)
 		f.writeByte(0)
+
 		if len(op.Values) > 0 {
 			f[flagPos] |= flagQueryValues
 			f.writeShort(uint16(len(op.Values)))
@@ -432,10 +518,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 				f.writeBytes(value)
 			}
 		}
+
 		if op.PageSize > 0 {
 			f[flagPos] |= flagPageSize
 			f.writeInt(int32(op.PageSize))
 		}
+
 		if len(op.PageState) > 0 {
 			f[flagPos] |= flagPageState
 			f.writeBytes(op.PageState)
@@ -449,6 +537,7 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		}
 		f.writeConsistency(op.Cons)
 	}
+
 	return f, nil
 }
 
@@ -456,9 +545,13 @@ type prepareFrame struct {
 	Stmt string
 }
 
+func (op *prepareFrame) String() string {
+	return fmt.Sprintf("[prepare statement=%q]", op.Stmt)
+}
+
 func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	f.setHeader(version, 0, 0, opPrepare)
 	f.writeLongString(op.Stmt)
@@ -467,9 +560,13 @@ func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 
 type optionsFrame struct{}
 
+func (op *optionsFrame) String() string {
+	return "[options]"
+}
+
 func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	f.setHeader(version, 0, 0, opOptions)
 	return f, nil
@@ -479,13 +576,21 @@ type authenticateFrame struct {
 	Authenticator string
 }
 
+func (op *authenticateFrame) String() string {
+	return fmt.Sprintf("[authenticate authenticator=%q]", op.Authenticator)
+}
+
 type authResponseFrame struct {
 	Data []byte
 }
 
+func (op *authResponseFrame) String() string {
+	return fmt.Sprintf("[auth_response data=%q]", op.Data)
+}
+
 func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	f.setHeader(version, 0, 0, opAuthResponse)
 	f.writeBytes(op.Data)
@@ -496,6 +601,14 @@ type authSuccessFrame struct {
 	Data []byte
 }
 
+func (op *authSuccessFrame) String() string {
+	return fmt.Sprintf("[auth_success data=%q]", op.Data)
+}
+
 type authChallengeFrame struct {
 	Data []byte
 }
+
+func (op *authChallengeFrame) String() string {
+	return fmt.Sprintf("[auth_challenge data=%q]", op.Data)
+}