Browse Source

Remove possible data race in tls.Config map init

also resorted code
Julien Schmidt 12 năm trước cách đây
mục cha
commit
abd1799c82
1 tập tin đã thay đổi với 85 bổ sung89 xóa
  1. 85 89
      utils.go

+ 85 - 89
utils.go

@@ -23,62 +23,24 @@ import (
 	"time"
 )
 
-// NullTime represents a time.Time that may be NULL.
-// NullTime implements the Scanner interface so
-// it can be used as a scan destination:
-//
-//  var nt NullTime
-//  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
-//  ...
-//  if nt.Valid {
-//     // use nt.Time
-//  } else {
-//     // NULL value
-//  }
-//
-// This NullTime implementation is not driver-specific
-type NullTime struct {
-	Time  time.Time
-	Valid bool // Valid is true if Time is not NULL
-}
-
-// Scan implements the Scanner interface.
-// The value type must be time.Time or string / []byte (formatted time-string),
-// otherwise Scan fails.
-func (nt *NullTime) Scan(value interface{}) (err error) {
-	if value == nil {
-		nt.Time, nt.Valid = time.Time{}, false
-		return
-	}
+var (
+	errLog            *log.Logger            // Error Logger
+	dsnPattern        *regexp.Regexp         // Data Source Name Parser
+	tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
+)
 
-	switch v := value.(type) {
-	case time.Time:
-		nt.Time, nt.Valid = v, true
-		return
-	case []byte:
-		nt.Time, err = parseDateTime(string(v), time.UTC)
-		nt.Valid = (err == nil)
-		return
-	case string:
-		nt.Time, err = parseDateTime(v, time.UTC)
-		nt.Valid = (err == nil)
-		return
-	}
+func init() {
+	errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
 
-	nt.Valid = false
-	return fmt.Errorf("Can't convert %T to time.Time", value)
-}
+	dsnPattern = regexp.MustCompile(
+		`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
+			`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
+			`\/(?P<dbname>.*?)` + // /dbname
+			`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
 
-// Value implements the driver Valuer interface.
-func (nt NullTime) Value() (driver.Value, error) {
-	if !nt.Valid {
-		return nil, nil
-	}
-	return nt.Time, nil
+	tlsConfigRegister = make(map[string]*tls.Config)
 }
 
-var tlsConfigMap map[string]*tls.Config
-
 // Registers a custom tls.Config to be used with sql.Open.
 // Use the key as a value in the DSN where tls=value.
 //
@@ -103,38 +65,14 @@ var tlsConfigMap map[string]*tls.Config
 //  db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
 //
 func RegisterTLSConfig(key string, config *tls.Config) {
-	if tlsConfigMap == nil {
-		tlsConfigMap = make(map[string]*tls.Config)
-	}
-	tlsConfigMap[key] = config
+	tlsConfigRegister[key] = config
 }
 
 // Removes tls.Config associated with key.
 func DeregisterTLSConfig(key string) {
-	if tlsConfigMap == nil {
-		return
-	}
-	delete(tlsConfigMap, key)
+	delete(tlsConfigRegister, key)
 }
 
-// Logger
-var (
-	errLog *log.Logger
-)
-
-func init() {
-	errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
-
-	dsnPattern = regexp.MustCompile(
-		`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
-			`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
-			`\/(?P<dbname>.*?)` + // /dbname
-			`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
-}
-
-// Data Source Name Parser
-var dsnPattern *regexp.Regexp
-
 func parseDSN(dsn string) (cfg *config, err error) {
 	cfg = new(config)
 	cfg.params = make(map[string]string)
@@ -192,8 +130,8 @@ func parseDSN(dsn string) (cfg *config, err error) {
 						cfg.tls = &tls.Config{}
 					} else if strings.ToLower(value) == "skip-verify" {
 						cfg.tls = &tls.Config{InsecureSkipVerify: true}
-					// TODO: Check for Boolean false
-					} else if tlsConfig, ok := tlsConfigMap[value]; ok {
+						// TODO: Check for Boolean false
+					} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
 						cfg.tls = tlsConfig
 					}
 
@@ -253,6 +191,74 @@ func scramblePassword(scramble, password []byte) []byte {
 	return scramble
 }
 
+func readBool(value string) bool {
+	switch strings.ToLower(value) {
+	case "true":
+		return true
+	case "1":
+		return true
+	}
+	return false
+}
+
+/******************************************************************************
+*                           Time related utils                                *
+******************************************************************************/
+
+// NullTime represents a time.Time that may be NULL.
+// NullTime implements the Scanner interface so
+// it can be used as a scan destination:
+//
+//  var nt NullTime
+//  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
+//  ...
+//  if nt.Valid {
+//     // use nt.Time
+//  } else {
+//     // NULL value
+//  }
+//
+// This NullTime implementation is not driver-specific
+type NullTime struct {
+	Time  time.Time
+	Valid bool // Valid is true if Time is not NULL
+}
+
+// Scan implements the Scanner interface.
+// The value type must be time.Time or string / []byte (formatted time-string),
+// otherwise Scan fails.
+func (nt *NullTime) Scan(value interface{}) (err error) {
+	if value == nil {
+		nt.Time, nt.Valid = time.Time{}, false
+		return
+	}
+
+	switch v := value.(type) {
+	case time.Time:
+		nt.Time, nt.Valid = v, true
+		return
+	case []byte:
+		nt.Time, err = parseDateTime(string(v), time.UTC)
+		nt.Valid = (err == nil)
+		return
+	case string:
+		nt.Time, err = parseDateTime(v, time.UTC)
+		nt.Valid = (err == nil)
+		return
+	}
+
+	nt.Valid = false
+	return fmt.Errorf("Can't convert %T to time.Time", value)
+}
+
+// Value implements the driver Valuer interface.
+func (nt NullTime) Value() (driver.Value, error) {
+	if !nt.Valid {
+		return nil, nil
+	}
+	return nt.Time, nil
+}
+
 func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
 	switch len(str) {
 	case 10: // YYYY-MM-DD
@@ -369,16 +375,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
 	return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
 }
 
-func readBool(value string) bool {
-	switch strings.ToLower(value) {
-	case "true":
-		return true
-	case "1":
-		return true
-	}
-	return false
-}
-
 /******************************************************************************
 *                       Convert from and to bytes                             *
 ******************************************************************************/