浏览代码

fixed cn.c.Read vs. ReadFull bug; added framing checks to guard against bad frames

References titanous#1, titanous#2
Closes #8
Thorsten von Eicken 13 年之前
父节点
当前提交
b13b599472
共有 1 个文件被更改,包括 54 次插入3 次删除
  1. 54 3
      gocql.go

+ 54 - 3
gocql.go

@@ -56,6 +56,7 @@ const (
 	opResult       byte = 0x08
 	opPrepare      byte = 0x09
 	opExecute      byte = 0x0A
+	opLAST         byte = 0x0A // not a real opcode -- used to check for valid opcodes
 
 	flagCompressed byte = 0x01
 
@@ -172,7 +173,17 @@ func Open(name string) (*connection, error) {
 	return cn, nil
 }
 
+// close a connection actively, typically used when there's an error and we want to ensure
+// we don't repeatedly try to use the broken connection
+func (cn *connection) close() {
+	cn.c.Close()
+	cn.c = nil // ensure we generate ErrBadConn when cn gets reused
+}
+
 func (cn *connection) send(opcode byte, body []byte) error {
+	if cn.c == nil {
+		return driver.ErrBadConn
+	}
 	frame := make([]byte, len(body)+8)
 	frame[0] = protoRequest
 	frame[1] = 0
@@ -187,16 +198,42 @@ func (cn *connection) send(opcode byte, body []byte) error {
 }
 
 func (cn *connection) recv() (byte, []byte, error) {
+	if cn.c == nil {
+		return 0, nil, driver.ErrBadConn
+	}
 	header := make([]byte, 8)
-	if _, err := cn.c.Read(header); err != nil {
+	if _, err := io.ReadFull(cn.c, header); err != nil {
+		cn.close() // better assume that the connection is broken (may have read some bytes)
 		return 0, nil, err
 	}
+	// verify that the frame starts with version==1 and req/resp flag==response
+	// this may be overly conservative in that future versions may be backwards compatible
+	// in that case simply amend the check...
+	if header[0] != protoResponse {
+		cn.close()
+		return 0, nil, fmt.Errorf("unsupported frame version or not a response: 0x%x (header=%v)", header[0], header)
+	}
+	// verify that the flags field has only a single flag set, again, this may
+	// be overly conservative if additional flags are backwards-compatible
+	if header[1] > 1 {
+		cn.close()
+		return 0, nil, fmt.Errorf("unsupported frame flags: 0x%x (header=%v)", header[1], header)
+	}
 	opcode := header[3]
+	if opcode > opLAST {
+		cn.close()
+		return 0, nil, fmt.Errorf("unknown opcode: 0x%x (header=%v)", opcode, header)
+	}
 	length := binary.BigEndian.Uint32(header[4:8])
 	var body []byte
 	if length > 0 {
+		if length > 256*1024*1024 { // spec says 256MB is max
+			cn.close()
+			return 0, nil, fmt.Errorf("frame too large: %d (header=%v)", length, header)
+		}
 		body = make([]byte, length)
-		if _, err := cn.c.Read(body); err != nil {
+		if _, err := io.ReadFull(cn.c, body); err != nil {
+			cn.close() // better assume that the connection is broken
 			return 0, nil, err
 		}
 	}
@@ -204,6 +241,7 @@ func (cn *connection) recv() (byte, []byte, error) {
 		var err error
 		body, err = snappy.Decode(nil, body)
 		if err != nil {
+			cn.close()
 			return 0, nil, err
 		}
 	}
@@ -217,18 +255,31 @@ func (cn *connection) recv() (byte, []byte, error) {
 }
 
 func (cn *connection) Begin() (driver.Tx, error) {
+	if cn.c == nil {
+		return nil, driver.ErrBadConn
+	}
 	return cn, nil
 }
 
 func (cn *connection) Commit() error {
+	if cn.c == nil {
+		return driver.ErrBadConn
+	}
 	return nil
 }
 
 func (cn *connection) Close() error {
-	return cn.c.Close()
+	if cn.c == nil {
+		return driver.ErrBadConn
+	}
+	cn.close()
+	return nil
 }
 
 func (cn *connection) Rollback() error {
+	if cn.c == nil {
+		return driver.ErrBadConn
+	}
 	return nil
 }