소스 검색

prevent 2 callers closing a connection

If multiple callers calls closeWithError they can both see
that Closed() return false, then they both call Close() where
only one will win the race, they will then both call close() on
the request channel, causing a panic.

Replace the mutex based closed checks with an atomic compare and
swap and a load to prevent races.
Chris Bannister 10 년 전
부모
커밋
e136ade6d4
1개의 변경된 파일24개의 추가작업 그리고 36개의 파일을 삭제
  1. 24 36
      conn.go

+ 24 - 36
conn.go

@@ -112,8 +112,7 @@ type Conn struct {
 	currentKeyspace string
 	started         bool
 
-	closedMu sync.RWMutex
-	isClosed bool
+	closed int32
 
 	timeouts int64
 }
@@ -310,30 +309,6 @@ func (c *Conn) serve() {
 	c.closeWithError(err)
 }
 
-func (c *Conn) closeWithError(err error) {
-	if c.Closed() {
-		return
-	}
-
-	c.Close()
-
-	for id := 0; id < len(c.calls); id++ {
-		req := &c.calls[id]
-		// we need to send the error to all waiting queries, put the state
-		// of this conn into not active so that it can not execute any queries.
-		select {
-		case req.resp <- err:
-		default:
-		}
-
-		close(req.resp)
-	}
-
-	if c.started {
-		c.errorHandler.HandleError(c, err, true)
-	}
-}
-
 func (c *Conn) recv() error {
 	// not safe for concurrent reads
 
@@ -605,22 +580,35 @@ func (c *Conn) Pick(qry *Query) *Conn {
 }
 
 func (c *Conn) Closed() bool {
-	c.closedMu.RLock()
-	closed := c.isClosed
-	c.closedMu.RUnlock()
-	return closed
+	return atomic.LoadInt32(&c.closed) == 1
 }
 
-func (c *Conn) Close() {
-	c.closedMu.Lock()
-	if c.isClosed {
-		c.closedMu.Unlock()
+func (c *Conn) closeWithError(err error) {
+	if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
 		return
 	}
-	c.isClosed = true
-	c.closedMu.Unlock()
+
+	for id := 0; id < len(c.calls); id++ {
+		req := &c.calls[id]
+		// we need to send the error to all waiting queries, put the state
+		// of this conn into not active so that it can not execute any queries.
+		select {
+		case req.resp <- err:
+		default:
+		}
+
+		close(req.resp)
+	}
 
 	c.conn.Close()
+
+	if c.started && err != nil {
+		c.errorHandler.HandleError(c, err, true)
+	}
+}
+
+func (c *Conn) Close() {
+	c.closeWithError(nil)
 }
 
 func (c *Conn) Address() string {