Browse Source

Merge pull request #1608 from jonboulle/flags

pkg: move to more generic StringsFlag
Jonathan Boulle 11 years ago
parent
commit
ab00d23cd3
4 changed files with 53 additions and 96 deletions
  1. 24 10
      etcdmain/etcd.go
  2. 0 55
      pkg/flags/proxy.go
  3. 19 24
      pkg/flags/strings.go
  4. 10 7
      pkg/flags/strings_test.go

+ 24 - 10
etcdmain/etcd.go

@@ -40,6 +40,13 @@ import (
 const (
 	// the owner can make/remove files inside the directory
 	privateDirMode = 0700
+
+	proxyFlagOff      = "off"
+	proxyFlagReadonly = "readonly"
+	proxyFlagOn       = "on"
+
+	fallbackFlagExit  = "exit"
+	fallbackFlagProxy = "proxy"
 )
 
 var (
@@ -47,7 +54,6 @@ var (
 	name         = fs.String("name", "default", "Unique human-readable name for this node")
 	dir          = fs.String("data-dir", "", "Path to the data directory")
 	durl         = fs.String("discovery", "", "Discovery service used to bootstrap the cluster")
-	dfallback    = new(flags.Fallback)
 	snapCount    = fs.Uint64("snapshot-count", etcdserver.DefaultSnapCount, "Number of committed transactions to trigger a snapshot")
 	printVersion = fs.Bool("version", false, "Print the version and exit")
 
@@ -56,7 +62,15 @@ var (
 	clusterState        = new(etcdserver.ClusterState)
 
 	corsInfo  = &cors.CORSInfo{}
-	proxyFlag = new(flags.Proxy)
+	proxyFlag = flags.NewStringsFlag(
+		proxyFlagOff,
+		proxyFlagReadonly,
+		proxyFlagOn,
+	)
+	fallbackFlag = flags.NewStringsFlag(
+		fallbackFlagExit,
+		fallbackFlagProxy,
+	)
 
 	clientTLSInfo = transport.TLSInfo{}
 	peerTLSInfo   = transport.TLSInfo{}
@@ -92,13 +106,13 @@ func init() {
 
 	fs.Var(corsInfo, "cors", "Comma-separated white list of origins for CORS (cross-origin resource sharing).")
 
-	fs.Var(proxyFlag, "proxy", fmt.Sprintf("Valid values include %s", strings.Join(flags.ProxyValues, ", ")))
-	if err := proxyFlag.Set(flags.ProxyValueOff); err != nil {
+	fs.Var(proxyFlag, "proxy", fmt.Sprintf("Valid values include %s", strings.Join(proxyFlag.Values, ", ")))
+	if err := proxyFlag.Set(proxyFlagOff); err != nil {
 		// Should never happen.
 		log.Panicf("unexpected error setting up proxyFlag: %v", err)
 	}
-	fs.Var(dfallback, "discovery-fallback", fmt.Sprintf("Valid values include %s", strings.Join(flags.FallbackValues, ", ")))
-	if err := dfallback.Set(flags.FallbackProxy); err != nil {
+	fs.Var(fallbackFlag, "discovery-fallback", fmt.Sprintf("Valid values include %s", strings.Join(fallbackFlag.Values, ", ")))
+	if err := fallbackFlag.Set(fallbackFlagProxy); err != nil {
 		// Should never happen.
 		log.Panicf("unexpected error setting up discovery-fallback flag: %v", err)
 	}
@@ -143,13 +157,13 @@ func Main() {
 
 	flags.SetFlagsFromEnv(fs)
 
-	if string(*proxyFlag) == flags.ProxyValueOff {
+	if proxyFlag.String() == proxyFlagOff {
 		if err := startEtcd(); err == nil {
 			// Block indefinitely
 			<-make(chan struct{})
 		} else {
-			if err == discovery.ErrFullCluster && *dfallback == flags.FallbackProxy {
-				fmt.Printf("etcd: dicovery cluster full, falling back to %s", flags.FallbackProxy)
+			if err == discovery.ErrFullCluster && fallbackFlag.String() == fallbackFlagProxy {
+				fmt.Printf("etcd: discovery cluster full, falling back to %s", fallbackFlagProxy)
 			} else {
 				log.Fatalf("etcd: %v", err)
 			}
@@ -318,7 +332,7 @@ func startProxy() error {
 		Info:    corsInfo,
 	}
 
-	if string(*proxyFlag) == flags.ProxyValueReadonly {
+	if proxyFlag.String() == proxyFlagReadonly {
 		ph = proxy.NewReadonlyHandler(ph)
 	}
 	lcurls, err := flags.URLsFromFlags(fs, "listen-client-urls", "bind-addr", clientTLSInfo)

+ 0 - 55
pkg/flags/proxy.go

@@ -1,55 +0,0 @@
-/*
-   Copyright 2014 CoreOS, Inc.
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
-*/
-
-package flags
-
-import (
-	"errors"
-)
-
-const (
-	ProxyValueOff      = "off"
-	ProxyValueReadonly = "readonly"
-	ProxyValueOn       = "on"
-)
-
-var (
-	ProxyValues = []string{
-		ProxyValueOff,
-		ProxyValueReadonly,
-		ProxyValueOn,
-	}
-)
-
-// ProxyFlag implements the flag.Value interface.
-type Proxy string
-
-// Set verifies the argument to be a valid member of proxyFlagValues
-// before setting the underlying flag value.
-func (pf *Proxy) Set(s string) error {
-	for _, v := range ProxyValues {
-		if s == v {
-			*pf = Proxy(s)
-			return nil
-		}
-	}
-
-	return errors.New("invalid value")
-}
-
-func (pf *Proxy) String() string {
-	return string(*pf)
-}

+ 19 - 24
pkg/flags/fallback.go → pkg/flags/strings.go

@@ -16,38 +16,33 @@
 
 package flags
 
-import (
-	"errors"
-)
-
-const (
-	FallbackExit  = "exit"
-	FallbackProxy = "proxy"
-)
-
-var (
-	FallbackValues = []string{
-		FallbackExit,
-		FallbackProxy,
-	}
-)
+import "errors"
 
-// FallbackFlag implements the flag.Value interface.
-type Fallback string
+// NewStringsFlag creates a new string flag for which any one of the given
+// strings is a valid value, and any other value is an error.
+func NewStringsFlag(valids ...string) *StringsFlag {
+	return &StringsFlag{Values: valids}
+}
 
-// Set verifies the argument to be a valid member of FallbackFlagValues
+// StringsFlag implements the flag.Value interface.
+type StringsFlag struct {
+	Values []string
+	val    string
+}
+
+// Set verifies the argument to be a valid member of the allowed values
 // before setting the underlying flag value.
-func (fb *Fallback) Set(s string) error {
-	for _, v := range FallbackValues {
+func (ss *StringsFlag) Set(s string) error {
+	for _, v := range ss.Values {
 		if s == v {
-			*fb = Fallback(s)
+			ss.val = s
 			return nil
 		}
 	}
-
 	return errors.New("invalid value")
 }
 
-func (fb *Fallback) String() string {
-	return string(*fb)
+// String returns the set value (if any) of the StringsFlag
+func (ss *StringsFlag) String() string {
+	return ss.val
 }

+ 10 - 7
pkg/flags/proxy_test.go → pkg/flags/strings_test.go

@@ -20,23 +20,26 @@ import (
 	"testing"
 )
 
-func TestProxySet(t *testing.T) {
+func TestStringsSet(t *testing.T) {
 	tests := []struct {
+		vals []string
+
 		val  string
 		pass bool
 	}{
 		// known values
-		{"on", true},
-		{"off", true},
+		{[]string{"abc", "def"}, "abc", true},
+		{[]string{"on", "off", "false"}, "on", true},
 
 		// unrecognized values
-		{"foo", false},
-		{"", false},
+		{[]string{"abc", "def"}, "ghi", false},
+		{[]string{"on", "off"}, "", false},
+		{[]string{}, "asdf", false},
 	}
 
 	for i, tt := range tests {
-		pf := new(Proxy)
-		err := pf.Set(tt.val)
+		sf := NewStringsFlag(tt.vals...)
+		err := sf.Set(tt.val)
 		if tt.pass != (err == nil) {
 			t.Errorf("#%d: want pass=%t, but got err=%v", i, tt.pass, err)
 		}