stm.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. // Copyright 2016 CoreOS, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package concurrency
  15. import (
  16. "github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
  17. v3 "github.com/coreos/etcd/clientv3"
  18. )
  19. // STM is an interface for software transactional memory.
  20. type STM interface {
  21. // Get returns the value for a key and inserts the key in the txn's read set.
  22. // If Get fails, it aborts the transaction with an error, never returning.
  23. Get(key string) string
  24. // Put adds a value for a key to the write set.
  25. Put(key, val string, opts ...v3.OpOption)
  26. // Rev returns the revision of a key in the read set.
  27. Rev(key string) int64
  28. // Del deletes a key.
  29. Del(key string)
  30. // commit attempts to apply the txn's changes to the server.
  31. commit() *v3.TxnResponse
  32. reset()
  33. }
  34. // stmError safely passes STM errors through panic to the STM error channel.
  35. type stmError struct{ err error }
  36. // NewSTMRepeatable initiates new repeatable read transaction; reads within
  37. // the same transaction attempt always return the same data.
  38. func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  39. s := &stm{client: c, ctx: ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
  40. return runSTM(s, apply)
  41. }
  42. // NewSTMSerializable initiates a new serialized transaction; reads within the
  43. // same transactiona attempt return data from the revision of the first read.
  44. func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
  45. s := &stmSerializable{
  46. stm: stm{client: c, ctx: ctx},
  47. prefetch: make(map[string]*v3.GetResponse),
  48. }
  49. return runSTM(s, apply)
  50. }
  51. type stmResponse struct {
  52. resp *v3.TxnResponse
  53. err error
  54. }
  55. func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
  56. outc := make(chan stmResponse, 1)
  57. go func() {
  58. defer func() {
  59. if r := recover(); r != nil {
  60. e, ok := r.(stmError)
  61. if !ok {
  62. // client apply panicked
  63. panic(r)
  64. }
  65. outc <- stmResponse{nil, e.err}
  66. }
  67. }()
  68. var out stmResponse
  69. for {
  70. s.reset()
  71. if out.err = apply(s); out.err != nil {
  72. break
  73. }
  74. if out.resp = s.commit(); out.resp != nil {
  75. break
  76. }
  77. }
  78. outc <- out
  79. }()
  80. r := <-outc
  81. return r.resp, r.err
  82. }
  83. // stm implements repeatable-read software transactional memory over etcd
  84. type stm struct {
  85. client *v3.Client
  86. ctx context.Context
  87. // rset holds read key values and revisions
  88. rset map[string]*v3.GetResponse
  89. // wset holds overwritten keys and their values
  90. wset map[string]stmPut
  91. // getOpts are the opts used for gets
  92. getOpts []v3.OpOption
  93. }
  94. type stmPut struct {
  95. val string
  96. op v3.Op
  97. }
  98. func (s *stm) Get(key string) string {
  99. if wv, ok := s.wset[key]; ok {
  100. return wv.val
  101. }
  102. return respToValue(s.fetch(key))
  103. }
  104. func (s *stm) Put(key, val string, opts ...v3.OpOption) {
  105. s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
  106. }
  107. func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
  108. func (s *stm) Rev(key string) int64 {
  109. if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
  110. return resp.Kvs[0].ModRevision
  111. }
  112. return 0
  113. }
  114. func (s *stm) commit() *v3.TxnResponse {
  115. txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...).Commit()
  116. if err != nil {
  117. panic(stmError{err})
  118. }
  119. if txnresp.Succeeded {
  120. return txnresp
  121. }
  122. return nil
  123. }
  124. // cmps guards the txn from updates to read set
  125. func (s *stm) cmps() (cmps []v3.Cmp) {
  126. for k, rk := range s.rset {
  127. cmps = append(cmps, isKeyCurrent(k, rk))
  128. }
  129. return
  130. }
  131. func (s *stm) fetch(key string) *v3.GetResponse {
  132. if resp, ok := s.rset[key]; ok {
  133. return resp
  134. }
  135. resp, err := s.client.Get(s.ctx, key, s.getOpts...)
  136. if err != nil {
  137. panic(stmError{err})
  138. }
  139. s.rset[key] = resp
  140. return resp
  141. }
  142. // puts is the list of ops for all pending writes
  143. func (s *stm) puts() (puts []v3.Op) {
  144. for _, v := range s.wset {
  145. puts = append(puts, v.op)
  146. }
  147. return
  148. }
  149. func (s *stm) reset() {
  150. s.rset = make(map[string]*v3.GetResponse)
  151. s.wset = make(map[string]stmPut)
  152. }
  153. type stmSerializable struct {
  154. stm
  155. prefetch map[string]*v3.GetResponse
  156. }
  157. func (s *stmSerializable) Get(key string) string {
  158. if wv, ok := s.wset[key]; ok {
  159. return wv.val
  160. }
  161. firstRead := len(s.rset) == 0
  162. if resp, ok := s.prefetch[key]; ok {
  163. delete(s.prefetch, key)
  164. s.rset[key] = resp
  165. }
  166. resp := s.stm.fetch(key)
  167. if firstRead {
  168. // txn's base revision is defined by the first read
  169. s.getOpts = []v3.OpOption{
  170. v3.WithRev(resp.Header.Revision),
  171. v3.WithSerializable(),
  172. }
  173. }
  174. return respToValue(resp)
  175. }
  176. func (s *stmSerializable) Rev(key string) int64 {
  177. s.Get(key)
  178. return s.stm.Rev(key)
  179. }
  180. func (s *stmSerializable) gets() (keys []string, ops []v3.Op) {
  181. for k := range s.rset {
  182. keys = append(keys, k)
  183. ops = append(ops, v3.OpGet(k))
  184. }
  185. return
  186. }
  187. func (s *stmSerializable) commit() *v3.TxnResponse {
  188. keys, getops := s.gets()
  189. txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...)
  190. // use Else to prefetch keys in case of conflict to save a round trip
  191. txnresp, err := txn.Else(getops...).Commit()
  192. if err != nil {
  193. panic(stmError{err})
  194. }
  195. if txnresp.Succeeded {
  196. return txnresp
  197. }
  198. // load prefetch with Else data
  199. for i := range keys {
  200. resp := txnresp.Responses[i].GetResponseRange()
  201. s.rset[keys[i]] = (*v3.GetResponse)(resp)
  202. }
  203. s.prefetch = s.rset
  204. s.getOpts = nil
  205. return nil
  206. }
  207. func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
  208. rev := r.Header.Revision + 1
  209. if len(r.Kvs) != 0 {
  210. rev = r.Kvs[0].ModRevision + 1
  211. }
  212. return v3.Compare(v3.ModifiedRevision(k), "<", rev)
  213. }
  214. func respToValue(resp *v3.GetResponse) string {
  215. if len(resp.Kvs) == 0 {
  216. return ""
  217. }
  218. return string(resp.Kvs[0].Value)
  219. }