123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387 |
- // Copyright 2016 The etcd Authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package concurrency
- import (
- "context"
- "math"
- v3 "go.etcd.io/etcd/clientv3"
- )
- // STM is an interface for software transactional memory.
- type STM interface {
- // Get returns the value for a key and inserts the key in the txn's read set.
- // If Get fails, it aborts the transaction with an error, never returning.
- Get(key ...string) string
- // Put adds a value for a key to the write set.
- Put(key, val string, opts ...v3.OpOption)
- // Rev returns the revision of a key in the read set.
- Rev(key string) int64
- // Del deletes a key.
- Del(key string)
- // commit attempts to apply the txn's changes to the server.
- commit() *v3.TxnResponse
- reset()
- }
- // Isolation is an enumeration of transactional isolation levels which
- // describes how transactions should interfere and conflict.
- type Isolation int
- const (
- // SerializableSnapshot provides serializable isolation and also checks
- // for write conflicts.
- SerializableSnapshot Isolation = iota
- // Serializable reads within the same transaction attempt return data
- // from the at the revision of the first read.
- Serializable
- // RepeatableReads reads within the same transaction attempt always
- // return the same data.
- RepeatableReads
- // ReadCommitted reads keys from any committed revision.
- ReadCommitted
- )
- // stmError safely passes STM errors through panic to the STM error channel.
- type stmError struct{ err error }
- type stmOptions struct {
- iso Isolation
- ctx context.Context
- prefetch []string
- }
- type stmOption func(*stmOptions)
- // WithIsolation specifies the transaction isolation level.
- func WithIsolation(lvl Isolation) stmOption {
- return func(so *stmOptions) { so.iso = lvl }
- }
- // WithAbortContext specifies the context for permanently aborting the transaction.
- func WithAbortContext(ctx context.Context) stmOption {
- return func(so *stmOptions) { so.ctx = ctx }
- }
- // WithPrefetch is a hint to prefetch a list of keys before trying to apply.
- // If an STM transaction will unconditionally fetch a set of keys, prefetching
- // those keys will save the round-trip cost from requesting each key one by one
- // with Get().
- func WithPrefetch(keys ...string) stmOption {
- return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
- }
- // NewSTM initiates a new STM instance, using serializable snapshot isolation by default.
- func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
- opts := &stmOptions{ctx: c.Ctx()}
- for _, f := range so {
- f(opts)
- }
- if len(opts.prefetch) != 0 {
- f := apply
- apply = func(s STM) error {
- s.Get(opts.prefetch...)
- return f(s)
- }
- }
- return runSTM(mkSTM(c, opts), apply)
- }
- func mkSTM(c *v3.Client, opts *stmOptions) STM {
- switch opts.iso {
- case SerializableSnapshot:
- s := &stmSerializable{
- stm: stm{client: c, ctx: opts.ctx},
- prefetch: make(map[string]*v3.GetResponse),
- }
- s.conflicts = func() []v3.Cmp {
- return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
- }
- return s
- case Serializable:
- s := &stmSerializable{
- stm: stm{client: c, ctx: opts.ctx},
- prefetch: make(map[string]*v3.GetResponse),
- }
- s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
- return s
- case RepeatableReads:
- s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
- s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
- return s
- case ReadCommitted:
- s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
- s.conflicts = func() []v3.Cmp { return nil }
- return s
- default:
- panic("unsupported stm")
- }
- }
- type stmResponse struct {
- resp *v3.TxnResponse
- err error
- }
- func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
- outc := make(chan stmResponse, 1)
- go func() {
- defer func() {
- if r := recover(); r != nil {
- e, ok := r.(stmError)
- if !ok {
- // client apply panicked
- panic(r)
- }
- outc <- stmResponse{nil, e.err}
- }
- }()
- var out stmResponse
- for {
- s.reset()
- if out.err = apply(s); out.err != nil {
- break
- }
- if out.resp = s.commit(); out.resp != nil {
- break
- }
- }
- outc <- out
- }()
- r := <-outc
- return r.resp, r.err
- }
- // stm implements repeatable-read software transactional memory over etcd
- type stm struct {
- client *v3.Client
- ctx context.Context
- // rset holds read key values and revisions
- rset readSet
- // wset holds overwritten keys and their values
- wset writeSet
- // getOpts are the opts used for gets
- getOpts []v3.OpOption
- // conflicts computes the current conflicts on the txn
- conflicts func() []v3.Cmp
- }
- type stmPut struct {
- val string
- op v3.Op
- }
- type readSet map[string]*v3.GetResponse
- func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
- for i, resp := range txnresp.Responses {
- rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
- }
- }
- // first returns the store revision from the first fetch
- func (rs readSet) first() int64 {
- ret := int64(math.MaxInt64 - 1)
- for _, resp := range rs {
- if rev := resp.Header.Revision; rev < ret {
- ret = rev
- }
- }
- return ret
- }
- // cmps guards the txn from updates to read set
- func (rs readSet) cmps() []v3.Cmp {
- cmps := make([]v3.Cmp, 0, len(rs))
- for k, rk := range rs {
- cmps = append(cmps, isKeyCurrent(k, rk))
- }
- return cmps
- }
- type writeSet map[string]stmPut
- func (ws writeSet) get(keys ...string) *stmPut {
- for _, key := range keys {
- if wv, ok := ws[key]; ok {
- return &wv
- }
- }
- return nil
- }
- // cmps returns a cmp list testing no writes have happened past rev
- func (ws writeSet) cmps(rev int64) []v3.Cmp {
- cmps := make([]v3.Cmp, 0, len(ws))
- for key := range ws {
- cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
- }
- return cmps
- }
- // puts is the list of ops for all pending writes
- func (ws writeSet) puts() []v3.Op {
- puts := make([]v3.Op, 0, len(ws))
- for _, v := range ws {
- puts = append(puts, v.op)
- }
- return puts
- }
- func (s *stm) Get(keys ...string) string {
- if wv := s.wset.get(keys...); wv != nil {
- return wv.val
- }
- return respToValue(s.fetch(keys...))
- }
- func (s *stm) Put(key, val string, opts ...v3.OpOption) {
- s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
- }
- func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
- func (s *stm) Rev(key string) int64 {
- if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
- return resp.Kvs[0].ModRevision
- }
- return 0
- }
- func (s *stm) commit() *v3.TxnResponse {
- txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
- if err != nil {
- panic(stmError{err})
- }
- if txnresp.Succeeded {
- return txnresp
- }
- return nil
- }
- func (s *stm) fetch(keys ...string) *v3.GetResponse {
- if len(keys) == 0 {
- return nil
- }
- ops := make([]v3.Op, len(keys))
- for i, key := range keys {
- if resp, ok := s.rset[key]; ok {
- return resp
- }
- ops[i] = v3.OpGet(key, s.getOpts...)
- }
- txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
- if err != nil {
- panic(stmError{err})
- }
- s.rset.add(keys, txnresp)
- return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
- }
- func (s *stm) reset() {
- s.rset = make(map[string]*v3.GetResponse)
- s.wset = make(map[string]stmPut)
- }
- type stmSerializable struct {
- stm
- prefetch map[string]*v3.GetResponse
- }
- func (s *stmSerializable) Get(keys ...string) string {
- if wv := s.wset.get(keys...); wv != nil {
- return wv.val
- }
- firstRead := len(s.rset) == 0
- for _, key := range keys {
- if resp, ok := s.prefetch[key]; ok {
- delete(s.prefetch, key)
- s.rset[key] = resp
- }
- }
- resp := s.stm.fetch(keys...)
- if firstRead {
- // txn's base revision is defined by the first read
- s.getOpts = []v3.OpOption{
- v3.WithRev(resp.Header.Revision),
- v3.WithSerializable(),
- }
- }
- return respToValue(resp)
- }
- func (s *stmSerializable) Rev(key string) int64 {
- s.Get(key)
- return s.stm.Rev(key)
- }
- func (s *stmSerializable) gets() ([]string, []v3.Op) {
- keys := make([]string, 0, len(s.rset))
- ops := make([]v3.Op, 0, len(s.rset))
- for k := range s.rset {
- keys = append(keys, k)
- ops = append(ops, v3.OpGet(k))
- }
- return keys, ops
- }
- func (s *stmSerializable) commit() *v3.TxnResponse {
- keys, getops := s.gets()
- txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
- // use Else to prefetch keys in case of conflict to save a round trip
- txnresp, err := txn.Else(getops...).Commit()
- if err != nil {
- panic(stmError{err})
- }
- if txnresp.Succeeded {
- return txnresp
- }
- // load prefetch with Else data
- s.rset.add(keys, txnresp)
- s.prefetch = s.rset
- s.getOpts = nil
- return nil
- }
- func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
- if len(r.Kvs) != 0 {
- return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
- }
- return v3.Compare(v3.ModRevision(k), "=", 0)
- }
- func respToValue(resp *v3.GetResponse) string {
- if resp == nil || len(resp.Kvs) == 0 {
- return ""
- }
- return string(resp.Kvs[0].Value)
- }
- // NewSTMRepeatable is deprecated.
- func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
- return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads))
- }
- // NewSTMSerializable is deprecated.
- func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
- return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable))
- }
- // NewSTMReadCommitted is deprecated.
- func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
- return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted))
- }
|