mempipe_test.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "io"
  7. "sync"
  8. "testing"
  9. )
  10. // An in-memory packetConn. It is safe to call Close and writePacket
  11. // from different goroutines.
  12. type memTransport struct {
  13. eof bool
  14. pending [][]byte
  15. write *memTransport
  16. sync.Mutex
  17. *sync.Cond
  18. }
  19. func (t *memTransport) readPacket() ([]byte, error) {
  20. t.Lock()
  21. defer t.Unlock()
  22. for {
  23. if len(t.pending) > 0 {
  24. r := t.pending[0]
  25. t.pending = t.pending[1:]
  26. return r, nil
  27. }
  28. if t.eof {
  29. return nil, io.EOF
  30. }
  31. t.Cond.Wait()
  32. }
  33. }
  34. func (t *memTransport) Close() error {
  35. t.write.Lock()
  36. defer t.write.Unlock()
  37. if t.write.eof {
  38. return io.EOF
  39. }
  40. t.write.eof = true
  41. t.write.Cond.Broadcast()
  42. return nil
  43. }
  44. func (t *memTransport) writePacket(p []byte) error {
  45. t.write.Lock()
  46. defer t.write.Unlock()
  47. if t.write.eof {
  48. return io.EOF
  49. }
  50. t.write.pending = append(t.write.pending, p)
  51. t.write.Cond.Signal()
  52. return nil
  53. }
  54. func memPipe() (a, b packetConn) {
  55. t1 := memTransport{}
  56. t2 := memTransport{}
  57. t1.write = &t2
  58. t2.write = &t1
  59. t1.Cond = sync.NewCond(&t1.Mutex)
  60. t2.Cond = sync.NewCond(&t2.Mutex)
  61. return &t1, &t2
  62. }
  63. func TestmemPipe(t *testing.T) {
  64. a, b := memPipe()
  65. if err := a.writePacket([]byte{42}); err != nil {
  66. t.Fatalf("writePacket: %v", err)
  67. }
  68. if err := a.Close(); err != nil {
  69. t.Fatal("Close: ", err)
  70. }
  71. p, err := b.readPacket()
  72. if err != nil {
  73. t.Fatal("readPacket: ", err)
  74. }
  75. if len(p) != 1 || p[0] != 42 {
  76. t.Fatalf("got %v, want {42}", p)
  77. }
  78. p, err = b.readPacket()
  79. if err != io.EOF {
  80. t.Fatalf("got %v, %v, want EOF", p, err)
  81. }
  82. }
  83. func TestDoubleClose(t *testing.T) {
  84. a, _ := memPipe()
  85. err := a.Close()
  86. if err != nil {
  87. t.Errorf("Close: %v", err)
  88. }
  89. err = a.Close()
  90. if err != io.EOF {
  91. t.Errorf("expect EOF on double close.")
  92. }
  93. }