Bläddra i källkod

Add atomic wrappers for bool and error (#612)

* Add atomic wrappers for bool and error

Improves #608

* Drop Go 1.2 and Go 1.3 support

* "test" noCopy.Lock()
Julien Schmidt 8 år sedan
förälder
incheckning
72e0ac3f5f
9 ändrade filer med 169 tillägg och 47 borttagningar
  1. 0 2
      .travis.yml
  2. 2 2
      README.md
  3. 14 34
      connection.go
  4. 1 1
      connection_go18.go
  5. 3 3
      packets.go
  6. 3 3
      statement.go
  7. 2 2
      transaction.go
  8. 64 0
      utils.go
  9. 80 0
      utils_test.go

+ 0 - 2
.travis.yml

@@ -1,8 +1,6 @@
 sudo: false
 sudo: false
 language: go
 language: go
 go:
 go:
-  - 1.2
-  - 1.3
   - 1.4
   - 1.4
   - 1.5
   - 1.5
   - 1.6
   - 1.6

+ 2 - 2
README.md

@@ -39,7 +39,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
   * Optional placeholder interpolation
   * Optional placeholder interpolation
 
 
 ## Requirements
 ## Requirements
-  * Go 1.2 or higher
+  * Go 1.4 or higher
   * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
   * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
 
 
 ---------------------------------------
 ---------------------------------------
@@ -279,7 +279,7 @@ Default:        false
 
 
 `rejectreadOnly=true` causes the driver to reject read-only connections. This
 `rejectreadOnly=true` causes the driver to reject read-only connections. This
 is for a possible race condition during an automatic failover, where the mysql
 is for a possible race condition during an automatic failover, where the mysql
-client gets connected to a read-only replica after the failover. 
+client gets connected to a read-only replica after the failover.
 
 
 Note that this should be a fairly rare case, as an automatic failover normally
 Note that this should be a fairly rare case, as an automatic failover normally
 happens when the primary is down, and the race condition shouldn't happen
 happens when the primary is down, and the race condition shouldn't happen

+ 14 - 34
connection.go

@@ -14,17 +14,15 @@ import (
 	"net"
 	"net"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
-	"sync"
-	"sync/atomic"
 	"time"
 	"time"
 )
 )
 
 
-// a copy of context.Context for Go 1.7 and later.
+// a copy of context.Context for Go 1.7 and earlier
 type mysqlContext interface {
 type mysqlContext interface {
 	Done() <-chan struct{}
 	Done() <-chan struct{}
 	Err() error
 	Err() error
 
 
-	// They are defined in context.Context, but go-mysql-driver does not use them.
+	// defined in context.Context, but not used in this driver:
 	// Deadline() (deadline time.Time, ok bool)
 	// Deadline() (deadline time.Time, ok bool)
 	// Value(key interface{}) interface{}
 	// Value(key interface{}) interface{}
 }
 }
@@ -44,18 +42,13 @@ type mysqlConn struct {
 	parseTime        bool
 	parseTime        bool
 	strict           bool
 	strict           bool
 
 
-	// for context support (From Go 1.8)
+	// for context support (Go 1.8+)
 	watching bool
 	watching bool
 	watcher  chan<- mysqlContext
 	watcher  chan<- mysqlContext
 	closech  chan struct{}
 	closech  chan struct{}
 	finished chan<- struct{}
 	finished chan<- struct{}
-
-	// set non-zero when conn is closed, before closech is closed.
-	// accessed atomically.
-	closed int32
-
-	mu          sync.Mutex // guards following fields
-	canceledErr error      // set non-nil if conn is canceled
+	canceled atomicError // set non-nil if conn is canceled
+	closed   atomicBool  // set when conn is closed, before closech is closed
 }
 }
 
 
 // Handles parameters set in DSN after the connection is established
 // Handles parameters set in DSN after the connection is established
@@ -89,7 +82,7 @@ func (mc *mysqlConn) handleParams() (err error) {
 }
 }
 
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 		return nil, driver.ErrBadConn
 	}
 	}
