Pārlūkot izejas kodu

Update Conn to support different protocol headers

Update Conn so that it supports reading and writing frame headers
of varying sizes, add support for >v3 stream counts.

Simplify reading frames in recv, check that we get the expected
version back and direction.
Chris Bannister 10 gadi atpakaļ
vecāks
revīzija
0c6c1e0c2b
1 mainītis faili ar 89 papildinājumiem un 49 dzēšanām
  1. 89 49
      conn.go

+ 89 - 49
conn.go

@@ -10,7 +10,9 @@ import (
 	"crypto/x509"
 	"crypto/x509"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"io/ioutil"
+	"log"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -88,7 +90,7 @@ type Conn struct {
 	r       *bufio.Reader
 	r       *bufio.Reader
 	timeout time.Duration
 	timeout time.Duration
 
 
-	uniq  chan uint8
+	uniq  chan int
 	calls []callReq
 	calls []callReq
 	nwait int32
 	nwait int32
 
 
@@ -123,14 +125,17 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 				return nil, errors.New("Failed parsing or appending certs")
 				return nil, errors.New("Failed parsing or appending certs")
 			}
 			}
 		}
 		}
+
 		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
 		mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
+
 		config := tls.Config{
 		config := tls.Config{
 			Certificates: []tls.Certificate{mycert},
 			Certificates: []tls.Certificate{mycert},
 			RootCAs:      certPool,
 			RootCAs:      certPool,
 		}
 		}
+
 		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
 		config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
 		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
 		if conn, err = tls.Dial("tcp", addr, &config); err != nil {
 			return nil, err
 			return nil, err
@@ -139,16 +144,25 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	if cfg.NumStreams <= 0 || cfg.NumStreams > 128 {
-		cfg.NumStreams = 128
-	}
-	if cfg.ProtoVersion != 1 && cfg.ProtoVersion != 2 {
+	// going to default to proto 2
+	if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
+		log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
 		cfg.ProtoVersion = 2
 		cfg.ProtoVersion = 2
 	}
 	}
