Prechádzať zdrojové kódy

Merge pull request #1349 from mrsinham/sasl_client_datarace

[sasl] use a SCRAM client for each connection
Vlad Gorodetsky 6 rokov pred
rodič
commit
2a49b70a5c
5 zmenil súbory, kde vykonal 12 pridanie a 12 odobranie
  1. 1 1
      broker.go
  2. 1 1
      broker_test.go
  3. 4 4
      config.go
  4. 4 4
      config_test.go
  5. 2 2
      examples/sasl_scram_client/main.go

+ 1 - 1
broker.go

@@ -1004,7 +1004,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
 		return err
 	}
 
-	scramClient := b.conf.Net.SASL.SCRAMClient
+	scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc()
 	if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
 		return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
 	}

+ 1 - 1
broker_test.go

@@ -340,7 +340,7 @@ func TestSASLSCRAMSHAXXX(t *testing.T) {
 
 		conf := NewConfig()
 		conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
-		conf.Net.SASL.SCRAMClient = test.scramClient
+		conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient }
 
 		broker.conf = conf
 		dialer := net.Dialer{

+ 4 - 4
config.go

@@ -67,9 +67,9 @@ type Config struct {
 			Password string
 			// authz id used for SASL/SCRAM authentication
 			SCRAMAuthzID string
-			// SCRAMClient is a user provided implementation of a SCRAM
+			// SCRAMClientGeneratorFunc is a generator of a user provided implementation of a SCRAM
 			// client used to perform the SCRAM exchange with the server.
-			SCRAMClient SCRAMClient
+			SCRAMClientGeneratorFunc func() SCRAMClient
 			// TokenProvider is a user-defined callback for generating
 			// access tokens for SASL/OAUTHBEARER auth. See the
 			// AccessTokenProvider interface docs for proper implementation
@@ -517,8 +517,8 @@ func (c *Config) Validate() error {
 			if c.Net.SASL.Password == "" {
 				return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
 			}
-			if c.Net.SASL.SCRAMClient == nil {
-				return ConfigurationError("A SCRAMClient instance must be provided to Net.SASL.SCRAMClient")
+			if c.Net.SASL.SCRAMClientGeneratorFunc == nil {
+				return ConfigurationError("A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc")
 			}
 		default:
 			msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s` and `%s`",

+ 4 - 4
config_test.go

@@ -103,20 +103,20 @@ func TestNetConfigValidates(t *testing.T) {
 			func(cfg *Config) {
 				cfg.Net.SASL.Enable = true
 				cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA256
-				cfg.Net.SASL.SCRAMClient = nil
+				cfg.Net.SASL.SCRAMClientGeneratorFunc = nil
 				cfg.Net.SASL.User = "user"
 				cfg.Net.SASL.Password = "stong_password"
 			},
-			"A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"},
+			"A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc"},
 		{"SASL.Mechanism SCRAM-SHA-512 - Missing SCRAM client",
 			func(cfg *Config) {
 				cfg.Net.SASL.Enable = true
 				cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
-				cfg.Net.SASL.SCRAMClient = nil
+				cfg.Net.SASL.SCRAMClientGeneratorFunc = nil
 				cfg.Net.SASL.User = "user"
 				cfg.Net.SASL.Password = "stong_password"
 			},
-			"A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"},
+			"A SCRAMClientGeneratorFunc function must be provided to Net.SASL.SCRAMClientGeneratorFunc"},
 	}
 
 	for i, test := range tests {

+ 2 - 2
examples/sasl_scram_client/main.go

@@ -86,10 +86,10 @@ func main() {
 	conf.Net.SASL.Password = *passwd
 	conf.Net.SASL.Handshake = true
 	if *algorithm == "sha512" {
-		conf.Net.SASL.SCRAMClient = &XDGSCRAMClient{HashGeneratorFcn: SHA512}
+		conf.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { return &XDGSCRAMClient{HashGeneratorFcn: SHA512} }
 		conf.Net.SASL.Mechanism = sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA512)
 	} else if *algorithm == "sha256" {
-		conf.Net.SASL.SCRAMClient = &XDGSCRAMClient{HashGeneratorFcn: SHA256}
+		conf.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { return &XDGSCRAMClient{HashGeneratorFcn: SHA256} }
 		conf.Net.SASL.Mechanism = sarama.SASLMechanism(sarama.SASLTypeSCRAMSHA256)
 
 	} else {