Browse Source

buffer: Use a double-buffering scheme to prevent data races (#943)

Fixes #903

Co-Authored-By: vmg <vicent@github.com>
Vicent Martí 6 years ago
parent
commit
df597a2343
4 changed files with 151 additions and 13 deletions
  1. 54 0
      benchmark_test.go
  2. 35 13
      buffer.go
  3. 55 0
      driver_test.go
  4. 7 0
      rows.go

+ 54 - 0
benchmark_test.go

@@ -317,3 +317,57 @@ func BenchmarkExecContext(b *testing.B) {
 		})
 	}
 }
+
+// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes.
+// "size=" means size of each blobs.
+func BenchmarkQueryRawBytes(b *testing.B) {
+	var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000}
+	db := initDB(b,
+		"DROP TABLE IF EXISTS bench_rawbytes",
+		"CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)",
+	)
+	defer db.Close()
+
+	blob := make([]byte, sizes[len(sizes)-1])
+	for i := range blob {
+		blob[i] = 42
+	}
+	for i := 0; i < 100; i++ {
+		_, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob)
+		if err != nil {
+			b.Fatal(err)
+		}
+	}
+
+	for _, s := range sizes {
+		b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) {
+			db.SetMaxIdleConns(0)
+			db.SetMaxIdleConns(1)
+			b.ReportAllocs()
+			b.ResetTimer()
+
+			for j := 0; j < b.N; j++ {
+				rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s)
+				if err != nil {
+					b.Fatal(err)
+				}
+				nrows := 0
+				for rows.Next() {
+					var buf sql.RawBytes
+					err := rows.Scan(&buf)
+					if err != nil {
+						b.Fatal(err)
+					}
+					if len(buf) != s {
+						b.Fatalf("size mismatch: expected %v, got %v", s, len(buf))
+					}
+					nrows++
+				}
+				rows.Close()
+				if nrows != 100 {
+					b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows)
+				}
+			}
+		})
+	}
+}

+ 35 - 13
buffer.go

@@ -15,47 +15,69 @@ import (
 )
 
 const defaultBufSize = 4096
+const maxCachedBufSize = 256 * 1024
 
 // A buffer which is used for both reading and writing.
 // This is possible since communication on each connection is synchronous.
 // In other words, we can't write and read simultaneously on the same connection.
 // The buffer is similar to bufio.Reader / Writer but zero-copy-ish
 // Also highly optimized for this particular use case.
+// This buffer is backed by two byte slices in a double-buffering scheme
 type buffer struct {
 	buf     []byte // buf is a byte buffer who's length and capacity are equal.
 	nc      net.Conn
 	idx     int
 	length  int
 	timeout time.Duration
+	dbuf    [2][]byte // dbuf is an array with the two byte slices that back this buffer
+	flipcnt uint      // flipccnt is the current buffer counter for double-buffering
 }
 
 // newBuffer allocates and returns a new buffer.
 func newBuffer(nc net.Conn) buffer {
+	fg := make([]byte, defaultBufSize)
 	return buffer{
-		buf: make([]byte, defaultBufSize),
-		nc:  nc,
+		buf:  fg,
+		nc:   nc,
+		dbuf: [2][]byte{fg, nil},
 	}
 }
 
+// flip replaces the active buffer with the background buffer
+// this is a delayed flip that simply increases the buffer counter;
+// the actual flip will be performed the next time we call `buffer.fill`
+func (b *buffer) flip() {
+	b.flipcnt += 1
+}
+
 // fill reads into the buffer until at least _need_ bytes are in it
 func (b *buffer) fill(need int) error {
 	n := b.length
+	// fill data into its double-buffering target: if we've called
+	// flip on this buffer, we'll be copying to the background buffer,
+	// and then filling it with network data; otherwise we'll just move
+	// the contents of the current buffer to the front before filling it
+	dest := b.dbuf[b.flipcnt&1]
+
+	// grow buffer if necessary to fit the whole packet.
+	if need > len(dest) {
+		// Round up to the next multiple of the default size
+		dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
 
-	// move existing data to the beginning
-	if n > 0 && b.idx > 0 {
-		copy(b.buf[0:n], b.buf[b.idx:])
+		// if the allocated buffer is not too large, move it to backing storage
+		// to prevent extra allocations on applications that perform large reads
+		if len(dest) <= maxCachedBufSize {
+			b.dbuf[b.flipcnt&1] = dest
+		}
 	}
 
-	// grow buffer if necessary
-	// TODO: let the buffer shrink again at some point
-	//       Maybe keep the org buf slice and swap back?
-	if need > len(b.buf) {
-		// Round up to the next multiple of the default size
-		newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
-		copy(newBuf, b.buf)
-		b.buf = newBuf
+	// if we're filling the fg buffer, move the existing data to the start of it.
+	// if we're filling the bg buffer, copy over the data
+	if n > 0 {
+		copy(dest[:n], b.buf[b.idx:])
 	}
 
+	b.buf = dest
 	b.idx = 0
 
 	for {

+ 55 - 0
driver_test.go

@@ -2938,3 +2938,58 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
 		// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
 	})
 }
+
+// TestRawBytesAreNotModified checks for a race condition that arises when a query context
+// is canceled while a user is calling rows.Scan. This is a more stringent test than the one
+// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using
+// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit
+// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers.
+func TestRawBytesAreNotModified(t *testing.T) {
+	const blob = "abcdefghijklmnop"
+	const contextRaceIterations = 20
+	const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row.
+	const insertRows = 4
+
+	var sqlBlobs = [2]string{
+		strings.Repeat(blob, blobSize/len(blob)),
+		strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)),
+	}
+
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
+		for i := 0; i < insertRows; i++ {
+			dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1])
+		}
+
+		for i := 0; i < contextRaceIterations; i++ {
+			func() {
+				ctx, cancel := context.WithCancel(context.Background())
+				defer cancel()
+
+				rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				var b int
+				var raw sql.RawBytes
+				for rows.Next() {
+					if err := rows.Scan(&b, &raw); err != nil {
+						t.Fatal(err)
+					}
+
+					before := string(raw)
+					// Ensure cancelling the query does not corrupt the contents of `raw`
+					cancel()
+					time.Sleep(time.Microsecond * 100)
+					after := string(raw)
+
+					if before != after {
+						t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i)
+					}
+				}
+				rows.Close()
+			}()
+		}
+	})
+}

+ 7 - 0
rows.go

@@ -111,6 +111,13 @@ func (rows *mysqlRows) Close() (err error) {
 		return err
 	}
 
+	// flip the buffer for this connection if we need to drain it.
+	// note that for a successful query (i.e. one where rows.next()
+	// has been called until it returns false), `rows.mc` will be nil
+	// by the time the user calls `(*Rows).Close`, so we won't reach this
+	// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
+	mc.buf.flip()
+
 	// Remove unread packets from stream
 	if !rows.rs.done {
 		err = mc.readUntilEOF()