Browse Source

etcdctlv3: use clientv3 api for txn command

Anthony Romano 9 years ago
parent
commit
54d15256e7
1 changed files with 82 additions and 135 deletions
  1. 82 135
      etcdctlv3/command/txn_command.go

+ 82 - 135
etcdctlv3/command/txn_command.go

@@ -23,7 +23,7 @@ import (
 
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/spf13/cobra"
 	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
-	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
+	"github.com/coreos/etcd/clientv3"
 )
 
 // NewTxnCommand returns the cobra command for "txn".
@@ -43,13 +43,15 @@ func txnCommandFunc(cmd *cobra.Command, args []string) {
 
 	reader := bufio.NewReader(os.Stdin)
 
-	next := compareState
-	txn := &pb.TxnRequest{}
-	for next != nil {
-		next = next(txn, reader)
-	}
+	txn := clientv3.NewKV(mustClientFromCmd(cmd)).Txn(context.Background())
+	fmt.Println("entry comparison[key target expected_result compare_value] (end with empty line):")
+	txn.If(readCompares(reader)...)
+	fmt.Println("entry success request[method key value(end_range)] (end with empty line):")
+	txn.Then(readOps(reader)...)
+	fmt.Println("entry failure request[method key value(end_range)] (end with empty line):")
+	txn.Else(readOps(reader)...)
 
-	resp, err := mustClientFromCmd(cmd).KV.Txn(context.Background(), txn)
+	resp, err := txn.Commit()
 	if err != nil {
 		ExitWithError(ExitError, err)
 	}
@@ -60,179 +62,124 @@ func txnCommandFunc(cmd *cobra.Command, args []string) {
 	}
 }
 
