Pārlūkot izejas kodu

conn: tidy up frame coalescer (#1181)

Chris Bannister 7 gadi atpakaļ
vecāks
revīzija
e898b2baaf
2 mainītis faili ar 51 papildinājumiem un 49 dzēšanām
  1. 47 45
      conn.go
  2. 4 4
      conn_test.go

+ 47 - 45
conn.go

@@ -126,8 +126,7 @@ var TimeoutLimit int64 = 0
 type Conn struct {
 	conn net.Conn
 	r    *bufio.Reader
-
-	w *writeCoalescer
+	w    io.Writer
 
 	timeout       time.Duration
 	cfg           *ConnConfig
@@ -211,6 +210,10 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa
 		streams:       streams.New(cfg.ProtoVersion),
 		host:          host,
 		frameObserver: s.frameObserver,
+		w: &deadlineWriter{
+			w:       conn,
+			timeout: cfg.Timeout,
+		},
 	}
 
 	var (
@@ -263,44 +266,20 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa
 	// dont coalesce startup frames
 	if s.cfg.WriteCoalesceWaitTime > 0 {
 		w := &writeCoalescer{
-			w:       conn,
-			timeout: cfg.Timeout,
+			cond: sync.NewCond(&sync.Mutex{}),
+			w:    c.w,
 		}
-		w.cond = sync.NewCond(&w.mu)
+		go w.writeFlusher(s.cfg.WriteCoalesceWaitTime, c.quit)
 		c.w = w
-		go c.writeFlusher()
 	}
+
 	go c.serve()
 
 	return c, nil
 }
 
-func (c *Conn) writeFlusher() {
-	ticker := time.NewTicker(c.session.cfg.WriteCoalesceWaitTime)
-	defer ticker.Stop()
-	defer c.w.flush()
-
-	for {
-		select {
-		case <-c.quit:
-			return
-		case <-ticker.C:
-		}
-
-		c.w.flush()
-	}
-}
-
 func (c *Conn) Write(p []byte) (n int, err error) {
-	if c.w != nil {
-		n, err = c.w.write(p)
-	} else {
-		if c.timeout > 0 {
-			c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
-		}
-		n, err = c.conn.Write(p)
-	}
-	return n, err
+	return c.w.Write(p)
 }
 
 func (c *Conn) Read(p []byte) (n int, err error) {
@@ -617,12 +596,25 @@ type callReq struct {
 	timer *time.Timer
 }
 
-type writeCoalescer struct {
-	w       io.Writer
+type deadlineWriter struct {
+	w interface {
+		SetWriteDeadline(time.Time) error
+		io.Writer
+	}
 	timeout time.Duration
+}
+
+func (c *deadlineWriter) Write(p []byte) (int, error) {
+	if c.timeout > 0 {
+		c.w.SetWriteDeadline(time.Now().Add(c.timeout))
+	}
+	return c.w.Write(p)
+}
+
+type writeCoalescer struct {
+	w io.Writer
 
 	cond    *sync.Cond
-	mu      sync.Mutex
 	buffers net.Buffers
 
 	// result of the write
@@ -630,15 +622,9 @@ type writeCoalescer struct {
 }
 
 func (w *writeCoalescer) flush() {
-	if w.timeout > 0 {
-		type deadliner interface {
-			SetWriteDeadline(time.Time) error
-		}
-		w.w.(deadliner).SetWriteDeadline(time.Now().Add(w.timeout))
-	}
+	w.cond.L.Lock()
+	defer w.cond.L.Unlock()
 
-	w.mu.Lock()
-	defer w.mu.Unlock()
 	if len(w.buffers) == 0 {
 		return
 	}
@@ -653,15 +639,15 @@ func (w *writeCoalescer) flush() {
 	w.cond.Broadcast()
 }
 
-func (w *writeCoalescer) write(p []byte) (int, error) {
-	w.mu.Lock()
+func (w *writeCoalescer) Write(p []byte) (int, error) {
+	w.cond.L.Lock()
 	w.buffers = append(w.buffers, p)
 	for len(w.buffers) != 0 {
 		w.cond.Wait()
 	}
 
 	err := w.err
-	w.mu.Unlock()
+	w.cond.L.Unlock()
 
 	if err != nil {
 		return 0, err
@@ -669,6 +655,22 @@ func (w *writeCoalescer) write(p []byte) (int, error) {
 	return len(p), nil
 }
 
+func (w *writeCoalescer) writeFlusher(interval time.Duration, quit chan struct{}) {
+	ticker := time.NewTicker(interval)
+	defer ticker.Stop()
+	defer w.flush()
+
+	for {
+		select {
+		case <-quit:
+			return
+		case <-ticker.C:
+		}
+
+		w.flush()
+	}
+}
+
 func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
 	// TODO: move tracer onto conn
 	stream, ok := c.streams.GetStream()

+ 4 - 4
conn_test.go

@@ -754,16 +754,16 @@ func TestContext_Timeout(t *testing.T) {
 func TestWriteCoalescing(t *testing.T) {
 	var buf bytes.Buffer
 	w := &writeCoalescer{
-		w: &buf,
+		w:    &buf,
+		cond: sync.NewCond(&sync.Mutex{}),
 	}
-	w.cond = sync.NewCond(&w.mu)
 
 	var wg sync.WaitGroup
 
 	wg.Add(1)
 	go func() {
 		wg.Done()
-		if _, err := w.write([]byte("one")); err != nil {
+		if _, err := w.Write([]byte("one")); err != nil {
 			t.Error(err)
 		}
 	}()
@@ -772,7 +772,7 @@ func TestWriteCoalescing(t *testing.T) {
 	wg.Add(1)
 	go func() {
 		wg.Done()
-		if _, err := w.write([]byte("two")); err != nil {
+		if _, err := w.Write([]byte("two")); err != nil {
 			t.Error(err)
 		}
 	}()