stm.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. // Copyright 2016 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package concurrency
  15. import (
  16. "math"
  17. v3 "github.com/coreos/etcd/clientv3"
  18. "golang.org/x/net/context"
  19. )
  20. // STM is an interface for software transactional memory.
  21. type STM interface {
  22. // Get returns the value for a key and inserts the key in the txn's read set.
  23. // If Get fails, it aborts the transaction with an error, never returning.
  24. Get(key ...string) string
  25. // Put adds a value for a key to the write set.
  26. Put(key, val string, opts ...v3.OpOption)
  27. // Rev returns the revision of a key in the read set.
  28. Rev(key string) int64
  29. // Del deletes a key.
  30. Del(key string)
  31. // commit attempts to apply the txn's changes to the server.
  32. commit() *v3.TxnResponse
  33. reset()
  34. }
  35. // Isolation is an enumeration of transactional isolation levels which
  36. // describes how transactions should interfere and conflict.
  37. type Isolation int
  38. const (
  39. // Snapshot is serializable but also checks writes for conflicts.
  40. Snapshot Isolation = iota
  41. // Serializable reads within the same transactiona attempt return data
  42. // from the at the revision of the first read.
  43. Serializable
  44. // RepeatableReads reads within the same transaction attempt always
  45. // return the same data.
  46. RepeatableReads
  47. // ReadCommitted reads keys from any committed revision.
  48. ReadCommitted
  49. )
  50. // stmError safely passes STM errors through panic to the STM error channel.
  51. type stmError struct{ err error }
  52. type stmOptions struct {
  53. iso Isolation
  54. ctx context.Context
  55. prefetch []string
  56. }
  57. type stmOption func(*stmOptions)
  58. // WithIsolation specifies the transaction isolation level.
  59. func WithIsolation(lvl Isolation) stmOption {
  60. return func(so *stmOptions) { so.iso = lvl }
  61. }
  62. // WithAbortContext specifies the context for permanently aborting the transaction.
  63. func WithAbortContext(ctx context.Context) stmOption {
  64. return func(so *stmOptions) { so.ctx = ctx }
  65. }
  66. // WithPrefetch is a hint to prefetch a list of keys before trying to apply.
  67. // If an STM transaction will unconditionally fetch a set of keys, prefetching
  68. // those keys will save the round-trip cost from requesting each key one by one
  69. // with Get().
  70. func WithPrefetch(keys ...string) stmOption {
  71. return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
  72. }
  73. // NewSTM initiates a new STM instance, using snapshot isolation by default.
  74. func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
  75. opts := &stmOptions{ctx: c.Ctx()}
  76. for _, f := range so {
  77. f(opts)
  78. }
  79. if len(opts.prefetch) != 0 {
  80. f := apply
  81. apply = func(s STM) error {
  82. s.Get(opts.prefetch...)
  83. return f(s)
  84. }
  85. }
  86. return runSTM(mkSTM(c, opts), apply)
  87. }
  88. func mkSTM(c *v3.Client, opts *stmOptions) STM {
  89. switch opts.iso {
  90. case Snapshot:
  91. s := &stmSerializable{
  92. stm: stm{client: c, ctx: opts.ctx},
  93. prefetch: make(map[string]*v3.GetResponse),
  94. }
  95. s.conflicts = func() []v3.Cmp {
  96. return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
  97. }
  98. return s
  99. case Serializable:
  100. s := &stmSerializable{
  101. stm: stm{client: c, ctx: opts.ctx},
  102. prefetch: make(map[string]*v3.GetResponse),
  103. }
  104. s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
  105. return s
  106. case RepeatableReads:
  107. s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
  108. s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
  109. return s
  110. case ReadCommitted:
  111. s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
  112. s.conflicts = func() []v3.Cmp { return nil }
  113. return s
  114. default:
  115. panic("unsupported stm")
  116. }
  117. }
  118. type stmResponse struct {
  119. resp *v3.TxnResponse
  120. err error
  121. }
  122. func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
  123. outc := make(chan stmResponse, 1)
  124. go func() {
  125. defer func() {
  126. if r := recover(); r != nil {
  127. e, ok := r.(stmError)
  128. if !ok {
  129. // client apply panicked
  130. panic(r)
  131. }
  132. outc <- stmResponse{nil, e.err}
  133. }
  134. }()
  135. var out stmResponse
  136. for {
  137. s.reset()
  138. if out.err = apply(s); out.err != nil {
  139. break
  140. }
  141. if out.resp = s.commit(); out.resp != nil {
  142. break
  143. }
  144. }
  145. outc <- out
  146. }()
  147. r := <-outc
  148. return r.resp, r.err
  149. }
  150. // stm implements repeatable-read software transactional memory over etcd
  151. type stm struct {
  152. client *v3.Client
  153. ctx context.Context
  154. // rset holds read key values and revisions
  155. rset readSet
  156. // wset holds overwritten keys and their values
  157. wset writeSet
  158. // getOpts are the opts used for gets
  159. getOpts []v3.OpOption
  160. // conflicts computes the current conflicts on the txn
  161. conflicts func() []v3.Cmp
  162. }
  163. type stmPut struct {
  164. val string
  165. op v3.Op
  166. }
  167. type readSet map[string]*v3.GetResponse
  168. func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
  169. for i, resp := range txnresp.Responses {
  170. rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
  171. }
  172. }
  173. func (rs readSet) first() int64 {
  174. ret := int64(math.MaxInt64 - 1)
  175. for _, resp := range rs {
  176. if len(resp.Kvs) > 0 && resp.Kvs[0].ModRevision < ret {
  177. ret = resp.Kvs[0].ModRevision
  178. }
  179. }
  180. return ret
  181. }
  182. // cmps guards the txn from updates to read set
  183. func (rs readSet) cmps() []v3.Cmp {
  184. cmps := make([]v3.Cmp, 0, len(rs))
  185. for k, rk := range rs {
  186. cmps = append(cmps, isKeyCurrent(k, rk))
  187. }
  188. return cmps
  189. }
  190. type writeSet map[string]stmPut
  191. func (ws writeSet) get(keys ...string) *stmPut {
  192. for _, key := range keys {
  193. if wv, ok := ws[key]; ok {
  194. return &wv
  195. }
  196. }
  197. return nil
  198. }
  199. // cmps returns a cmp list testing no writes have happened past rev
  200. func (ws writeSet) cmps(rev int64) []v3.Cmp {
  201. cmps := make([]v3.Cmp, 0, len(ws))
  202. for key := range ws {
  203. cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
  204. }
  205. return cmps
  206. }
  207. // puts is the list of ops for all pending writes
  208. func (ws writeSet) puts() []v3.Op {
  209. puts := make([]v3.Op, 0, len(ws))
  210. for _, v := range ws {
  211. puts = append(puts, v.op)
  212. }
  213. return puts
  214. }
  215. func (s *stm) Get(keys ...string) string {
  216. if wv := s.wset.get(keys...); wv != nil {
  217. return wv.val
  218. }
  219. return respToValue(s.fetch(keys...))
  220. }
  221. func (s *stm) Put(key, val string, opts ...v3.OpOption) {
  222. s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
  223. }
  224. func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
  225. func (s *stm) Rev(key string) int64 {
  226. if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
  227. return resp.Kvs[0].ModRevision
  228. }
  229. return 0
  230. }
  231. func (s *stm) commit() *v3.TxnResponse {
  232. txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
  233. if err != nil {
  234. panic(stmError{err})
  235. }
  236. if txnresp.Succeeded {
  237. return txnresp
  238. }
  239. return nil
  240. }
  241. func (s *stm) fetch(keys ...string) *v3.GetResponse {
  242. if len(keys) == 0 {
  243. return nil
  244. }
  245. ops := make([]v3.Op, len(keys))
  246. for i, key := range keys {
  247. if resp, ok := s.rset[key]; ok {
  248. return resp
  249. }
  250. ops[i] = v3.OpGet(key, s.getOpts...)
  251. }
  252. txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
  253. if err != nil {
  254. panic(stmError{err})
  255. }
  256. s.rset.add(keys, txnresp)
  257. return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
  258. }
  259. func (s *stm) reset() {
  260. s.rset = make(map[string]*v3.GetResponse)
  261. s.wset = make(map[string]stmPut)
  262. }
  263. type stmSerializable struct {
  264. stm
  265. prefetch map[string]*v3.GetResponse
  266. }
  267. func (s *stmSerializable) Get(keys ...string) string {
  268. if wv := s.wset.get(keys...); wv != nil {
  269. return wv.val
  270. }
  271. firstRead := len(s.rset) == 0
  272. for _, key := range keys {
  273. if resp, ok := s.prefetch[key]; ok {
  274. delete(s.prefetch, key)
  275. s.rset[key] = resp
  276. }
  277. }
  278. resp := s.stm.fetch(keys...)
  279. if firstRead {
  280. // txn's base revision is defined by the first read
  281. s.getOpts = []v3.OpOption{
  282. v3.WithRev(resp.Header.Revision),
  283. v3.WithSerializable(),
  284. }
  285. }
  286. return respToValue(resp)
  287. }
  288. func (s *stmSerializable) Rev(key string) int64 {
  289. s.Get(key)
  290. return s.stm.Rev(key)
  291. }
  292. func (s *stmSerializable) gets() ([]string, []v3.Op) {
  293. keys := make([]string, 0, len(s.rset))
  294. ops := make([]v3.Op, 0, len(s.rset))
  295. for k := range s.rset {
  296. keys = append(keys, k)
  297. ops = append(ops, v3.OpGet(k))
  298. }
  299. return keys, ops
  300. }
  301. func (s *stmSerializable) commit() *v3.TxnResponse {
  302. keys, getops := s.gets()
  303. txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
  304. // use Else to prefetch keys in case of conflict to save a round trip
  305. txnresp, err := txn.Else(getops...).Commit()
  306. if err != nil {
  307. panic(stmError{err})
  308. }
  309. if txnresp.Succeeded {
  310. return txnresp
  311. }
  312. // load prefetch with Else data
  313. s.rset.add(keys, txnresp)
  314. s.prefetch = s.rset
  315. s.getOpts = nil
  316. return nil
  317. }
  318. func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
  319. if len(r.Kvs) != 0 {
  320. return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
  321. }
  322. return v3.Compare(v3.ModRevision(k), "=", 0)
  323. }
  324. func respToValue(resp *v3.GetResponse) string {
  325. if resp == nil || len(resp.Kvs) == 0 {
  326. return ""
  327. }
  328. return string(resp.Kvs[0].Value)
  329. }