dial_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package socks_test
  5. import (
  6. "context"
  7. "io"
  8. "math/rand"
  9. "net"
  10. "os"
  11. "testing"
  12. "time"
  13. "golang.org/x/net/internal/socks"
  14. "golang.org/x/net/internal/sockstest"
  15. )
  16. func TestDial(t *testing.T) {
  17. t.Run("Connect", func(t *testing.T) {
  18. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
  19. if err != nil {
  20. t.Fatal(err)
  21. }
  22. defer ss.Close()
  23. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  24. d.AuthMethods = []socks.AuthMethod{
  25. socks.AuthMethodNotRequired,
  26. socks.AuthMethodUsernamePassword,
  27. }
  28. d.Authenticate = (&socks.UsernamePassword{
  29. Username: "username",
  30. Password: "password",
  31. }).Authenticate
  32. c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. c.(*socks.Conn).BoundAddr()
  37. c.Close()
  38. })
  39. t.Run("ConnectWithConn", func(t *testing.T) {
  40. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
  41. if err != nil {
  42. t.Fatal(err)
  43. }
  44. defer ss.Close()
  45. c, err := net.Dial(ss.Addr().Network(), ss.Addr().String())
  46. if err != nil {
  47. t.Fatal(err)
  48. }
  49. defer c.Close()
  50. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  51. d.AuthMethods = []socks.AuthMethod{
  52. socks.AuthMethodNotRequired,
  53. socks.AuthMethodUsernamePassword,
  54. }
  55. d.Authenticate = (&socks.UsernamePassword{
  56. Username: "username",
  57. Password: "password",
  58. }).Authenticate
  59. a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
  60. if err != nil {
  61. t.Fatal(err)
  62. }
  63. if _, ok := a.(*socks.Addr); !ok {
  64. t.Fatalf("got %+v; want socks.Addr", a)
  65. }
  66. })
  67. t.Run("Cancel", func(t *testing.T) {
  68. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
  69. if err != nil {
  70. t.Fatal(err)
  71. }
  72. defer ss.Close()
  73. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  74. ctx, cancel := context.WithCancel(context.Background())
  75. defer cancel()
  76. dialErr := make(chan error)
  77. go func() {
  78. c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
  79. if err == nil {
  80. c.Close()
  81. }
  82. dialErr <- err
  83. }()
  84. time.Sleep(100 * time.Millisecond)
  85. cancel()
  86. err = <-dialErr
  87. if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
  88. t.Fatalf("got %v; want context.Canceled or equivalent", err)
  89. }
  90. })
  91. t.Run("Deadline", func(t *testing.T) {
  92. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
  93. if err != nil {
  94. t.Fatal(err)
  95. }
  96. defer ss.Close()
  97. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  98. ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
  99. defer cancel()
  100. c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
  101. if err == nil {
  102. c.Close()
  103. }
  104. if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
  105. t.Fatalf("got %v; want context.DeadlineExceeded or equivalent", err)
  106. }
  107. })
  108. t.Run("WithRogueServer", func(t *testing.T) {
  109. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
  110. if err != nil {
  111. t.Fatal(err)
  112. }
  113. defer ss.Close()
  114. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  115. for i := 0; i < 2*len(rogueCmdList); i++ {
  116. ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
  117. defer cancel()
  118. c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
  119. if err == nil {
  120. t.Log(c.(*socks.Conn).BoundAddr())
  121. c.Close()
  122. t.Error("should fail")
  123. }
  124. }
  125. })
  126. }
  127. func blackholeCmdFunc(rw io.ReadWriter, b []byte) error {
  128. if _, err := sockstest.ParseCmdRequest(b); err != nil {
  129. return err
  130. }
  131. var bb [1]byte
  132. for {
  133. if _, err := rw.Read(bb[:]); err != nil {
  134. return err
  135. }
  136. }
  137. }
  138. func rogueCmdFunc(rw io.ReadWriter, b []byte) error {
  139. if _, err := sockstest.ParseCmdRequest(b); err != nil {
  140. return err
  141. }
  142. rw.Write(rogueCmdList[rand.Intn(len(rogueCmdList))])
  143. return nil
  144. }
  145. var rogueCmdList = [][]byte{
  146. {0x05},
  147. {0x06, 0x00, 0x00, 0x01, 192, 0, 2, 1, 0x17, 0x4b},
  148. {0x05, 0x00, 0xff, 0x01, 192, 0, 2, 2, 0x17, 0x4b},
  149. {0x05, 0x00, 0x00, 0x01, 192, 0, 2, 3},
  150. {0x05, 0x00, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N'},
  151. }
  152. func parseDialError(err error) (perr, nerr error) {
  153. if e, ok := err.(*net.OpError); ok {
  154. err = e.Err
  155. nerr = e
  156. }
  157. if e, ok := err.(*os.SyscallError); ok {
  158. err = e.Err
  159. }
  160. perr = err
  161. return
  162. }