Browse Source

Merge pull request #84 from Zariel/prep-cache-errors

Fix possible infinite wait on prepare error.
Christoph Hack 12 years ago
parent
commit
5d6cb0f98b
1 changed files with 38 additions and 21 deletions
  1. 38 21
      conn.go

+ 38 - 21
conn.go

@@ -50,7 +50,7 @@ type Conn struct {
 	nwait int32
 
 	prepMu sync.Mutex
-	prep   map[string]*queryInfo
+	prep   map[string]*inflightPrepare
 
 	cluster    Cluster
 	compressor Compressor
@@ -76,7 +76,7 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
 		r:          bufio.NewReader(conn),
 		uniq:       make(chan uint8, cfg.NumStreams),
 		calls:      make([]callReq, cfg.NumStreams),
-		prep:       make(map[string]*queryInfo),
+		prep:       make(map[string]*inflightPrepare),
 		timeout:    cfg.Timeout,
 		version:    uint8(cfg.ProtoVersion),
 		addr:       conn.RemoteAddr().String(),
@@ -252,32 +252,44 @@ func (c *Conn) ping() error {
 
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 	c.prepMu.Lock()
-	info := c.prep[stmt]
-	if info != nil {
+	flight := c.prep[stmt]
+	if flight != nil {
 		c.prepMu.Unlock()
-		info.wg.Wait()
-		return info, nil
+		flight.wg.Wait()
+		return flight.info, flight.err
 	}
-	info = new(queryInfo)
-	info.wg.Add(1)
-	c.prep[stmt] = info
+
+	flight = new(inflightPrepare)
+	flight.wg.Add(1)
+	c.prep[stmt] = flight
 	c.prepMu.Unlock()
 
 	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	if err != nil {
-		return nil, err
+		flight.err = err
+	} else {
+		switch x := resp.(type) {
+		case resultPreparedFrame:
+			flight.info = &queryInfo{
+				id:   x.PreparedId,
+				args: x.Values,
+			}
+		case error:
+			flight.err = x
+		default:
+			flight.err = ErrProtocol
+		}
 	}
-	switch x := resp.(type) {
-	case resultPreparedFrame:
-		info.id = x.PreparedId
-		info.args = x.Values
-		info.wg.Done()
-	case error:
-		return nil, x
-	default:
-		return nil, ErrProtocol
+
+	flight.wg.Done()
+
+	if err != nil {
+		c.prepMu.Lock()
+		delete(c.prep, stmt)
+		c.prepMu.Unlock()
 	}
-	return info, nil
+
+	return flight.info, flight.err
 }
 
 func (c *Conn) executeQuery(qry *Query) *Iter {
@@ -496,7 +508,6 @@ type queryInfo struct {
 	id   []byte
 	args []ColumnInfo
 	rval []ColumnInfo
-	wg   sync.WaitGroup
 }
 
 type callReq struct {
@@ -515,6 +526,12 @@ type Compressor interface {
 	Decode(data []byte) ([]byte, error)
 }
 
+type inflightPrepare struct {
+	info *queryInfo
+	err  error
+	wg   sync.WaitGroup
+}
+
 // SnappyCompressor implements the Compressor interface and can be used to
 // compress incoming and outgoing frames. The snappy compression algorithm
 // aims for very high speeds and reasonable compression.