123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- // Copyright 2016 CoreOS, Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package concurrency
- import (
- "github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
- v3 "github.com/coreos/etcd/clientv3"
- )
- // STM is an interface for software transactional memory.
- type STM interface {
- // Get returns the value for a key and inserts the key in the txn's read set.
- // If Get fails, it aborts the transaction with an error, never returning.
- Get(key string) string
- // Put adds a value for a key to the write set.
- Put(key, val string, opts ...v3.OpOption)
- // Rev returns the revision of a key in the read set.
- Rev(key string) int64
- // Del deletes a key.
- Del(key string)
- // commit attempts to apply the txn's changes to the server.
- commit() *v3.TxnResponse
- reset()
- }
- // stmError safely passes STM errors through panic to the STM error channel.
- type stmError struct{ err error }
- // NewSTMRepeatable initiates new repeatable read transaction; reads within
- // the same transaction attempt always return the same data.
- func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
- s := &stm{client: c, ctx: ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
- return runSTM(s, apply)
- }
- // NewSTMSerializable initiates a new serialized transaction; reads within the
- // same transactiona attempt return data from the revision of the first read.
- func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
- s := &stmSerializable{
- stm: stm{client: c, ctx: ctx},
- prefetch: make(map[string]*v3.GetResponse),
- }
- return runSTM(s, apply)
- }
- type stmResponse struct {
- resp *v3.TxnResponse
- err error
- }
- func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
- outc := make(chan stmResponse, 1)
- go func() {
- defer func() {
- if r := recover(); r != nil {
- e, ok := r.(stmError)
- if !ok {
- // client apply panicked
- panic(r)
- }
- outc <- stmResponse{nil, e.err}
- }
- }()
- var out stmResponse
- for {
- s.reset()
- if out.err = apply(s); out.err != nil {
- break
- }
- if out.resp = s.commit(); out.resp != nil {
- break
- }
- }
- outc <- out
- }()
- r := <-outc
- return r.resp, r.err
- }
- // stm implements repeatable-read software transactional memory over etcd
- type stm struct {
- client *v3.Client
- ctx context.Context
- // rset holds read key values and revisions
- rset map[string]*v3.GetResponse
- // wset holds overwritten keys and their values
- wset map[string]stmPut
- // getOpts are the opts used for gets
- getOpts []v3.OpOption
- }
- type stmPut struct {
- val string
- op v3.Op
- }
- func (s *stm) Get(key string) string {
- if wv, ok := s.wset[key]; ok {
- return wv.val
- }
- return respToValue(s.fetch(key))
- }
- func (s *stm) Put(key, val string, opts ...v3.OpOption) {
- s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
- }
- func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
- func (s *stm) Rev(key string) int64 {
- if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
- return resp.Kvs[0].ModRevision
- }
- return 0
- }
- func (s *stm) commit() *v3.TxnResponse {
- txnresp, err := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...).Commit()
- if err != nil {
- panic(stmError{err})
- }
- if txnresp.Succeeded {
- return txnresp
- }
- return nil
- }
- // cmps guards the txn from updates to read set
- func (s *stm) cmps() (cmps []v3.Cmp) {
- for k, rk := range s.rset {
- cmps = append(cmps, isKeyCurrent(k, rk))
- }
- return
- }
- func (s *stm) fetch(key string) *v3.GetResponse {
- if resp, ok := s.rset[key]; ok {
- return resp
- }
- resp, err := s.client.Get(s.ctx, key, s.getOpts...)
- if err != nil {
- panic(stmError{err})
- }
- s.rset[key] = resp
- return resp
- }
- // puts is the list of ops for all pending writes
- func (s *stm) puts() (puts []v3.Op) {
- for _, v := range s.wset {
- puts = append(puts, v.op)
- }
- return
- }
- func (s *stm) reset() {
- s.rset = make(map[string]*v3.GetResponse)
- s.wset = make(map[string]stmPut)
- }
- type stmSerializable struct {
- stm
- prefetch map[string]*v3.GetResponse
- }
- func (s *stmSerializable) Get(key string) string {
- if wv, ok := s.wset[key]; ok {
- return wv.val
- }
- firstRead := len(s.rset) == 0
- if resp, ok := s.prefetch[key]; ok {
- delete(s.prefetch, key)
- s.rset[key] = resp
- }
- resp := s.stm.fetch(key)
- if firstRead {
- // txn's base revision is defined by the first read
- s.getOpts = []v3.OpOption{
- v3.WithRev(resp.Header.Revision),
- v3.WithSerializable(),
- }
- }
- return respToValue(resp)
- }
- func (s *stmSerializable) Rev(key string) int64 {
- s.Get(key)
- return s.stm.Rev(key)
- }
- func (s *stmSerializable) gets() (keys []string, ops []v3.Op) {
- for k := range s.rset {
- keys = append(keys, k)
- ops = append(ops, v3.OpGet(k))
- }
- return
- }
- func (s *stmSerializable) commit() *v3.TxnResponse {
- keys, getops := s.gets()
- txn := s.client.Txn(s.ctx).If(s.cmps()...).Then(s.puts()...)
- // use Else to prefetch keys in case of conflict to save a round trip
- txnresp, err := txn.Else(getops...).Commit()
- if err != nil {
- panic(stmError{err})
- }
- if txnresp.Succeeded {
- return txnresp
- }
- // load prefetch with Else data
- for i := range keys {
- resp := txnresp.Responses[i].GetResponseRange()
- s.rset[keys[i]] = (*v3.GetResponse)(resp)
- }
- s.prefetch = s.rset
- s.getOpts = nil
- return nil
- }
- func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
- rev := r.Header.Revision + 1
- if len(r.Kvs) != 0 {
- rev = r.Kvs[0].ModRevision + 1
- }
- return v3.Compare(v3.ModifiedRevision(k), "<", rev)
- }
- func respToValue(resp *v3.GetResponse) string {
- if len(resp.Kvs) == 0 {
- return ""
- }
- return string(resp.Kvs[0].Value)
- }
|