Browse Source

Remove unnecessary struct wrapping

Julien Schmidt 12 years ago
parent
commit
8416bd00a6
5 changed files with 35 additions and 43 deletions
  1. 5 4
      connection.go
  2. 3 3
      packets.go
  3. 3 3
      result.go
  4. 16 21
      rows.go
  5. 8 12
      statement.go

+ 5 - 4
connection.go

@@ -108,8 +108,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 		return nil, err
 	}
 
-	stmt := mysqlStmt{new(stmtContent)}
-	stmt.mc = mc
+	stmt := &mysqlStmt{
+		mc: mc,
+	}
 
 	// Read Result
 	var columnCount uint16
@@ -202,11 +203,11 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 		return nil, err
 	}
 
-	rows := mysqlRows{&rowsContent{mc, false, nil, false}}
+	rows := &mysqlRows{mc, false, nil, false}
 
 	if resLen > 0 {
 		// Columns
-		rows.content.columns, err = mc.readColumns(resLen)
+		rows.columns, err = mc.readColumns(resLen)
 		if err != nil {
 			return nil, err
 		}

+ 3 - 3
packets.go

@@ -658,7 +658,7 @@ Prepare OK Packet
         (EOF packet)
 
 */
-func (stmt mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
+func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
 	data, err := stmt.mc.readPacket()
 	if err != nil {
 		return
@@ -704,7 +704,7 @@ Bytes                Name
 n*2                  type of parameters
 n                    values for the parameters
 */
-func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
+func (stmt *mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
 	argsLen := len(*args)
 	if argsLen < stmt.paramCount {
 		return fmt.Errorf(
@@ -838,7 +838,7 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
 }
 
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
-func (mc *mysqlConn) readBinaryRow(rc *rowsContent) (*[]*[]byte, error) {
+func (mc *mysqlConn) readBinaryRow(rc *mysqlRows) (*[]*[]byte, error) {
 	data, err := mc.readPacket()
 	if err != nil {
 		return nil, err

+ 3 - 3
result.go

@@ -2,7 +2,7 @@
 //
 // Copyright 2012 Julien Schmidt. All rights reserved.
 // http://www.julienschmidt.com
-// 
+//
 // This Source Code Form is subject to the terms of the Mozilla Public
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
@@ -14,10 +14,10 @@ type mysqlResult struct {
 	insertId     int64
 }
 
-func (res mysqlResult) LastInsertId() (int64, error) {
+func (res *mysqlResult) LastInsertId() (int64, error) {
 	return res.insertId, nil
 }
 
-func (res mysqlResult) RowsAffected() (int64, error) {
+func (res *mysqlResult) RowsAffected() (int64, error) {
 	return res.affectedRows, nil
 }

+ 16 - 21
rows.go

@@ -21,38 +21,33 @@ type mysqlField struct {
 	flags     FieldFlag
 }
 
-type rowsContent struct {
+type mysqlRows struct {
 	mc      *mysqlConn
 	binary  bool
 	columns []mysqlField
 	eof     bool
 }
 
-type mysqlRows struct {
-	content *rowsContent
-}
-
-func (rows mysqlRows) Columns() (columns []string) {
-	columns = make([]string, len(rows.content.columns))
+func (rows *mysqlRows) Columns() (columns []string) {
+	columns = make([]string, len(rows.columns))
 	for i := 0; i < cap(columns); i++ {
-		columns[i] = rows.content.columns[i].name
+		columns[i] = rows.columns[i].name
 	}
 	return
 }
 
-func (rows mysqlRows) Close() (err error) {
+func (rows *mysqlRows) Close() (err error) {
 	defer func() {
-		rows.content.mc = nil
-		rows.content = nil
+		rows.mc = nil
 	}()
 
 	// Remove unread packets from stream
-	if !rows.content.eof {
-		if rows.content.mc == nil {
+	if !rows.eof {
+		if rows.mc == nil {
 			return errors.New("Invalid Connection")
 		}
 
-		_, err = rows.content.mc.readUntilEOF()
+		_, err = rows.mc.readUntilEOF()
 		if err != nil {
 			return
 		}
@@ -65,12 +60,12 @@ func (rows mysqlRows) Close() (err error) {
 // or []byte's for all other entries. Type conversion is done on rows.scan(),
 // when the dest type is know, which makes type conversion easier and avoids
 // unnecessary conversions.
-func (rows mysqlRows) Next(dest []driver.Value) error {
-	if rows.content.eof {
+func (rows *mysqlRows) Next(dest []driver.Value) error {
+	if rows.eof {
 		return io.EOF
 	}
 
-	if rows.content.mc == nil {
+	if rows.mc == nil {
 		return errors.New("Invalid Connection")
 	}
 
@@ -79,15 +74,15 @@ func (rows mysqlRows) Next(dest []driver.Value) error {
 	// Fetch next row from stream
 	var row *[]*[]byte
 	var err error
-	if rows.content.binary {
-		row, err = rows.content.mc.readBinaryRow(rows.content)
+	if rows.binary {
+		row, err = rows.mc.readBinaryRow(rows)
 	} else {
-		row, err = rows.content.mc.readRow(columnsCount)
+		row, err = rows.mc.readRow(columnsCount)
 	}
 
 	if err != nil {
 		if err == io.EOF {
-			rows.content.eof = true
+			rows.eof = true
 		}
 		return err
 	}

+ 8 - 12
statement.go

@@ -14,28 +14,24 @@ import (
 	"errors"
 )
 
-type stmtContent struct {
+type mysqlStmt struct {
 	mc         *mysqlConn
 	id         uint32
 	paramCount int
 	params     []mysqlField
 }
 
-type mysqlStmt struct {
-	*stmtContent
-}
-
-func (stmt mysqlStmt) Close() (err error) {
+func (stmt *mysqlStmt) Close() (err error) {
 	err = stmt.mc.writeCommandPacket(COM_STMT_CLOSE, stmt.id)
 	stmt.mc = nil
 	return
 }
 
-func (stmt mysqlStmt) NumInput() int {
+func (stmt *mysqlStmt) NumInput() int {
 	return stmt.paramCount
 }
 
-func (stmt mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 	if stmt.mc == nil {
 		return nil, errors.New(`Invalid Statement`)
 	}
@@ -72,13 +68,13 @@ func (stmt mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 		return nil, err
 	}
 
-	return mysqlResult{
+	return &mysqlResult{
 			affectedRows: int64(stmt.mc.affectedRows),
 			insertId:     int64(stmt.mc.insertId)},
 		nil
 }
 
-func (stmt mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
+func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 	if stmt.mc == nil {
 		return nil, errors.New(`Invalid Statement`)
 	}
@@ -96,11 +92,11 @@ func (stmt mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 		return nil, err
 	}
 
-	rows := mysqlRows{&rowsContent{stmt.mc, true, nil, false}}
+	rows := &mysqlRows{stmt.mc, true, nil, false}
 
 	if resLen > 0 {
 		// Columns
-		rows.content.columns, err = stmt.mc.readColumns(resLen)
+		rows.columns, err = stmt.mc.readColumns(resLen)
 		if err != nil {
 			return nil, err
 		}