فهرست منبع

prevent 2 callers closing a connection

If multiple callers calls closeWithError they can both see
that Closed() return false, then they both call Close() where
only one will win the race, they will then both call close() on
the request channel, causing a panic.

Replace the mutex based closed checks with an atomic compare and
swap and a load to prevent races.
Chris Bannister 10 سال پیش
والد
کامیت
e136ade6d4
1فایلهای تغییر یافته به همراه24 افزوده شده و 36 حذف شده
  1. 24 36
      conn.go

+ 24 - 36
conn.go

@@ -112,8 +112,7 @@ type Conn struct {
 	currentKeyspace string
 	started         bool
 
-	closedMu sync.RWMutex
-	isClosed bool
+	closed int32
 
 	timeouts int64
 }
@@ -310,30 +309,6 @@ func (c *Conn) serve() {
 	c.closeWithError(err)
 }
 
-func (c *Conn) closeWithError(err error) {
-	if c.Closed() {
-		return
-	}
-
-	c.Close()
-
-	for id := 0; id < len(c.calls); id++ {
-		req := &c.calls[id]
-		// we need to send the error to all waiting queries, put the state
-		// of this conn into not active so that it can not execute any queries.
-		select {
-		case req.resp <- err:
-		default:
-		}
-
-		close(req.resp)
-	}
-
-	if c.started {
-		c.errorHandler.HandleError(c, err, true)
-	}
-}
-
 func (c *Conn) recv() error {
 	// not safe for concurrent reads
 
@@ -605,22 +580,35 @@ func (c *Conn) Pick(qry *Query) *Conn {
 }
 
 func (c *Conn) Closed() bool {
-	c.closedMu.RLock()
-	closed := c.isClosed
-	c.closedMu.RUnlock()
-	return closed
+	return atomic.LoadInt32(&c.closed) == 1
 }
 
-func (c *Conn) Close() {
-	c.closedMu.Lock()
-	if c.isClosed {
-		c.closedMu.Unlock()
+func (c *Conn) closeWithError(err error) {
+	if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
 		return
 	}
-	c.isClosed = true
-	c.closedMu.Unlock()
+
+	for id := 0; id < len(c.calls); id++ {
+		req := &c.calls[id]
+		// we need to send the error to all waiting queries, put the state
+		// of this conn into not active so that it can not execute any queries.
+		select {
+		case req.resp <- err:
+		default:
+		}
+
+		close(req.resp)
+	}
 
 	c.conn.Close()
+
+	if c.started && err != nil {
+		c.errorHandler.HandleError(c, err, true)
+	}
+}
+
+func (c *Conn) Close() {
+	c.closeWithError(nil)
 }
 
 func (c *Conn) Address() string {