stm.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. // Copyright 2016 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package concurrency
  15. import (
  16. "math"
  17. v3 "github.com/coreos/etcd/clientv3"
  18. "golang.org/x/net/context"
  19. )
  20. // STM is an interface for software transactional memory.
  21. type STM interface {
  22. // Get returns the value for a key and inserts the key in the txn's read set.
  23. // If Get fails, it aborts the transaction with an error, never returning.
  24. Get(key ...string) string
  25. // Put adds a value for a key to the write set.
  26. Put(key, val string, opts ...v3.OpOption)
  27. // Rev returns the revision of a key in the read set.
  28. Rev(key string) int64
  29. // Del deletes a key.
  30. Del(key string)
  31. // commit attempts to apply the txn's changes to the server.
  32. commit() *v3.TxnResponse
  33. reset()
  34. }
  35. // Isolation is an enumeration of transactional isolation levels which
  36. // describes how transactions should interfere and conflict.
  37. type Isolation int
  38. const (
  39. // SerializableSnapshot provides serializable isolation and also checks
  40. // for write conflicts.
  41. SerializableSnapshot Isolation = iota
  42. // Serializable reads within the same transactiona attempt return data
  43. // from the at the revision of the first read.
  44. Serializable
  45. // RepeatableReads reads within the same transaction attempt always
  46. // return the same data.
  47. RepeatableReads
  48. // ReadCommitted reads keys from any committed revision.
  49. ReadCommitted
  50. )
  51. // stmError safely passes STM errors through panic to the STM error channel.
  52. type stmError struct{ err error }
  53. type stmOptions struct {
  54. iso Isolation
  55. ctx context.Context
  56. prefetch []string
  57. }
  58. type stmOption func(*stmOptions)
  59. // WithIsolation specifies the transaction isolation level.
  60. func WithIsolation(lvl Isolation) stmOption {
  61. return func(so *stmOptions) { so.iso = lvl }
  62. }
  63. // WithAbortContext specifies the context for permanently aborting the transaction.
  64. func WithAbortContext(ctx context.Context) stmOption {
  65. return func(so *stmOptions) { so.ctx = ctx }
  66. }
  67. // WithPrefetch is a hint to prefetch a list of keys before trying to apply.
  68. // If an STM transaction will unconditionally fetch a set of keys, prefetching
  69. // those keys will save the round-trip cost from requesting each key one by one
  70. // with Get().
  71. func WithPrefetch(keys ...string) stmOption {
  72. return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
  73. }
  74. // NewSTM initiates a new STM instance, using serializable snapshot isolation by default.
  75. func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
  76. opts := &stmOptions{ctx: c.Ctx()}
  77. for _, f := range so {
  78. f(opts)
  79. }
  80. if len(opts.prefetch) != 0 {
  81. f := apply
  82. apply = func(s STM) error {
  83. s.Get(opts.prefetch...)
  84. return f(s)
  85. }
  86. }
  87. return runSTM(mkSTM(c, opts), apply)
  88. }
  89. func mkSTM(c *v3.Client, opts *stmOptions) STM {
  90. switch opts.iso {
  91. case SerializableSnapshot:
  92. s := &stmSerializable{
  93. stm: stm{client: c, ctx: opts.ctx},
  94. prefetch: make(map[string]*v3.GetResponse),
  95. }
  96. s.conflicts = func() []v3.Cmp {
  97. return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
  98. }
  99. return s
  100. case Serializable:
  101. s := &stmSerializable{
  102. stm: stm{client: c, ctx: opts.ctx},
  103. prefetch: make(map[string]*v3.GetResponse),
  104. }
  105. s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
  106. return s
  107. case RepeatableReads:
  108. s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
  109. s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
  110. return s
  111. case ReadCommitted:
  112. s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
  113. s.conflicts = func() []v3.Cmp { return nil }
  114. return s
  115. default:
  116. panic("unsupported stm")
  117. }
  118. }
  119. type stmResponse struct {
  120. resp *v3.TxnResponse
  121. err error
  122. }
  123. func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
  124. outc := make(chan stmResponse, 1)
  125. go func() {
  126. defer func() {
  127. if r := recover(); r != nil {
  128. e, ok := r.(stmError)
  129. if !ok {
  130. // client apply panicked
  131. panic(r)
  132. }
  133. outc <- stmResponse{nil, e.err}
  134. }
  135. }()
  136. var out stmResponse
  137. for {
  138. s.reset()
  139. if out.err = apply(s); out.err != nil {
  140. break
  141. }
  142. if out.resp = s.commit(); out.resp != nil {
  143. break
  144. }
  145. }
  146. outc <- out
  147. }()
  148. r := <-outc
  149. return r.resp, r.err
  150. }
  151. // stm implements repeatable-read software transactional memory over etcd
  152. type stm struct {
  153. client *v3.Client
  154. ctx context.Context
  155. // rset holds read key values and revisions
  156. rset readSet
  157. // wset holds overwritten keys and their values
  158. wset writeSet
  159. // getOpts are the opts used for gets
  160. getOpts []v3.OpOption
  161. // conflicts computes the current conflicts on the txn
  162. conflicts func() []v3.Cmp
  163. }
  164. type stmPut struct {
  165. val string
  166. op v3.Op
  167. }
  168. type readSet map[string]*v3.GetResponse
  169. func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
  170. for i, resp := range txnresp.Responses {
  171. rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
  172. }
  173. }
  174. func (rs readSet) first() int64 {
  175. ret := int64(math.MaxInt64 - 1)
  176. for _, resp := range rs {
  177. if len(resp.Kvs) > 0 && resp.Kvs[0].ModRevision < ret {
  178. ret = resp.Kvs[0].ModRevision
  179. }
  180. }
  181. return ret
  182. }
  183. // cmps guards the txn from updates to read set
  184. func (rs readSet) cmps() []v3.Cmp {
  185. cmps := make([]v3.Cmp, 0, len(rs))
  186. for k, rk := range rs {
  187. cmps = append(cmps, isKeyCurrent(k, rk))
  188. }
  189. return cmps
  190. }
  191. type writeSet map[string]stmPut
  192. func (ws writeSet) get(keys ...string) *stmPut {
  193. for _, key := range keys {
  194. if wv, ok := ws[key]; ok {
  195. return &wv
  196. }
  197. }
  198. return nil
  199. }
  200. // cmps returns a cmp list testing no writes have happened past rev
  201. func (ws writeSet) cmps(rev int64) []v3.Cmp {
  202. cmps := make([]v3.Cmp, 0, len(ws))
  203. for key := range ws {
  204. cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
  205. }
  206. return cmps
  207. }
  208. // puts is the list of ops for all pending writes
  209. func (ws writeSet) puts() []v3.Op {
  210. puts := make([]v3.Op, 0, len(ws))
  211. for _, v := range ws {
  212. puts = append(puts, v.op)
  213. }
  214. return puts
  215. }
  216. func (s *stm) Get(keys ...string) string {
  217. if wv := s.wset.get(keys...); wv != nil {
  218. return wv.val
  219. }
  220. return respToValue(s.fetch(keys...))
  221. }
  222. func (s *stm) Put(key, val string, opts ...v3.OpOption) {
  223. s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
  224. }
  225. func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
  226. func (s *stm) Rev(key string) int64 {
  227. if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
  228. return resp.Kvs[0].ModRevision
  229. }
  230. return 0
  231. }
  232. func (s *stm) commit() *v3.TxnResponse {
  233. txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
  234. if err != nil {
  235. panic(stmError{err})
  236. }
  237. if txnresp.Succeeded {
  238. return txnresp
  239. }
  240. return nil
  241. }
  242. func (s *stm) fetch(keys ...string) *v3.GetResponse {
  243. if len(keys) == 0 {
  244. return nil
  245. }
  246. ops := make([]v3.Op, len(keys))
  247. for i, key := range keys {
  248. if resp, ok := s.rset[key]; ok {
  249. return resp
  250. }
  251. ops[i] = v3.OpGet(key, s.getOpts...)
  252. }
  253. txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
  254. if err != nil {
  255. panic(stmError{err})
  256. }
  257. s.rset.add(keys, txnresp)
  258. return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
  259. }
  260. func (s *stm) reset() {
  261. s.rset = make(map[string]*v3.GetResponse)
  262. s.wset = make(map[string]stmPut)
  263. }
  264. type stmSerializable struct {
  265. stm
  266. prefetch map[string]*v3.GetResponse
  267. }
  268. func (s *stmSerializable) Get(keys ...string) string {
  269. if wv := s.wset.get(keys...); wv != nil {
  270. return wv.val
  271. }
  272. firstRead := len(s.rset) == 0
  273. for _, key := range keys {
  274. if resp, ok := s.prefetch[key]; ok {
  275. delete(s.prefetch, key)
  276. s.rset[key] = resp
  277. }
  278. }
  279. resp := s.stm.fetch(keys...)
  280. if firstRead {
  281. // txn's base revision is defined by the first read
  282. s.getOpts = []v3.OpOption{
  283. v3.WithRev(resp.Header.Revision),
  284. v3.WithSerializable(),
  285. }
  286. }
  287. return respToValue(resp)
  288. }
  289. func (s *stmSerializable) Rev(key string) int64 {
  290. s.Get(key)
  291. return s.stm.Rev(key)
  292. }
  293. func (s *stmSerializable) gets() ([]string, []v3.Op) {
  294. keys := make([]string, 0, len(s.rset))
  295. ops := make([]v3.Op, 0, len(s.rset))
  296. for k := range s.rset {
  297. keys = append(keys, k)
  298. ops = append(ops, v3.OpGet(k))
  299. }
  300. return keys, ops
  301. }
  302. func (s *stmSerializable) commit() *v3.TxnResponse {
  303. keys, getops := s.gets()
  304. txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
  305. // use Else to prefetch keys in case of conflict to save a round trip
  306. txnresp, err := txn.Else(getops...).Commit()
  307. if err != nil {
  308. panic(stmError{err})
  309. }
  310. if txnresp.Succeeded {
  311. return txnresp
  312. }
  313. // load prefetch with Else data
  314. s.rset.add(keys, txnresp)
  315. s.prefetch = s.rset
  316. s.getOpts = nil
  317. return nil
  318. }
  319. func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
  320. if len(r.Kvs) != 0 {
  321. return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
  322. }
  323. return v3.Compare(v3.ModRevision(k), "=", 0)
  324. }
  325. func respToValue(resp *v3.GetResponse) string {
  326. if resp == nil || len(resp.Kvs) == 0 {
  327. return ""
  328. }
  329. return string(resp.Kvs[0].Value)
  330. }
  331. // NewSTMRepeatable is deprecated.
  332. func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  333. return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads))
  334. }
  335. // NewSTMSerializable is deprecated.
  336. func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  337. return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable))
  338. }
  339. // NewSTMReadCommitted is deprecated.
  340. func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  341. return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted))
  342. }