|
|
@@ -56,6 +56,7 @@ const (
|
|
|
opResult byte = 0x08
|
|
|
opPrepare byte = 0x09
|
|
|
opExecute byte = 0x0A
|
|
|
+ opLAST byte = 0x0A // not a real opcode -- used to check for valid opcodes
|
|
|
|
|
|
flagCompressed byte = 0x01
|
|
|
|
|
|
@@ -172,7 +173,17 @@ func Open(name string) (*connection, error) {
|
|
|
return cn, nil
|
|
|
}
|
|
|
|
|
|
+// close a connection actively, typically used when there's an error and we want to ensure
|
|
|
+// we don't repeatedly try to use the broken connection
|
|
|
+func (cn *connection) close() {
|
|
|
+ cn.c.Close()
|
|
|
+ cn.c = nil // ensure we generate ErrBadConn when cn gets reused
|
|
|
+}
|
|
|
+
|
|
|
func (cn *connection) send(opcode byte, body []byte) error {
|
|
|
+ if cn.c == nil {
|
|
|
+ return driver.ErrBadConn
|
|
|
+ }
|
|
|
frame := make([]byte, len(body)+8)
|
|
|
frame[0] = protoRequest
|
|
|
frame[1] = 0
|
|
|
@@ -187,16 +198,42 @@ func (cn *connection) send(opcode byte, body []byte) error {
|
|
|
}
|
|
|
|
|
|
func (cn *connection) recv() (byte, []byte, error) {
|
|
|
+ if cn.c == nil {
|
|
|
+ return 0, nil, driver.ErrBadConn
|
|
|
+ }
|
|
|
header := make([]byte, 8)
|
|
|
- if _, err := cn.c.Read(header); err != nil {
|
|
|
+ if _, err := io.ReadFull(cn.c, header); err != nil {
|
|
|
+ cn.close() // better assume that the connection is broken (may have read some bytes)
|
|
|
return 0, nil, err
|
|
|
}
|
|
|
+ // verify that the frame starts with version==1 and req/resp flag==response
|
|
|
+ // this may be overly conservative in that future versions may be backwards compatible
|
|
|
+ // in that case simply amend the check...
|
|
|
+ if header[0] != protoResponse {
|
|
|
+ cn.close()
|
|
|
+ return 0, nil, fmt.Errorf("unsupported frame version or not a response: 0x%x (header=%v)", header[0], header)
|
|
|
+ }
|
|
|
+ // verify that the flags field has only a single flag set, again, this may
|
|
|
+ // be overly conservative if additional flags are backwards-compatible
|
|
|
+ if header[1] > 1 {
|
|
|
+ cn.close()
|
|
|
+ return 0, nil, fmt.Errorf("unsupported frame flags: 0x%x (header=%v)", header[1], header)
|
|
|
+ }
|
|
|
opcode := header[3]
|
|
|
+ if opcode > opLAST {
|
|
|
+ cn.close()
|
|
|
+ return 0, nil, fmt.Errorf("unknown opcode: 0x%x (header=%v)", opcode, header)
|
|
|
+ }
|
|
|
length := binary.BigEndian.Uint32(header[4:8])
|
|
|
var body []byte
|
|
|
if length > 0 {
|
|
|
+ if length > 256*1024*1024 { // spec says 256MB is max
|
|
|
+ cn.close()
|
|
|
+ return 0, nil, fmt.Errorf("frame too large: %d (header=%v)", length, header)
|
|
|
+ }
|
|
|
body = make([]byte, length)
|
|
|
- if _, err := cn.c.Read(body); err != nil {
|
|
|
+ if _, err := io.ReadFull(cn.c, body); err != nil {
|
|
|
+ cn.close() // better assume that the connection is broken
|
|
|
return 0, nil, err
|
|
|
}
|
|
|
}
|
|
|
@@ -204,6 +241,7 @@ func (cn *connection) recv() (byte, []byte, error) {
|
|
|
var err error
|
|
|
body, err = snappy.Decode(nil, body)
|
|
|
if err != nil {
|
|
|
+ cn.close()
|
|
|
return 0, nil, err
|
|
|
}
|
|
|
}
|
|
|
@@ -217,18 +255,31 @@ func (cn *connection) recv() (byte, []byte, error) {
|
|
|
}
|
|
|
|
|
|
func (cn *connection) Begin() (driver.Tx, error) {
|
|
|
+ if cn.c == nil {
|
|
|
+ return nil, driver.ErrBadConn
|
|
|
+ }
|
|
|
return cn, nil
|
|
|
}
|
|
|
|
|
|
func (cn *connection) Commit() error {
|
|
|
+ if cn.c == nil {
|
|
|
+ return driver.ErrBadConn
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func (cn *connection) Close() error {
|
|
|
- return cn.c.Close()
|
|
|
+ if cn.c == nil {
|
|
|
+ return driver.ErrBadConn
|
|
|
+ }
|
|
|
+ cn.close()
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (cn *connection) Rollback() error {
|
|
|
+ if cn.c == nil {
|
|
|
+ return driver.ErrBadConn
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|