|
@@ -41,7 +41,7 @@ func init() {
|
|
|
tlsConfigRegister = make(map[string]*tls.Config)
|
|
tlsConfigRegister = make(map[string]*tls.Config)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Registers a custom tls.Config to be used with sql.Open.
|
|
|
|
|
|
|
+// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
|
|
|
// Use the key as a value in the DSN where tls=value.
|
|
// Use the key as a value in the DSN where tls=value.
|
|
|
//
|
|
//
|
|
|
// rootCertPool := x509.NewCertPool()
|
|
// rootCertPool := x509.NewCertPool()
|
|
@@ -64,11 +64,16 @@ func init() {
|
|
|
// })
|
|
// })
|
|
|
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
|
|
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
|
|
|
//
|
|
//
|
|
|
-func RegisterTLSConfig(key string, config *tls.Config) {
|
|
|
|
|
|
|
+func RegisterTLSConfig(key string, config *tls.Config) error {
|
|
|
|
|
+ if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
|
|
|
|
|
+ return fmt.Errorf("Key '%s' is reserved", key)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
tlsConfigRegister[key] = config
|
|
tlsConfigRegister[key] = config
|
|
|
|
|
+ return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Removes tls.Config associated with key.
|
|
|
|
|
|
|
+// DeregisterTLSConfig removes the tls.Config associated with key.
|
|
|
func DeregisterTLSConfig(key string) {
|
|
func DeregisterTLSConfig(key string) {
|
|
|
delete(tlsConfigRegister, key)
|
|
delete(tlsConfigRegister, key)
|
|
|
}
|
|
}
|
|
@@ -104,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) {
|
|
|
|
|
|
|
|
// Disable INFILE whitelist / enable all files
|
|
// Disable INFILE whitelist / enable all files
|
|
|
case "allowAllFiles":
|
|
case "allowAllFiles":
|
|
|
- cfg.allowAllFiles = readBool(value)
|
|
|
|
|
|
|
+ var isBool bool
|
|
|
|
|
+ cfg.allowAllFiles, isBool = readBool(value)
|
|
|
|
|
+ if !isBool {
|
|
|
|
|
+ err = fmt.Errorf("Invalid Bool value: %s", value)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
// Switch "rowsAffected" mode
|
|
// Switch "rowsAffected" mode
|
|
|
case "clientFoundRows":
|
|
case "clientFoundRows":
|
|
|
- cfg.clientFoundRows = readBool(value)
|
|
|
|
|
|
|
+ var isBool bool
|
|
|
|
|
+ cfg.clientFoundRows, isBool = readBool(value)
|
|
|
|
|
+ if !isBool {
|
|
|
|
|
+ err = fmt.Errorf("Invalid Bool value: %s", value)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
// Time Location
|
|
// Time Location
|
|
|
case "loc":
|
|
case "loc":
|
|
@@ -126,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) {
|
|
|
|
|
|
|
|
// TLS-Encryption
|
|
// TLS-Encryption
|
|
|
case "tls":
|
|
case "tls":
|
|
|
- if readBool(value) {
|
|
|
|
|
- cfg.tls = &tls.Config{}
|
|
|
|
|
- } else if strings.ToLower(value) == "skip-verify" {
|
|
|
|
|
- cfg.tls = &tls.Config{InsecureSkipVerify: true}
|
|
|
|
|
- // TODO: Check for Boolean false
|
|
|
|
|
- } else if tlsConfig, ok := tlsConfigRegister[value]; ok {
|
|
|
|
|
- cfg.tls = tlsConfig
|
|
|
|
|
|
|
+ boolValue, isBool := readBool(value)
|
|
|
|
|
+ if isBool {
|
|
|
|
|
+ if boolValue {
|
|
|
|
|
+ cfg.tls = &tls.Config{}
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ if strings.ToLower(value) == "skip-verify" {
|
|
|
|
|
+ cfg.tls = &tls.Config{InsecureSkipVerify: true}
|
|
|
|
|
+ } else if tlsConfig, ok := tlsConfigRegister[value]; ok {
|
|
|
|
|
+ cfg.tls = tlsConfig
|
|
|
|
|
+ } else {
|
|
|
|
|
+ err = fmt.Errorf("Invalid value / unknown config name: %s", value)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
default:
|
|
@@ -191,14 +213,18 @@ func scramblePassword(scramble, password []byte) []byte {
|
|
|
return scramble
|
|
return scramble
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func readBool(value string) bool {
|
|
|
|
|
- switch strings.ToLower(value) {
|
|
|
|
|
- case "true":
|
|
|
|
|
- return true
|
|
|
|
|
- case "1":
|
|
|
|
|
- return true
|
|
|
|
|
|
|
+// Returns the bool value of the input.
|
|
|
|
|
+// The 2nd return value indicates if the input was a valid bool value
|
|
|
|
|
+func readBool(input string) (value bool, valid bool) {
|
|
|
|
|
+ switch input {
|
|
|
|
|
+ case "1", "true", "TRUE", "True":
|
|
|
|
|
+ return true, true
|
|
|
|
|
+ case "0", "false", "FALSE", "False":
|
|
|
|
|
+ return false, true
|
|
|
}
|
|
}
|
|
|
- return false
|
|
|
|
|
|
|
+
|
|
|
|
|
+ // Not a valid bool value
|
|
|
|
|
+ return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
/******************************************************************************
|
|
/******************************************************************************
|