Sfoglia il codice sorgente

Add maxAllowedPacket DSN Parameter

Allows to set the driver-side max_packet_allowed value manually

* add docs

* add author, add test

* fix AUTHOR file and README

* change the param name from MaxPacketAllowed to MaxAllowedPacket

* fix maxAllowedPacket param name in README

* rename all maxPacketAllowed to maxAllowedPacket

* fix doc format

* fix doc
Xie Zhenye 9 anni fa
parent
commit
b81e73ccad
9 ha cambiato i file con 66 aggiunte e 35 eliminazioni
  1. 1 0
      AUTHORS
  2. 9 0
      README.md
  3. 1 1
      benchmark_test.go
  4. 2 2
      connection.go
  5. 3 3
      connection_test.go
  6. 13 9
      driver.go
  7. 31 14
      dsn.go
  8. 2 2
      dsn_test.go
  9. 4 4
      packets.go

+ 1 - 0
AUTHORS

@@ -45,6 +45,7 @@ Stan Putrya <root.vagner at gmail.com>
 Stanley Gunawan <gunawan.stanley at gmail.com>
 Xiaobing Jiang <s7v7nislands at gmail.com>
 Xiuming Chen <cc at cxm.cc>
+Zhenye Xie <xiezhenye at gmail.com>
 
 # Organizations
 

+ 9 - 0
README.md

@@ -299,6 +299,15 @@ Default:        0
 I/O write timeout. The value must be a decimal number with an unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*.
 
 
+##### `maxAllowedPacket`
+```
+Type:          decimal number
+Default:       0
+```
+
+Max packet size allowed in bytes. Use `maxAllowedPacket=0` to automatically fetch the `max_allowed_packet` variable from server.
+
+
 ##### System Variables
 
 All other parameters are interpreted as system variables:

+ 1 - 1
benchmark_test.go

@@ -220,7 +220,7 @@ func BenchmarkInterpolation(b *testing.B) {
 			InterpolateParams: true,
 			Loc:               time.UTC,
 		},
-		maxPacketAllowed: maxPacketSize,
+		maxAllowedPacket: maxPacketSize,
 		maxWriteSize:     maxPacketSize - 1,
 		buf:              newBuffer(nil),
 	}

+ 2 - 2
connection.go

@@ -22,7 +22,7 @@ type mysqlConn struct {
 	affectedRows     uint64
 	insertId         uint64
 	cfg              *Config
-	maxPacketAllowed int
+	maxAllowedPacket int
 	maxWriteSize     int
 	writeTimeout     time.Duration
 	flags            clientFlag
@@ -246,7 +246,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
 			return "", driver.ErrSkip
 		}
 
-		if len(buf)+4 > mc.maxPacketAllowed {
+		if len(buf)+4 > mc.maxAllowedPacket {
 			return "", driver.ErrSkip
 		}
 	}

+ 3 - 3
connection_test.go

