Jelajahi Sumber

cleanup config param handling

Julien Schmidt 12 tahun lalu
induk
melakukan
b0d08caea2
5 mengubah file dengan 30 tambahan dan 24 penghapusan
  1. 11 13
      connection.go
  2. 1 1
      infile.go
  3. 1 1
      packets.go
  4. 8 0
      utils.go
  5. 9 9
      utils_test.go

+ 11 - 13
connection.go

@@ -35,15 +35,17 @@ type mysqlConn struct {
 }
 
 type config struct {
-	user    string
-	passwd  string
-	net     string
-	addr    string
-	dbname  string
-	params  map[string]string
-	loc     *time.Location
-	timeout time.Duration
-	tls     *tls.Config
+	user            string
+	passwd          string
+	net             string
+	addr            string
+	dbname          string
+	params          map[string]string
+	loc             *time.Location
+	timeout         time.Duration
+	tls             *tls.Config
+	allowAllFiles   bool
+	clientFoundRows bool
 }
 
 // Handles parameters set in DSN
@@ -64,10 +66,6 @@ func (mc *mysqlConn) handleParams() (err error) {
 				return
 			}
 
-		// handled elsewhere
-		case "allowAllFiles", "clientFoundRows":
-			continue
-
 		// time.Time parsing
 		case "parseTime":
 			mc.parseTime = readBool(val)

+ 1 - 1
infile.go

@@ -74,7 +74,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 		}
 	} else { // File
 		name = strings.Trim(name, `"`)
-		if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` {
+		if mc.cfg.allowAllFiles || fileRegister[name] {
 			rdr, err = os.Open(name)
 		} else {
 			err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)

+ 1 - 1
packets.go

@@ -215,7 +215,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 		clientLocalFiles |
 		mc.flags&clientLongFlag
 
-	if _, ok := mc.cfg.params["clientFoundRows"]; ok {
+	if mc.cfg.clientFoundRows {
 		clientFlags |= clientFoundRows
 	}
 

+ 8 - 0
utils.go

@@ -124,6 +124,14 @@ func parseDSN(dsn string) (cfg *config, err error) {
 				// cfg params
 				switch value := param[1]; param[0] {
 
+				// Disable INFILE whitelist / enable all files
+				case "allowAllFiles":
+					cfg.allowAllFiles = readBool(value)
+
+				// Switch "rowsAffected" mode
+				case "clientFoundRows":
+					cfg.clientFoundRows = readBool(value)
+
 				// Time Location
 				case "loc":
 					cfg.loc, err = time.LoadLocation(value)

+ 9 - 9
utils_test.go

@@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) {
 		out string
 		loc *time.Location
 	}{
-		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user:password@/dbname?loc=UTC&timeout=30s", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil>}", time.UTC},
-		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.Local},
-		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true clientFoundRows:true}", time.UTC},
+		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.Local},
+		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
 	}
 
 	var cfg *config