@@ -103,7 +96,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 
 
 func (mc *mysqlConn) Close() (err error) {
 func (mc *mysqlConn) Close() (err error) {
 	// Makes Close idempotent
 	// Makes Close idempotent
-	if !mc.isBroken() {
+	if !mc.closed.IsSet() {
 		err = mc.writeCommandPacket(comQuit)
 		err = mc.writeCommandPacket(comQuit)
 	}
 	}
 
 
@@ -117,7 +110,7 @@ func (mc *mysqlConn) Close() (err error) {
 // is called before auth or on auth failure because MySQL will have already
 // is called before auth or on auth failure because MySQL will have already
 // closed the network connection.
 // closed the network connection.
 func (mc *mysqlConn) cleanup() {
 func (mc *mysqlConn) cleanup() {
-	if atomic.SwapInt32(&mc.closed, 1) != 0 {
+	if !mc.closed.TrySet(true) {
 		return
 		return
 	}
 	}
 
 
@@ -131,13 +124,9 @@ func (mc *mysqlConn) cleanup() {
 	}
 	}
 }
 }
 
 
-func (mc *mysqlConn) isBroken() bool {
-	return atomic.LoadInt32(&mc.closed) != 0
-}
-
 func (mc *mysqlConn) error() error {
 func (mc *mysqlConn) error() error {
-	if mc.isBroken() {
-		if err := mc.canceled(); err != nil {
+	if mc.closed.IsSet() {
+		if err := mc.canceled.Value(); err != nil {
 			return err
 			return err
 		}
 		}
 		return ErrInvalidConn
 		return ErrInvalidConn
@@ -146,7 +135,7 @@ func (mc *mysqlConn) error() error {
 }
 }
 
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 		return nil, driver.ErrBadConn
 	}
 	}
@@ -300,7 +289,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
 }
 }
 
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 		return nil, driver.ErrBadConn
 	}
 	}
@@ -361,7 +350,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 }
 }
 
 
 func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
 func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 		return nil, driver.ErrBadConn
 	}
 	}
@@ -436,19 +425,10 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 
 
 // finish is called when the query has canceled.
 // finish is called when the query has canceled.
 func (mc *mysqlConn) cancel(err error) {
 func (mc *mysqlConn) cancel(err error) {
-	mc.mu.Lock()
-	mc.canceledErr = err
-	mc.mu.Unlock()
+	mc.canceled.Set(err)
 	mc.cleanup()
 	mc.cleanup()
 }
 }
 
 
