Browse Source

Export function to initialize Config with defaults (#679)

* dsn: export NewConfig

* dsn: move DSN normalization to separate func

* dsn: add godoc

* dsn: add missing return
Julien Schmidt 8 years ago
parent
commit
ee359f9587
2 changed files with 44 additions and 30 deletions
  1. 43 30
      dsn.go
  2. 1 0
      dsn_test.go

+ 43 - 30
dsn.go

@@ -28,7 +28,9 @@ var (
 	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
 )
 
-// Config is a configuration parsed from a DSN string
+// Config is a configuration parsed from a DSN string.
+// If a new Config is created instead of being parsed from a DSN string,
+// the NewConfig function should be used, which sets default values.
 type Config struct {
 	User             string            // Username
 	Passwd           string            // Password (requires User)
@@ -57,6 +59,43 @@ type Config struct {
 	RejectReadOnly          bool // Reject read-only connections
 }
 
+// NewConfig creates a new Config and sets default values.
+func NewConfig() *Config {
+	return &Config{
+		Collation:            defaultCollation,
+		Loc:                  time.UTC,
+		AllowNativePasswords: true,
+	}
+}
+
+func (cfg *Config) normalize() error {
+	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
+		return errInvalidDSNUnsafeCollation
+	}
+
+	// Set default network if empty
+	if cfg.Net == "" {
+		cfg.Net = "tcp"
+	}
+
+	// Set default address if empty
+	if cfg.Addr == "" {
+		switch cfg.Net {
+		case "tcp":
+			cfg.Addr = "127.0.0.1:3306"
+		case "unix":
+			cfg.Addr = "/tmp/mysql.sock"
+		default:
+			return errors.New("default addr for network '" + cfg.Net + "' unknown")
+		}
+
+	} else if cfg.Net == "tcp" {
+		cfg.Addr = ensureHavePort(cfg.Addr)
+	}
+
+	return nil
+}
+
 // FormatDSN formats the given Config into a DSN string which can be passed to
 // the driver.
 func (cfg *Config) FormatDSN() string {
@@ -273,11 +312,7 @@ func (cfg *Config) FormatDSN() string {
 // ParseDSN parses the DSN string to a Config
 func ParseDSN(dsn string) (cfg *Config, err error) {
 	// New config with some default values
-	cfg = &Config{
-		Loc:                  time.UTC,
-		Collation:            defaultCollation,
-		AllowNativePasswords: true,
-	}
+	cfg = NewConfig()
 
 	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
 	// Find the last '/' (since the password or the net addr might contain a '/')
@@ -345,31 +380,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
 		return nil, errInvalidDSNNoSlash
 	}
 
-	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
-		return nil, errInvalidDSNUnsafeCollation
-	}
-
-	// Set default network if empty
-	if cfg.Net == "" {
-		cfg.Net = "tcp"
+	if err = cfg.normalize(); err != nil {
+		return nil, err
 	}
-
-	// Set default address if empty
-	if cfg.Addr == "" {
-		switch cfg.Net {
-		case "tcp":
-			cfg.Addr = "127.0.0.1:3306"
-		case "unix":
-			cfg.Addr = "/tmp/mysql.sock"
-		default:
-			return nil, errors.New("default addr for network '" + cfg.Net + "' unknown")
-		}
-
-	}
-	if cfg.Net == "tcp" {
-		cfg.Addr = ensureHavePort(cfg.Addr)
-	}
-
 	return
 }
 

+ 1 - 0
dsn_test.go

@@ -98,6 +98,7 @@ func TestDSNParserInvalid(t *testing.T) {
 		"(/",                          // no closing brace
 		"net(addr)//",                 // unescaped
 		"User:pass@tcp(1.2.3.4:3306)", // no trailing slash
+		"net()/",                      // unknown default addr
 		//"/dbname?arg=/some/unescaped/path",
 	}