Browse Source

provide some details when encountering a protocol error

mihasya 11 years ago
parent
commit
818bd933bd
3 changed files with 26 additions and 19 deletions
  1. 14 12
      conn.go
  2. 6 6
      frame.go
  3. 6 1
      session.go

+ 14 - 12
conn.go

@@ -174,7 +174,7 @@ func (c *Conn) startup(cfg *ConnConfig) error {
 			}
 			}
 			return nil
 			return nil
 		default:
 		default:
-			return ErrProtocol
+			return NewErrProtocol("Unknown type of response to startup frame: %s", x)
 		}
 		}
 	}
 	}
 }
 }
@@ -235,7 +235,7 @@ func (c *Conn) recv() (frame, error) {
 		}
 		}
 		if n == headerSize && len(resp) == headerSize {
 		if n == headerSize && len(resp) == headerSize {
 			if resp[0] != c.version|flagResponse {
 			if resp[0] != c.version|flagResponse {
-				return nil, ErrProtocol
+				return nil, NewErrProtocol("recv: Response protocol version does not match connection protocol version (%d != %d)", resp[0], c.version|flagResponse)
 			}
 			}
 			resp.grow(resp.Length())
 			resp.grow(resp.Length())
 		}
 		}
@@ -342,7 +342,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 		case error:
 		case error:
 			flight.err = x
 			flight.err = x
 		default:
 		default:
-			flight.err = ErrProtocol
+			flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
 		}
 		}
 	}
 	}
 
 
@@ -418,7 +418,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case error:
 	case error:
 		return &Iter{err: x}
 		return &Iter{err: x}
 	default:
 	default:
-		return &Iter{err: ErrProtocol}
+		return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
 	}
 	}
 }
 }
 
 
@@ -462,7 +462,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	case error:
 	case error:
 		return x
 		return x
 	default:
 	default:
-		return ErrProtocol
+		return NewErrProtocol("Unknown type in response to USE: %s", x)
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -511,22 +511,24 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	case error:
 	case error:
 		return x
 		return x
 	default:
 	default:
-		return ErrProtocol
+		return NewErrProtocol("Unknown type in response to batch statement: %s", x)
 	}
 	}
 }
 }
 
 
 func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error) {
 func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error) {
 	defer func() {
 	defer func() {
 		if r := recover(); r != nil {
 		if r := recover(); r != nil {
-			if e, ok := r.(error); ok && e == ErrProtocol {
+			if e, ok := r.(ErrProtocol); ok {
 				err = e
 				err = e
 				return
 				return
 			}
 			}
 			panic(r)
 			panic(r)
 		}
 		}
 	}()
 	}()
-	if len(f) < headerSize || (f[0] != c.version|flagResponse) {
-		return nil, ErrProtocol
+	if len(f) < headerSize {
+		return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
+	} else if f[0] != c.version|flagResponse {
+		return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
 	}
 	}
 	flags, op, f := f[1], f[3], f[headerSize:]
 	flags, op, f := f[1], f[3], f[headerSize:]
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
@@ -538,7 +540,7 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 	}
 	}
 	if flags&flagTrace != 0 {
 	if flags&flagTrace != 0 {
 		if len(f) < 16 {
 		if len(f) < 16 {
-			return nil, ErrProtocol
+			return nil, NewErrProtocol("Decoding frame: length of frame less than 16 while tracing is enabled")
 		}
 		}
 		traceId := []byte(f[:16])
 		traceId := []byte(f[:16])
 		f = f[16:]
 		f = f[16:]
@@ -574,7 +576,7 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 		case resultKindSchemaChanged:
 		case resultKindSchemaChanged:
 			return resultVoidFrame{}, nil
 			return resultVoidFrame{}, nil
 		default:
 		default:
-			return nil, ErrProtocol
+			return nil, NewErrProtocol("Decoding frame: unknown result kind %s", kind)
 		}
 		}
 	case opAuthenticate:
 	case opAuthenticate:
 		return authenticateFrame{f.readString()}, nil
 		return authenticateFrame{f.readString()}, nil
@@ -589,7 +591,7 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 		msg := f.readString()
 		msg := f.readString()
 		return errorFrame{code, msg}, nil
 		return errorFrame{code, msg}, nil
 	default:
 	default:
-		return nil, ErrProtocol
+		return nil, NewErrProtocol("Decoding frame: unknown op", op)
 	}
 	}
 }
 }
 
 

+ 6 - 6
frame.go

