Browse Source

Fix the TestConcurrent test to pass Go's race detection.

Running the race detection against the previous version of this test reveals
race conditions surrounding the "max" variable. Subsequent work in this patch
also revealed race conditions surrounding canStop.

Additionally, it is not valid to call *Testing.FailNow in goroutines started
by the tests.

This should retain the meaning of the original test while cleaning up the race
conditions and guaranteeing the *Testing.Fatal call occurs in the correct
goroutine.
Jeremy Bowers 12 years ago
parent
commit
fb0dc84878
2 changed files with 51 additions and 15 deletions
  1. 1 0
      AUTHORS
  2. 50 15
      driver_test.go

+ 1 - 0
AUTHORS

@@ -27,4 +27,5 @@ Xiuming Chen <cc at cxm.cc>
 
 # Organizations
 
+Barracuda Networks, Inc.
 Google Inc.

+ 50 - 15
driver_test.go

@@ -19,6 +19,8 @@ import (
 	"net/url"
 	"os"
 	"strings"
+	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -1220,40 +1222,73 @@ func TestConcurrent(t *testing.T) {
 			dbt.Fatalf("%s", err.Error())
 		}
 		dbt.Logf("Testing up to %d concurrent connections \r\n", max)
-		canStop := false
-		c := make(chan struct{}, max)
+
+		canStopVal := false
+		var canStopM sync.Mutex
+		canStop := func() bool {
+			canStopM.Lock()
+			defer canStopM.Unlock()
+			return canStopVal
+		}
+		setCanStop := func() {
+			canStopM.Lock()
+			defer canStopM.Unlock()
+			canStopVal = true
+		}
+
+		var succeeded int32
+
+		var wg sync.WaitGroup
+		wg.Add(max)
+
+		var fatalError string
+		var once sync.Once
+		fatal := func(s string, vals ...interface{}) {
+			once.Do(func() {
+				fatalError = fmt.Sprintf(s, vals...)
+			})
+		}
+
 		for i := 0; i < max; i++ {
 			go func(id int) {
+				defer wg.Done()
+
 				tx, err := dbt.db.Begin()
 				if err != nil {
-					canStop = true
+					setCanStop()
 					if err.Error() == "Error 1040: Too many connections" {
-						max--
 						return
 					} else {
-						dbt.Fatalf("Error on Con %d: %s", id, err.Error())
+						fatal("Error on Con %d: %s", id, err.Error())
 					}
 				}
-				c <- struct{}{}
-				for !canStop {
+
+				atomic.AddInt32(&succeeded, 1)
+
+				var hasSelected bool
+				for !canStop() || !hasSelected {
 					_, err = tx.Exec("SELECT 1")
 					if err != nil {
-						canStop = true
-						dbt.Fatalf("Error on Con %d: %s", id, err.Error())
+						setCanStop()
+						fatal("Error on Con %d: %s", id, err.Error())
 					}
+					hasSelected = true
 				}
 				err = tx.Commit()
 				if err != nil {
-					canStop = true
-					dbt.Fatalf("Error on Con %d: %s", id, err.Error())
+					fatal("Error on Con %d: %s", id, err.Error())
 				}
 			}(i)
 		}
-		for i := 0; i < max; i++ {
-			<-c
+
+		setCanStop()
+
+		wg.Wait()
+
+		if fatalError != "" {
+			dbt.Fatal(fatalError)
 		}
-		canStop = true
 
-		dbt.Logf("Reached %d concurrent connections \r\n", max)
+		dbt.Logf("Reached %d concurrent connections \r\n", succeeded)
 	})
 }