|
@@ -10,6 +10,7 @@ package mysql
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"bytes"
|
|
"bytes"
|
|
|
|
|
+ "crypto/tls"
|
|
|
"encoding/binary"
|
|
"encoding/binary"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"testing"
|
|
"testing"
|
|
@@ -74,6 +75,46 @@ func TestDSNParserInvalid(t *testing.T) {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func TestDSNWithCustomTLS(t *testing.T) {
|
|
|
|
|
+ baseDSN := "user:password@tcp(localhost:5555)/dbname?tls="
|
|
|
|
|
+ tlsCfg := tls.Config{}
|
|
|
|
|
+
|
|
|
|
|
+ RegisterTLSConfig("utils_test", &tlsCfg)
|
|
|
|
|
+
|
|
|
|
|
+ // Custom TLS is missing
|
|
|
|
|
+ tst := baseDSN + "invalid_tls"
|
|
|
|
|
+ cfg, err := parseDSN(tst)
|
|
|
|
|
+ if err == nil {
|
|
|
|
|
+ t.Errorf("Invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ tst = baseDSN + "utils_test"
|
|
|
|
|
+
|
|
|
|
|
+ // Custom TLS with a server name
|
|
|
|
|
+ name := "foohost"
|
|
|
|
|
+ tlsCfg.ServerName = name
|
|
|
|
|
+ cfg, err = parseDSN(tst)
|
|
|
|
|
+
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Error(err.Error())
|
|
|
|
|
+ } else if cfg.tls.ServerName != name {
|
|
|
|
|
+ t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Custom TLS without a server name
|
|
|
|
|
+ name = "localhost"
|
|
|
|
|
+ tlsCfg.ServerName = ""
|
|
|
|
|
+ cfg, err = parseDSN(tst)
|
|
|
|
|
+
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Error(err.Error())
|
|
|
|
|
+ } else if cfg.tls.ServerName != name {
|
|
|
|
|
+ t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ DeregisterTLSConfig("utils_test")
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func BenchmarkParseDSN(b *testing.B) {
|
|
func BenchmarkParseDSN(b *testing.B) {
|
|
|
b.ReportAllocs()
|
|
b.ReportAllocs()
|
|
|
|
|
|