소스 검색

continue with next KDC on communication failure (#399)

refactor send to KDC with trying subsequent KDC on failure
Jonathan Turner 5 년 전
부모
커밋
260a581c95
3개의 변경된 파일115개의 추가작업 그리고 103개의 파일을 삭제
  1. 29 0
      v8/client/client_integration_test.go
  2. 76 67
      v8/client/network.go
  3. 10 36
      v8/client/passwd.go

+ 29 - 0
v8/client/client_integration_test.go

@@ -253,6 +253,35 @@ func TestClient_NetworkTimeout(t *testing.T) {
 	}
 }
 
+func TestClient_NetworkTryNextKDC(t *testing.T) {
+	test.Integration(t)
+
+	b, _ := hex.DecodeString(testdata.KEYTAB_TESTUSER1_TEST_GOKRB5)
+	kt := keytab.New()
+	kt.Unmarshal(b)
+	c, _ := config.NewFromString(testdata.KRB5_CONF)
+	addr := os.Getenv("TEST_KDC_ADDR")
+	if addr == "" {
+		addr = testdata.KDC_IP_TEST_GOKRB5
+	}
+	// Two out fo three times this should fail the first time.
+	// So will run login twice to expect at least once the first time it will be to a bad KDC
+	c.Realms[0].KDC = []string{testdata.KDC_IP_TEST_GOKRB5_BADADDR + ":88",
+		testdata.KDC_IP_TEST_GOKRB5_BADADDR + ":88",
+		addr + ":" + testdata.KDC_PORT_TEST_GOKRB5,
+	}
+	cl := client.NewWithKeytab("testuser1", "TEST.GOKRB5", kt, c)
+
+	err := cl.Login()
+	if err != nil {
+		t.Fatal("login failed")
+	}
+	err = cl.Login()
+	if err != nil {
+		t.Fatal("login failed")
+	}
+}
+
 func TestClient_GetServiceTicket(t *testing.T) {
 	test.Integration(t)
 

+ 76 - 67
v8/client/network.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"strings"
 	"time"
 
 	"github.com/jcmturner/gokrb5/v8/iana/errorcode"
@@ -67,88 +68,52 @@ func (cl *Client) sendToKDC(b []byte, realm string) ([]byte, error) {
 	return rb, nil
 }
 
-// dialKDCTCP establishes a UDP connection to a KDC.
-func dialKDCUDP(count int, kdcs map[int]string) (*net.UDPConn, error) {
-	i := 1
-	for i <= count {
-		udpAddr, err := net.ResolveUDPAddr("udp", kdcs[i])
-		if err != nil {
-			return nil, fmt.Errorf("error resolving KDC address: %v", err)
-		}
-
-		conn, err := net.DialTimeout("udp", udpAddr.String(), 5*time.Second)
-		if err == nil {
-			if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
-				return nil, err
-			}
-			// conn is guaranteed to be a UDPConn
-			return conn.(*net.UDPConn), nil
-		}
-		i++
-	}
-	return nil, errors.New("error in getting a UDP connection to any of the KDCs")
-}
-
-// dialKDCTCP establishes a TCP connection to a KDC.
-func dialKDCTCP(count int, kdcs map[int]string) (*net.TCPConn, error) {
-	i := 1
-	for i <= count {
-		tcpAddr, err := net.ResolveTCPAddr("tcp", kdcs[i])
-		if err != nil {
-			return nil, fmt.Errorf("error resolving KDC address: %v", err)
-		}
-
-		conn, err := net.DialTimeout("tcp", tcpAddr.String(), 5*time.Second)
-		if err == nil {
-			if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
-				return nil, err
-			}
-			// conn is guaranteed to be a TCPConn
-			return conn.(*net.TCPConn), nil
-		}
-		i++
-	}
-	return nil, errors.New("error in getting a TCP connection to any of the KDCs")
-}
-
 // sendKDCUDP sends bytes to the KDC via UDP.
 func (cl *Client) sendKDCUDP(realm string, b []byte) ([]byte, error) {
 	var r []byte
-	count, kdcs, err := cl.Config.GetKDCs(realm, false)
-	if err != nil {
-		return r, err
-	}
-	conn, err := dialKDCUDP(count, kdcs)
+	_, kdcs, err := cl.Config.GetKDCs(realm, false)
 	if err != nil {
 		return r, err
 	}
-	r, err = cl.sendUDP(conn, b)
+	r, err = dialSendUDP(kdcs, b)
 	if err != nil {
 		return r, err
 	}
 	return checkForKRBError(r)
 }
 
