Browse Source

pkg: add URLsFromFlags

Brian Waldon 11 years ago
parent
commit
11582b0f5f
2 changed files with 126 additions and 0 deletions
  1. 39 0
      pkg/flag.go
  2. 87 0
      pkg/flag_test.go

+ 39 - 0
pkg/flag.go

@@ -4,8 +4,12 @@ import (
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"log"
 	"log"
+	"net/url"
 	"os"
 	"os"
 	"strings"
 	"strings"
+
+	"github.com/coreos/etcd/pkg/flags"
+	"github.com/coreos/etcd/pkg/transport"
 )
 )
 
 
 type DeprecatedFlag struct {
 type DeprecatedFlag struct {
@@ -64,3 +68,38 @@ func SetFlagsFromEnv(fs *flag.FlagSet) {
 		}
 		}
 	})
 	})
 }
 }
+
+// URLsFromFlags decides what URLs should be using two different flags
+// as datasources. The first flag's Value must be of type URLs, while
+// the second must be of type IPAddressPort. If both of these flags
+// are set, an error will be returned. If only the first flag is set,
+// the underlying url.URL objects will be returned unmodified. If the
+// second flag happens to be set, the underlying IPAddressPort will be
+// converted to a url.URL and returned. The Scheme of the returned
+// url.URL will be http unless the provided TLSInfo object is non-empty.
+// If neither of the flags have been explicitly set, the default value
+// of the first flag will be returned unmodified.
+func URLsFromFlags(fs *flag.FlagSet, urlsFlagName string, addrFlagName string, tlsInfo transport.TLSInfo) ([]url.URL, error) {
+	visited := make(map[string]struct{})
+	fs.Visit(func(f *flag.Flag) {
+		visited[f.Name] = struct{}{}
+	})
+
+	_, urlsFlagIsSet := visited[urlsFlagName]
+	_, addrFlagIsSet := visited[addrFlagName]
+
+	if addrFlagIsSet {
+		if urlsFlagIsSet {
+			return nil, fmt.Errorf("Set only one of flags -%s and -%s", urlsFlagName, addrFlagName)
+		}
+
+		addr := *fs.Lookup(addrFlagName).Value.(*flags.IPAddressPort)
+		addrURL := url.URL{Scheme: "http", Host: addr.String()}
+		if !tlsInfo.Empty() {
+			addrURL.Scheme = "https"
+		}
+		return []url.URL{addrURL}, nil
+	}
+
+	return []url.URL(*fs.Lookup(urlsFlagName).Value.(*flags.URLs)), nil
+}

+ 87 - 0
pkg/flag_test.go

@@ -2,8 +2,13 @@ package pkg
 
 
 import (
 import (
 	"flag"
 	"flag"
+	"net/url"
 	"os"
 	"os"
+	"reflect"
 	"testing"
 	"testing"
+
+	"github.com/coreos/etcd/pkg/flags"
+	"github.com/coreos/etcd/pkg/transport"
 )
 )
 
 
 func TestSetFlagsFromEnv(t *testing.T) {
 func TestSetFlagsFromEnv(t *testing.T) {
@@ -49,3 +54,85 @@ func TestSetFlagsFromEnv(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestURLsFromFlags(t *testing.T) {
+	tests := []struct {
+		args     []string
+		tlsInfo  transport.TLSInfo
+		wantURLs []url.URL
+		wantFail bool
+	}{
+		// use -urls default when no flags defined
+		{
+			args:    []string{},
+			tlsInfo: transport.TLSInfo{},
+			wantURLs: []url.URL{
+				url.URL{Scheme: "http", Host: "127.0.0.1:2379"},
+			},
+			wantFail: false,
+		},
+
+		// explicitly setting -urls should carry through
+		{
+			args:    []string{"-urls=https://192.0.3.17:2930,http://127.0.0.1:1024"},
+			tlsInfo: transport.TLSInfo{},
+			wantURLs: []url.URL{
+				url.URL{Scheme: "https", Host: "192.0.3.17:2930"},
+				url.URL{Scheme: "http", Host: "127.0.0.1:1024"},
+			},
+			wantFail: false,
+		},
+
+		// explicitly setting -addr should carry through
+		{
+			args:    []string{"-addr=192.0.2.3:1024"},
+			tlsInfo: transport.TLSInfo{},
+			wantURLs: []url.URL{
+				url.URL{Scheme: "http", Host: "192.0.2.3:1024"},
+			},
+			wantFail: false,
+		},
+
+		// scheme prepended to -addr should be https if TLSInfo non-empty
+		{
+			args: []string{"-addr=192.0.2.3:1024"},
+			tlsInfo: transport.TLSInfo{
+				CertFile: "/tmp/foo",
+				KeyFile:  "/tmp/bar",
+			},
+			wantURLs: []url.URL{
+				url.URL{Scheme: "https", Host: "192.0.2.3:1024"},
+			},
+			wantFail: false,
+		},
+
+		// explicitly setting both -urls and -addr should fail
+		{
+			args:     []string{"-urls=https://127.0.0.1:1024", "-addr=192.0.2.3:1024"},
+			tlsInfo:  transport.TLSInfo{},
+			wantURLs: nil,
+			wantFail: true,
+		},
+	}
+
+	for i, tt := range tests {
+		fs := flag.NewFlagSet("test", flag.PanicOnError)
+		fs.Var(flags.NewURLs("http://127.0.0.1:2379"), "urls", "")
+		fs.Var(&flags.IPAddressPort{}, "addr", "")
+
+		if err := fs.Parse(tt.args); err != nil {
+			t.Errorf("#%d: failed to parse flags: %v", i, err)
+			continue
+		}
+
+		gotURLs, err := URLsFromFlags(fs, "urls", "addr", tt.tlsInfo)
+		if tt.wantFail != (err != nil) {
+			t.Errorf("#%d: wantFail=%t, got err=%v", i, tt.wantFail, err)
+			continue
+		}
+
+		if !reflect.DeepEqual(tt.wantURLs, gotURLs) {
+			t.Errorf("#%d: incorrect URLs\nwant=%#v\ngot=%#v", i, tt.wantURLs, gotURLs)
+		}
+	}
+}