+
+	maxStreams := 128
+	if cfg.ProtoVersion > protoVersion2 {
+		maxStreams = 32768
+	}
+
+	if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
+		cfg.NumStreams = maxStreams
+	}
+
 	c := &Conn{
 	c := &Conn{
 		conn:       conn,
 		conn:       conn,
 		r:          bufio.NewReader(conn),
 		r:          bufio.NewReader(conn),
-		uniq:       make(chan uint8, cfg.NumStreams),
+		uniq:       make(chan int, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
 		timeout:    cfg.Timeout,
 		timeout:    cfg.Timeout,
 		version:    uint8(cfg.ProtoVersion),
 		version:    uint8(cfg.ProtoVersion),
@@ -162,8 +176,8 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 		c.setKeepalive(cfg.Keepalive)
 		c.setKeepalive(cfg.Keepalive)
 	}
 	}
 
 
-	for i := 0; i < cap(c.uniq); i++ {
-		c.uniq <- uint8(i)
+	for i := 0; i < cfg.NumStreams; i++ {
+		c.uniq <- i
 	}
 	}
 
 
 	if err := c.startup(&cfg); err != nil {
 	if err := c.startup(&cfg); err != nil {
@@ -254,47 +268,69 @@ func (c *Conn) serve() {
 	c.pool.HandleError(c, err, true)
 	c.pool.HandleError(c, err, true)
 }
 }
 
 
+func (c *Conn) Write(p []byte) (int, error) {
+	c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
+	return c.conn.Write(p)
+}
+
+func (c *Conn) Read(p []byte) (int, error) {
+	return c.r.Read(p)
+}
+
 func (c *Conn) recv() (frame, error) {
 func (c *Conn) recv() (frame, error) {
-	resp := make(frame, headerSize, headerSize+512)
-	c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-	n, last, pinged := 0, 0, false
-	for n < len(resp) {
-		nn, err := c.r.Read(resp[n:])
-		n += nn
-		if err != nil {
-			if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
-				if n > last {
-					// we hit the deadline but we made progress.
-					// simply extend the deadline
-					c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-					last = n
-				} else if n == 0 && !pinged {
-					c.conn.SetReadDeadline(time.Now().Add(c.timeout))
-					if atomic.LoadInt32(&c.nwait) > 0 {
-						go c.ping()
-						pinged = true
-					}
-				} else {
-					return nil, err
-				}
-			} else {
-				return nil, err
-			}
+	size := headerProtoSize[c.version]
+	resp := make(frame, size, size+512)
+
+	// read a full header, ignore timeouts, as this is being ran in a loop
+	c.conn.SetReadDeadline(time.Time{})
+	_, err := io.ReadFull(c.r, resp[:size])
+	if err != nil {
+		return nil, err
+	}
+
+	if v := c.version | flagResponse; resp[0] != v {
+		return nil, NewErrProtocol("recv: response protocol version does not match connection protocol version (%d != %d)", resp[0], v)
+	}
+
+	bodySize := resp.Length(c.version)
+	if bodySize == 0 {
+		return resp, nil
+	}
+	resp.grow(bodySize)
+
+	const maxAttempts = 5
+
+	n := size
+	for i := 0; i < maxAttempts; i++ {
+		var nn int
+		c.conn.SetReadDeadline(time.Now().Add(c.timeout))
+		nn, err = c.Read(resp[n : size+bodySize])
+		if err == nil {
+			break
 		}
 		}
-		if n == headerSize && len(resp) == headerSize {
-			if resp[0] != c.version|flagResponse {
-				return nil, NewErrProtocol("recv: Response protocol version does not match connection protocol version (%d != %d)", resp[0], c.version|flagResponse)
-			}
-			resp.grow(resp.Length())
+		n += nn
+
+		if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
+			break
 		}
 		}
 	}
 	}
+
+	if err != nil {
+		return nil, err
+	}
+
 	return resp, nil
 	return resp, nil
 }
 }
 
 
 func (c *Conn) execSimple(op operation) (interface{}, error) {
 func (c *Conn) execSimple(op operation) (interface{}, error) {
 	f, err := op.encodeFrame(c.version, nil)
 	f, err := op.encodeFrame(c.version, nil)
-	f.setLength(len(f) - headerSize)
-	if _, err := c.conn.Write([]byte(f)); err != nil {
+	if err != nil {
+		// this should be a noop err
+		return nil, err
+	}
+	f.setLength(len(f)-headerProtoSize[c.version], c.version)
+
+	if _, err := c.Write([]byte(f)); err != nil {
 		c.Close()
 		c.Close()
 		return nil, err
 		return nil, err
 	}
 	}
@@ -312,6 +348,8 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 	if trace != nil {
 	if trace != nil {
 		req[1] |= flagTrace
 		req[1] |= flagTrace
 	}
 	}
+
+	headerSize := headerProtoSize[c.version]
 	if len(req) > headerSize && c.compressor != nil {
 	if len(req) > headerSize && c.compressor != nil {
 		body, err := c.compressor.Encode([]byte(req[headerSize:]))
 		body, err := c.compressor.Encode([]byte(req[headerSize:]))
 		if err != nil {
 		if err != nil {
@@ -320,16 +358,16 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 		req = append(req[:headerSize], frame(body)...)
 		req = append(req[:headerSize], frame(body)...)
 		req[1] |= flagCompress
 		req[1] |= flagCompress
 	}
 	}
-	req.setLength(len(req) - headerSize)
+	req.setLength(len(req)-headerSize, c.version)
 
 
 	id := <-c.uniq
 	id := <-c.uniq
-	req[2] = id
+	req.setStream(id, c.version)
 	call := &c.calls[id]
 	call := &c.calls[id]
 	call.resp = make(chan callResp, 1)
 	call.resp = make(chan callResp, 1)
 	atomic.AddInt32(&c.nwait, 1)
 	atomic.AddInt32(&c.nwait, 1)
 	atomic.StoreInt32(&call.active, 1)
 	atomic.StoreInt32(&call.active, 1)
 
 
-	if _, err := c.conn.Write(req); err != nil {
+	if _, err := c.Write(req); err != nil {
 		c.uniq <- id
 		c.uniq <- id
 		c.Close()
 		c.Close()
 		return nil, err
 		return nil, err
@@ -346,7 +384,7 @@ func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 }
 }
 
 
 func (c *Conn) dispatch(resp frame) {
 func (c *Conn) dispatch(resp frame) {
-	id := int(resp[2])
+	id := resp.Stream(c.version)
 	if id >= len(c.calls) {
 	if id >= len(c.calls) {
 		return
 		return
 	}
 	}
@@ -543,10 +581,10 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 }
 }
 
 
 func (c *Conn) executeBatch(batch *Batch) error {
 func (c *Conn) executeBatch(batch *Batch) error {
-	if c.version == 1 {
+	if c.version == protoVersion1 {
 		return ErrUnsupported
 		return ErrUnsupported
 	}
 	}
-	f := make(frame, headerSize, defaultFrameSize)
+	f := make(frame, headerProtoSize[c.version], defaultFrameSize)
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.setHeader(c.version, 0, 0, opBatch)
 	f.writeByte(byte(batch.Type))
 	f.writeByte(byte(batch.Type))
 	f.writeShort(uint16(len(batch.Entries)))
 	f.writeShort(uint16(len(batch.Entries)))
@@ -631,12 +669,14 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 			panic(r)
 			panic(r)
 		}
 		}
 	}()
 	}()
-	if len(f) < headerSize {
-		return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
+
+	if len(f) < headerProtoSize[c.version] {
+		return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerProtoSize[c.version])
 	} else if f[0] != c.version|flagResponse {
 	} else if f[0] != c.version|flagResponse {
 		return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
 		return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
 	}
 	}
-	flags, op, f := f[1], f[3], f[headerSize:]
+
+	flags, op, f := f[1], f.Op(c.version), f[headerProtoSize[c.version]:]
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
 	if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
 		if buf, err := c.compressor.Decode([]byte(f)); err != nil {
 		if buf, err := c.compressor.Decode([]byte(f)); err != nil {
 			return nil, err
 			return nil, err