Przeglądaj źródła

prevent 2 callers closing a connection

If multiple callers calls closeWithError they can both see
that Closed() return false, then they both call Close() where
only one will win the race, they will then both call close() on
the request channel, causing a panic.

Replace the mutex based closed checks with an atomic compare and
swap and a load to prevent races.
Chris Bannister 10 lat temu
rodzic
commit
e136ade6d4
1 zmienionych plików z 24 dodań i 36 usunięć
  1. 24 36
      conn.go

+ 24 - 36
conn.go

@@ -112,8 +112,7 @@ type Conn struct {
 	currentKeyspace string
 	currentKeyspace string
 	started         bool
 	started         bool
 
 
-	closedMu sync.RWMutex
-	isClosed bool
+	closed int32
 
 
 	timeouts int64
 	timeouts int64
 }
 }
@@ -310,30 +309,6 @@ func (c *Conn) serve() {
 	c.closeWithError(err)
 	c.closeWithError(err)
 }
 }
 
 
-func (c *Conn) closeWithError(err error) {
-	if c.Closed() {
-		return
-	}
-
-	c.Close()
-
-	for id := 0; id < len(c.calls); id++ {
-		req := &c.calls[id]
-		// we need to send the error to all waiting queries, put the state
-		// of this conn into not active so that it can not execute any queries.
-		select {
-		case req.resp <- err:
-		default:
-		}
-
-		close(req.resp)
-	}
-
-	if c.started {
-		c.errorHandler.HandleError(c, err, true)
-	}
-}
-
 func (c *Conn) recv() error {
 func (c *Conn) recv() error {
 	// not safe for concurrent reads
 	// not safe for concurrent reads
 
 
@@ -605,22 +580,35 @@ func (c *Conn) Pick(qry *Query) *Conn {
 }
 }
 
 
 func (c *Conn) Closed() bool {
 func (c *Conn) Closed() bool {
-	c.closedMu.RLock()
-	closed := c.isClosed
-	c.closedMu.RUnlock()
-	return closed
+	return atomic.LoadInt32(&c.closed) == 1
 }
 }
 
 
-func (c *Conn) Close() {
-	c.closedMu.Lock()
-	if c.isClosed {
-		c.closedMu.Unlock()
+func (c *Conn) closeWithError(err error) {
+	if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
 		return
 		return
 	}
 	}
-	c.isClosed = true
-	c.closedMu.Unlock()
+
+	for id := 0; id < len(c.calls); id++ {
+		req := &c.calls[id]
+		// we need to send the error to all waiting queries, put the state
+		// of this conn into not active so that it can not execute any queries.
+		select {
+		case req.resp <- err:
+		default:
+		}
+
+		close(req.resp)
+	}
 
 
 	c.conn.Close()
 	c.conn.Close()
+
+	if c.started && err != nil {
+		c.errorHandler.HandleError(c, err, true)
+	}
+}
+
+func (c *Conn) Close() {
+	c.closeWithError(nil)
 }
 }
 
 
 func (c *Conn) Address() string {
 func (c *Conn) Address() string {