Browse Source

Add atomic wrappers for bool and error (#612)

* Add atomic wrappers for bool and error

Improves #608

* Drop Go 1.2 and Go 1.3 support

* "test" noCopy.Lock()
Julien Schmidt 8 years ago
parent
commit
72e0ac3f5f
9 changed files with 169 additions and 47 deletions
  1. 0 2
      .travis.yml
  2. 2 2
      README.md
  3. 14 34
      connection.go
  4. 1 1
      connection_go18.go
  5. 3 3
      packets.go
  6. 3 3
      statement.go
  7. 2 2
      transaction.go
  8. 64 0
      utils.go
  9. 80 0
      utils_test.go

+ 0 - 2
.travis.yml

@@ -1,8 +1,6 @@
 sudo: false
 language: go
 go:
-  - 1.2
-  - 1.3
   - 1.4
   - 1.5
   - 1.6

+ 2 - 2
README.md

@@ -39,7 +39,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
   * Optional placeholder interpolation
 
 ## Requirements
-  * Go 1.2 or higher
+  * Go 1.4 or higher
   * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
 
 ---------------------------------------
@@ -279,7 +279,7 @@ Default:        false
 
 `rejectreadOnly=true` causes the driver to reject read-only connections. This
 is for a possible race condition during an automatic failover, where the mysql
-client gets connected to a read-only replica after the failover. 
+client gets connected to a read-only replica after the failover.
 
 Note that this should be a fairly rare case, as an automatic failover normally
 happens when the primary is down, and the race condition shouldn't happen

+ 14 - 34
connection.go

@@ -14,17 +14,15 @@ import (
 	"net"
 	"strconv"
 	"strings"
-	"sync"
-	"sync/atomic"
 	"time"
 )
 
-// a copy of context.Context for Go 1.7 and later.
+// a copy of context.Context for Go 1.7 and earlier
 type mysqlContext interface {
 	Done() <-chan struct{}
 	Err() error
 
-	// They are defined in context.Context, but go-mysql-driver does not use them.
+	// defined in context.Context, but not used in this driver:
 	// Deadline() (deadline time.Time, ok bool)
 	// Value(key interface{}) interface{}
 }
@@ -44,18 +42,13 @@ type mysqlConn struct {
 	parseTime        bool
 	strict           bool
 
-	// for context support (From Go 1.8)
+	// for context support (Go 1.8+)
 	watching bool
 	watcher  chan<- mysqlContext
 	closech  chan struct{}
 	finished chan<- struct{}
-
-	// set non-zero when conn is closed, before closech is closed.
-	// accessed atomically.
-	closed int32
-
-	mu          sync.Mutex // guards following fields
-	canceledErr error      // set non-nil if conn is canceled
+	canceled atomicError // set non-nil if conn is canceled
+	closed   atomicBool  // set when conn is closed, before closech is closed
 }
 
 // Handles parameters set in DSN after the connection is established
@@ -89,7 +82,7 @@ func (mc *mysqlConn) handleParams() (err error) {
 }
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -103,7 +96,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 
 func (mc *mysqlConn) Close() (err error) {
 	// Makes Close idempotent
-	if !mc.isBroken() {
+	if !mc.closed.IsSet() {
 		err = mc.writeCommandPacket(comQuit)
 	}
 
@@ -117,7 +110,7 @@ func (mc *mysqlConn) Close() (err error) {
 // is called before auth or on auth failure because MySQL will have already
 // closed the network connection.
 func (mc *mysqlConn) cleanup() {
-	if atomic.SwapInt32(&mc.closed, 1) != 0 {
+	if !mc.closed.TrySet(true) {
 		return
 	}
 
@@ -131,13 +124,9 @@ func (mc *mysqlConn) cleanup() {
 	}
 }
 
-func (mc *mysqlConn) isBroken() bool {
-	return atomic.LoadInt32(&mc.closed) != 0
-}
-
 func (mc *mysqlConn) error() error {
-	if mc.isBroken() {
-		if err := mc.canceled(); err != nil {
+	if mc.closed.IsSet() {
+		if err := mc.canceled.Value(); err != nil {
 			return err
 		}
 		return ErrInvalidConn
@@ -146,7 +135,7 @@ func (mc *mysqlConn) error() error {
 }
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -300,7 +289,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
 }
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -361,7 +350,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 }
 
 func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -436,19 +425,10 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 
 // finish is called when the query has canceled.
 func (mc *mysqlConn) cancel(err error) {
-	mc.mu.Lock()
-	mc.canceledErr = err
-	mc.mu.Unlock()
+	mc.canceled.Set(err)
 	mc.cleanup()
 }
 
-// canceled returns non-nil if the connection was closed due to context cancelation.
-func (mc *mysqlConn) canceled() error {
-	mc.mu.Lock()
-	defer mc.mu.Unlock()
-	return mc.canceledErr
-}
-
 // finish is called when the query has succeeded.
 func (mc *mysqlConn) finish() {
 	if !mc.watching || mc.finished == nil {

+ 1 - 1
connection_go18.go

@@ -19,7 +19,7 @@ import (
 
 // Ping implements driver.Pinger interface
 func (mc *mysqlConn) Ping(ctx context.Context) error {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		return driver.ErrBadConn
 	}

+ 3 - 3
packets.go

@@ -30,7 +30,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 		// read packet header
 		data, err := mc.buf.readNext(4)
 		if err != nil {
-			if cerr := mc.canceled(); cerr != nil {
+			if cerr := mc.canceled.Value(); cerr != nil {
 				return nil, cerr
 			}
 			errLog.Print(err)
@@ -66,7 +66,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 		// read packet body [pktLen bytes]
 		data, err = mc.buf.readNext(pktLen)
 		if err != nil {
-			if cerr := mc.canceled(); cerr != nil {
+			if cerr := mc.canceled.Value(); cerr != nil {
 				return nil, cerr
 			}
 			errLog.Print(err)
@@ -134,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 			mc.cleanup()
 			errLog.Print(ErrMalformPkt)
 		} else {
-			if cerr := mc.canceled(); cerr != nil {
+			if cerr := mc.canceled.Value(); cerr != nil {
 				return cerr
 			}
 			mc.cleanup()

+ 3 - 3
statement.go

@@ -23,7 +23,7 @@ type mysqlStmt struct {
 }
 
 func (stmt *mysqlStmt) Close() error {
-	if stmt.mc == nil || stmt.mc.isBroken() {
+	if stmt.mc == nil || stmt.mc.closed.IsSet() {
 		// driver.Stmt.Close can be called more than once, thus this function
 		// has to be idempotent.
 		// See also Issue #450 and golang/go#16019.
@@ -45,7 +45,7 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
-	if stmt.mc.isBroken() {
+	if stmt.mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -93,7 +93,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 }
 
 func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
-	if stmt.mc.isBroken() {
+	if stmt.mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}

+ 2 - 2
transaction.go

@@ -13,7 +13,7 @@ type mysqlTx struct {
 }
 
 func (tx *mysqlTx) Commit() (err error) {
-	if tx.mc == nil || tx.mc.isBroken() {
+	if tx.mc == nil || tx.mc.closed.IsSet() {
 		return ErrInvalidConn
 	}
 	err = tx.mc.exec("COMMIT")
@@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
 }
 
 func (tx *mysqlTx) Rollback() (err error) {
-	if tx.mc == nil || tx.mc.isBroken() {
+	if tx.mc == nil || tx.mc.closed.IsSet() {
 		return ErrInvalidConn
 	}
 	err = tx.mc.exec("ROLLBACK")

+ 64 - 0
utils.go

@@ -16,6 +16,7 @@ import (
 	"fmt"
 	"io"
 	"strings"
+	"sync/atomic"
 	"time"
 )
 
@@ -740,3 +741,66 @@ func escapeStringQuotes(buf []byte, v string) []byte {
 
 	return buf[:pos]
 }
+
+/******************************************************************************
+*                               Sync utils                                    *
+******************************************************************************/
+// noCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://github.com/golang/go/issues/8005#issuecomment-190753527
+// for details.
+type noCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*noCopy) Lock() {}
+
+// atomicBool is a wrapper around uint32 for usage as a boolean value with
+// atomic access.
+type atomicBool struct {
+	_noCopy noCopy
+	value   uint32
+}
+
+// IsSet returns wether the current boolean value is true
+func (ab *atomicBool) IsSet() bool {
+	return atomic.LoadUint32(&ab.value) > 0
+}
+
+// Set sets the value of the bool regardless of the previous value
+func (ab *atomicBool) Set(value bool) {
+	if value {
+		atomic.StoreUint32(&ab.value, 1)
+	} else {
+		atomic.StoreUint32(&ab.value, 0)
+	}
+}
+
+// TrySet sets the value of the bool and returns wether the value changed
+func (ab *atomicBool) TrySet(value bool) bool {
+	if value {
+		return atomic.SwapUint32(&ab.value, 1) == 0
+	}
+	return atomic.SwapUint32(&ab.value, 0) > 0
+}
+
+// atomicBool is a wrapper for atomically accessed error values
+type atomicError struct {
+	_noCopy noCopy
+	value   atomic.Value
+}
+
+// Set sets the error value regardless of the previous value.
+// The value must not be nil
+func (ae *atomicError) Set(value error) {
+	ae.value.Store(value)
+}
+
+// Value returns the current error value
+func (ae *atomicError) Value() error {
+	if v := ae.value.Load(); v != nil {
+		// this will panic if the value doesn't implement the error interface
+		return v.(error)
+	}
+	return nil
+}

+ 80 - 0
utils_test.go

@@ -195,3 +195,83 @@ func TestEscapeQuotes(t *testing.T) {
 	expect("foo''bar", "foo'bar")      // affected
 	expect("foo\"bar", "foo\"bar")     // not affected
 }
+
+func TestAtomicBool(t *testing.T) {
+	var ab atomicBool
+	if ab.IsSet() {
+		t.Fatal("Expected value to be false")
+	}
+
+	ab.Set(true)
+	if ab.value != 1 {
+		t.Fatal("Set(true) did not set value to 1")
+	}
+	if !ab.IsSet() {
+		t.Fatal("Expected value to be true")
+	}
+
+	ab.Set(true)
+	if !ab.IsSet() {
+		t.Fatal("Expected value to be true")
+	}
+
+	ab.Set(false)
+	if ab.value != 0 {
+		t.Fatal("Set(false) did not set value to 0")
+	}
+	if ab.IsSet() {
+		t.Fatal("Expected value to be false")
+	}
+
+	ab.Set(false)
+	if ab.IsSet() {
+		t.Fatal("Expected value to be false")
+	}
+	if ab.TrySet(false) {
+		t.Fatal("Expected TrySet(false) to fail")
+	}
+	if !ab.TrySet(true) {
+		t.Fatal("Expected TrySet(true) to succeed")
+	}
+	if !ab.IsSet() {
+		t.Fatal("Expected value to be true")
+	}
+
+	ab.Set(true)
+	if !ab.IsSet() {
+		t.Fatal("Expected value to be true")
+	}
+	if ab.TrySet(true) {
+		t.Fatal("Expected TrySet(true) to fail")
+	}
+	if !ab.TrySet(false) {
+		t.Fatal("Expected TrySet(false) to succeed")
+	}
+	if ab.IsSet() {
+		t.Fatal("Expected value to be false")
+	}
+
+	ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
+}
+
+func TestAtomicError(t *testing.T) {
+	var ae atomicError
+	if ae.Value() != nil {
+		t.Fatal("Expected value to be nil")
+	}
+
+	ae.Set(ErrMalformPkt)
+	if v := ae.Value(); v != ErrMalformPkt {
+		if v == nil {
+			t.Fatal("Value is still nil")
+		}
+		t.Fatal("Error did not match")
+	}
+	ae.Set(ErrPktSync)
+	if ae.Value() == ErrMalformPkt {
+		t.Fatal("Error still matches old error")
+	}
+	if v := ae.Value(); v != ErrPktSync {
+		t.Fatal("Error did not match")
+	}
+}