Browse Source

Postgres dialect parse password with spaces

xormplus 8 years ago
parent
commit
e0e2d418e1
2 changed files with 62 additions and 46 deletions
  1. 18 46
      dialect_postgres.go
  2. 44 0
      dialect_postgres_test.go

+ 18 - 46
dialect_postgres.go

@@ -8,7 +8,6 @@ import (
 	"errors"
 	"fmt"
 	"net/url"
-	"sort"
 	"strconv"
 	"strings"
 
@@ -1117,10 +1116,6 @@ func (vs values) Get(k string) (v string) {
 	return vs[k]
 }
 
-func errorf(s string, args ...interface{}) {
-	panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
-}
-
 func parseURL(connstr string) (string, error) {
 	u, err := url.Parse(connstr)
 	if err != nil {
@@ -1131,46 +1126,18 @@ func parseURL(connstr string) (string, error) {
 		return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
 	}
 
-	var kvs []string
 	escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
-	accrue := func(k, v string) {
-		if v != "" {
-			kvs = append(kvs, k+"="+escaper.Replace(v))
-		}
-	}
-
-	if u.User != nil {
-		v := u.User.Username()
-		accrue("user", v)
-
-		v, _ = u.User.Password()
-		accrue("password", v)
-	}
-
-	i := strings.Index(u.Host, ":")
-	if i < 0 {
-		accrue("host", u.Host)
-	} else {
-		accrue("host", u.Host[:i])
-		accrue("port", u.Host[i+1:])
-	}
 
 	if u.Path != "" {
-		accrue("dbname", u.Path[1:])
+		return escaper.Replace(u.Path[1:]), nil
 	}
 
-	q := u.Query()
-	for k := range q {
-		accrue(k, q.Get(k))
-	}
-
-	sort.Strings(kvs) // Makes testing easier (not a performance concern)
-	return strings.Join(kvs, " "), nil
+	return "", nil
 }
 
-func parseOpts(name string, o values) {
+func parseOpts(name string, o values) error {
 	if len(name) == 0 {
-		return
+		return fmt.Errorf("invalid options: %s", name)
 	}
 
 	name = strings.TrimSpace(name)
@@ -1179,31 +1146,36 @@ func parseOpts(name string, o values) {
 	for _, p := range ps {
 		kv := strings.Split(p, "=")
 		if len(kv) < 2 {
-			errorf("invalid option: %q", p)
+			return fmt.Errorf("invalid option: %q", p)
 		}
 		o.Set(kv[0], kv[1])
 	}
+
+	return nil
 }
 
 func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
 	db := &core.Uri{DbType: core.POSTGRES}
-	o := make(values)
 	var err error
+
 	if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
-		dataSourceName, err = parseURL(dataSourceName)
+		db.DbName, err = parseURL(dataSourceName)
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		o := make(values)
+		err = parseOpts(dataSourceName, o)
 		if err != nil {
 			return nil, err
 		}
+
+		db.DbName = o.Get("dbname")
 	}
-	parseOpts(dataSourceName, o)
 
-	db.DbName = o.Get("dbname")
 	if db.DbName == "" {
 		return nil, errors.New("dbname is empty")
 	}
-	/*db.Schema = o.Get("schema")
-	if len(db.Schema) == 0 {
-		db.Schema = "public"
-	}*/
+
 	return db, nil
 }

+ 44 - 0
dialect_postgres_test.go

@@ -0,0 +1,44 @@
+package xorm
+
+import (
+	"reflect"
+	"testing"
+
+	"github.com/xormplus/core"
+)
+
+func TestPostgresDialect(t *testing.T) {
+	TestParse(t)
+}
+
+func TestParse(t *testing.T) {
+	tests := []struct {
+		in       string
+		expected string
+		valid    bool
+	}{
+		{"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true},
+		{"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true},
+		{"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false},
+		{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true},
+		{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true},
+		{"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true},
+		{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
+		{"dbname=db sslmode=disable", "db", true},
+		{"user=auser password=password dbname=db sslmode=disable", "db", true},
+		{"", "db", false},
+		{"dbname=db =disable", "db", false},
+	}
+
+	driver := core.QueryDriver("postgres")
+
+	for _, test := range tests {
+		uri, err := driver.Parse("postgres", test.in)
+
+		if err != nil && test.valid {
+			t.Errorf("%q got unexpected error: %s", test.in, err)
+		} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
+			t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
+		}
+	}
+}