Browse Source

Merge pull request #2341 from yichengq/326

migrate/starter: fix flag parsing
Yicheng Qin 10 years ago
parent
commit
4228c703a7
2 changed files with 80 additions and 3 deletions
  1. 10 3
      migrate/starter/starter.go
  2. 70 0
      migrate/starter/starter_test.go

+ 10 - 3
migrate/starter/starter.go

@@ -353,7 +353,8 @@ func newDefaultClient(tls *TLSInfo) (*http.Client, error) {
 }
 
 type value struct {
-	s string
+	isBoolFlag bool
+	s          string
 }
 
 func (v *value) String() string { return v.s }
@@ -363,14 +364,20 @@ func (v *value) Set(s string) error {
 	return nil
 }
 
-func (v *value) IsBoolFlag() bool { return true }
+func (v *value) IsBoolFlag() bool { return v.isBoolFlag }
+
+type boolFlag interface {
+	flag.Value
+	IsBoolFlag() bool
+}
 
 // parseConfig parses out the input config from cmdline arguments and
 // environment variables.
 func parseConfig(args []string) (*flag.FlagSet, error) {
 	fs := flag.NewFlagSet("full flagset", flag.ContinueOnError)
 	etcdmain.NewConfig().VisitAll(func(f *flag.Flag) {
-		fs.Var(&value{}, f.Name, "")
+		_, isBoolFlag := f.Value.(boolFlag)
+		fs.Var(&value{isBoolFlag: isBoolFlag}, f.Name, "")
 	})
 	if err := fs.Parse(args); err != nil {
 		return nil, err

+ 70 - 0
migrate/starter/starter_test.go

@@ -0,0 +1,70 @@
+// Copyright 2015 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package starter
+
+import (
+	"flag"
+	"reflect"
+	"testing"
+)
+
+func TestParseConfig(t *testing.T) {
+	tests := []struct {
+		args  []string
+		wvals map[string]string
+	}{
+		{
+			[]string{"--name", "etcd", "--data-dir", "dir"},
+			map[string]string{
+				"name":     "etcd",
+				"data-dir": "dir",
+			},
+		},
+		{
+			[]string{"--name=etcd", "--data-dir=dir"},
+			map[string]string{
+				"name":     "etcd",
+				"data-dir": "dir",
+			},
+		},
+		{
+			[]string{"--version", "--name", "etcd"},
+			map[string]string{
+				"version": "true",
+				"name":    "etcd",
+			},
+		},
+		{
+			[]string{"--version=true", "--name", "etcd"},
+			map[string]string{
+				"version": "true",
+				"name":    "etcd",
+			},
+		},
+	}
+	for i, tt := range tests {
+		fs, err := parseConfig(tt.args)
+		if err != nil {
+			t.Fatalf("#%d: unexpected parseConfig error: %v", i, err)
+		}
+		vals := make(map[string]string)
+		fs.Visit(func(f *flag.Flag) {
+			vals[f.Name] = f.Value.String()
+		})
+		if !reflect.DeepEqual(vals, tt.wvals) {
+			t.Errorf("#%d: vals = %+v, want %+v", i, vals, tt.wvals)
+		}
+	}
+}