Explorar el Código

Add frame level support for protocol v3

Make operations on frame headers dependant on the protocol version
Chris Bannister hace 11 años
padre
commit
175294b952
Se han modificado 1 ficheros con 138 adiciones y 25 borrados
  1. 138 25
      frame.go

+ 138 - 25
frame.go

@@ -5,12 +5,16 @@
 package gocql
 package gocql
 
 
 import (
 import (
+	"fmt"
 	"net"
 	"net"
 )
 )
 
 
 const (
 const (
-	protoRequest  byte = 0x02
-	protoResponse byte = 0x82
+	protoDirectionMask = 0x80
+	protoVersionMask   = 0x7F
+	protoVersion1      = 0x01
+	protoVersion2      = 0x02
+	protoVersion3      = 0x03
 
 
 	opError         byte = 0x00
 	opError         byte = 0x00
 	opStartup       byte = 0x01
 	opStartup       byte = 0x01
@@ -42,13 +46,26 @@ const (
 	flagPageState   uint8 = 8
 	flagPageState   uint8 = 8
 	flagHasMore     uint8 = 2
 	flagHasMore     uint8 = 2
 
 
-	headerSize = 8
-
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 )
 )
 
 
+var headerProtoSize = [...]int{
+	protoVersion1: 8,
+	protoVersion2: 8,
+	protoVersion3: 9,
+}
+
+// TODO: replace with a struct which has a header and a body buffer,
+// header just has methods like, set/get the options in its backing array
+// then in a writeTo we write the header then the body.
 type frame []byte
 type frame []byte
 
 
+func newFrame(version uint8) frame {
+	// TODO: pool these at the session level incase anyone is using different
+	// clusters with different versions in the same application.
+	return make(frame, headerProtoSize[version], defaultFrameSize)
+}
+
 func (f *frame) writeInt(v int32) {
 func (f *frame) writeInt(v int32) {
 	p := f.grow(4)
 	p := f.grow(4)
 	(*f)[p] = byte(v >> 24)
 	(*f)[p] = byte(v >> 24)
@@ -129,22 +146,67 @@ func (f *frame) writeStringMultimap(v map[string][]string) {
 	}
 	}
 }
 }
 
 
-func (f *frame) setHeader(version, flags, stream, opcode uint8) {
+func (f *frame) setHeader(version, flags uint8, stream int, opcode uint8) {
 	(*f)[0] = version
 	(*f)[0] = version
 	(*f)[1] = flags
 	(*f)[1] = flags
-	(*f)[2] = stream
-	(*f)[3] = opcode
+	p := 2
+	if version&maskVersion > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+		p += 2
+	} else {
+		(*f)[2] = byte(stream & 0xFF)
+		p++
+	}
+
+	(*f)[p] = opcode
 }
 }
 
 
-func (f *frame) setLength(length int) {
-	(*f)[4] = byte(length >> 24)
-	(*f)[5] = byte(length >> 16)
-	(*f)[6] = byte(length >> 8)
-	(*f)[7] = byte(length)
+func (f *frame) setStream(stream int, version uint8) {
+	if version > protoVersion2 {
+		(*f)[2] = byte(stream >> 8)
+		(*f)[3] = byte(stream)
+	} else {
+		(*f)[2] = byte(stream)
+	}
 }
 }
 
 
