wenzuochao 6 роки тому
батько
коміт
53f19b3c6b
5 змінених файлів з 180 додано та 21 видалено
  1. 23 0
      integration/core_test.go
  2. 79 1
      sdk/client.go
  3. 31 10
      sdk/client_test.go
  4. 28 5
      sdk/requests/acs_request.go
  5. 19 5
      sdk/requests/acs_request_test.go

+ 23 - 0
integration/core_test.go

@@ -3,6 +3,7 @@ package integration
 import (
 	"os"
 	"testing"
+	"time"
 
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk"
 
@@ -138,6 +139,28 @@ func Test_DescribeClusterDetailWithCommonRequestWithROAWithHTTPS(t *testing.T) {
 	assert.Contains(t, err.Error(), "Request url is invalid")
 }
 
+func Test_DescribeClusterDetailWithCommonRequestWithTimeout(t *testing.T) {
+	client, err := sdk.NewClientWithAccessKey(os.Getenv("REGION_ID"), os.Getenv("ACCESS_KEY_ID"), os.Getenv("ACCESS_KEY_SECRET"))
+	assert.Nil(t, err)
+	request := requests.NewCommonRequest()
+	request.Domain = "cs.aliyuncs.com"
+	request.Version = "2015-12-15"
+	request.SetScheme("HTTPS")
+	request.PathPattern = "/clusters/[ClusterId]"
+	request.QueryParams["RegionId"] = os.Getenv("REGION_ID")
+	request.ReadTimeout = 1 * time.Millisecond
+	request.ConnectTimeout = 1 * time.Millisecond
+	request.TransToAcsRequest()
+	_, err = client.ProcessCommonRequest(request)
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "Connect timeout. Please set a valid ConnectTimeout.")
+
+	request.ConnectTimeout = 1 * time.Second
+	_, err = client.ProcessCommonRequest(request)
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "Read timeout. Please set a valid ReadTimeout.")
+}
+
 func Test_CreateInstanceWithCommonRequestWithPolicy(t *testing.T) {
 	err := createAttachPolicyToRole()
 	assert.Nil(t, err)

+ 79 - 1
sdk/client.go

@@ -15,12 +15,15 @@
 package sdk
 
 import (
+	"context"
 	"fmt"
+	"net"
 	"net/http"
 	"runtime"
 	"strconv"
 	"strings"
 	"sync"
+	"time"
 
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/auth"
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/auth/credentials"
@@ -39,6 +42,8 @@ func init() {
 
 // Version this value will be replaced while build: -ldflags="-X sdk.version=x.x.x"
 var Version = "0.0.1"
+var defaultConnectTimeout = 10 * time.Second
+var defaultReadTimeout = 5 * time.Second
 
 var DefaultUserAgent = fmt.Sprintf("AlibabaCloud (%s; %s) Golang/%s Core/%s", runtime.GOOS, runtime.GOARCH, strings.Trim(runtime.Version(), "go"), Version)
 
@@ -54,6 +59,8 @@ type Client struct {
 	signer         auth.Signer
 	httpClient     *http.Client
 	asyncTaskQueue chan func()
+	readTimeout    time.Duration
+	connectTimeout time.Duration
 
 	debug     bool
 	isRunning bool
@@ -89,6 +96,22 @@ func (client *Client) InitWithOptions(regionId string, config *Config, credentia
 	return
 }
 
+func (client *Client) SetReadTimeout(readTimeout time.Duration) {
+	client.readTimeout = readTimeout
+}
+
+func (client *Client) SetConnectTimeout(connectTimeout time.Duration) {
+	client.connectTimeout = connectTimeout
+}
+
+func (client *Client) GetReadTimeout() time.Duration {
+	return client.readTimeout
+}
+
+func (client *Client) GetConnectTimeout() time.Duration {
+	return client.connectTimeout
+}
+
 // EnableAsync enable the async task queue
 func (client *Client) EnableAsync(routinePoolSize, maxTaskQueueSize int) {
 	client.asyncTaskQueue = make(chan func(), maxTaskQueueSize)
@@ -272,11 +295,59 @@ func (client *Client) BuildRequestWithSigner(request requests.AcsRequest, signer
 	return
 }
 
+func (client *Client) getTimeout(request requests.AcsRequest) (time.Duration, time.Duration) {
+	readTimeout := defaultReadTimeout
+	connectTimeout := defaultConnectTimeout
+
+	reqReadTimeout := request.GetReadTimeout()
+	reqConnectTimeout := request.GetConnectTimeout()
+	if reqReadTimeout != 0*time.Millisecond {
+		readTimeout = reqReadTimeout
+	} else if client.readTimeout != 0*time.Millisecond {
+		readTimeout = client.readTimeout
+	}
+
+	if reqConnectTimeout != 0*time.Millisecond {
+		connectTimeout = reqConnectTimeout
+	} else if client.connectTimeout != 0*time.Millisecond {
+		connectTimeout = client.connectTimeout
+	}
+	return readTimeout, connectTimeout
+}
+
+func Timeout(connectTimeout, readTimeout time.Duration) func(cxt context.Context, net, addr string) (c net.Conn, err error) {
+	return func(ctx context.Context, network, address string) (net.Conn, error) {
+		conn, err := (&net.Dialer{
+			Timeout:   connectTimeout,
+			KeepAlive: 0 * time.Second,
+			DualStack: true,
+		}).DialContext(ctx, network, address)
+
+		if err == nil {
+			conn.SetDeadline(time.Now().Add(readTimeout))
+		}
+
+		return conn, err
+	}
+}
+func (client *Client) setTimeout(request requests.AcsRequest) {
+	readTimeout, connectTimeout := client.getTimeout(request)
+	if trans, ok := client.httpClient.Transport.(*http.Transport); ok && trans != nil {
+		trans.DialContext = Timeout(connectTimeout, readTimeout)
+		client.httpClient.Transport = trans
+	} else {
+		client.httpClient.Transport = &http.Transport{
+			DialContext: Timeout(connectTimeout, readTimeout),
+		}
+	}
+}
 func (client *Client) DoActionWithSigner(request requests.AcsRequest, response responses.AcsResponse, signer auth.Signer) (err error) {
 	httpRequest, err := client.buildRequestWithSigner(request, signer)
 	if err != nil {
 		return
 	}
+	client.setTimeout(request)
+
 	var httpResponse *http.Response
 	for retryTimes := 0; retryTimes <= client.config.MaxRetryTime; retryTimes++ {
 		debug("> %s %s %s", httpRequest.Method, httpRequest.URL.RequestURI(), httpRequest.Proto)
@@ -299,13 +370,19 @@ func (client *Client) DoActionWithSigner(request requests.AcsRequest, response r
 				return
 			} else if retryTimes >= client.config.MaxRetryTime {
 				// timeout but reached the max retry times, return
-				timeoutErrorMsg := fmt.Sprintf(errors.TimeoutErrorMessage, strconv.Itoa(retryTimes+1), strconv.Itoa(retryTimes+1))
+				var timeoutErrorMsg string
+				if strings.Contains(err.Error(), "read tcp") {
+					timeoutErrorMsg = fmt.Sprintf(errors.TimeoutErrorMessage, strconv.Itoa(retryTimes+1), strconv.Itoa(retryTimes+1)) + " Read timeout. Please set a valid ReadTimeout."
+				} else {
+					timeoutErrorMsg = fmt.Sprintf(errors.TimeoutErrorMessage, strconv.Itoa(retryTimes+1), strconv.Itoa(retryTimes+1)) + " Connect timeout. Please set a valid ConnectTimeout."
+				}
 				err = errors.NewClientError(errors.TimeoutErrorCode, timeoutErrorMsg, err)
 				return
 			}
 		}
 		//  if status code >= 500 or timeout, will trigger retry
 		if client.config.AutoRetry && (err != nil || isServerError(httpResponse)) {
+			client.setTimeout(request)
 			// rewrite signatureNonce and signature
 			httpRequest, err = client.buildRequestWithSigner(request, signer)
 			// buildHttpRequest(request, finalSigner, regionId)
@@ -316,6 +393,7 @@ func (client *Client) DoActionWithSigner(request requests.AcsRequest, response r
 		}
 		break
 	}
+
 	err = responses.Unmarshal(response, httpResponse, request.GetAcceptFormat())
 	// wrap server errors
 	if serverErr, ok := err.(*errors.ServerError); ok {

+ 31 - 10
sdk/client_test.go

@@ -175,17 +175,38 @@ func Test_DoAction_Timeout(t *testing.T) {
 	request.QueryParams["PageSize"] = "30"
 	request.TransToAcsRequest()
 	response := responses.NewCommonResponse()
-	origTestHookDo := hookDo
-	defer func() { hookDo = origTestHookDo }()
-	hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
-		return func(req *http.Request) (*http.Response, error) {
-			return mockResponse(200, "")
-		}
-	}
 	err = client.DoAction(request, response)
-	assert.Nil(t, err)
-	assert.Equal(t, 200, response.GetHttpStatus())
-	assert.Equal(t, "", response.GetHttpContentString())
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "Specified access key is not found.")
+
+	client.SetReadTimeout(1 * time.Millisecond)
+	assert.Equal(t, 1*time.Millisecond, client.GetReadTimeout())
+	err = client.DoAction(request, response)
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "Read timeout. Please set a valid ReadTimeout.")
+
+	client.SetConnectTimeout(1 * time.Millisecond)
+	assert.Equal(t, 1*time.Millisecond, client.GetConnectTimeout())
+	err = client.DoAction(request, response)
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "Connect timeout. Please set a valid ConnectTimeout.")
+
+	client.SetReadTimeout(10 * time.Second)
+	client.SetConnectTimeout(10 * time.Second)
+	err = client.DoAction(request, response)
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "Specified access key is not found.")
+
+	request.SetReadTimeout(1 * time.Millisecond)
+	err = client.DoAction(request, response)
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "Read timeout. Please set a valid ReadTimeout.")
+
+	request.SetConnectTimeout(1 * time.Millisecond)
+	err = client.DoAction(request, response)
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "Connect timeout. Please set a valid ConnectTimeout.")
+
 	client.Shutdown()
 	assert.Equal(t, false, client.isRunning)
 }

