stm.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. // Copyright 2016 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package concurrency
  15. import (
  16. "context"
  17. "math"
  18. v3 "github.com/coreos/etcd/clientv3"
  19. )
  20. // STM is an interface for software transactional memory.
  21. type STM interface {
  22. // Get returns the value for a key and inserts the key in the txn's read set.
  23. // If Get fails, it aborts the transaction with an error, never returning.
  24. Get(key ...string) string
  25. // Put adds a value for a key to the write set.
  26. Put(key, val string, opts ...v3.OpOption)
  27. // Rev returns the revision of a key in the read set.
  28. Rev(key string) int64
  29. // Del deletes a key.
  30. Del(key string)
  31. // commit attempts to apply the txn's changes to the server.
  32. commit() *v3.TxnResponse
  33. reset()
  34. }
  35. // Isolation is an enumeration of transactional isolation levels which
  36. // describes how transactions should interfere and conflict.
  37. type Isolation int
  38. const (
  39. // SerializableSnapshot provides serializable isolation and also checks
  40. // for write conflicts.
  41. SerializableSnapshot Isolation = iota
  42. // Serializable reads within the same transaction attempt return data
  43. // from the at the revision of the first read.
  44. Serializable
  45. // RepeatableReads reads within the same transaction attempt always
  46. // return the same data.
  47. RepeatableReads
  48. // ReadCommitted reads keys from any committed revision.
  49. ReadCommitted
  50. )
  51. // stmError safely passes STM errors through panic to the STM error channel.
  52. type stmError struct{ err error }
  53. type stmOptions struct {
  54. iso Isolation
  55. ctx context.Context
  56. prefetch []string
  57. }
  58. type stmOption func(*stmOptions)
  59. // WithIsolation specifies the transaction isolation level.
  60. func WithIsolation(lvl Isolation) stmOption {
  61. return func(so *stmOptions) { so.iso = lvl }
  62. }
  63. // WithAbortContext specifies the context for permanently aborting the transaction.
  64. func WithAbortContext(ctx context.Context) stmOption {
  65. return func(so *stmOptions) { so.ctx = ctx }
  66. }
  67. // WithPrefetch is a hint to prefetch a list of keys before trying to apply.
  68. // If an STM transaction will unconditionally fetch a set of keys, prefetching
  69. // those keys will save the round-trip cost from requesting each key one by one
  70. // with Get().
  71. func WithPrefetch(keys ...string) stmOption {
  72. return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
  73. }
  74. // NewSTM initiates a new STM instance, using serializable snapshot isolation by default.
  75. func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
  76. opts := &stmOptions{ctx: c.Ctx()}
  77. for _, f := range so {
  78. f(opts)
  79. }
  80. if len(opts.prefetch) != 0 {
  81. f := apply
  82. apply = func(s STM) error {
  83. s.Get(opts.prefetch...)
  84. return f(s)
  85. }
  86. }
  87. return runSTM(mkSTM(c, opts), apply)
  88. }
  89. func mkSTM(c *v3.Client, opts *stmOptions) STM {
  90. switch opts.iso {
  91. case SerializableSnapshot:
  92. s := &stmSerializable{
  93. stm: stm{client: c, ctx: opts.ctx},
  94. prefetch: make(map[string]*v3.GetResponse),
  95. }
  96. s.conflicts = func() []v3.Cmp {
  97. return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
  98. }
  99. return s
  100. case Serializable:
  101. s := &stmSerializable{
  102. stm: stm{client: c, ctx: opts.ctx},
  103. prefetch: make(map[string]*v3.GetResponse),
  104. }
  105. s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
  106. return s
  107. case RepeatableReads:
  108. s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
  109. s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
  110. return s
  111. case ReadCommitted:
  112. s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
  113. s.conflicts = func() []v3.Cmp { return nil }
  114. return s
  115. default:
  116. panic("unsupported stm")
  117. }
  118. }
  119. type stmResponse struct {
  120. resp *v3.TxnResponse
  121. err error
  122. }
  123. func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
  124. outc := make(chan stmResponse, 1)
  125. go func() {
  126. defer func() {
  127. if r := recover(); r != nil {
  128. e, ok := r.(stmError)
  129. if !ok {
  130. // client apply panicked
  131. panic(r)
  132. }
  133. outc <- stmResponse{nil, e.err}
  134. }
  135. }()
  136. var out stmResponse
  137. for {
  138. s.reset()
  139. if out.err = apply(s); out.err != nil {
  140. break
  141. }
  142. if out.resp = s.commit(); out.resp != nil {
  143. break
  144. }
  145. }
  146. outc <- out
  147. }()
  148. r := <-outc
  149. return r.resp, r.err
  150. }
  151. // stm implements repeatable-read software transactional memory over etcd
  152. type stm struct {
  153. client *v3.Client
  154. ctx context.Context
  155. // rset holds read key values and revisions
  156. rset readSet
  157. // wset holds overwritten keys and their values
  158. wset writeSet
  159. // getOpts are the opts used for gets
  160. getOpts []v3.OpOption
  161. // conflicts computes the current conflicts on the txn
  162. conflicts func() []v3.Cmp
  163. }
  164. type stmPut struct {
  165. val string
  166. op v3.Op
  167. }
  168. type readSet map[string]*v3.GetResponse
  169. func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
  170. for i, resp := range txnresp.Responses {
  171. rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
  172. }
  173. }
  174. // first returns the store revision from the first fetch
  175. func (rs readSet) first() int64 {
  176. ret := int64(math.MaxInt64 - 1)
  177. for _, resp := range rs {
  178. if rev := resp.Header.Revision; rev < ret {
  179. ret = rev
  180. }
  181. }
  182. return ret
  183. }
  184. // cmps guards the txn from updates to read set
  185. func (rs readSet) cmps() []v3.Cmp {
  186. cmps := make([]v3.Cmp, 0, len(rs))
  187. for k, rk := range rs {
  188. cmps = append(cmps, isKeyCurrent(k, rk))
  189. }
  190. return cmps
  191. }
  192. type writeSet map[string]stmPut
  193. func (ws writeSet) get(keys ...string) *stmPut {
  194. for _, key := range keys {
  195. if wv, ok := ws[key]; ok {
  196. return &wv
  197. }
  198. }
  199. return nil
  200. }
  201. // cmps returns a cmp list testing no writes have happened past rev
  202. func (ws writeSet) cmps(rev int64) []v3.Cmp {
  203. cmps := make([]v3.Cmp, 0, len(ws))
  204. for key := range ws {
  205. cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
  206. }
  207. return cmps
  208. }
  209. // puts is the list of ops for all pending writes
  210. func (ws writeSet) puts() []v3.Op {
  211. puts := make([]v3.Op, 0, len(ws))
  212. for _, v := range ws {
  213. puts = append(puts, v.op)
  214. }
  215. return puts
  216. }
  217. func (s *stm) Get(keys ...string) string {
  218. if wv := s.wset.get(keys...); wv != nil {
  219. return wv.val
  220. }
  221. return respToValue(s.fetch(keys...))
  222. }
  223. func (s *stm) Put(key, val string, opts ...v3.OpOption) {
  224. s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
  225. }
  226. func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
  227. func (s *stm) Rev(key string) int64 {
  228. if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
  229. return resp.Kvs[0].ModRevision
  230. }
  231. return 0
  232. }
  233. func (s *stm) commit() *v3.TxnResponse {
  234. txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
  235. if err != nil {
  236. panic(stmError{err})
  237. }
  238. if txnresp.Succeeded {
  239. return txnresp
  240. }
  241. return nil
  242. }
  243. func (s *stm) fetch(keys ...string) *v3.GetResponse {
  244. if len(keys) == 0 {
  245. return nil
  246. }
  247. ops := make([]v3.Op, len(keys))
  248. for i, key := range keys {
  249. if resp, ok := s.rset[key]; ok {
  250. return resp
  251. }
  252. ops[i] = v3.OpGet(key, s.getOpts...)
  253. }
  254. txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
  255. if err != nil {
  256. panic(stmError{err})
  257. }
  258. s.rset.add(keys, txnresp)
  259. return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
  260. }
  261. func (s *stm) reset() {
  262. s.rset = make(map[string]*v3.GetResponse)
  263. s.wset = make(map[string]stmPut)
  264. }
  265. type stmSerializable struct {
  266. stm
  267. prefetch map[string]*v3.GetResponse
  268. }
  269. func (s *stmSerializable) Get(keys ...string) string {
  270. if wv := s.wset.get(keys...); wv != nil {
  271. return wv.val
  272. }
  273. firstRead := len(s.rset) == 0
  274. for _, key := range keys {
  275. if resp, ok := s.prefetch[key]; ok {
  276. delete(s.prefetch, key)
  277. s.rset[key] = resp
  278. }
  279. }
  280. resp := s.stm.fetch(keys...)
  281. if firstRead {
  282. // txn's base revision is defined by the first read
  283. s.getOpts = []v3.OpOption{
  284. v3.WithRev(resp.Header.Revision),
  285. v3.WithSerializable(),
  286. }
  287. }
  288. return respToValue(resp)
  289. }
  290. func (s *stmSerializable) Rev(key string) int64 {
  291. s.Get(key)
  292. return s.stm.Rev(key)
  293. }
  294. func (s *stmSerializable) gets() ([]string, []v3.Op) {
  295. keys := make([]string, 0, len(s.rset))
  296. ops := make([]v3.Op, 0, len(s.rset))
  297. for k := range s.rset {
  298. keys = append(keys, k)
  299. ops = append(ops, v3.OpGet(k))
  300. }
  301. return keys, ops
  302. }
  303. func (s *stmSerializable) commit() *v3.TxnResponse {
  304. keys, getops := s.gets()
  305. txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
  306. // use Else to prefetch keys in case of conflict to save a round trip
  307. txnresp, err := txn.Else(getops...).Commit()
  308. if err != nil {
  309. panic(stmError{err})
  310. }
  311. if txnresp.Succeeded {
  312. return txnresp
  313. }
  314. // load prefetch with Else data
  315. s.rset.add(keys, txnresp)
  316. s.prefetch = s.rset
  317. s.getOpts = nil
  318. return nil
  319. }
  320. func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
  321. if len(r.Kvs) != 0 {
  322. return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
  323. }
  324. return v3.Compare(v3.ModRevision(k), "=", 0)
  325. }
  326. func respToValue(resp *v3.GetResponse) string {
  327. if resp == nil || len(resp.Kvs) == 0 {
  328. return ""
  329. }
  330. return string(resp.Kvs[0].Value)
  331. }
  332. // NewSTMRepeatable is deprecated.
  333. func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  334. return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads))
  335. }
  336. // NewSTMSerializable is deprecated.
  337. func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  338. return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable))
  339. }
  340. // NewSTMReadCommitted is deprecated.
  341. func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  342. return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted))
  343. }