Bladeren bron

strict tls.Config key check

Julien Schmidt 12 jaren geleden
bovenliggende
commit
c24056d3d5
3 gewijzigde bestanden met toevoegingen van 56 en 22 verwijderingen
  1. 10 2
      connection.go
  2. 1 1
      driver_test.go
  3. 45 19
      utils.go

+ 10 - 2
connection.go

@@ -68,11 +68,19 @@ func (mc *mysqlConn) handleParams() (err error) {
 
 		// time.Time parsing
 		case "parseTime":
-			mc.parseTime = readBool(val)
+			var isBool bool
+			mc.parseTime, isBool = readBool(val)
+			if !isBool {
+				return errors.New("Invalid Bool value: " + val)
+			}
 
 		// Strict mode
 		case "strict":
-			mc.strict = readBool(val)
+			var isBool bool
+			mc.strict, isBool = readBool(val)
+			if !isBool {
+				return errors.New("Invalid Bool value: " + val)
+			}
 
 		// Compression
 		case "compress":

+ 1 - 1
driver_test.go

@@ -1053,7 +1053,7 @@ func TestStmtMultiRows(t *testing.T) {
 }
 
 func TestConcurrent(t *testing.T) {
-	if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
+	if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
 		t.Skip("CONCURRENT env var not set")
 	}
 

+ 45 - 19
utils.go

@@ -41,7 +41,7 @@ func init() {
 	tlsConfigRegister = make(map[string]*tls.Config)
 }
 
-// Registers a custom tls.Config to be used with sql.Open.
+// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
 // Use the key as a value in the DSN where tls=value.
 //
 //  rootCertPool := x509.NewCertPool()
@@ -64,11 +64,16 @@ func init() {
 //  })
 //  db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
 //
-func RegisterTLSConfig(key string, config *tls.Config) {
+func RegisterTLSConfig(key string, config *tls.Config) error {
+	if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
+		return fmt.Errorf("Key '%s' is reserved", key)
+	}
+
 	tlsConfigRegister[key] = config
+	return nil
 }
 
-// Removes tls.Config associated with key.
+// DeregisterTLSConfig removes the tls.Config associated with key.
 func DeregisterTLSConfig(key string) {
 	delete(tlsConfigRegister, key)
 }
@@ -104,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) {
 
 				// Disable INFILE whitelist / enable all files
 				case "allowAllFiles":
-					cfg.allowAllFiles = readBool(value)
+					var isBool bool
+					cfg.allowAllFiles, isBool = readBool(value)
+					if !isBool {
+						err = fmt.Errorf("Invalid Bool value: %s", value)
+						return
+					}
 
 				// Switch "rowsAffected" mode
 				case "clientFoundRows":
-					cfg.clientFoundRows = readBool(value)
+					var isBool bool
+					cfg.clientFoundRows, isBool = readBool(value)
+					if !isBool {
+						err = fmt.Errorf("Invalid Bool value: %s", value)
+						return
+					}
 
 				// Time Location
 				case "loc":
@@ -126,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) {
 
 				// TLS-Encryption
 				case "tls":
-					if readBool(value) {
-						cfg.tls = &tls.Config{}
-					} else if strings.ToLower(value) == "skip-verify" {
-						cfg.tls = &tls.Config{InsecureSkipVerify: true}
-						// TODO: Check for Boolean false
-					} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
-						cfg.tls = tlsConfig
+					boolValue, isBool := readBool(value)
+					if isBool {
+						if boolValue {
+							cfg.tls = &tls.Config{}
+						}
+					} else {
+						if strings.ToLower(value) == "skip-verify" {
+							cfg.tls = &tls.Config{InsecureSkipVerify: true}
+						} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
+							cfg.tls = tlsConfig
+						} else {
+							err = fmt.Errorf("Invalid value / unknown config name: %s", value)
+							return
+						}
 					}
 
 				default:
@@ -191,14 +213,18 @@ func scramblePassword(scramble, password []byte) []byte {
 	return scramble
 }
 
-func readBool(value string) bool {
-	switch strings.ToLower(value) {
-	case "true":
-		return true
-	case "1":
-		return true
+// Returns the bool value of the input.
+// The 2nd return value indicates if the input was a valid bool value
+func readBool(input string) (value bool, valid bool) {
+	switch input {
+	case "1", "true", "TRUE", "True":
+		return true, true
+	case "0", "false", "FALSE", "False":
+		return false, true
 	}
-	return false
+
+	// Not a valid bool value
+	return
 }
 
 /******************************************************************************