Pārlūkot izejas kodu

etcdserver, v3rpc: support nested txns

Anthony Romano 8 gadi atpakaļ
vecāks
revīzija
6ed51dc621
3 mainītis faili ar 236 papildinājumiem un 133 dzēšanām
  1. 91 61
      etcdserver/api/v3rpc/key.go
  2. 144 71
      etcdserver/apply.go
  3. 1 1
      etcdserver/server.go

+ 91 - 61
etcdserver/api/v3rpc/key.go

@@ -16,11 +16,10 @@
 package v3rpc
 
 import (
-	"sort"
-
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
+	"github.com/coreos/etcd/pkg/adt"
 	"github.com/coreos/pkg/capnslog"
 	"golang.org/x/net/context"
 )
@@ -89,6 +88,13 @@ func (s *kvServer) Txn(ctx context.Context, r *pb.TxnRequest) (*pb.TxnResponse,
 	if err := checkTxnRequest(r, int(s.maxTxnOps)); err != nil {
 		return nil, err
 	}
+	// check for forbidden put/del overlaps after checking request to avoid quadratic blowup
+	if _, _, err := checkIntervals(r.Success); err != nil {
+		return nil, err
+	}
+	if _, _, err := checkIntervals(r.Failure); err != nil {
+		return nil, err
+	}
 
 	resp, err := s.kv.Txn(ctx, r)
 	if err != nil {
@@ -137,7 +143,14 @@ func checkDeleteRequest(r *pb.DeleteRangeRequest) error {
 }
 
 func checkTxnRequest(r *pb.TxnRequest, maxTxnOps int) error {
-	if len(r.Compare) > maxTxnOps || len(r.Success) > maxTxnOps || len(r.Failure) > maxTxnOps {
+	opc := len(r.Compare)
+	if opc < len(r.Success) {
+		opc = len(r.Success)
+	}
+	if opc < len(r.Failure) {
+		opc = len(r.Failure)
+	}
+	if opc > maxTxnOps {
 		return rpctypes.ErrGRPCTooManyOps
 	}
 
@@ -146,100 +159,117 @@ func checkTxnRequest(r *pb.TxnRequest, maxTxnOps int) error {
 			return rpctypes.ErrGRPCEmptyKey
 		}
 	}
-
 	for _, u := range r.Success {
-		if err := checkRequestOp(u); err != nil {
+		if err := checkRequestOp(u, maxTxnOps-opc); err != nil {
 			return err
 		}
 	}
-	if err := checkRequestDupKeys(r.Success); err != nil {
-		return err
-	}
-
 	for _, u := range r.Failure {
-		if err := checkRequestOp(u); err != nil {
+		if err := checkRequestOp(u, maxTxnOps-opc); err != nil {
 			return err
 		}
 	}
-	return checkRequestDupKeys(r.Failure)
+
+	return nil
 }
 
-// checkRequestDupKeys gives rpctypes.ErrGRPCDuplicateKey if the same key is modified twice
-func checkRequestDupKeys(reqs []*pb.RequestOp) error {
-	// check put overlap
-	keys := make(map[string]struct{})
-	for _, requ := range reqs {
-		tv, ok := requ.Request.(*pb.RequestOp_RequestPut)
+// checkIntervals tests whether puts and deletes overlap for a list of ops. If
+// there is an overlap, returns an error. If no overlap, return put and delete
+// sets for recursive evaluation.
+func checkIntervals(reqs []*pb.RequestOp) (map[string]struct{}, adt.IntervalTree, error) {
+	var dels adt.IntervalTree
+
+	// collect deletes from this level; build first to check lower level overlapped puts
+	for _, req := range reqs {
+		tv, ok := req.Request.(*pb.RequestOp_RequestDeleteRange)
 		if !ok {
 			continue
 		}
-		preq := tv.RequestPut
-		if preq == nil {
+		dreq := tv.RequestDeleteRange
+		if dreq == nil {
 			continue
 		}
-		if _, ok := keys[string(preq.Key)]; ok {
-			return rpctypes.ErrGRPCDuplicateKey
+		var iv adt.Interval
+		if len(dreq.RangeEnd) != 0 {
+			iv = adt.NewStringAffineInterval(string(dreq.Key), string(dreq.RangeEnd))
+		} else {
+			iv = adt.NewStringAffinePoint(string(dreq.Key))
 		}
-		keys[string(preq.Key)] = struct{}{}
-	}
-
-	// no need to check deletes if no puts; delete overlaps are permitted
-	if len(keys) == 0 {
-		return nil
-	}
-
-	// sort keys for range checking
-	sortedKeys := []string{}
-	for k := range keys {
-		sortedKeys = append(sortedKeys, k)
+		dels.Insert(iv, struct{}{})
 	}
-	sort.Strings(sortedKeys)
 
-	// check put overlap with deletes
-	for _, requ := range reqs {
-		tv, ok := requ.Request.(*pb.RequestOp_RequestDeleteRange)
+	// collect children puts/deletes
+	puts := make(map[string]struct{})
+	for _, req := range reqs {
+		tv, ok := req.Request.(*pb.RequestOp_RequestTxn)
 		if !ok {
 			continue
 		}
-		dreq := tv.RequestDeleteRange
-		if dreq == nil {
-			continue
+		putsThen, delsThen, err := checkIntervals(tv.RequestTxn.Success)
+		if err != nil {
+			return nil, dels, err
+		}
+		putsElse, delsElse, err := checkIntervals(tv.RequestTxn.Failure)
+		if err != nil {
+			return nil, dels, err
 		}
-		if dreq.RangeEnd == nil {
-			if _, found := keys[string(dreq.Key)]; found {
-				return rpctypes.ErrGRPCDuplicateKey
+		for k := range putsThen {
+			if _, ok := puts[k]; ok {
+				return nil, dels, rpctypes.ErrGRPCDuplicateKey
 			}
-		} else {
-			lo := sort.SearchStrings(sortedKeys, string(dreq.Key))
-			hi := sort.SearchStrings(sortedKeys, string(dreq.RangeEnd))
-			if lo != hi {
-				// element between lo and hi => overlap
-				return rpctypes.ErrGRPCDuplicateKey
+			if dels.Intersects(adt.NewStringAffinePoint(k)) {
+				return nil, dels, rpctypes.ErrGRPCDuplicateKey
 			}
+			puts[k] = struct{}{}
 		}
+		for k := range putsElse {
+			if _, ok := puts[k]; ok {
+				// if key is from putsThen, overlap is OK since
+				// either then/else are mutually exclusive
+				if _, isSafe := putsThen[k]; !isSafe {
+					return nil, dels, rpctypes.ErrGRPCDuplicateKey
+				}
+			}
+			if dels.Intersects(adt.NewStringAffinePoint(k)) {
+				return nil, dels, rpctypes.ErrGRPCDuplicateKey
+			}
+			puts[k] = struct{}{}
+		}
+		dels.Union(delsThen, adt.NewStringAffineInterval("\x00", ""))
+		dels.Union(delsElse, adt.NewStringAffineInterval("\x00", ""))
 	}
 
-	return nil
+	// collect and check this level's puts
+	for _, req := range reqs {
+		tv, ok := req.Request.(*pb.RequestOp_RequestPut)
+		if !ok || tv.RequestPut == nil {
+			continue
+		}
+		k := string(tv.RequestPut.Key)
+		if _, ok := puts[k]; ok {
+			return nil, dels, rpctypes.ErrGRPCDuplicateKey
+		}
+		if dels.Intersects(adt.NewStringAffinePoint(k)) {
+			return nil, dels, rpctypes.ErrGRPCDuplicateKey
+		}
+		puts[k] = struct{}{}
+	}
+	return puts, dels, nil
 }
 
-func checkRequestOp(u *pb.RequestOp) error {
+func checkRequestOp(u *pb.RequestOp, maxTxnOps int) error {
 	// TODO: ensure only one of the field is set.
 	switch uv := u.Request.(type) {
 	case *pb.RequestOp_RequestRange:
-		if uv.RequestRange != nil {
-			return checkRangeRequest(uv.RequestRange)
-		}
+		return checkRangeRequest(uv.RequestRange)
 	case *pb.RequestOp_RequestPut:
-		if uv.RequestPut != nil {
-			return checkPutRequest(uv.RequestPut)
-		}
+		return checkPutRequest(uv.RequestPut)
 	case *pb.RequestOp_RequestDeleteRange:
-		if uv.RequestDeleteRange != nil {
-			return checkDeleteRequest(uv.RequestDeleteRange)
-		}
+		return checkDeleteRequest(uv.RequestDeleteRange)
+	case *pb.RequestOp_RequestTxn:
+		return checkTxnRequest(uv.RequestTxn, maxTxnOps)
 	default:
 		// empty op / nil entry
 		return rpctypes.ErrGRPCKeyNotFound
 	}
-	return nil
 }

+ 144 - 71
etcdserver/apply.go

@@ -76,14 +76,30 @@ type applierV3 interface {
 	RoleList(ua *pb.AuthRoleListRequest) (*pb.AuthRoleListResponse, error)
 }
 
+type checkReqFunc func(mvcc.ReadView, *pb.RequestOp) error
+
 type applierV3backend struct {
 	s *EtcdServer
+
+	checkPut   checkReqFunc
+	checkRange checkReqFunc
+}
+
+func (s *EtcdServer) newApplierV3Backend() applierV3 {
+	base := &applierV3backend{s: s}
+	base.checkPut = func(rv mvcc.ReadView, req *pb.RequestOp) error {
+		return base.checkRequestPut(rv, req)
+	}
+	base.checkRange = func(rv mvcc.ReadView, req *pb.RequestOp) error {
+		return base.checkRequestRange(rv, req)
+	}
+	return base
 }
 
 func (s *EtcdServer) newApplierV3() applierV3 {
 	return newAuthApplierV3(
 		s.AuthStore(),
-		newQuotaApplierV3(s, &applierV3backend{s}),
+		newQuotaApplierV3(s, s.newApplierV3Backend()),
 		s.lessor,
 	)
 }
@@ -315,24 +331,19 @@ func (a *applierV3backend) Txn(rt *pb.TxnRequest) (*pb.TxnResponse, error) {
 	isWrite := !isTxnReadonly(rt)
 	txn := mvcc.NewReadOnlyTxnWrite(a.s.KV().Read())
 
-	reqs, ok := a.compareToOps(txn, rt)
+	txnPath := compareToPath(txn, rt)
 	if isWrite {
-		if err := a.checkRequestPut(txn, reqs); err != nil {
+		if _, err := checkRequests(txn, rt, txnPath, a.checkPut); err != nil {
 			txn.End()
 			return nil, err
 		}
 	}
-	if err := checkRequestRange(txn, reqs); err != nil {
+	if _, err := checkRequests(txn, rt, txnPath, a.checkRange); err != nil {
 		txn.End()
 		return nil, err
 	}
 
-	resps := make([]*pb.ResponseOp, len(reqs))
-	txnResp := &pb.TxnResponse{
-		Responses: resps,
-		Succeeded: ok,
-		Header:    &pb.ResponseHeader{},
-	}
+	txnResp, _ := newTxnResp(rt, txnPath)
 
 	// When executing mutable txn ops, etcd must hold the txn lock so
 	// readers do not see any intermediate results. Since writes are
@@ -342,9 +353,7 @@ func (a *applierV3backend) Txn(rt *pb.TxnRequest) (*pb.TxnResponse, error) {
 		txn.End()
 		txn = a.s.KV().Write()
 	}
-	for i := range reqs {
-		resps[i] = a.applyUnion(txn, reqs[i])
-	}
+	a.applyTxn(txn, rt, txnPath, txnResp)
 	rev := txn.Rev()
 	if len(txn.Changes()) != 0 {
 		rev++
@@ -355,13 +364,60 @@ func (a *applierV3backend) Txn(rt *pb.TxnRequest) (*pb.TxnResponse, error) {
 	return txnResp, nil
 }
 
-func (a *applierV3backend) compareToOps(rv mvcc.ReadView, rt *pb.TxnRequest) ([]*pb.RequestOp, bool) {
-	for _, c := range rt.Compare {
+// newTxnResp allocates a txn response for a txn request given a path.
+func newTxnResp(rt *pb.TxnRequest, txnPath []bool) (txnResp *pb.TxnResponse, txnCount int) {
+	reqs := rt.Success
+	if !txnPath[0] {
+		reqs = rt.Failure
+	}
+	resps := make([]*pb.ResponseOp, len(reqs))
+	txnResp = &pb.TxnResponse{
+		Responses: resps,
+		Succeeded: txnPath[0],
+		Header:    &pb.ResponseHeader{},
+	}
+	for i, req := range reqs {
+		switch tv := req.Request.(type) {
+		case *pb.RequestOp_RequestRange:
+			resps[i] = &pb.ResponseOp{Response: &pb.ResponseOp_ResponseRange{}}
+		case *pb.RequestOp_RequestPut:
+			resps[i] = &pb.ResponseOp{Response: &pb.ResponseOp_ResponsePut{}}
+		case *pb.RequestOp_RequestDeleteRange:
+			resps[i] = &pb.ResponseOp{Response: &pb.ResponseOp_ResponseDeleteRange{}}
+		case *pb.RequestOp_RequestTxn:
+			resp, txns := newTxnResp(tv.RequestTxn, txnPath[1:])
+			resps[i] = &pb.ResponseOp{Response: &pb.ResponseOp_ResponseTxn{ResponseTxn: resp}}
+			txnPath = txnPath[1+txns:]
+			txnCount += txns + 1
+		default:
+		}
+	}
+	return txnResp, txnCount
+}
+
+func compareToPath(rv mvcc.ReadView, rt *pb.TxnRequest) []bool {
+	txnPath := make([]bool, 1)
+	ops := rt.Success
+	if txnPath[0] = applyCompares(rv, rt.Compare); !txnPath[0] {
+		ops = rt.Failure
+	}
+	for _, op := range ops {
+		tv, ok := op.Request.(*pb.RequestOp_RequestTxn)
+		if !ok || tv.RequestTxn == nil {
+			continue
+		}
+		txnPath = append(txnPath, compareToPath(rv, tv.RequestTxn)...)
+	}
+	return txnPath
+}
+
+func applyCompares(rv mvcc.ReadView, cmps []*pb.Compare) bool {
+	for _, c := range cmps {
 		if !applyCompare(rv, c) {
-			return rt.Failure, false
+			return false
 		}
 	}
-	return rt.Success, true
+	return true
 }
 
 // applyCompare applies the compare request.
@@ -431,38 +487,42 @@ func compareKV(c *pb.Compare, ckv mvccpb.KeyValue) bool {
 	return true
 }
 
-func (a *applierV3backend) applyUnion(txn mvcc.TxnWrite, union *pb.RequestOp) *pb.ResponseOp {
-	switch tv := union.Request.(type) {
-	case *pb.RequestOp_RequestRange:
-		if tv.RequestRange != nil {
+func (a *applierV3backend) applyTxn(txn mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int) {
+	reqs := rt.Success
+	if !txnPath[0] {
+		reqs = rt.Failure
+	}
+	for i, req := range reqs {
+		respi := tresp.Responses[i].Response
+		switch tv := req.Request.(type) {
+		case *pb.RequestOp_RequestRange:
 			resp, err := a.Range(txn, tv.RequestRange)
 			if err != nil {
 				plog.Panicf("unexpected error during txn: %v", err)
 			}
-			return &pb.ResponseOp{Response: &pb.ResponseOp_ResponseRange{ResponseRange: resp}}
-		}
-	case *pb.RequestOp_RequestPut:
-		if tv.RequestPut != nil {
+			respi.(*pb.ResponseOp_ResponseRange).ResponseRange = resp
+		case *pb.RequestOp_RequestPut:
 			resp, err := a.Put(txn, tv.RequestPut)
 			if err != nil {
 				plog.Panicf("unexpected error during txn: %v", err)
 			}
-			return &pb.ResponseOp{Response: &pb.ResponseOp_ResponsePut{ResponsePut: resp}}
-		}
-	case *pb.RequestOp_RequestDeleteRange:
-		if tv.RequestDeleteRange != nil {
+			respi.(*pb.ResponseOp_ResponsePut).ResponsePut = resp
+		case *pb.RequestOp_RequestDeleteRange:
 			resp, err := a.DeleteRange(txn, tv.RequestDeleteRange)
 			if err != nil {
 				plog.Panicf("unexpected error during txn: %v", err)
 			}
-			return &pb.ResponseOp{Response: &pb.ResponseOp_ResponseDeleteRange{ResponseDeleteRange: resp}}
+			respi.(*pb.ResponseOp_ResponseDeleteRange).ResponseDeleteRange = resp
+		case *pb.RequestOp_RequestTxn:
+			resp := respi.(*pb.ResponseOp_ResponseTxn).ResponseTxn
+			applyTxns := a.applyTxn(txn, tv.RequestTxn, txnPath[1:], resp)
+			txns += applyTxns + 1
+			txnPath = txnPath[applyTxns+1:]
+		default:
+			// empty union
 		}
-	default:
-		// empty union
-		return nil
 	}
-	return nil
-
+	return txns
 }
 
 func (a *applierV3backend) Compaction(compaction *pb.CompactionRequest) (*pb.CompactionResponse, <-chan struct{}, error) {
@@ -768,57 +828,70 @@ func (s *kvSortByValue) Less(i, j int) bool {
 	return bytes.Compare(s.kvs[i].Value, s.kvs[j].Value) < 0
 }
 
-func (a *applierV3backend) checkRequestPut(rv mvcc.ReadView, reqs []*pb.RequestOp) error {
-	for _, requ := range reqs {
-		tv, ok := requ.Request.(*pb.RequestOp_RequestPut)
-		if !ok {
-			continue
-		}
-		preq := tv.RequestPut
-		if preq == nil {
-			continue
-		}
-		if preq.IgnoreValue || preq.IgnoreLease {
-			// expects previous key-value, error if not exist
-			rr, err := rv.Range(preq.Key, nil, mvcc.RangeOptions{})
+func checkRequests(rv mvcc.ReadView, rt *pb.TxnRequest, txnPath []bool, f checkReqFunc) (int, error) {
+	txnCount := 0
+	reqs := rt.Success
+	if !txnPath[0] {
+		reqs = rt.Failure
+	}
+	for _, req := range reqs {
+		if tv, ok := req.Request.(*pb.RequestOp_RequestTxn); ok && tv.RequestTxn != nil {
+			txns, err := checkRequests(rv, tv.RequestTxn, txnPath[1:], f)
 			if err != nil {
-				return err
+				return 0, err
 			}
-			if rr == nil || len(rr.KVs) == 0 {
-				return ErrKeyNotFound
-			}
-		}
-		if lease.LeaseID(preq.Lease) == lease.NoLease {
+			txnCount += txns + 1
+			txnPath = txnPath[txns+1:]
 			continue
 		}
-		if l := a.s.lessor.Lookup(lease.LeaseID(preq.Lease)); l == nil {
-			return lease.ErrLeaseNotFound
+		if err := f(rv, req); err != nil {
+			return 0, err
 		}
 	}
-	return nil
+	return txnCount, nil
 }
 
-func checkRequestRange(rv mvcc.ReadView, reqs []*pb.RequestOp) error {
-	for _, requ := range reqs {
-		tv, ok := requ.Request.(*pb.RequestOp_RequestRange)
-		if !ok {
-			continue
-		}
-		greq := tv.RequestRange
-		if greq == nil || greq.Revision == 0 {
-			continue
+func (a *applierV3backend) checkRequestPut(rv mvcc.ReadView, reqOp *pb.RequestOp) error {
+	tv, ok := reqOp.Request.(*pb.RequestOp_RequestPut)
+	if !ok || tv.RequestPut == nil {
+		return nil
+	}
+	req := tv.RequestPut
+	if req.IgnoreValue || req.IgnoreLease {
+		// expects previous key-value, error if not exist
+		rr, err := rv.Range(req.Key, nil, mvcc.RangeOptions{})
+		if err != nil {
+			return err
 		}
-
-		if greq.Revision > rv.Rev() {
-			return mvcc.ErrFutureRev
+		if rr == nil || len(rr.KVs) == 0 {
+			return ErrKeyNotFound
 		}
-		if greq.Revision < rv.FirstRev() {
-			return mvcc.ErrCompacted
+	}
+	if lease.LeaseID(req.Lease) != lease.NoLease {
+		if l := a.s.lessor.Lookup(lease.LeaseID(req.Lease)); l == nil {
+			return lease.ErrLeaseNotFound
 		}
 	}
 	return nil
 }
 
+func (a *applierV3backend) checkRequestRange(rv mvcc.ReadView, reqOp *pb.RequestOp) error {
+	tv, ok := reqOp.Request.(*pb.RequestOp_RequestRange)
+	if !ok || tv.RequestRange == nil {
+		return nil
+	}
+	req := tv.RequestRange
+	switch {
+	case req.Revision == 0:
+		return nil
+	case req.Revision > rv.Rev():
+		return mvcc.ErrFutureRev
+	case req.Revision < rv.FirstRev():
+		return mvcc.ErrCompacted
+	}
+	return nil
+}
+
 func compareInt64(a, b int64) int {
 	switch {
 	case a < b:

+ 1 - 1
etcdserver/server.go

@@ -474,7 +474,7 @@ func NewServer(cfg ServerConfig) (srv *EtcdServer, err error) {
 		srv.compactor.Run()
 	}
 
-	srv.applyV3Base = &applierV3backend{srv}
+	srv.applyV3Base = srv.newApplierV3Backend()
 	if err = srv.restoreAlarms(); err != nil {
 		return nil, err
 	}