mempipe_test.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "io"
  7. "sync"
  8. "testing"
  9. )
  10. // An in-memory packetConn. It is safe to call Close and writePacket
  11. // from different goroutines.
  12. type memTransport struct {
  13. eof bool
  14. pending [][]byte
  15. write *memTransport
  16. sync.Mutex
  17. *sync.Cond
  18. }
  19. func (t *memTransport) readPacket() ([]byte, error) {
  20. t.Lock()
  21. defer t.Unlock()
  22. for {
  23. if len(t.pending) > 0 {
  24. r := t.pending[0]
  25. t.pending = t.pending[1:]
  26. return r, nil
  27. }
  28. if t.eof {
  29. return nil, io.EOF
  30. }
  31. t.Cond.Wait()
  32. }
  33. }
  34. func (t *memTransport) closeSelf() error {
  35. t.Lock()
  36. defer t.Unlock()
  37. if t.eof {
  38. return io.EOF
  39. }
  40. t.eof = true
  41. t.Cond.Broadcast()
  42. return nil
  43. }
  44. func (t *memTransport) Close() error {
  45. err := t.write.closeSelf()
  46. t.closeSelf()
  47. return err
  48. }
  49. func (t *memTransport) writePacket(p []byte) error {
  50. t.write.Lock()
  51. defer t.write.Unlock()
  52. if t.write.eof {
  53. return io.EOF
  54. }
  55. t.write.pending = append(t.write.pending, p)
  56. t.write.Cond.Signal()
  57. return nil
  58. }
  59. func memPipe() (a, b packetConn) {
  60. t1 := memTransport{}
  61. t2 := memTransport{}
  62. t1.write = &t2
  63. t2.write = &t1
  64. t1.Cond = sync.NewCond(&t1.Mutex)
  65. t2.Cond = sync.NewCond(&t2.Mutex)
  66. return &t1, &t2
  67. }
  68. func TestmemPipe(t *testing.T) {
  69. a, b := memPipe()
  70. if err := a.writePacket([]byte{42}); err != nil {
  71. t.Fatalf("writePacket: %v", err)
  72. }
  73. if err := a.Close(); err != nil {
  74. t.Fatal("Close: ", err)
  75. }
  76. p, err := b.readPacket()
  77. if err != nil {
  78. t.Fatal("readPacket: ", err)
  79. }
  80. if len(p) != 1 || p[0] != 42 {
  81. t.Fatalf("got %v, want {42}", p)
  82. }
  83. p, err = b.readPacket()
  84. if err != io.EOF {
  85. t.Fatalf("got %v, %v, want EOF", p, err)
  86. }
  87. }
  88. func TestDoubleClose(t *testing.T) {
  89. a, _ := memPipe()
  90. err := a.Close()
  91. if err != nil {
  92. t.Errorf("Close: %v", err)
  93. }
  94. err = a.Close()
  95. if err != io.EOF {
  96. t.Errorf("expect EOF on double close.")
  97. }
  98. }