|
|
@@ -438,13 +438,14 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
|
|
|
}
|
|
|
|
|
|
type TestServer struct {
|
|
|
- Address string
|
|
|
- t testing.TB
|
|
|
- nreq uint64
|
|
|
- listen net.Listener
|
|
|
- nKillReq int64
|
|
|
-
|
|
|
- protocol uint8
|
|
|
+ Address string
|
|
|
+ t testing.TB
|
|
|
+ nreq uint64
|
|
|
+ listen net.Listener
|
|
|
+ nKillReq int64
|
|
|
+ compressor Compressor
|
|
|
+
|
|
|
+ protocol byte
|
|
|
headerSize int
|
|
|
}
|
|
|
|
|
|
@@ -458,16 +459,19 @@ func (srv *TestServer) serve() {
|
|
|
go func(conn net.Conn) {
|
|
|
defer conn.Close()
|
|
|
for {
|
|
|
- frame, err := srv.readFrame(conn)
|
|
|
- if err == io.EOF {
|
|
|
- return
|
|
|
- } else if err != nil {
|
|
|
+ framer, err := srv.readFrame(conn)
|
|
|
+ if err != nil {
|
|
|
+ if err == io.EOF {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
srv.t.Error(err)
|
|
|
- continue
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
atomic.AddUint64(&srv.nreq, 1)
|
|
|
- go srv.process(frame, conn)
|
|
|
+
|
|
|
+ go srv.process(framer)
|
|
|
}
|
|
|
}(conn)
|
|
|
}
|
|
|
@@ -477,24 +481,21 @@ func (srv *TestServer) Stop() {
|
|
|
srv.listen.Close()
|
|
|
}
|
|
|
|
|
|
-func (srv *TestServer) process(f frame, conn net.Conn) {
|
|
|
- headerSize := headerProtoSize[srv.protocol]
|
|
|
- stream := f.Stream(srv.protocol)
|
|
|
+func (srv *TestServer) process(f *framer) {
|
|
|
+ head := f.header
|
|
|
+ if head == nil {
|
|
|
+ srv.t.Error("process frame with a nil header")
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- switch f.Op(srv.protocol) {
|
|
|
+ switch head.op {
|
|
|
case opStartup:
|
|
|
- f = f[:headerSize]
|
|
|
- f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opReady)
|
|
|
+ f.writeHeader(0, opReady, head.stream)
|
|
|
case opOptions:
|
|
|
- f = f[:headerSize]
|
|
|
- f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opSupported)
|
|
|
+ f.writeHeader(0, opSupported, head.stream)
|
|
|
f.writeShort(0)
|
|
|
case opQuery:
|
|
|
- input := f
|
|
|
- input.skipHeader(srv.protocol)
|
|
|
- query := strings.TrimSpace(input.readLongString())
|
|
|
- f = f[:headerSize]
|
|
|
- f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opResult)
|
|
|
+ query := f.readLongString()
|
|
|
first := query
|
|
|
if n := strings.Index(query, " "); n > 0 {
|
|
|
first = first[:n]
|
|
|
@@ -502,62 +503,63 @@ func (srv *TestServer) process(f frame, conn net.Conn) {
|
|
|
switch strings.ToLower(first) {
|
|
|
case "kill":
|
|
|
atomic.AddInt64(&srv.nKillReq, 1)
|
|
|
- f = f[:headerSize]
|
|
|
- f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
|
|
|
+ f.writeHeader(0, opError, head.stream)
|
|
|
f.writeInt(0x1001)
|
|
|
f.writeString("query killed")
|
|
|
case "slow":
|
|
|
go func() {
|
|
|
<-time.After(1 * time.Second)
|
|
|
+ f.writeHeader(0, opResult, head.stream)
|
|
|
+ f.buf[0] = srv.protocol | 0x80
|
|
|
f.writeInt(resultKindVoid)
|
|
|
- f.setLength(len(f)-headerSize, srv.protocol)
|
|
|
- if _, err := conn.Write(f); err != nil {
|
|
|
- return
|
|
|
+ if err := f.finishWrite(); err != nil {
|
|
|
+ srv.t.Error(err)
|
|
|
}
|
|
|
}()
|
|
|
+
|
|
|
return
|
|
|
case "use":
|
|
|
- f.writeInt(3)
|
|
|
+ f.writeInt(resultKindKeyspace)
|
|
|
f.writeString(strings.TrimSpace(query[3:]))
|
|
|
case "void":
|
|
|
+ f.writeHeader(0, opResult, head.stream)
|
|
|
f.writeInt(resultKindVoid)
|
|
|
default:
|
|
|
+ f.writeHeader(0, opResult, head.stream)
|
|
|
f.writeInt(resultKindVoid)
|
|
|
}
|
|
|
default:
|
|
|
- f = f[:headerSize]
|
|
|
- f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
|
|
|
+ f.writeHeader(0, opError, head.stream)
|
|
|
f.writeInt(0)
|
|
|
f.writeString("not supported")
|
|
|
}
|
|
|
|
|
|
- f.setLength(len(f)-headerSize, srv.protocol)
|
|
|
- if _, err := conn.Write(f); err != nil {
|
|
|
- srv.t.Log(err)
|
|
|
- return
|
|
|
+ f.buf[0] = srv.protocol | 0x80
|
|
|
+
|
|
|
+ if err := f.finishWrite(); err != nil {
|
|
|
+ srv.t.Error(err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (srv *TestServer) readFrame(conn net.Conn) (frame, error) {
|
|
|
- frame := make(frame, srv.headerSize, srv.headerSize+512)
|
|
|
- if _, err := io.ReadFull(conn, frame); err != nil {
|
|
|
+func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
|
|
|
+ buf := make([]byte, srv.headerSize)
|
|
|
+ head, err := readHeader(conn, buf)
|
|
|
+ if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
+ framer := newFramer(conn, conn, nil, srv.protocol)
|
|
|
|
|
|
- // should be a request frame
|
|
|
- if frame[0]&protoDirectionMask != 0 {
|
|
|
- return nil, fmt.Errorf("expected to read a request frame got version: 0x%x", frame[0])
|
|
|
- }
|
|
|
- if v := frame[0] & protoVersionMask; v != srv.protocol {
|
|
|
- return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, v)
|
|
|
+ err = framer.readFrame(&head)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- if n := frame.Length(srv.protocol); n > 0 {
|
|
|
- frame.grow(n)
|
|
|
- if _, err := io.ReadFull(conn, frame[srv.headerSize:]); err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ // should be a request frame
|
|
|
+ if head.version.response() {
|
|
|
+ return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version)
|
|
|
+ } else if head.version.version() != srv.protocol {
|
|
|
+ return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version())
|
|
|
}
|
|
|
|
|
|
- return frame, nil
|
|
|
+ return framer, nil
|
|
|
}
|