Browse Source

Merge pull request #166 from thejerf/master

Fix the TestConcurrent test to pass Go's race detection.
Julien Schmidt 12 years ago
parent
commit
587def8198
2 changed files with 44 additions and 22 deletions
  1. 1 0
      AUTHORS
  2. 43 22
      driver_test.go

+ 1 - 0
AUTHORS

@@ -27,4 +27,5 @@ Xiuming Chen <cc at cxm.cc>
 
 # Organizations
 
+Barracuda Networks, Inc.
 Google Inc.

+ 43 - 22
driver_test.go

@@ -19,6 +19,8 @@ import (
 	"net/url"
 	"os"
 	"strings"
+	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -1220,40 +1222,59 @@ func TestConcurrent(t *testing.T) {
 			dbt.Fatalf("%s", err.Error())
 		}
 		dbt.Logf("Testing up to %d concurrent connections \r\n", max)
-		canStop := false
-		c := make(chan struct{}, max)
+
+		var remaining, succeeded int32 = int32(max), 0
+
+		var wg sync.WaitGroup
+		wg.Add(max)
+
+		var fatalError string
+		var once sync.Once
+		fatal := func(s string, vals ...interface{}) {
+			once.Do(func() {
+				fatalError = fmt.Sprintf(s, vals...)
+			})
+		}
+
 		for i := 0; i < max; i++ {
 			go func(id int) {
+				defer wg.Done()
+
 				tx, err := dbt.db.Begin()
+				atomic.AddInt32(&remaining, -1)
+
 				if err != nil {
-					canStop = true
-					if err.Error() == "Error 1040: Too many connections" {
-						max--
-						return
-					} else {
-						dbt.Fatalf("Error on Con %d: %s", id, err.Error())
+					if err.Error() != "Error 1040: Too many connections" {
+						fatal("Error on Conn %d: %s", id, err.Error())
 					}
+					return
 				}
-				c <- struct{}{}
-				for !canStop {
-					_, err = tx.Exec("SELECT 1")
-					if err != nil {
-						canStop = true
-						dbt.Fatalf("Error on Con %d: %s", id, err.Error())
+
+				// keep the connection busy until all connections are open
+				for remaining > 0 {
+					if _, err = tx.Exec("DO 1"); err != nil {
+						fatal("Error on Conn %d: %s", id, err.Error())
+						return
 					}
 				}
-				err = tx.Commit()
-				if err != nil {
-					canStop = true
-					dbt.Fatalf("Error on Con %d: %s", id, err.Error())
+
+				if err = tx.Commit(); err != nil {
+					fatal("Error on Conn %d: %s", id, err.Error())
+					return
 				}
+
+				// everything went fine with this connection
+				atomic.AddInt32(&succeeded, 1)
 			}(i)
 		}
-		for i := 0; i < max; i++ {
-			<-c
+
+		// wait until all conections are open
+		wg.Wait()
+
+		if fatalError != "" {
+			dbt.Fatal(fatalError)
 		}
-		canStop = true
 
-		dbt.Logf("Reached %d concurrent connections \r\n", max)
+		dbt.Logf("Reached %d concurrent connections\r\n", succeeded)
 	})
 }