stm.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. // Copyright 2016 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package concurrency
  15. import (
  16. v3 "github.com/coreos/etcd/clientv3"
  17. "golang.org/x/net/context"
  18. )
  19. // STM is an interface for software transactional memory.
  20. type STM interface {
  21. // Get returns the value for a key and inserts the key in the txn's read set.
  22. // If Get fails, it aborts the transaction with an error, never returning.
  23. Get(key string) string
  24. // Put adds a value for a key to the write set.
  25. Put(key, val string, opts ...v3.OpOption)
  26. // Rev returns the revision of a key in the read set.
  27. Rev(key string) int64
  28. // Del deletes a key.
  29. Del(key string)
  30. // commit attempts to apply the txn's changes to the server.
  31. commit() *v3.TxnResponse
  32. reset()
  33. }
  34. // stmError safely passes STM errors through panic to the STM error channel.
  35. type stmError struct{ err error }
  36. // NewSTMRepeatable initiates new repeatable read transaction; reads within
  37. // the same transaction attempt always return the same data.
  38. func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  39. s := &stm{client: c, ctx: ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
  40. return runSTM(s, apply)
  41. }
  42. // NewSTMSerializable initiates a new serialized transaction; reads within the
  43. // same transactiona attempt return data from the revision of the first read.
  44. func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  45. s := &stmSerializable{
  46. stm: stm{client: c, ctx: ctx},
  47. prefetch: make(map[string]*v3.GetResponse),
  48. }
  49. return runSTM(s, apply)
  50. }
  51. // NewSTMReadCommitted initiates a new read committed transaction.
  52. func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  53. s := &stmReadCommitted{stm{client: c, ctx: ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}}
  54. return runSTM(s, apply)
  55. }
  56. type stmResponse struct {
  57. resp *v3.TxnResponse
  58. err error
  59. }
  60. func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
  61. outc := make(chan stmResponse, 1)
  62. go func() {
  63. defer func() {
  64. if r := recover(); r != nil {
  65. e, ok := r.(stmError)
  66. if !ok {
  67. // client apply panicked
  68. panic(r)
  69. }
  70. outc <- stmResponse{nil, e.err}
  71. }
  72. }()
  73. var out stmResponse
  74. for {
  75. s.reset()
  76. if out.err = apply(s); out.err != nil {
  77. break
  78. }
  79. if out.resp = s.commit(); out.resp != nil {
  80. break
  81. }
  82. }
  83. outc <- out
  84. }()
  85. r := <-outc
  86. return r.resp, r.err
  87. }
  88. // stm implements repeatable-read software transactional memory over etcd
  89. type stm struct {
  90. client *v3.Client
  91. ctx context.Context
  92. // rset holds read key values and revisions
  93. rset map[string]*v3.GetResponse
  94. // wset holds overwritten keys and their values
  95. wset map[string]stmPut
  96. // getOpts are the opts used for gets
  97. getOpts []v3.OpOption
  98. }
  99. type stmPut struct {
  100. val string
  101. op v3.Op
  102. }
  103. func (s *stm) Get(key string) string {
  104. if wv, ok := s.wset[key]; ok {
  105. return wv.val
  106. }
  107. return respToValue(s.fetch(key))
  108. }
  109. func (s *stm) Put(key, val string, opts ...v3.OpOption) {
  110. s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
  111. }
  112. func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
  113. func (s *stm) Rev(key string) int64 {
  114. if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
  115. return resp.Kvs[0].ModRevision
  116. }
  117. return 0
  118. }
  119. func (s *stm) commit() *v3.TxnResponse {
  120. txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...).Commit()
  121. if err != nil {
  122. panic(stmError{err})
  123. }
  124. if txnresp.Succeeded {
  125. return txnresp
  126. }
  127. return nil
  128. }
  129. // cmps guards the txn from updates to read set
  130. func (s *stm) cmps() []v3.Cmp {
  131. cmps := make([]v3.Cmp, 0, len(s.rset))
  132. for k, rk := range s.rset {
  133. cmps = append(cmps, isKeyCurrent(k, rk))
  134. }
  135. return cmps
  136. }
  137. func (s *stm) fetch(key string) *v3.GetResponse {
  138. if resp, ok := s.rset[key]; ok {
  139. return resp
  140. }
  141. resp, err := s.client.Get(s.ctx, key, s.getOpts...)
  142. if err != nil {
  143. panic(stmError{err})
  144. }
  145. s.rset[key] = resp
  146. return resp
  147. }
  148. // puts is the list of ops for all pending writes
  149. func (s *stm) puts() []v3.Op {
  150. puts := make([]v3.Op, 0, len(s.wset))
  151. for _, v := range s.wset {
  152. puts = append(puts, v.op)
  153. }
  154. return puts
  155. }
  156. func (s *stm) reset() {
  157. s.rset = make(map[string]*v3.GetResponse)
  158. s.wset = make(map[string]stmPut)
  159. }
  160. type stmSerializable struct {
  161. stm
  162. prefetch map[string]*v3.GetResponse
  163. }
  164. func (s *stmSerializable) Get(key string) string {
  165. if wv, ok := s.wset[key]; ok {
  166. return wv.val
  167. }
  168. firstRead := len(s.rset) == 0
  169. if resp, ok := s.prefetch[key]; ok {
  170. delete(s.prefetch, key)
  171. s.rset[key] = resp
  172. }
  173. resp := s.stm.fetch(key)
  174. if firstRead {
  175. // txn's base revision is defined by the first read
  176. s.getOpts = []v3.OpOption{
  177. v3.WithRev(resp.Header.Revision),
  178. v3.WithSerializable(),
  179. }
  180. }
  181. return respToValue(resp)
  182. }
  183. func (s *stmSerializable) Rev(key string) int64 {
  184. s.Get(key)
  185. return s.stm.Rev(key)
  186. }
  187. func (s *stmSerializable) gets() ([]string, []v3.Op) {
  188. keys := make([]string, 0, len(s.rset))
  189. ops := make([]v3.Op, 0, len(s.rset))
  190. for k := range s.rset {
  191. keys = append(keys, k)
  192. ops = append(ops, v3.OpGet(k))
  193. }
  194. return keys, ops
  195. }
  196. func (s *stmSerializable) commit() *v3.TxnResponse {
  197. keys, getops := s.gets()
  198. txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...)
  199. // use Else to prefetch keys in case of conflict to save a round trip
  200. txnresp, err := txn.Else(getops...).Commit()
  201. if err != nil {
  202. panic(stmError{err})
  203. }
  204. if txnresp.Succeeded {
  205. return txnresp
  206. }
  207. // load prefetch with Else data
  208. for i := range keys {
  209. resp := txnresp.Responses[i].GetResponseRange()
  210. s.rset[keys[i]] = (*v3.GetResponse)(resp)
  211. }
  212. s.prefetch = s.rset
  213. s.getOpts = nil
  214. return nil
  215. }
  216. type stmReadCommitted struct{ stm }
  217. // commit always goes through when read committed
  218. func (s *stmReadCommitted) commit() *v3.TxnResponse {
  219. s.rset = nil
  220. return s.stm.commit()
  221. }
  222. func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
  223. rev := r.Header.Revision + 1
  224. if len(r.Kvs) != 0 {
  225. rev = r.Kvs[0].ModRevision + 1
  226. }
  227. return v3.Compare(v3.ModRevision(k), "<", rev)
  228. }
  229. func respToValue(resp *v3.GetResponse) string {
  230. if len(resp.Kvs) == 0 {
  231. return ""
  232. }
  233. return string(resp.Kvs[0].Value)
  234. }