@@ -16,7 +16,7 @@ import (
 func TestInterpolateParams(t *testing.T) {
 	mc := &mysqlConn{
 		buf:              newBuffer(nil),
-		maxPacketAllowed: maxPacketSize,
+		maxAllowedPacket: maxPacketSize,
 		cfg: &Config{
 			InterpolateParams: true,
 		},
@@ -36,7 +36,7 @@ func TestInterpolateParams(t *testing.T) {
 func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
 	mc := &mysqlConn{
 		buf:              newBuffer(nil),
-		maxPacketAllowed: maxPacketSize,
+		maxAllowedPacket: maxPacketSize,
 		cfg: &Config{
 			InterpolateParams: true,
 		},
@@ -53,7 +53,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
 func TestInterpolateParamsPlaceholderInString(t *testing.T) {
 	mc := &mysqlConn{
 		buf:              newBuffer(nil),
-		maxPacketAllowed: maxPacketSize,
+		maxAllowedPacket: maxPacketSize,
 		cfg: &Config{
 			InterpolateParams: true,
 		},

+ 13 - 9
driver.go

@@ -50,7 +50,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 
 	// New mysqlConn
 	mc := &mysqlConn{
-		maxPacketAllowed: maxPacketSize,
+		maxAllowedPacket: maxPacketSize,
 		maxWriteSize:     maxPacketSize - 1,
 	}
 	mc.cfg, err = ParseDSN(dsn)
@@ -109,15 +109,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 		return nil, err
 	}
 
-	// Get max allowed packet size
-	maxap, err := mc.getSystemVar("max_allowed_packet")
-	if err != nil {
-		mc.Close()
-		return nil, err
+	if mc.cfg.MaxAllowedPacket > 0 {
+		mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
+	} else {
+		// Get max allowed packet size
+		maxap, err := mc.getSystemVar("max_allowed_packet")
+		if err != nil {
+			mc.Close()
+			return nil, err
+		}
+		mc.maxAllowedPacket = stringToInt(maxap) - 1
 	}
-	mc.maxPacketAllowed = stringToInt(maxap) - 1
-	if mc.maxPacketAllowed < maxPacketSize {
-		mc.maxWriteSize = mc.maxPacketAllowed
+	if mc.maxAllowedPacket < maxPacketSize {
+		mc.maxWriteSize = mc.maxAllowedPacket
 	}
 
 	// Handle DSN Params

+ 31 - 14
dsn.go

@@ -15,6 +15,7 @@ import (
 	"fmt"
 	"net"
 	"net/url"
+	"strconv"
 	"strings"
 	"time"
 )
@@ -28,19 +29,20 @@ var (
 
 // Config is a configuration parsed from a DSN string
 type Config struct {
-	User         string            // Username
-	Passwd       string            // Password (requires User)
-	Net          string            // Network type
-	Addr         string            // Network address (requires Net)
-	DBName       string            // Database name
-	Params       map[string]string // Connection parameters
-	Collation    string            // Connection collation
-	Loc          *time.Location    // Location for time.Time values
-	TLSConfig    string            // TLS configuration name
-	tls          *tls.Config       // TLS configuration
-	Timeout      time.Duration     // Dial timeout
-	ReadTimeout  time.Duration     // I/O read timeout
-	WriteTimeout time.Duration     // I/O write timeout
+	User             string            // Username
+	Passwd           string            // Password (requires User)
+	Net              string            // Network type
+	Addr             string            // Network address (requires Net)
+	DBName           string            // Database name
+	Params           map[string]string // Connection parameters
+	Collation        string            // Connection collation
+	Loc              *time.Location    // Location for time.Time values
+	MaxAllowedPacket int               // Max packet size allowed
+	TLSConfig        string            // TLS configuration name
+	tls              *tls.Config       // TLS configuration
+	Timeout          time.Duration     // Dial timeout
+	ReadTimeout      time.Duration     // I/O read timeout
+	WriteTimeout     time.Duration     // I/O write timeout
 
 	AllowAllFiles           bool // Allow all files to be used with LOAD DATA LOCAL INFILE
 	AllowCleartextPasswords bool // Allows the cleartext client side plugin
@@ -222,6 +224,17 @@ func (cfg *Config) FormatDSN() string {
 		buf.WriteString(cfg.WriteTimeout.String())
 	}
 
+	if cfg.MaxAllowedPacket > 0 {
+		if hasParam {
+			buf.WriteString("&maxAllowedPacket=")
+		} else {
+			hasParam = true
+			buf.WriteString("?maxAllowedPacket=")
+		}
+		buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))
+
+	}
+
 	// other params
 	if cfg.Params != nil {
 		for param, value := range cfg.Params {
@@ -496,7 +509,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 			if err != nil {
 				return
 			}
-
+		case "maxAllowedPacket":
+			cfg.MaxAllowedPacket, err = strconv.Atoi(value)
+			if err != nil {
+				return
+			}
 		default:
 			// lazy init
 			if cfg.Params == nil {

+ 2 - 2
dsn_test.go

@@ -39,8 +39,8 @@ var testDSNs = []struct {
 	"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
 	&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, TLSConfig: "skip-verify"},
 }, {
-	"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci",
-	&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true},
+	"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216",
+	&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
 }, {
 	"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
 	&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local},

+ 4 - 4
packets.go

@@ -80,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 func (mc *mysqlConn) writePacket(data []byte) error {
 	pktLen := len(data) - 4
 
-	if pktLen > mc.maxPacketAllowed {
+	if pktLen > mc.maxAllowedPacket {
 		return ErrPktTooLarge
 	}
 
@@ -786,7 +786,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
 
 // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
 func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
-	maxLen := stmt.mc.maxPacketAllowed - 1
+	maxLen := stmt.mc.maxAllowedPacket - 1
 	pktLen := maxLen
 
 	// After the header (bytes 0-3) follows before the data:
@@ -977,7 +977,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 					paramTypes[i+i] = fieldTypeString
 					paramTypes[i+i+1] = 0x00
 
-					if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+					if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
 						paramValues = appendLengthEncodedInteger(paramValues,
 							uint64(len(v)),
 						)
@@ -999,7 +999,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramTypes[i+i] = fieldTypeString
 				paramTypes[i+i+1] = 0x00
 
-				if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+				if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
 					paramValues = appendLengthEncodedInteger(paramValues,
 						uint64(len(v)),
 					)