-type stateFunc func(txn *pb.TxnRequest, r *bufio.Reader) stateFunc
-
-func compareState(txn *pb.TxnRequest, r *bufio.Reader) stateFunc {
-	fmt.Println("entry comparison[key target expected_result compare_value] (end with empty line):")
-
-	line, err := r.ReadString('\n')
-	if err != nil {
-		ExitWithError(ExitInvalidInput, err)
-	}
-
-	if len(line) == 1 {
-		return successState
-	}
-
-	// remove trialling \n
-	line = line[:len(line)-1]
-	c, err := parseCompare(line)
-	if err != nil {
-		ExitWithError(ExitInvalidInput, err)
-	}
-
-	txn.Compare = append(txn.Compare, c)
-
-	return compareState
-}
-
-func successState(txn *pb.TxnRequest, r *bufio.Reader) stateFunc {
-	fmt.Println("entry success request[method key value(end_range)] (end with empty line):")
-
-	line, err := r.ReadString('\n')
-	if err != nil {
-		ExitWithError(ExitInvalidInput, err)
-	}
-
-	if len(line) == 1 {
-		return failureState
-	}
+func readCompares(r *bufio.Reader) (cmps []clientv3.Cmp) {
+	for {
+		line, err := r.ReadString('\n')
+		if err != nil {
+			ExitWithError(ExitInvalidInput, err)
+		}
+		if len(line) == 1 {
+			break
+		}
 
-	// remove trialling \n
-	line = line[:len(line)-1]
-	ru, err := parseRequestUnion(line)
-	if err != nil {
-		ExitWithError(ExitInvalidInput, err)
+		// remove trialling \n
+		line = line[:len(line)-1]
+		cmp, err := parseCompare(line)
+		if err != nil {
+			ExitWithError(ExitInvalidInput, err)
+		}
+		cmps = append(cmps, *cmp)
 	}
 
-	txn.Success = append(txn.Success, ru)
-
-	return successState
+	return cmps
 }
 
-func failureState(txn *pb.TxnRequest, r *bufio.Reader) stateFunc {
-	fmt.Println("entry failure request[method key value(end_range)] (end with empty line):")
-
-	line, err := r.ReadString('\n')
-	if err != nil {
-		ExitWithError(ExitInvalidInput, err)
-	}
-
-	if len(line) == 1 {
-		return nil
-	}
+func readOps(r *bufio.Reader) (ops []clientv3.Op) {
+	for {
+		line, err := r.ReadString('\n')
+		if err != nil {
+			ExitWithError(ExitInvalidInput, err)
+		}
+		if len(line) == 1 {
+			break
+		}
 
-	// remove trialling \n
-	line = line[:len(line)-1]
-	ru, err := parseRequestUnion(line)
-	if err != nil {
-		ExitWithError(ExitInvalidInput, err)
+		// remove trialling \n
+		line = line[:len(line)-1]
+		op, err := parseRequestUnion(line)
+		if err != nil {
+			ExitWithError(ExitInvalidInput, err)
+		}
+		ops = append(ops, *op)
 	}
 
-	txn.Failure = append(txn.Failure, ru)
-
-	return failureState
+	return ops
 }
 
-func parseRequestUnion(line string) (*pb.RequestUnion, error) {
+func parseRequestUnion(line string) (*clientv3.Op, error) {
 	parts := strings.Split(line, " ")
 	if len(parts) < 2 {
 		return nil, fmt.Errorf("invalid txn compare request: %s", line)
 	}
 
-	ru := &pb.RequestUnion{}
-	key := []byte(parts[1])
+	op := &clientv3.Op{}
+	key := parts[1]
 	switch parts[0] {
 	case "r", "range":
 		if len(parts) == 3 {
-			ru.Request = &pb.RequestUnion_RequestRange{
-				RequestRange: &pb.RangeRequest{
-					Key:      key,
-					RangeEnd: []byte(parts[2]),
-				}}
+			*op = clientv3.OpGet(key, clientv3.WithRange(parts[2]))
 		} else {
-			ru.Request = &pb.RequestUnion_RequestRange{
-				RequestRange: &pb.RangeRequest{
-					Key: key,
-				}}
+			*op = clientv3.OpGet(key)
 		}
 	case "p", "put":
-		ru.Request = &pb.RequestUnion_RequestPut{
-			RequestPut: &pb.PutRequest{
-				Key:   key,
-				Value: []byte(parts[2]),
-			}}
+		*op = clientv3.OpPut(key, parts[2])
 	case "d", "deleteRange":
 		if len(parts) == 3 {
-			ru.Request = &pb.RequestUnion_RequestDeleteRange{
-				RequestDeleteRange: &pb.DeleteRangeRequest{
-					Key:      key,
-					RangeEnd: []byte(parts[2]),
-				}}
+			*op = clientv3.OpDelete(key, clientv3.WithRange(parts[2]))
 		} else {
-			ru.Request = &pb.RequestUnion_RequestDeleteRange{
-				RequestDeleteRange: &pb.DeleteRangeRequest{
-					Key: key,
-				}}
+			*op = clientv3.OpDelete(key)
 		}
 	default:
 		return nil, fmt.Errorf("invalid txn request: %s", line)
 	}
-	return ru, nil
+	return op, nil
 }
 
-func parseCompare(line string) (*pb.Compare, error) {
+func parseCompare(line string) (*clientv3.Cmp, error) {
 	parts := strings.Split(line, " ")
 	if len(parts) != 4 {
 		return nil, fmt.Errorf("invalid txn compare request: %s", line)
 	}
 
-	var err error
-	c := &pb.Compare{}
-	c.Key = []byte(parts[0])
+	cmpType := ""
+	switch parts[2] {
+	case "g", "greater":
+		cmpType = ">"
+	case "e", "equal":
+		cmpType = "="
+	case "l", "less":
+		cmpType = "<"
+	default:
+		return nil, fmt.Errorf("invalid txn compare request: %s", line)
+	}
+
+	var (
+		v   int64
+		err error
+		cmp clientv3.Cmp
+	)
+
+	key := parts[0]
 	switch parts[1] {
 	case "ver", "version":
-		tv, _ := c.TargetUnion.(*pb.Compare_Version)
-		if tv != nil {
-			tv.Version, err = strconv.ParseInt(parts[3], 10, 64)
-			if err != nil {
-				return nil, fmt.Errorf("invalid txn compare request: %s", line)
-			}
+		if v, err = strconv.ParseInt(parts[3], 10, 64); err != nil {
+			cmp = clientv3.Compare(clientv3.Version(key), cmpType, v)
 		}
 	case "c", "create":
-		tv, _ := c.TargetUnion.(*pb.Compare_CreateRevision)
-		if tv != nil {
-			tv.CreateRevision, err = strconv.ParseInt(parts[3], 10, 64)
-			if err != nil {
-				return nil, fmt.Errorf("invalid txn compare request: %s", line)
-			}
+		if v, err = strconv.ParseInt(parts[3], 10, 64); err != nil {
+			cmp = clientv3.Compare(clientv3.CreatedRevision(key), cmpType, v)
 		}
 	case "m", "mod":
-		tv, _ := c.TargetUnion.(*pb.Compare_ModRevision)
-		if tv != nil {
-			tv.ModRevision, err = strconv.ParseInt(parts[3], 10, 64)
-			if err != nil {
-				return nil, fmt.Errorf("invalid txn compare request: %s", line)
-			}
+		if v, err = strconv.ParseInt(parts[3], 10, 64); err != nil {
+			cmp = clientv3.Compare(clientv3.ModifiedRevision(key), cmpType, v)
 		}
 	case "val", "value":
-		tv, _ := c.TargetUnion.(*pb.Compare_Value)
-		if tv != nil {
-			tv.Value = []byte(parts[3])
-		}
-	default:
-		return nil, fmt.Errorf("invalid txn compare request: %s", line)
+		cmp = clientv3.Compare(clientv3.Value(key), cmpType, parts[3])
 	}
 
-	switch parts[2] {
-	case "g", "greater":
-		c.Result = pb.Compare_GREATER
-	case "e", "equal":
-		c.Result = pb.Compare_EQUAL
-	case "l", "less":
-		c.Result = pb.Compare_LESS
-	default:
+	if err != nil {
 		return nil, fmt.Errorf("invalid txn compare request: %s", line)
 	}
-	return c, nil
+
+	return &cmp, nil
 }