tx_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package redis_test
  2. import (
  3. "context"
  4. "strconv"
  5. "sync"
  6. "github.com/go-redis/redis/v7"
  7. . "github.com/onsi/ginkgo"
  8. . "github.com/onsi/gomega"
  9. )
  10. var _ = Describe("Tx", func() {
  11. var client *redis.Client
  12. BeforeEach(func() {
  13. client = redis.NewClient(redisOptions())
  14. Expect(client.FlushDB().Err()).NotTo(HaveOccurred())
  15. })
  16. AfterEach(func() {
  17. Expect(client.Close()).NotTo(HaveOccurred())
  18. })
  19. It("should Watch", func() {
  20. var incr func(string) error
  21. // Transactionally increments key using GET and SET commands.
  22. incr = func(key string) error {
  23. err := client.Watch(func(tx *redis.Tx) error {
  24. n, err := tx.Get(key).Int64()
  25. if err != nil && err != redis.Nil {
  26. return err
  27. }
  28. _, err = tx.TxPipelined(func(pipe redis.Pipeliner) error {
  29. pipe.Set(key, strconv.FormatInt(n+1, 10), 0)
  30. return nil
  31. })
  32. return err
  33. }, key)
  34. if err == redis.TxFailedErr {
  35. return incr(key)
  36. }
  37. return err
  38. }
  39. var wg sync.WaitGroup
  40. for i := 0; i < 100; i++ {
  41. wg.Add(1)
  42. go func() {
  43. defer GinkgoRecover()
  44. defer wg.Done()
  45. err := incr("key")
  46. Expect(err).NotTo(HaveOccurred())
  47. }()
  48. }
  49. wg.Wait()
  50. n, err := client.Get("key").Int64()
  51. Expect(err).NotTo(HaveOccurred())
  52. Expect(n).To(Equal(int64(100)))
  53. })
  54. It("should discard", func() {
  55. err := client.Watch(func(tx *redis.Tx) error {
  56. cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
  57. pipe.Set("key1", "hello1", 0)
  58. pipe.Discard()
  59. pipe.Set("key2", "hello2", 0)
  60. return nil
  61. })
  62. Expect(err).NotTo(HaveOccurred())
  63. Expect(cmds).To(HaveLen(1))
  64. return err
  65. }, "key1", "key2")
  66. Expect(err).NotTo(HaveOccurred())
  67. get := client.Get("key1")
  68. Expect(get.Err()).To(Equal(redis.Nil))
  69. Expect(get.Val()).To(Equal(""))
  70. get = client.Get("key2")
  71. Expect(get.Err()).NotTo(HaveOccurred())
  72. Expect(get.Val()).To(Equal("hello2"))
  73. })
  74. It("returns no error when there are no commands", func() {
  75. err := client.Watch(func(tx *redis.Tx) error {
  76. _, err := tx.TxPipelined(func(redis.Pipeliner) error { return nil })
  77. return err
  78. })
  79. Expect(err).NotTo(HaveOccurred())
  80. v, err := client.Ping().Result()
  81. Expect(err).NotTo(HaveOccurred())
  82. Expect(v).To(Equal("PONG"))
  83. })
  84. It("should exec bulks", func() {
  85. const N = 20000
  86. err := client.Watch(func(tx *redis.Tx) error {
  87. cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
  88. for i := 0; i < N; i++ {
  89. pipe.Incr("key")
  90. }
  91. return nil
  92. })
  93. Expect(err).NotTo(HaveOccurred())
  94. Expect(len(cmds)).To(Equal(N))
  95. for _, cmd := range cmds {
  96. Expect(cmd.Err()).NotTo(HaveOccurred())
  97. }
  98. return err
  99. })
  100. Expect(err).NotTo(HaveOccurred())
  101. num, err := client.Get("key").Int64()
  102. Expect(err).NotTo(HaveOccurred())
  103. Expect(num).To(Equal(int64(N)))
  104. })
  105. It("should recover from bad connection", func() {
  106. // Put bad connection in the pool.
  107. cn, err := client.Pool().Get(context.Background())
  108. Expect(err).NotTo(HaveOccurred())
  109. cn.SetNetConn(&badConn{})
  110. client.Pool().Put(cn)
  111. do := func() error {
  112. err := client.Watch(func(tx *redis.Tx) error {
  113. _, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
  114. pipe.Ping()
  115. return nil
  116. })
  117. return err
  118. })
  119. return err
  120. }
  121. err = do()
  122. Expect(err).To(MatchError("bad connection"))
  123. err = do()
  124. Expect(err).NotTo(HaveOccurred())
  125. })
  126. })