@@ -178,7 +178,7 @@ func (f *frame) skipHeader() {
 
 
 func (f *frame) readInt() int {
 func (f *frame) readInt() int {
 	if len(*f) < 4 {
 	if len(*f) < 4 {
-		panic(ErrProtocol)
+		panic(NewErrProtocol("Trying to read an int while >4 bytes in the buffer"))
 	}
 	}
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	v := uint32((*f)[0])<<24 | uint32((*f)[1])<<16 | uint32((*f)[2])<<8 | uint32((*f)[3])
 	*f = (*f)[4:]
 	*f = (*f)[4:]
@@ -187,7 +187,7 @@ func (f *frame) readInt() int {
 
 
 func (f *frame) readShort() uint16 {
 func (f *frame) readShort() uint16 {
 	if len(*f) < 2 {
 	if len(*f) < 2 {
-		panic(ErrProtocol)
+		panic(NewErrProtocol("Trying to read a short while >2 bytes in the buffer"))
 	}
 	}
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	v := uint16((*f)[0])<<8 | uint16((*f)[1])
 	*f = (*f)[2:]
 	*f = (*f)[2:]
@@ -197,7 +197,7 @@ func (f *frame) readShort() uint16 {
 func (f *frame) readString() string {
 func (f *frame) readString() string {
 	n := int(f.readShort())
 	n := int(f.readShort())
 	if len(*f) < n {
 	if len(*f) < n {
-		panic(ErrProtocol)
+		panic(NewErrProtocol("Trying to read a string of %d bytes from a buffer with %d bytes in it", n, len(*f)))
 	}
 	}
 	v := string((*f)[:n])
 	v := string((*f)[:n])
 	*f = (*f)[n:]
 	*f = (*f)[n:]
@@ -207,7 +207,7 @@ func (f *frame) readString() string {
 func (f *frame) readLongString() string {
 func (f *frame) readLongString() string {
 	n := f.readInt()
 	n := f.readInt()
 	if len(*f) < n {
 	if len(*f) < n {
-		panic(ErrProtocol)
+		panic(NewErrProtocol("Trying to read a string of %d bytes from a buffer with %d bytes in it", n, len(*f)))
 	}
 	}
 	v := string((*f)[:n])
 	v := string((*f)[:n])
 	*f = (*f)[n:]
 	*f = (*f)[n:]
@@ -220,7 +220,7 @@ func (f *frame) readBytes() []byte {
 		return nil
 		return nil
 	}
 	}
 	if len(*f) < n {
 	if len(*f) < n {
-		panic(ErrProtocol)
+		panic(NewErrProtocol("Trying to read %d bytes from a buffer with %d bytes in it", n, len(*f)))
 	}
 	}
 	v := (*f)[:n]
 	v := (*f)[:n]
 	*f = (*f)[n:]
 	*f = (*f)[n:]
@@ -230,7 +230,7 @@ func (f *frame) readBytes() []byte {
 func (f *frame) readShortBytes() []byte {
 func (f *frame) readShortBytes() []byte {
 	n := int(f.readShort())
 	n := int(f.readShort())
 	if len(*f) < n {
 	if len(*f) < n {
-		panic(ErrProtocol)
+		panic(NewErrProtocol("Trying to read %d bytes from a buffer with %d bytes in it", n, len(*f)))
 	}
 	}
 	v := (*f)[:n]
 	v := (*f)[:n]
 	*f = (*f)[n:]
 	*f = (*f)[n:]

+ 6 - 1
session.go

@@ -491,12 +491,17 @@ func (e Error) Error() string {
 var (
 var (
 	ErrNotFound     = errors.New("not found")
 	ErrNotFound     = errors.New("not found")
 	ErrUnavailable  = errors.New("unavailable")
 	ErrUnavailable  = errors.New("unavailable")
-	ErrProtocol     = errors.New("protocol error")
 	ErrUnsupported  = errors.New("feature not supported")
 	ErrUnsupported  = errors.New("feature not supported")
 	ErrTooManyStmts = errors.New("too many statements")
 	ErrTooManyStmts = errors.New("too many statements")
 	ErrUseStmt      = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.")
 	ErrUseStmt      = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explaination.")
 )
 )
 
 
+type ErrProtocol struct{ error }
+
+func NewErrProtocol(format string, args ...interface{}) error {
+	return ErrProtocol{fmt.Errorf(format, args...)}
+}
+
 // BatchSizeMaximum is the maximum number of statements a batch operation can have.
 // BatchSizeMaximum is the maximum number of statements a batch operation can have.
 // This limit is set by cassandra and could change in the future.
 // This limit is set by cassandra and could change in the future.
 const BatchSizeMaximum = 65535
 const BatchSizeMaximum = 65535