-func (f *frame) Length() int {
-	return int((*f)[4])<<24 | int((*f)[5])<<16 | int((*f)[6])<<8 | int((*f)[7])
+func (f *frame) Stream(version uint8) (n int) {
+	if version > protoVersion2 {
+		n = int((*f)[2])<<8 | int((*f)[3])
+	} else {
+		n = int((*f)[2])
+	}
+	return
+}
+
+func (f *frame) setLength(length int, version uint8) {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	(*f)[p] = byte(length >> 24)
+	(*f)[p+1] = byte(length >> 16)
+	(*f)[p+2] = byte(length >> 8)
+	(*f)[p+3] = byte(length)
+}
+
+func (f *frame) Op(version uint8) byte {
+	if version > protoVersion2 {
+		return (*f)[4]
+	} else {
+		return (*f)[3]
+	}
+}
+
+func (f *frame) Length(version uint8) int {
+	p := 4
+	if version > protoVersion2 {
+		p = 5
+	}
+
+	return int((*f)[p])<<24 | int((*f)[p+1])<<16 | int((*f)[p+2])<<8 | int((*f)[p+3])
 }
 }
 
 
 func (f *frame) grow(n int) int {
 func (f *frame) grow(n int) int {
@@ -158,13 +220,13 @@ func (f *frame) grow(n int) int {
 	return p
 	return p
 }
 }
 
 
-func (f *frame) skipHeader() {
-	*f = (*f)[headerSize:]
+func (f *frame) skipHeader(version uint8) {
+	*f = (*f)[headerProtoSize[version]:]
 }
 }
 
 
 func (f *frame) readInt() int {
 func (f *frame) readInt() int {
 	if len(*f) < 4 {
 	if len(*f) < 4 {
-		panic(NewErrProtocol("Trying to read an int while >4 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read an int while <4 bytes in the buffer"))
 	}
 	}
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	*f = (*f)[4:]
 	*f = (*f)[4:]
@@ -173,7 +235,7 @@ func (f *frame) readInt() int {
 
 
 func (f *frame) readShort() uint16 {
 func (f *frame) readShort() uint16 {
 	if len(*f) < 2 {
 	if len(*f) < 2 {
-		panic(NewErrProtocol("Trying to read a short while >2 bytes in the buffer"))
+		panic(NewErrProtocol("Trying to read a short while <2 bytes in the buffer"))
 	}
 	}
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	*f = (*f)[2:]
 	*f = (*f)[2:]
@@ -251,16 +313,19 @@ func (f *frame) readTypeInfo() *TypeInfo {
 func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
 func (f *frame) readMetaData() ([]ColumnInfo, []byte) {
 	flags := f.readInt()
 	flags := f.readInt()
 	numColumns := f.readInt()
 	numColumns := f.readInt()
+
 	var pageState []byte
 	var pageState []byte
 	if flags&2 != 0 {
 	if flags&2 != 0 {
 		pageState = f.readBytes()
 		pageState = f.readBytes()
 	}
 	}
+
 	globalKeyspace := ""
 	globalKeyspace := ""
 	globalTable := ""
 	globalTable := ""
 	if flags&1 != 0 {
 	if flags&1 != 0 {
 		globalKeyspace = f.readString()
 		globalKeyspace = f.readString()
 		globalTable = f.readString()
 		globalTable = f.readString()
 	}
 	}
+
 	columns := make([]ColumnInfo, numColumns)
 	columns := make([]ColumnInfo, numColumns)
 	for i := 0; i < numColumns; i++ {
 	for i := 0; i < numColumns; i++ {
 		columns[i].Keyspace = globalKeyspace
 		columns[i].Keyspace = globalKeyspace
@@ -381,19 +446,32 @@ type startupFrame struct {
 	Compression string
 	Compression string
 }
 }
 
 
+func (op *startupFrame) String() string {
+	return fmt.Sprintf("[startup cqlversion=%q compression=%q]", op.CQLVersion, op.Compression)
+}
+
 func (op *startupFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *startupFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
+
 	f.setHeader(version, 0, 0, opStartup)
 	f.setHeader(version, 0, 0, opStartup)
-	f.writeShort(1)
+
+	// TODO: fix this, this is actually a StringMap
+	var size uint16 = 1
+	if op.Compression != "" {
+		size++
+	}
+
+	f.writeShort(size)
 	f.writeString("CQL_VERSION")
 	f.writeString("CQL_VERSION")
 	f.writeString(op.CQLVersion)
 	f.writeString(op.CQLVersion)
+
 	if op.Compression != "" {
 	if op.Compression != "" {
-		f[headerSize+1] += 1
 		f.writeString("COMPRESSION")
 		f.writeString("COMPRESSION")
 		f.writeString(op.Compression)
 		f.writeString(op.Compression)
 	}
 	}
+
 	return f, nil
 	return f, nil
 }
 }
 
 
@@ -406,14 +484,20 @@ type queryFrame struct {
 	PageState []byte
 	PageState []byte
 }
 }
 
 
+func (op *queryFrame) String() string {
+	return fmt.Sprintf("[query statement=%q prepared=%x cons=%v ...]", op.Stmt, op.Prepared, op.Cons)
+}
+
 func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if version == 1 && (op.PageSize != 0 || len(op.PageState) > 0 ||
 	if version == 1 && (op.PageSize != 0 || len(op.PageState) > 0 ||
 		(len(op.Values) > 0 && len(op.Prepared) == 0)) {
 		(len(op.Values) > 0 && len(op.Prepared) == 0)) {
 		return nil, ErrUnsupported
 		return nil, ErrUnsupported
 	}
 	}
+
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
+
 	if len(op.Prepared) > 0 {
 	if len(op.Prepared) > 0 {
 		f.setHeader(version, 0, 0, opExecute)
 		f.setHeader(version, 0, 0, opExecute)
 		f.writeShortBytes(op.Prepared)
 		f.writeShortBytes(op.Prepared)
@@ -421,10 +505,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		f.setHeader(version, 0, 0, opQuery)
 		f.setHeader(version, 0, 0, opQuery)
 		f.writeLongString(op.Stmt)
 		f.writeLongString(op.Stmt)
 	}
 	}
+
 	if version >= 2 {
 	if version >= 2 {
 		f.writeConsistency(op.Cons)
 		f.writeConsistency(op.Cons)
 		flagPos := len(f)
 		flagPos := len(f)
 		f.writeByte(0)
 		f.writeByte(0)
+
 		if len(op.Values) > 0 {
 		if len(op.Values) > 0 {
 			f[flagPos] |= flagQueryValues
 			f[flagPos] |= flagQueryValues
 			f.writeShort(uint16(len(op.Values)))
 			f.writeShort(uint16(len(op.Values)))
@@ -432,10 +518,12 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 				f.writeBytes(value)
 				f.writeBytes(value)
 			}
 			}
 		}
 		}
+
 		if op.PageSize > 0 {
 		if op.PageSize > 0 {
 			f[flagPos] |= flagPageSize
 			f[flagPos] |= flagPageSize
 			f.writeInt(int32(op.PageSize))
 			f.writeInt(int32(op.PageSize))
 		}
 		}
+
 		if len(op.PageState) > 0 {
 		if len(op.PageState) > 0 {
 			f[flagPos] |= flagPageState
 			f[flagPos] |= flagPageState
 			f.writeBytes(op.PageState)
 			f.writeBytes(op.PageState)
@@ -449,6 +537,7 @@ func (op *queryFrame) encodeFrame(version uint8, f frame) (frame, error) {
 		}
 		}
 		f.writeConsistency(op.Cons)
 		f.writeConsistency(op.Cons)
 	}
 	}
+
 	return f, nil
 	return f, nil
 }
 }
 
 
@@ -456,9 +545,13 @@ type prepareFrame struct {
 	Stmt string
 	Stmt string
 }
 }
 
 
+func (op *prepareFrame) String() string {
+	return fmt.Sprintf("[prepare statement=%q]", op.Stmt)
+}
+
 func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opPrepare)
 	f.setHeader(version, 0, 0, opPrepare)
 	f.writeLongString(op.Stmt)
 	f.writeLongString(op.Stmt)
@@ -467,9 +560,13 @@ func (op *prepareFrame) encodeFrame(version uint8, f frame) (frame, error) {
 
 
 type optionsFrame struct{}
 type optionsFrame struct{}
 
 
+func (op *optionsFrame) String() string {
+	return "[options]"
+}
+
 func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opOptions)
 	f.setHeader(version, 0, 0, opOptions)
 	return f, nil
 	return f, nil
@@ -479,13 +576,21 @@ type authenticateFrame struct {
 	Authenticator string
 	Authenticator string
 }
 }
 
 
+func (op *authenticateFrame) String() string {
+	return fmt.Sprintf("[authenticate authenticator=%q]", op.Authenticator)
+}
+
 type authResponseFrame struct {
 type authResponseFrame struct {
 	Data []byte
 	Data []byte
 }
 }
 
 
+func (op *authResponseFrame) String() string {
+	return fmt.Sprintf("[auth_response data=%q]", op.Data)
+}
+
 func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
 func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	if f == nil {
 	if f == nil {
-		f = make(frame, headerSize, defaultFrameSize)
+		f = newFrame(version)
 	}
 	}
 	f.setHeader(version, 0, 0, opAuthResponse)
 	f.setHeader(version, 0, 0, opAuthResponse)
 	f.writeBytes(op.Data)
 	f.writeBytes(op.Data)
@@ -496,6 +601,14 @@ type authSuccessFrame struct {
 	Data []byte
 	Data []byte
 }
 }
 
 
+func (op *authSuccessFrame) String() string {
+	return fmt.Sprintf("[auth_success data=%q]", op.Data)
+}
+
 type authChallengeFrame struct {
 type authChallengeFrame struct {
 	Data []byte
 	Data []byte
 }
 }
+
+func (op *authChallengeFrame) String() string {
+	return fmt.Sprintf("[auth_challenge data=%q]", op.Data)
+}