Ver código fonte

fix concurrent map write bug

dehui.kdh 7 anos atrás
pai
commit
bd4a026ad7

+ 10 - 0
sdk/endpoints/location_resolver.go

@@ -22,14 +22,17 @@ import (
 )
 
 const (
+	// EndpointCacheExpireTime ...
 	EndpointCacheExpireTime = 3600 //Seconds
 )
 
+// Cache caches endpoint for specific product and region
 type Cache struct {
 	sync.RWMutex
 	cache map[string]interface{}
 }
 
+// Get ...
 func (c *Cache) Get(k string) (v interface{}) {
 	c.RLock()
 	v = c.cache[k]
@@ -37,6 +40,7 @@ func (c *Cache) Get(k string) (v interface{}) {
 	return
 }
 
+// Set ...
 func (c *Cache) Set(k string, v interface{}) {
 	c.Lock()
 	c.cache[k] = v
@@ -46,6 +50,7 @@ func (c *Cache) Set(k string, v interface{}) {
 var lastClearTimePerProduct = &Cache{cache: make(map[string]interface{})}
 var endpointCache = &Cache{cache: make(map[string]interface{})}
 
+// LocationResolver ...
 type LocationResolver struct {
 }
 
@@ -54,6 +59,7 @@ func (resolver *LocationResolver) GetName() (name string) {
 	return
 }
 
+// TryResolve resolves endpoint giving product and region
 func (resolver *LocationResolver) TryResolve(param *ResolveParam) (endpoint string, support bool, err error) {
 	if len(param.LocationProduct) <= 0 {
 		support = false
@@ -126,6 +132,7 @@ func (resolver *LocationResolver) TryResolve(param *ResolveParam) (endpoint stri
 	return
 }
 
+// CheckCacheIsExpire ...
 func CheckCacheIsExpire(cacheKey string) bool {
 	lastClearTime, ok := lastClearTimePerProduct.Get(cacheKey).(int64)
 	if !ok {
@@ -146,16 +153,19 @@ func CheckCacheIsExpire(cacheKey string) bool {
 	return false
 }
 
+// GetEndpointResponse ...
 type GetEndpointResponse struct {
 	Endpoints *EndpointsObj
 	RequestId string
 	Success   bool
 }
 
+// EndpointsObj ...
 type EndpointsObj struct {
 	Endpoint []EndpointObj
 }
 
+// EndpointObj ...
 type EndpointObj struct {
 	// Protocols   map[string]string
 	Type        string

+ 99 - 65
sdk/endpoints/location_resolver_test.go

@@ -2,16 +2,17 @@ package endpoints
 
 import (
 	"bytes"
+	"fmt"
 	"io/ioutil"
 	"net/http"
 	"strconv"
+	"sync"
 	"testing"
 
-	"github.com/stretchr/testify/assert"
-
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/errors"
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/responses"
+	"github.com/stretchr/testify/assert"
 )
 
 func TestLocationResolver_GetName(t *testing.T) {
@@ -19,6 +20,7 @@ func TestLocationResolver_GetName(t *testing.T) {
 	assert.Equal(t, "location resolver", resolver.GetName())
 }
 
+// cases from later commit
 func TestLocationResolver_TryResolve_EmptyLocationProduct(t *testing.T) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{}
@@ -27,7 +29,6 @@ func TestLocationResolver_TryResolve_EmptyLocationProduct(t *testing.T) {
 	assert.Equal(t, "", endpoint)
 	assert.Equal(t, false, support)
 }
-
 func TestLocationResolver_TryResolve_LocationWithError(t *testing.T) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{
@@ -60,43 +61,65 @@ func makeHTTPResponse(statusCode int, content string) (res *http.Response) {
 	return
 }
 
-func TestLocationResolver_TryResolve_LocationWith404(t *testing.T) {
+func TestLocationResolver_TryResolve_Location_With_Endpoint2(t *testing.T) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{
-		LocationProduct: "ecs",
+		LocationProduct: "ecs3",
 		RegionId:        "cn-hangzhou",
-		Product:         "ecs",
+		Product:         "ecs3",
 		CommonApi: func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
 			response = responses.NewCommonResponse()
-			responses.Unmarshal(response, makeHTTPResponse(404, "content"), "JSON")
+			responses.Unmarshal(response, makeHTTPResponse(200, `{
+  "Endpoints":{
+    "Endpoint":[
+      {
+        "Protocols":{
+          "Protocols":["HTTP","HTTPS"]
+        },
+        "Type":"openAPI",
+        "Namespace":"26842",
+        "Id":"cn-beijing",
+        "SerivceCode":"ecs",
+        "Endpoint":"ecs-cn-hangzhou.aliyuncs.com"
+      }
+    ]
+  },
+  "RequestId":"B3B36D8E-5029-42E3-B1FB-9B687F7591DA",
+  "Success":true
+}`), "JSON")
 			return
 		},
 	}
 	endpoint, support, err := resolver.TryResolve(resolveParam)
-	assert.Equal(t, "", endpoint)
-	assert.Equal(t, false, support)
+	assert.Equal(t, "ecs-cn-hangzhou.aliyuncs.com", endpoint)
+	assert.Equal(t, true, support)
+	assert.Nil(t, err)
+	// hit the cache
+	endpoint, support, err = resolver.TryResolve(resolveParam)
+	assert.Equal(t, "ecs-cn-hangzhou.aliyuncs.com", endpoint)
+	assert.Equal(t, true, support)
 	assert.Nil(t, err)
 }
 
-func TestLocationResolver_TryResolve_LocationWith200InvalidJSON(t *testing.T) {
+func TestLocationResolver_TryResolve_Location_With_EmptyEndpoint(t *testing.T) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{
-		LocationProduct: "ecs",
+		LocationProduct: "ecs2",
 		RegionId:        "cn-hangzhou",
-		Product:         "ecs",
+		Product:         "ecs2",
 		CommonApi: func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
 			response = responses.NewCommonResponse()
-			responses.Unmarshal(response, makeHTTPResponse(200, "content"), "JSON")
+			responses.Unmarshal(response, makeHTTPResponse(200, `{"Success":true,"RequestId":"request id","Endpoints":{"Endpoint":[{"Endpoint":""}]}}`), "JSON")
 			return
 		},
 	}
 	endpoint, support, err := resolver.TryResolve(resolveParam)
 	assert.Equal(t, "", endpoint)
 	assert.Equal(t, false, support)
-	assert.Equal(t, "invalid character 'c' looking for beginning of value", err.Error())
+	assert.Nil(t, err)
 }
 
-func TestLocationResolver_TryResolve_LocationWith200ValidJSON(t *testing.T) {
+func TestLocationResolver_TryResolve_LocationWith404(t *testing.T) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{
 		LocationProduct: "ecs",
@@ -104,7 +127,7 @@ func TestLocationResolver_TryResolve_LocationWith200ValidJSON(t *testing.T) {
 		Product:         "ecs",
 		CommonApi: func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
 			response = responses.NewCommonResponse()
-			responses.Unmarshal(response, makeHTTPResponse(200, `{"Code":"Success","RequestId":"request id"}`), "JSON")
+			responses.Unmarshal(response, makeHTTPResponse(404, "content"), "JSON")
 			return
 		},
 	}
@@ -112,10 +135,9 @@ func TestLocationResolver_TryResolve_LocationWith200ValidJSON(t *testing.T) {
 	assert.Equal(t, "", endpoint)
 	assert.Equal(t, false, support)
 	assert.Nil(t, err)
-	// assert.Equal(t, "json: cannot unmarshal array into Go struct field GetEndpointResponse.Endpoints of type endpoints.EndpointsObj", err.Error())
 }
 
-func TestLocationResolver_TryResolve_LocationWith200(t *testing.T) {
+func TestLocationResolver_TryResolve_LocationWith200InvalidJSON(t *testing.T) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{
 		LocationProduct: "ecs",
@@ -123,17 +145,17 @@ func TestLocationResolver_TryResolve_LocationWith200(t *testing.T) {
 		Product:         "ecs",
 		CommonApi: func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
 			response = responses.NewCommonResponse()
-			responses.Unmarshal(response, makeHTTPResponse(200, `{"Success":true,"RequestId":"request id","Endpoints":{"Endpoint":[]}}`), "JSON")
+			responses.Unmarshal(response, makeHTTPResponse(200, "content"), "JSON")
 			return
 		},
 	}
 	endpoint, support, err := resolver.TryResolve(resolveParam)
 	assert.Equal(t, "", endpoint)
 	assert.Equal(t, false, support)
-	assert.Nil(t, err)
+	assert.Equal(t, "invalid character 'c' looking for beginning of value", err.Error())
 }
 
-func TestLocationResolver_TryResolve_Location_With_Endpoint(t *testing.T) {
+func TestLocationResolver_TryResolve_LocationWith200ValidJSON(t *testing.T) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{
 		LocationProduct: "ecs",
@@ -141,75 +163,87 @@ func TestLocationResolver_TryResolve_Location_With_Endpoint(t *testing.T) {
 		Product:         "ecs",
 		CommonApi: func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
 			response = responses.NewCommonResponse()
-			responses.Unmarshal(response, makeHTTPResponse(200, `{"Success":true,"RequestId":"request id","Endpoints":{"Endpoint":[{"Endpoint":"domain.com"}]}}`), "JSON")
+			responses.Unmarshal(response, makeHTTPResponse(200, `{"Code":"Success","RequestId":"request id"}`), "JSON")
 			return
 		},
 	}
 	endpoint, support, err := resolver.TryResolve(resolveParam)
-	assert.Equal(t, "domain.com", endpoint)
-	assert.Equal(t, true, support)
-	assert.Nil(t, err)
-	// hit the cache
-	endpoint, support, err = resolver.TryResolve(resolveParam)
-	assert.Equal(t, "domain.com", endpoint)
-	assert.Equal(t, true, support)
+	assert.Equal(t, "", endpoint)
+	assert.Equal(t, false, support)
 	assert.Nil(t, err)
+	// assert.Equal(t, "json: cannot unmarshal array into Go struct field GetEndpointResponse.Endpoints of type endpoints.EndpointsObj", err.Error())
 }
 
-func TestLocationResolver_TryResolve_Location_With_Endpoint2(t *testing.T) {
+func TestLocationResolver_TryResolve_LocationWith200(t *testing.T) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{
-		LocationProduct: "ecs3",
+		LocationProduct: "ecs",
 		RegionId:        "cn-hangzhou",
-		Product:         "ecs3",
+		Product:         "ecs",
 		CommonApi: func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
 			response = responses.NewCommonResponse()
-			responses.Unmarshal(response, makeHTTPResponse(200, `{
-  "Endpoints":{
-    "Endpoint":[
-      {
-        "Protocols":{
-          "Protocols":["HTTP","HTTPS"]
-        },
-        "Type":"openAPI",
-        "Namespace":"26842",
-        "Id":"cn-beijing",
-        "SerivceCode":"ecs",
-        "Endpoint":"ecs-cn-hangzhou.aliyuncs.com"
-      }
-    ]
-  },
-  "RequestId":"B3B36D8E-5029-42E3-B1FB-9B687F7591DA",
-  "Success":true
-}`), "JSON")
+			responses.Unmarshal(response, makeHTTPResponse(200, `{"Success":true,"RequestId":"request id","Endpoints":{"Endpoint":[]}}`), "JSON")
 			return
 		},
 	}
 	endpoint, support, err := resolver.TryResolve(resolveParam)
-	assert.Equal(t, "ecs-cn-hangzhou.aliyuncs.com", endpoint)
-	assert.Equal(t, true, support)
-	assert.Nil(t, err)
-	// hit the cache
-	endpoint, support, err = resolver.TryResolve(resolveParam)
-	assert.Equal(t, "ecs-cn-hangzhou.aliyuncs.com", endpoint)
-	assert.Equal(t, true, support)
+	assert.Equal(t, "", endpoint)
+	assert.Equal(t, false, support)
 	assert.Nil(t, err)
 }
 
-func TestLocationResolver_TryResolve_Location_With_EmptyEndpoint(t *testing.T) {
+func resovleSucc(i int) (ep string, isSupport bool, err error) {
 	resolver := &LocationResolver{}
 	resolveParam := &ResolveParam{
-		LocationProduct: "ecs2",
-		RegionId:        "cn-hangzhou",
-		Product:         "ecs2",
+		LocationProduct: "ecs",
+		RegionId:        fmt.Sprintf("cn-hangzhou%d", i),
+		Product:         "ecs",
 		CommonApi: func(request *requests.CommonRequest) (response *responses.CommonResponse, err error) {
 			response = responses.NewCommonResponse()
-			responses.Unmarshal(response, makeHTTPResponse(200, `{"Success":true,"RequestId":"request id","Endpoints":{"Endpoint":[{"Endpoint":""}]}}`), "JSON")
+			responses.Unmarshal(response, makeHTTPResponse(200, `{"Success":true,"RequestId":"request id","Endpoints":{"Endpoint":[{"Endpoint":"domain.com"}]}}`), "JSON")
 			return
 		},
 	}
 	endpoint, support, err := resolver.TryResolve(resolveParam)
-	assert.Equal(t, "", endpoint)
-	assert.Equal(t, false, support)
-	assert.Nil(t, err)
+	return endpoint, support, err
+}
+
+// concurrent cases
+func TestResolveConcurrent(t *testing.T) {
+	current := len(endpointCache.cache)
+	cnt := 50
+	var wg sync.WaitGroup
+	for i := 0; i < cnt; i++ {
+		wg.Add(1)
+		go func(k int) {
+			defer wg.Done()
+			cachedKey := fmt.Sprintf("ecs#cn-hangzhou%d", k)
+			for j := 0; j < 50; j++ {
+				endpoint, support, err := resovleSucc(k)
+				assert.Equal(t, "domain.com", endpointCache.Get(cachedKey))
+				assert.Equal(t, "domain.com", endpoint)
+				assert.Equal(t, true, support)
+				assert.Nil(t, err)
+			}
+		}(i)
+	}
+	wg.Wait()
+	assert.Equal(t, (current + cnt), len(endpointCache.cache))
+	// hit cache and concurrent get
+	for i := 0; i < cnt; i++ {
+		wg.Add(1)
+		go func(k int) {
+			defer wg.Done()
+			cachedKey := fmt.Sprintf("ecs#cn-hangzhou%d", k)
+			for j := 0; j < cnt; j++ {
+				assert.Equal(t, "domain.com", endpointCache.Get(cachedKey))
+				endpoint, support, err := resovleSucc(k)
+				assert.Equal(t, "domain.com", endpoint)
+				assert.Equal(t, true, support)
+				assert.Nil(t, err)
+			}
+		}(i)
+		wg.Wait()
+	}
+	assert.Equal(t, (current + cnt), len(endpointCache.cache))
 }