Ver Fonte

Merge pull request #106 from go-sql-driver/tls

tls.Config: strict key check + remove data race
Julien Schmidt há 12 anos atrás
pai
commit
2f1342a67c
3 ficheiros alterados com 131 adições e 101 exclusões
  1. 10 2
      connection.go
  2. 1 1
      driver_test.go
  3. 120 98
      utils.go

+ 10 - 2
connection.go

@@ -68,11 +68,19 @@ func (mc *mysqlConn) handleParams() (err error) {
 
 		// time.Time parsing
 		case "parseTime":
-			mc.parseTime = readBool(val)
+			var isBool bool
+			mc.parseTime, isBool = readBool(val)
+			if !isBool {
+				return errors.New("Invalid Bool value: " + val)
+			}
 
 		// Strict mode
 		case "strict":
-			mc.strict = readBool(val)
+			var isBool bool
+			mc.strict, isBool = readBool(val)
+			if !isBool {
+				return errors.New("Invalid Bool value: " + val)
+			}
 
 		// Compression
 		case "compress":

+ 1 - 1
driver_test.go

@@ -1053,7 +1053,7 @@ func TestStmtMultiRows(t *testing.T) {
 }
 
 func TestConcurrent(t *testing.T) {
-	if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
+	if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
 		t.Skip("CONCURRENT env var not set")
 	}
 

+ 120 - 98
utils.go

@@ -23,63 +23,25 @@ import (
 	"time"
 )
 
-// NullTime represents a time.Time that may be NULL.
-// NullTime implements the Scanner interface so
-// it can be used as a scan destination:
-//
-//  var nt NullTime
-//  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
-//  ...
-//  if nt.Valid {
-//     // use nt.Time
-//  } else {
-//     // NULL value
-//  }
-//
-// This NullTime implementation is not driver-specific
-type NullTime struct {
-	Time  time.Time
-	Valid bool // Valid is true if Time is not NULL
-}
-
-// Scan implements the Scanner interface.
-// The value type must be time.Time or string / []byte (formatted time-string),
-// otherwise Scan fails.
-func (nt *NullTime) Scan(value interface{}) (err error) {
-	if value == nil {
-		nt.Time, nt.Valid = time.Time{}, false
-		return
-	}
+var (
+	errLog            *log.Logger            // Error Logger
+	dsnPattern        *regexp.Regexp         // Data Source Name Parser
+	tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
+)
 
-	switch v := value.(type) {
-	case time.Time:
-		nt.Time, nt.Valid = v, true
-		return
-	case []byte:
-		nt.Time, err = parseDateTime(string(v), time.UTC)
-		nt.Valid = (err == nil)
-		return
-	case string:
-		nt.Time, err = parseDateTime(v, time.UTC)
-		nt.Valid = (err == nil)
-		return
-	}
+func init() {
+	errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
 
-	nt.Valid = false
-	return fmt.Errorf("Can't convert %T to time.Time", value)
-}
+	dsnPattern = regexp.MustCompile(
+		`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
+			`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
+			`\/(?P<dbname>.*?)` + // /dbname
+			`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
 
-// Value implements the driver Valuer interface.
-func (nt NullTime) Value() (driver.Value, error) {
-	if !nt.Valid {
-		return nil, nil
-	}
-	return nt.Time, nil
+	tlsConfigRegister = make(map[string]*tls.Config)
 }
 
-var tlsConfigMap map[string]*tls.Config
-
-// Registers a custom tls.Config to be used with sql.Open.
+// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
 // Use the key as a value in the DSN where tls=value.
 //
 //  rootCertPool := x509.NewCertPool()
@@ -102,39 +64,20 @@ var tlsConfigMap map[string]*tls.Config
 //  })
 //  db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
 //
-func RegisterTLSConfig(key string, config *tls.Config) {
-	if tlsConfigMap == nil {
-		tlsConfigMap = make(map[string]*tls.Config)
+func RegisterTLSConfig(key string, config *tls.Config) error {
+	if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
+		return fmt.Errorf("Key '%s' is reserved", key)
 	}
-	tlsConfigMap[key] = config
-}
 
-// Removes tls.Config associated with key.
-func DeregisterTLSConfig(key string) {
-	if tlsConfigMap == nil {
-		return
-	}
-	delete(tlsConfigMap, key)
+	tlsConfigRegister[key] = config
+	return nil
 }
 
-// Logger
-var (
-	errLog *log.Logger
-)
-
-func init() {
-	errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
-
-	dsnPattern = regexp.MustCompile(
-		`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
-			`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
-			`\/(?P<dbname>.*?)` + // /dbname
-			`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
+// DeregisterTLSConfig removes the tls.Config associated with key.
+func DeregisterTLSConfig(key string) {
+	delete(tlsConfigRegister, key)
 }
 
-// Data Source Name Parser
-var dsnPattern *regexp.Regexp
-
 func parseDSN(dsn string) (cfg *config, err error) {
 	cfg = new(config)
 	cfg.params = make(map[string]string)
@@ -166,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) {
 
 				// Disable INFILE whitelist / enable all files
 				case "allowAllFiles":
-					cfg.allowAllFiles = readBool(value)
+					var isBool bool
+					cfg.allowAllFiles, isBool = readBool(value)
+					if !isBool {
+						err = fmt.Errorf("Invalid Bool value: %s", value)
+						return
+					}
 
 				// Switch "rowsAffected" mode
 				case "clientFoundRows":
-					cfg.clientFoundRows = readBool(value)
+					var isBool bool
+					cfg.clientFoundRows, isBool = readBool(value)
+					if !isBool {
+						err = fmt.Errorf("Invalid Bool value: %s", value)
+						return
+					}
 
 				// Time Location
 				case "loc":
@@ -188,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) {
 
 				// TLS-Encryption
 				case "tls":
-					if readBool(value) {
-						cfg.tls = &tls.Config{}
-					} else if strings.ToLower(value) == "skip-verify" {
-						cfg.tls = &tls.Config{InsecureSkipVerify: true}
-					// TODO: Check for Boolean false
-					} else if tlsConfig, ok := tlsConfigMap[value]; ok {
-						cfg.tls = tlsConfig
+					boolValue, isBool := readBool(value)
+					if isBool {
+						if boolValue {
+							cfg.tls = &tls.Config{}
+						}
+					} else {
+						if strings.ToLower(value) == "skip-verify" {
+							cfg.tls = &tls.Config{InsecureSkipVerify: true}
+						} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
+							cfg.tls = tlsConfig
+						} else {
+							err = fmt.Errorf("Invalid value / unknown config name: %s", value)
+							return
+						}
 					}
 
 				default:
@@ -253,6 +213,78 @@ func scramblePassword(scramble, password []byte) []byte {
 	return scramble
 }
 
+// Returns the bool value of the input.
+// The 2nd return value indicates if the input was a valid bool value
+func readBool(input string) (value bool, valid bool) {
+	switch input {
+	case "1", "true", "TRUE", "True":
+		return true, true
+	case "0", "false", "FALSE", "False":
+		return false, true
+	}
+
+	// Not a valid bool value
+	return
+}
+
+/******************************************************************************
+*                           Time related utils                                *
+******************************************************************************/
+
+// NullTime represents a time.Time that may be NULL.
+// NullTime implements the Scanner interface so
+// it can be used as a scan destination:
+//
+//  var nt NullTime
+//  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
+//  ...
+//  if nt.Valid {
+//     // use nt.Time
+//  } else {
+//     // NULL value
+//  }
+//
+// This NullTime implementation is not driver-specific
+type NullTime struct {
+	Time  time.Time
+	Valid bool // Valid is true if Time is not NULL
+}
+
+// Scan implements the Scanner interface.
+// The value type must be time.Time or string / []byte (formatted time-string),
+// otherwise Scan fails.
+func (nt *NullTime) Scan(value interface{}) (err error) {
+	if value == nil {
+		nt.Time, nt.Valid = time.Time{}, false
+		return
+	}
+
+	switch v := value.(type) {
+	case time.Time:
+		nt.Time, nt.Valid = v, true
+		return
+	case []byte:
+		nt.Time, err = parseDateTime(string(v), time.UTC)
+		nt.Valid = (err == nil)
+		return
+	case string:
+		nt.Time, err = parseDateTime(v, time.UTC)
+		nt.Valid = (err == nil)
+		return
+	}
+
+	nt.Valid = false
+	return fmt.Errorf("Can't convert %T to time.Time", value)
+}
+
+// Value implements the driver Valuer interface.
+func (nt NullTime) Value() (driver.Value, error) {
+	if !nt.Valid {
+		return nil, nil
+	}
+	return nt.Time, nil
+}
+
 func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
 	switch len(str) {
 	case 10: // YYYY-MM-DD
@@ -369,16 +401,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
 	return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
 }
 
-func readBool(value string) bool {
-	switch strings.ToLower(value) {
-	case "true":
-		return true
-	case "1":
-		return true
-	}
-	return false
-}
-
 /******************************************************************************
 *                       Convert from and to bytes                             *
 ******************************************************************************/