Переглянути джерело

Do not mask bytes when reading on the client.

- The bytes were masked with zero, a nop.
- Add test for control messages.
Gary Burd 11 роки тому
батько
коміт
0e7b5f878f
2 змінених файлів з 42 додано та 2 видалено
  1. 6 2
      conn.go
  2. 36 0
      conn_test.go

+ 6 - 2
conn.go

@@ -615,7 +615,9 @@ func (c *Conn) advanceFrame() (int, error) {
 		if _, err := io.ReadFull(c.br, payload); err != nil {
 			return noFrame, err
 		}
-		maskBytes(c.readMaskKey, 0, payload)
+		if c.isServer {
+			maskBytes(c.readMaskKey, 0, payload)
+		}
 	}
 
 	// 7. Process control frame payload.
@@ -698,7 +700,9 @@ func (r messageReader) Read(b []byte) (n int, err error) {
 			}
 			n, err := r.c.br.Read(b)
 			r.c.readErr = hideTempErr(err)
-			r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
+			if r.c.isServer {
+				r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
+			}
 			r.c.readRemaining -= int64(n)
 			return n, r.c.readErr
 		}

+ 36 - 0
conn_test.go

@@ -107,6 +107,42 @@ func TestFraming(t *testing.T) {
 	}
 }
 
+func TestControl(t *testing.T) {
+	const message = "this is a ping/pong messsage"
+	for _, isServer := range []bool{true, false} {
+		for _, isWriteControl := range []bool{true, false} {
+			name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
+			var connBuf bytes.Buffer
+			wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
+			rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
+			if isWriteControl {
+				wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
+			} else {
+				w, err := wc.NextWriter(PongMessage)
+				if err != nil {
+					t.Errorf("%s: wc.NextWriter() returned %v", name, err)
+					continue
+				}
+				if _, err := w.Write([]byte(message)); err != nil {
+					t.Errorf("%s: w.Write() returned %v", name, err)
+					continue
+				}
+				if err := w.Close(); err != nil {
+					t.Errorf("%s: w.Close() returned %v", name, err)
+					continue
+				}
+				var actualMessage string
+				rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
+				rc.NextReader()
+				if actualMessage != message {
+					t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
+					continue
+				}
+			}
+		}
+	}
+}
+
 func TestReadLimit(t *testing.T) {
 
 	const readLimit = 512