Jelajahi Sumber

Fixed rebinding Bug
+ more clean up

Julien Schmidt 13 tahun lalu
induk
melakukan
59d433e881
2 mengubah file dengan 77 tambahan dan 99 penghapusan
  1. 75 97
      packets.go
  2. 2 2
      statement.go

+ 75 - 97
packets.go

@@ -49,7 +49,6 @@ func (mc *mysqlConn) readPacket() (data []byte, e error) {
 	data = make([]byte, pktLen)
 	n, e := mc.netConn.Read(data)
 	if e != nil || n != int(pktLen) {
-		fmt.Println(e)
 		e = driver.ErrBadConn
 		return
 	}
@@ -77,7 +76,6 @@ func (mc *mysqlConn) writePacket(data []byte) (e error) {
 	// Write packet
 	n, e := mc.netConn.Write(pktData)
 	if e != nil || n != len(pktData) {
-		fmt.Println("BadConn:", e)
 		e = driver.ErrBadConn
 		return
 	}
@@ -93,7 +91,6 @@ func (mc *mysqlConn) readNumber(n uint8) (num uint64, e error) {
 
 	nr, err := io.ReadFull(mc.netConn, buf)
 	if err != nil || nr != int(n) {
-		fmt.Println(e)
 		e = driver.ErrBadConn
 		return
 	}
@@ -141,7 +138,10 @@ func (mc *mysqlConn) readInitPacket() (e error) {
 	// Protocol version [8 bit uint]
 	mc.server.protocol = data[pos]
 	if mc.server.protocol < MIN_PROTOCOL_VERSION {
-		e = errors.New(fmt.Sprintf("Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", mc.server.protocol, MIN_PROTOCOL_VERSION))
+		e = fmt.Errorf(
+			"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
+			mc.server.protocol,
+			MIN_PROTOCOL_VERSION)
 	}
 	pos++
 
@@ -275,24 +275,24 @@ func (mc *mysqlConn) writeCommandPacket(command commandType, args ...interface{}
 	// Commands without args
 	case COM_QUIT, COM_PING:
 		if len(args) > 0 {
-			return errors.New(fmt.Sprintf("Too much arguments (Got: %d Has:0)", len(args)))
+			return fmt.Errorf("Too much arguments (Got: %d Has:0)", len(args))
 		}
 
 	// Commands with 1 arg unterminated string
 	case COM_QUERY, COM_STMT_PREPARE:
 		if len(args) != 1 {
-			return errors.New(fmt.Sprintf("Invalid arguments count (Got:%d Need: 1)", len(args)))
+			return fmt.Errorf("Invalid arguments count (Got:%d Need: 1)", len(args))
 		}
 		data = append(data, []byte(args[0].(string))...)
 
 	// Commands with 1 arg 32 bit uint
 	case COM_STMT_CLOSE:
 		if len(args) != 1 {
-			return errors.New(fmt.Sprintf("Invalid arguments count (Got:%d Need: 1)", len(args)))
+			return fmt.Errorf("Invalid arguments count (Got:%d Need: 1)", len(args))
 		}
 		data = append(data, uint32ToBytes(args[0].(uint32))...)
 	default:
-		return errors.New(fmt.Sprintf("Unknown command: %d", command))
+		return fmt.Errorf("Unknown command: %d", command)
 	}
 
 	// Send CMD packet
@@ -448,7 +448,7 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
 		// EOF Packet
 		if data[0] == 254 && len(data) == 5 {
 			if len(columns) != n {
-				e = errors.New(fmt.Sprintf("ColumnsCount mismatch n:%d len:%d", n, len(columns)))
+				e = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", n, len(columns))
 			}
 			return
 		}
@@ -716,7 +716,6 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
 		// Check for NULL fields
 		for i = 0; i < stmt.paramCount; i++ {
 			if (*args)[i] == nil {
-				fmt.Println("nil", i, (*args)[i])
 				bitMask += 1 << uint(i)
 			}
 		}
@@ -728,100 +727,79 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
 		// append nullBitMap [(param_count+7)/8 bytes]
 		data = append(data, nullBitMap...)
 
-		// Check for changed Params
-		newParamsBound := true
-		if stmt.args != nil {
-			for i := 0; i < len(*args); i++ {
-				if (*args)[i] != (*stmt.args)[i] {
-					fmt.Println((*args)[i], "!=", (*stmt.args)[i])
-					newParamsBound = false
-					break
-				}
+		// newParameterBoundFlag 1 [1 byte]
+		data = append(data, byte(1))
+
+		// append types and cache values
+		paramValues := make([]byte, 0)
+		var pv reflect.Value
+		for i = 0; i < stmt.paramCount; i++ {
+			switch (*args)[i].(type) {
+			case nil:
+				data = append(data, []byte{
+					byte(FIELD_TYPE_NULL),
+					0x0}...)
+				continue
+
+			case []byte:
+				data = append(data, []byte{
+					byte(FIELD_TYPE_STRING),
+					0x0}...)
+				val := (*args)[i].([]byte)
+				paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
+				paramValues = append(paramValues, val...)
+				continue
+
+			case time.Time:
+				// Format to string for time+date Fields
+				// Data is packed in case reflect.String below
+				(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
 			}
-		}
 
-		// No (new) Parameters bound or rebound
-		if !newParamsBound {
-			//newParameterBoundFlag 0 [1 byte]
-			data = append(data, byte(0))
-		} else {
-			// newParameterBoundFlag 1 [1 byte]
-			data = append(data, byte(1))
-
-			// append types and cache values
-			paramValues := make([]byte, 0)
-			var pv reflect.Value
-			for i = 0; i < stmt.paramCount; i++ {
-				switch (*args)[i].(type) {
-				case nil:
-					data = append(data, []byte{
-						byte(FIELD_TYPE_NULL),
-						0x0}...)
-					continue
-				
-				case []byte:
-					data = append(data, []byte{
-						byte(FIELD_TYPE_STRING),
-						0x0}...)
-					val := (*args)[i].([]byte)
-					paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
-					paramValues = append(paramValues, val...)
-					continue
-				
-				case time.Time:
-					// Format to string for time+date Fields
-					// Data is packed in case reflect.String below
-					(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
-				}
+			pv = reflect.ValueOf((*args)[i])
+			switch pv.Kind() {
+			case reflect.Int64:
+				data = append(data, []byte{
+					byte(FIELD_TYPE_LONGLONG),
+					0x0}...)
+				paramValues = append(paramValues, int64ToBytes(pv.Int())...)
+				continue
 
-				pv = reflect.ValueOf((*args)[i])
-				switch pv.Kind() {
-				case reflect.Int64:
-					data = append(data, []byte{
-						byte(FIELD_TYPE_LONGLONG),
-						0x0}...)
-					paramValues = append(paramValues, int64ToBytes(pv.Int())...)
-					continue
-
-				case reflect.Float64:
-					data = append(data, []byte{
-						byte(FIELD_TYPE_DOUBLE),
-						0x0}...)
-					paramValues = append(paramValues, float64ToBytes(pv.Float())...)
-					continue
-
-				case reflect.Bool:
-					data = append(data, []byte{
-						byte(FIELD_TYPE_TINY),
-						0x0}...)
-					val := pv.Bool()
-					if val {
-						paramValues = append(paramValues, byte(1))
-					} else {
-						paramValues = append(paramValues, byte(0))
-					}
-					continue
-
-				case reflect.String:
-					data = append(data, []byte{
-						byte(FIELD_TYPE_STRING),
-						0x0}...)
-					val := pv.String()
-					paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
-					paramValues = append(paramValues, []byte(val)...)
-					continue
-
-				default:
-					return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
+			case reflect.Float64:
+				data = append(data, []byte{
+					byte(FIELD_TYPE_DOUBLE),
+					0x0}...)
+				paramValues = append(paramValues, float64ToBytes(pv.Float())...)
+				continue
+
+			case reflect.Bool:
+				data = append(data, []byte{
+					byte(FIELD_TYPE_TINY),
+					0x0}...)
+				val := pv.Bool()
+				if val {
+					paramValues = append(paramValues, byte(1))
+				} else {
+					paramValues = append(paramValues, byte(0))
 				}
-			}
+				continue
 
-			// append cached values
-			data = append(data, paramValues...)
+			case reflect.String:
+				data = append(data, []byte{
+					byte(FIELD_TYPE_STRING),
+					0x0}...)
+				val := pv.String()
+				paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
+				paramValues = append(paramValues, []byte(val)...)
+				continue
+
+			default:
+				return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
+			}
 		}
 
-		// Save args
-		stmt.args = args
+		// append cached values
+		data = append(data, paramValues...)
 	}
 	return stmt.mc.writePacket(data)
 }

+ 2 - 2
statement.go

@@ -18,8 +18,6 @@ type stmtContent struct {
 	query          string
 	paramCount     int
 	params         []*mysqlField
-	args           *[]driver.Value
-	newParamsBound bool
 }
 
 type mysqlStmt struct {
@@ -55,11 +53,13 @@ func (stmt mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 	}
 
 	if resLen > 0 {
+		// Columns
 		_, e = stmt.mc.readUntilEOF()
 		if e != nil {
 			return nil, e
 		}
 
+		// Rows
 		stmt.mc.affectedRows, e = stmt.mc.readUntilEOF()
 		if e != nil {
 			return nil, e