-// sendKDCTCP sends bytes to the KDC via TCP.
-func (cl *Client) sendKDCTCP(realm string, b []byte) ([]byte, error) {
-	var r []byte
-	count, kdcs, err := cl.Config.GetKDCs(realm, true)
-	if err != nil {
-		return r, err
-	}
-	conn, err := dialKDCTCP(count, kdcs)
-	if err != nil {
-		return r, err
-	}
-	rb, err := cl.sendTCP(conn, b)
-	if err != nil {
-		return r, err
+// dialSendUDP establishes a UDP connection to a KDC.
+func dialSendUDP(kdcs map[int]string, b []byte) ([]byte, error) {
+	var errs []string
+	for i := 1; i <= len(kdcs); i++ {
+		udpAddr, err := net.ResolveUDPAddr("udp", kdcs[i])
+		if err != nil {
+			errs = append(errs, fmt.Sprintf("error resolving KDC address: %v", err))
+			continue
+		}
+
+		conn, err := net.DialTimeout("udp", udpAddr.String(), 5*time.Second)
+		if err != nil {
+			errs = append(errs, fmt.Sprintf("error setting dial timeout on connection to %s: %v", kdcs[i], err))
+			continue
+		}
+		if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
+			errs = append(errs, fmt.Sprintf("error setting deadline on connection to %s: %v", kdcs[i], err))
+			continue
+		}
+		// conn is guaranteed to be a UDPConn
+		rb, err := sendUDP(conn.(*net.UDPConn), b)
+		if err != nil {
+			errs = append(errs, fmt.Sprintf("error sneding to %s: %v", kdcs[i], err))
+			continue
+		}
+		return rb, nil
 	}
-	return checkForKRBError(rb)
+	return nil, fmt.Errorf("error sending to a KDC: %s", strings.Join(errs, "; "))
 }
 
 // sendUDP sends bytes to connection over UDP.
-func (cl *Client) sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) {
+func sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) {
 	var r []byte
 	defer conn.Close()
 	_, err := conn.Write(b)
@@ -167,8 +132,52 @@ func (cl *Client) sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) {
 	return r, nil
 }
 
+// sendKDCTCP sends bytes to the KDC via TCP.
+func (cl *Client) sendKDCTCP(realm string, b []byte) ([]byte, error) {
+	var r []byte
+	_, kdcs, err := cl.Config.GetKDCs(realm, true)
+	if err != nil {
+		return r, err
+	}
+	r, err = dialSendTCP(kdcs, b)
+	if err != nil {
+		return r, err
+	}
+	return checkForKRBError(r)
+}
+
+// dialKDCTCP establishes a TCP connection to a KDC.
+func dialSendTCP(kdcs map[int]string, b []byte) ([]byte, error) {
+	var errs []string
+	for i := 1; i <= len(kdcs); i++ {
+		tcpAddr, err := net.ResolveTCPAddr("tcp", kdcs[i])
+		if err != nil {
+			errs = append(errs, fmt.Sprintf("error resolving KDC address: %v", err))
+			continue
+		}
+
+		conn, err := net.DialTimeout("tcp", tcpAddr.String(), 5*time.Second)
+		if err != nil {
+			errs = append(errs, fmt.Sprintf("error setting dial timeout on connection to %s: %v", kdcs[i], err))
+			continue
+		}
+		if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
+			errs = append(errs, fmt.Sprintf("error setting deadline on connection to %s: %v", kdcs[i], err))
+			continue
+		}
+		// conn is guaranteed to be a TCPConn
+		rb, err := sendTCP(conn.(*net.TCPConn), b)
+		if err != nil {
+			errs = append(errs, fmt.Sprintf("error sneding to %s: %v", kdcs[i], err))
+			continue
+		}
+		return rb, nil
+	}
+	return nil, errors.New("error in getting a TCP connection to any of the KDCs")
+}
+
 // sendTCP sends bytes to connection over TCP.
-func (cl *Client) sendTCP(conn *net.TCPConn, b []byte) ([]byte, error) {
+func sendTCP(conn *net.TCPConn, b []byte) ([]byte, error) {
 	defer conn.Close()
 	var r []byte
 	// RFC 4120 7.2.2 specifies the first 4 bytes indicate the length of the message in big endian order.

+ 10 - 36
v8/client/passwd.go

@@ -2,7 +2,6 @@ package client
 
 import (
 	"fmt"
-	"net"
 
 	"github.com/jcmturner/gokrb5/v8/kadmin"
 	"github.com/jcmturner/gokrb5/v8/messages"
@@ -55,46 +54,21 @@ func (cl *Client) sendToKPasswd(msg kadmin.Request) (r kadmin.Reply, err error)
 	if err != nil {
 		return
 	}
-	addr := kps[1]
 	b, err := msg.Marshal()
 	if err != nil {
 		return
 	}
+	var rb []byte
 	if len(b) <= cl.Config.LibDefaults.UDPPreferenceLimit {
-		return cl.sendKPasswdUDP(b, addr)
-	}
-	return cl.sendKPasswdTCP(b, addr)
-}
-
-func (cl *Client) sendKPasswdTCP(b []byte, kadmindAddr string) (r kadmin.Reply, err error) {
-	tcpAddr, err := net.ResolveTCPAddr("tcp", kadmindAddr)
-	if err != nil {
-		return
-	}
-	conn, err := net.DialTCP("tcp", nil, tcpAddr)
-	if err != nil {
-		return
-	}
-	rb, err := cl.sendTCP(conn, b)
-	if err != nil {
-		return
-	}
-	err = r.Unmarshal(rb)
-	return
-}
-
-func (cl *Client) sendKPasswdUDP(b []byte, kadmindAddr string) (r kadmin.Reply, err error) {
-	udpAddr, err := net.ResolveUDPAddr("udp", kadmindAddr)
-	if err != nil {
-		return
-	}
-	conn, err := net.DialUDP("udp", nil, udpAddr)
-	if err != nil {
-		return
-	}
-	rb, err := cl.sendUDP(conn, b)
-	if err != nil {
-		return
+		rb, err = dialSendUDP(kps, b)
+		if err != nil {
+			return
+		}
+	} else {
+		rb, err = dialSendTCP(kps, b)
+		if err != nil {
+			return
+		}
 	}
 	err = r.Unmarshal(rb)
 	return