+ 28 - 5
sdk/requests/acs_request.go

@@ -20,6 +20,7 @@ import (
 	"reflect"
 	"strconv"
 	"strings"
+	"time"
 
 	"github.com/aliyun/alibaba-cloud-sdk-go/sdk/errors"
 )
@@ -72,6 +73,10 @@ type AcsRequest interface {
 	GetAcceptFormat() string
 	GetLocationServiceCode() string
 	GetLocationEndpointType() string
+	GetReadTimeout() time.Duration
+	GetConnectTimeout() time.Duration
+	SetReadTimeout(readTimeout time.Duration)
+	SetConnectTimeout(connectTimeout time.Duration)
 
 	GetUserAgent() map[string]string
 
@@ -92,11 +97,13 @@ type AcsRequest interface {
 
 // base class
 type baseRequest struct {
-	Scheme   string
-	Method   string
-	Domain   string
-	Port     string
-	RegionId string
+	Scheme         string
+	Method         string
+	Domain         string
+	Port           string
+	RegionId       string
+	ReadTimeout    time.Duration
+	ConnectTimeout time.Duration
 
 	userAgent map[string]string
 	product   string
@@ -127,6 +134,22 @@ func (request *baseRequest) GetFormParams() map[string]string {
 	return request.FormParams
 }
 
+func (request *baseRequest) GetReadTimeout() time.Duration {
+	return request.ReadTimeout
+}
+
+func (request *baseRequest) GetConnectTimeout() time.Duration {
+	return request.ConnectTimeout
+}
+
+func (request *baseRequest) SetReadTimeout(readTimeout time.Duration) {
+	request.ReadTimeout = readTimeout
+}
+
+func (request *baseRequest) SetConnectTimeout(connectTimeout time.Duration) {
+	request.ConnectTimeout = connectTimeout
+}
+
 func (request *baseRequest) GetContent() []byte {
 	return request.Content
 }

+ 19 - 5
sdk/requests/acs_request_test.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"io"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 )
@@ -85,12 +86,25 @@ func Test_AcsRequest(t *testing.T) {
 	r.SetScheme("HTTPS")
 	assert.Equal(t, "HTTPS", r.GetScheme())
 
+	// GetReadTimeout
+	assert.Equal(t, 0*time.Second, r.GetReadTimeout())
+	r.SetReadTimeout(5 * time.Second)
+	assert.Equal(t, 5*time.Second, r.GetReadTimeout())
+
+	// GetConnectTimeout
+	assert.Equal(t, 0*time.Second, r.GetConnectTimeout())
+	r.SetConnectTimeout(5 * time.Second)
+	assert.Equal(t, 5*time.Second, r.GetConnectTimeout())
+
 	// GetPort
 	assert.Equal(t, "", r.GetPort())
 
 	// GetUserAgent
 	r.AppendUserAgent("cli", "1.01")
 	assert.Equal(t, "1.01", r.GetUserAgent()["cli"])
+	// GetUserAgent
+	r.AppendUserAgent("cli", "2.02")
+	assert.Equal(t, "2.02", r.GetUserAgent()["cli"])
 	// Content
 	assert.Equal(t, []byte(nil), r.GetContent())
 	r.SetContent([]byte("The Content"))
@@ -100,11 +114,11 @@ func Test_AcsRequest(t *testing.T) {
 type AcsRequestTest struct {
 	*baseRequest
 	Ontology AcsRequest
-	Query    string      `position:"Query" name:"Query"`
-	Header   string      `position:"Header" name:"Header"`
-	Path     string      `position:"Path" name:"Path"`
-	Body     string      `position:"Body" name:"Body"`
-	TypeAcs  *[]string   `position:"type" name:"type" type:"Repeated"`
+	Query    string    `position:"Query" name:"Query"`
+	Header   string    `position:"Header" name:"Header"`
+	Path     string    `position:"Path" name:"Path"`
+	Body     string    `position:"Body" name:"Body"`
+	TypeAcs  *[]string `position:"type" name:"type" type:"Repeated"`
 }
 
 func (r AcsRequestTest) BuildQueries() string {