Browse Source

main: add address validation for bind-addr flag

Jonathan Boulle 11 years ago
parent
commit
f0789e7349
2 changed files with 62 additions and 4 deletions
  1. 25 4
      main.go
  2. 37 0
      main_test.go

+ 25 - 4
main.go

@@ -5,6 +5,7 @@ import (
 	"flag"
 	"fmt"
 	"log"
+	"net"
 	"net/http"
 	"os"
 	"path"
@@ -200,12 +201,15 @@ func startProxy() {
 type Addrs []string
 
 // Set parses a command line set of listen addresses, formatted like:
-// 127.0.0.1:7001,unix:///var/run/etcd.sock,10.1.1.1:8080
+// 127.0.0.1:7001,10.1.1.2:80
 func (as *Addrs) Set(s string) error {
-	// TODO(jonboulle): validate things.
 	parsed := make([]string, 0)
-	for _, a := range strings.Split(s, ",") {
-		parsed = append(parsed, strings.TrimSpace(a))
+	for _, in := range strings.Split(s, ",") {
+		a := strings.TrimSpace(in)
+		if err := validateAddr(a); err != nil {
+			return err
+		}
+		parsed = append(parsed, a)
 	}
 	if len(parsed) == 0 {
 		return errors.New("no valid addresses given!")
@@ -218,6 +222,23 @@ func (as *Addrs) String() string {
 	return strings.Join(*as, ",")
 }
 
+// validateAddr ensures that the provided string is a valid address. Valid
+// addresses are of the form IP:port.
+// Returns an error if the address is invalid, else nil.
+func validateAddr(s string) error {
+	parts := strings.SplitN(s, ":", 2)
+	if len(parts) != 2 {
+		return errors.New("bad format in address specification")
+	}
+	if net.ParseIP(parts[0]) == nil {
+		return errors.New("bad IP in address specification")
+	}
+	if _, err := strconv.Atoi(parts[1]); err != nil {
+		return errors.New("bad port in address specification")
+	}
+	return nil
+}
+
 // ProxyFlag implements the flag.Value interface.
 type ProxyFlag string
 

+ 37 - 0
main_test.go

@@ -64,3 +64,40 @@ func TestProxyFlagSet(t *testing.T) {
 		}
 	}
 }
+
+func TestBadValidateAddr(t *testing.T) {
+	tests := []string{
+		// bad IP specification
+		":4001",
+		"127.0:8080",
+		"123:456",
+		// bad port specification
+		"127.0.0.1:foo",
+		"127.0.0.1:",
+		// unix sockets not supported
+		"unix://",
+		"unix://tmp/etcd.sock",
+		// bad strings
+		"somewhere",
+		"234#$",
+		"file://foo/bar",
+		"http://hello",
+	}
+	for i, in := range tests {
+		if err := validateAddr(in); err == nil {
+			t.Errorf(`#%d: unexpected nil error for in=%q`, i, in)
+		}
+	}
+}
+
+func TestValidateAddr(t *testing.T) {
+	tests := []string{
+		"1.2.3.4:8080",
+		"10.1.1.1:80",
+	}
+	for i, in := range tests {
+		if err := validateAddr(in); err != nil {
+			t.Errorf("#%d: err=%v, want nil for in=%q", i, err, in)
+		}
+	}
+}