Browse Source

addRWMutex

wenzuochao 6 năm trước cách đây
mục cha
commit
c6223a3fd6

+ 1 - 1
integration/core_test.go

@@ -229,7 +229,7 @@ func Test_CreateInstanceWithCommonRequestWithPolicy(t *testing.T) {
 	request.TransToAcsRequest()
 	_, err = client.ProcessCommonRequest(request)
 	assert.NotNil(t, err)
-	assert.Contains(t, err.Error(), "This resource type is not supported; please try other resource types.")
+	assert.Contains(t, err.Error(), "user order resource type [classic] not exists in [random]")
 
 	policy := `{
     "Version": "1",

+ 15 - 4
sdk/endpoints/mapping_resolver.go

@@ -17,22 +17,33 @@ package endpoints
 import (
 	"fmt"
 	"strings"
+	"sync"
 )
 
 const keyFormatter = "%s::%s"
 
-var endpointMapping = make(map[string]string)
+type EndpointMapping struct {
+	sync.RWMutex
+	endpoint map[string]string
+}
+
+var endpointMapping = EndpointMapping{endpoint: make(map[string]string)}
 
-// AddEndpointMapping Use product id and region id as key to store the endpoint into inner map
+// AddEndpointMapping use productId and regionId as key to store the endpoint into inner map
+// when using the same productId and regionId as key, the endpoint will be covered.
 func AddEndpointMapping(regionId, productId, endpoint string) (err error) {
 	key := fmt.Sprintf(keyFormatter, strings.ToLower(regionId), strings.ToLower(productId))
-	endpointMapping[key] = endpoint
+	endpointMapping.Lock()
+	endpointMapping.endpoint[key] = endpoint
+	endpointMapping.Unlock()
 	return nil
 }
 
 // GetEndpointFromMap use Product and RegionId as key to find endpoint from inner map
 func GetEndpointFromMap(regionId, productId string) string {
 	key := fmt.Sprintf(keyFormatter, strings.ToLower(regionId), strings.ToLower(productId))
-	endpoint, _ := endpointMapping[key]
+	endpointMapping.RLock()
+	endpoint := endpointMapping.endpoint[key]
+	endpointMapping.RUnlock()
 	return endpoint
 }

+ 37 - 0
sdk/endpoints/mapping_resolver_test.go

@@ -1,6 +1,8 @@
 package endpoints
 
 import (
+	"fmt"
+	"sync"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
@@ -18,3 +20,38 @@ func TestMappingResolver_TryResolve(t *testing.T) {
 	endpoint = GetEndpointFromMap(regionId, productId)
 	assert.Equal(t, "unreachable.aliyuncs.com", endpoint)
 }
+
+func Test_MappingResolveConcurrent(t *testing.T) {
+	current := len(endpointMapping.endpoint)
+	cnt := 50
+	var wg sync.WaitGroup
+	for i := 0; i < cnt; i++ {
+		wg.Add(1)
+		go func(k int) {
+			defer wg.Done()
+			endpoint := fmt.Sprintf("ecs#cn-hangzhou%d", k)
+			for j := 0; j < 50; j++ {
+				err := AddEndpointMapping(fmt.Sprintf("cn-hangzhou%d", k), "ecs", endpoint)
+				assert.Nil(t, err)
+				assert.Equal(t, endpoint, GetEndpointFromMap(fmt.Sprintf("cn-hangzhou%d", k), "ecs"))
+			}
+		}(i)
+	}
+	wg.Wait()
+	assert.Equal(t, (current + cnt), len(endpointMapping.endpoint))
+	// hit cache and concurrent get
+	for i := 0; i < cnt; i++ {
+		wg.Add(1)
+		go func(k int) {
+			defer wg.Done()
+			endpoint := fmt.Sprintf("ecs#cn-hangzhou%d", k)
+			for j := 0; j < cnt; j++ {
+				assert.Equal(t, endpoint, GetEndpointFromMap(fmt.Sprintf("cn-hangzhou%d", k), "ecs"))
+				err := AddEndpointMapping(fmt.Sprintf("cn-hangzhou%d", k), "ecs", endpoint)
+				assert.Nil(t, err)
+			}
+		}(i)
+	}
+	wg.Wait()
+	assert.Equal(t, (current + cnt), len(endpointMapping.endpoint))
+}