-// canceled returns non-nil if the connection was closed due to context cancelation.
-func (mc *mysqlConn) canceled() error {
-	mc.mu.Lock()
-	defer mc.mu.Unlock()
-	return mc.canceledErr
-}
-
 // finish is called when the query has succeeded.
 // finish is called when the query has succeeded.
 func (mc *mysqlConn) finish() {
 func (mc *mysqlConn) finish() {
 	if !mc.watching || mc.finished == nil {
 	if !mc.watching || mc.finished == nil {

+ 1 - 1
connection_go18.go

@@ -19,7 +19,7 @@ import (
 
 
 // Ping implements driver.Pinger interface
 // Ping implements driver.Pinger interface
 func (mc *mysqlConn) Ping(ctx context.Context) error {
 func (mc *mysqlConn) Ping(ctx context.Context) error {
-	if mc.isBroken() {
+	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		errLog.Print(ErrInvalidConn)
 		return driver.ErrBadConn
 		return driver.ErrBadConn
 	}
 	}

+ 3 - 3
packets.go

@@ -30,7 +30,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 		// read packet header
 		// read packet header
 		data, err := mc.buf.readNext(4)
 		data, err := mc.buf.readNext(4)
 		if err != nil {
 		if err != nil {
-			if cerr := mc.canceled(); cerr != nil {
+			if cerr := mc.canceled.Value(); cerr != nil {
 				return nil, cerr
 				return nil, cerr
 			}
 			}
 			errLog.Print(err)
 			errLog.Print(err)
@@ -66,7 +66,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 		// read packet body [pktLen bytes]
 		// read packet body [pktLen bytes]
 		data, err = mc.buf.readNext(pktLen)
 		data, err = mc.buf.readNext(pktLen)
 		if err != nil {
 		if err != nil {
-			if cerr := mc.canceled(); cerr != nil {
+			if cerr := mc.canceled.Value(); cerr != nil {
 				return nil, cerr
 				return nil, cerr
 			}
 			}
 			errLog.Print(err)
 			errLog.Print(err)
@@ -134,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 			mc.cleanup()
 			mc.cleanup()
 			errLog.Print(ErrMalformPkt)
 			errLog.Print(ErrMalformPkt)
 		} else {
 		} else {
-			if cerr := mc.canceled(); cerr != nil {
+			if cerr := mc.canceled.Value(); cerr != nil {
 				return cerr
 				return cerr
 			}
 			}
 			mc.cleanup()
 			mc.cleanup()

+ 3 - 3
statement.go

@@ -23,7 +23,7 @@ type mysqlStmt struct {
 }
 }
 
 
 func (stmt *mysqlStmt) Close() error {
 func (stmt *mysqlStmt) Close() error {
-	if stmt.mc == nil || stmt.mc.isBroken() {
+	if stmt.mc == nil || stmt.mc.closed.IsSet() {
 		// driver.Stmt.Close can be called more than once, thus this function
 		// driver.Stmt.Close can be called more than once, thus this function
 		// has to be idempotent.
 		// has to be idempotent.
 		// See also Issue #450 and golang/go#16019.
 		// See also Issue #450 and golang/go#16019.
@@ -45,7 +45,7 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
 }
 }
 
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
-	if stmt.mc.isBroken() {
+	if stmt.mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 		return nil, driver.ErrBadConn
 	}
 	}
@@ -93,7 +93,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 }
 }
 
 
 func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
 func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
-	if stmt.mc.isBroken() {
+	if stmt.mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 		return nil, driver.ErrBadConn
 	}
 	}

+ 2 - 2
transaction.go

@@ -13,7 +13,7 @@ type mysqlTx struct {
 }
 }
 
 
 func (tx *mysqlTx) Commit() (err error) {
 func (tx *mysqlTx) Commit() (err error) {
-	if tx.mc == nil || tx.mc.isBroken() {
+	if tx.mc == nil || tx.mc.closed.IsSet() {
 		return ErrInvalidConn
 		return ErrInvalidConn
 	}
 	}
 	err = tx.mc.exec("COMMIT")
 	err = tx.mc.exec("COMMIT")
@@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
 }
 }
 
 
 func (tx *mysqlTx) Rollback() (err error) {
 func (tx *mysqlTx) Rollback() (err error) {
-	if tx.mc == nil || tx.mc.isBroken() {
+	if tx.mc == nil || tx.mc.closed.IsSet() {
 		return ErrInvalidConn
 		return ErrInvalidConn
 	}
 	}
 	err = tx.mc.exec("ROLLBACK")
 	err = tx.mc.exec("ROLLBACK")

+ 64 - 0
utils.go

@@ -16,6 +16,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"strings"
 	"strings"
+	"sync/atomic"
 	"time"
 	"time"
 )
 )
 
 
@@ -740,3 +741,66 @@ func escapeStringQuotes(buf []byte, v string) []byte {
 
 
 	return buf[:pos]
 	return buf[:pos]
 }
 }
+
+/******************************************************************************
+*                               Sync utils                                    *
+******************************************************************************/
+// noCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://github.com/golang/go/issues/8005#issuecomment-190753527
+// for details.
+type noCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*noCopy) Lock() {}
+
+// atomicBool is a wrapper around uint32 for usage as a boolean value with
+// atomic access.
+type atomicBool struct {
+	_noCopy noCopy
+	value   uint32
+}
+
+// IsSet returns wether the current boolean value is true
+func (ab *atomicBool) IsSet() bool {
+	return atomic.LoadUint32(&ab.value) > 0
+}
+
+// Set sets the value of the bool regardless of the previous value
+func (ab *atomicBool) Set(value bool) {
+	if value {
+		atomic.StoreUint32(&ab.value, 1)
+	} else {
+		atomic.StoreUint32(&ab.value, 0)
+	}
+}
+
+// TrySet sets the value of the bool and returns wether the value changed
+func (ab *atomicBool) TrySet(value bool) bool {
+	if value {
+		return atomic.SwapUint32(&ab.value, 1) == 0
+	}
+	return atomic.SwapUint32(&ab.value, 0) > 0
+}
+
+// atomicBool is a wrapper for atomically accessed error values
+type atomicError struct {
+	_noCopy noCopy
+	value   atomic.Value
+}
+
+// Set sets the error value regardless of the previous value.
+// The value must not be nil
+func (ae *atomicError) Set(value error) {
+	ae.value.Store(value)
+}
+
+// Value returns the current error value
+func (ae *atomicError) Value() error {
+	if v := ae.value.Load(); v != nil {
+		// this will panic if the value doesn't implement the error interface
+		return v.(error)
+	}
+	return nil
+}

+ 80 - 0
utils_test.go

@@ -195,3 +195,83 @@ func TestEscapeQuotes(t *testing.T) {
 	expect("foo''bar", "foo'bar")      // affected
 	expect("foo''bar", "foo'bar")      // affected
 	expect("foo\"bar", "foo\"bar")     // not affected
 	expect("foo\"bar", "foo\"bar")     // not affected
 }
 }
+
+func TestAtomicBool(t *testing.T) {
+	var ab atomicBool
+	if ab.IsSet() {
+		t.Fatal("Expected value to be false")
+	}
+
+	ab.Set(true)
+	if ab.value != 1 {
+		t.Fatal("Set(true) did not set value to 1")
+	}
+	if !ab.IsSet() {
+		t.Fatal("Expected value to be true")
+	}
+
+	ab.Set(true)
+	if !ab.IsSet() {
+		t.Fatal("Expected value to be true")
+	}
+
+	ab.Set(false)
+	if ab.value != 0 {
+		t.Fatal("Set(false) did not set value to 0")
+	}
+	if ab.IsSet() {
+		t.Fatal("Expected value to be false")
+	}
+
+	ab.Set(false)
+	if ab.IsSet() {
+		t.Fatal("Expected value to be false")
+	}
+	if ab.TrySet(false) {
+		t.Fatal("Expected TrySet(false) to fail")
+	}
+	if !ab.TrySet(true) {
+		t.Fatal("Expected TrySet(true) to succeed")
+	}
+	if !ab.IsSet() {
+		t.Fatal("Expected value to be true")
+	}
+
+	ab.Set(true)
+	if !ab.IsSet() {
+		t.Fatal("Expected value to be true")
+	}
+	if ab.TrySet(true) {
+		t.Fatal("Expected TrySet(true) to fail")
+	}
+	if !ab.TrySet(false) {
+		t.Fatal("Expected TrySet(false) to succeed")
+	}
+	if ab.IsSet() {
+		t.Fatal("Expected value to be false")
+	}
+
+	ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
+}
+
+func TestAtomicError(t *testing.T) {
+	var ae atomicError
+	if ae.Value() != nil {
+		t.Fatal("Expected value to be nil")
+	}
+
+	ae.Set(ErrMalformPkt)
+	if v := ae.Value(); v != ErrMalformPkt {
+		if v == nil {
+			t.Fatal("Value is still nil")
+		}
+		t.Fatal("Error did not match")
+	}
+	ae.Set(ErrPktSync)
+	if ae.Value() == ErrMalformPkt {
+		t.Fatal("Error still matches old error")
+	}
+	if v := ae.Value(); v != ErrPktSync {
+		t.Fatal("Error did not match")
+	}
+}