Browse Source

be explicit about reading and writing shorts and ints

Chris Bannister 10 years ago
parent
commit
e1b6bdfc84
1 changed files with 19 additions and 22 deletions
  1. 19 22
      frame.go

+ 19 - 22
frame.go

@@ -183,7 +183,7 @@ const (
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 )
 
-func writeInt(p []byte, n int) {
+func writeInt(p []byte, n int32) {
 	p[0] = byte(n >> 24)
 	p[1] = byte(n >> 16)
 	p[2] = byte(n >> 8)
@@ -199,8 +199,8 @@ func writeShort(p []byte, n uint16) {
 	p[1] = byte(n)
 }
 
-func readShort(p []byte) int {
-	return int(p[0])<<8 | int(p[1])
+func readShort(p []byte) uint16 {
+	return uint16(p[0])<<8 | uint16(p[1])
 }
 
 type frameHeader struct {
@@ -291,7 +291,7 @@ func readHeader(r io.Reader, p []byte) (frameHeader, error) {
 
 	head.flags = p[1]
 	if version > protoVersion2 {
-		head.stream = readShort(p[2:])
+		head.stream = int(readShort(p[2:]))
 		head.op = frameOp(p[4])
 		head.length = int(readInt(p[5:]))
 	} else {
@@ -923,7 +923,7 @@ func (f *framer) writeQueryParams(opts *queryParams) {
 	f.writeByte(flags)
 
 	if n := len(opts.values); n > 0 {
-		f.writeShort(n)
+		f.writeShort(uint16(n))
 		for i := 0; i < n; i++ {
 			if names {
 				f.writeString(opts.values[i].name)
@@ -933,7 +933,7 @@ func (f *framer) writeQueryParams(opts *queryParams) {
 	}
 
 	if opts.pageSize > 0 {
-		f.writeInt(opts.pageSize)
+		f.writeInt(int32(opts.pageSize))
 	}
 
 	if len(opts.pagingState) > 0 {
@@ -998,7 +998,7 @@ func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *quer
 		f.writeQueryParams(params)
 	} else {
 		n := len(params.values)
-		f.writeShort(n)
+		f.writeShort(uint16(n))
 		for i := 0; i < n; i++ {
 			f.writeBytes(params.values[i].value)
 		}
@@ -1033,7 +1033,7 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame) error {
 	f.writeByte(byte(w.typ))
 
 	n := len(w.statements)
-	f.writeShort(n)
+	f.writeShort(uint16(n))
 
 	var flags byte
 
@@ -1047,7 +1047,7 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame) error {
 
 		f.writeByte(1)
 		f.writeShortBytes(b.preparedID)
-		f.writeShort(len(b.values))
+		f.writeShort(uint16(len(b.values)))
 		for j := range b.values {
 			col := &b.values[j]
 			if f.proto > protoVersion2 && col.name != "" {
@@ -1212,7 +1212,7 @@ func (f *framer) writeByte(b byte) {
 }
 
 // these are protocol level binary types
-func (f *framer) writeInt(n int) {
+func (f *framer) writeInt(n int32) {
 	f.buf = append(f.buf,
 		byte(n>>24),
 		byte(n>>16),
@@ -1221,7 +1221,7 @@ func (f *framer) writeInt(n int) {
 	)
 }
 
-func (f *framer) writeShort(n int) {
+func (f *framer) writeShort(n uint16) {
 	f.buf = append(f.buf,
 		byte(n>>8),
 		byte(n),
@@ -1242,12 +1242,12 @@ func (f *framer) writeLong(n int64) {
 }
 
 func (f *framer) writeString(s string) {
-	f.writeShort(len(s))
+	f.writeShort(uint16(len(s)))
 	f.buf = append(f.buf, s...)
 }
 
 func (f *framer) writeLongString(s string) {
-	f.writeInt(len(s))
+	f.writeInt(int32(len(s)))
 	f.buf = append(f.buf, s...)
 }
 
@@ -1256,7 +1256,7 @@ func (f *framer) writeUUID(u *UUID) {
 }
 
 func (f *framer) writeStringList(l []string) {
-	f.writeShort(len(l))
+	f.writeShort(uint16(len(l)))
 	for _, s := range l {
 		f.writeString(s)
 	}
@@ -1269,19 +1269,16 @@ func (f *framer) writeBytes(p []byte) {
 	if p == nil {
 		f.writeInt(-1)
 	} else {
-		f.writeInt(len(p))
+		f.writeInt(int32(len(p)))
 		f.buf = append(f.buf, p...)
 	}
 }
 
 func (f *framer) writeShortBytes(p []byte) {
-	f.writeShort(len(p))
+	f.writeShort(uint16(len(p)))
 	f.buf = append(f.buf, p...)
 }
 
-// TODO: add writeOption, though no frame actually writes an option so probably
-// just need a read
-
 func (f *framer) writeInet(ip net.IP, port int) {
 	f.buf = append(f.buf,
 		byte(len(ip)),
@@ -1291,15 +1288,15 @@ func (f *framer) writeInet(ip net.IP, port int) {
 		[]byte(ip)...,
 	)
 
-	f.writeInt(port)
+	f.writeInt(int32(port))
 }
 
 func (f *framer) writeConsistency(cons Consistency) {
-	f.writeShort(int(cons))
+	f.writeShort(uint16(cons))
 }
 
 func (f *framer) writeStringMap(m map[string]string) {
-	f.writeShort(len(m))
+	f.writeShort(uint16(len(m)))
 	for k, v := range m {
 		f.writeString(k)
 		f.writeString(v)