Ver Fonte

Refactor Scan/ScanCAS/MapScanCAS common logics

ono_matope há 11 anos atrás
pai
commit
db419170f4
1 ficheiros alterados com 16 adições e 15 exclusões
  1. 16 15
      session.go

+ 16 - 15
session.go

@@ -322,11 +322,8 @@ func (q *Query) Iter() *Iter {
 // were selected, ErrNotFound is returned.
 // were selected, ErrNotFound is returned.
 func (q *Query) Scan(dest ...interface{}) error {
 func (q *Query) Scan(dest ...interface{}) error {
 	iter := q.Iter()
 	iter := q.Iter()
-	if iter.err != nil {
-		return iter.err
-	}
-	if len(iter.rows) == 0 {
-		return ErrNotFound
+	if err := iter.checkErrAndNotFound(); err != nil {
+		return err
 	}
 	}
 	iter.Scan(dest...)
 	iter.Scan(dest...)
 	return iter.Close()
 	return iter.Close()
@@ -338,11 +335,8 @@ func (q *Query) Scan(dest ...interface{}) error {
 // in dest.
 // in dest.
 func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
 func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
 	iter := q.Iter()
 	iter := q.Iter()
-	if iter.err != nil {
-		return false, iter.err
-	}
-	if len(iter.rows) == 0 {
-		return false, ErrNotFound
+	if err := iter.checkErrAndNotFound(); err != nil {
+		return false, err
 	}
 	}
 	if len(iter.Columns()) > 1 {
 	if len(iter.Columns()) > 1 {
 		dest = append([]interface{}{&applied}, dest...)
 		dest = append([]interface{}{&applied}, dest...)
@@ -359,11 +353,8 @@ func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
 // in dest.
 // in dest.
 func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error) {
 func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error) {
 	iter := q.Iter()
 	iter := q.Iter()
-	if iter.err != nil {
-		return false, iter.err
-	}
-	if len(iter.rows) == 0 {
-		return false, ErrNotFound
+	if err := iter.checkErrAndNotFound(); err != nil {
+		return false, err
 	}
 	}
 	iter.MapScan(dest)
 	iter.MapScan(dest)
 	applied = dest["[applied]"].(bool)
 	applied = dest["[applied]"].(bool)
@@ -434,6 +425,16 @@ func (iter *Iter) Close() error {
 	return iter.err
 	return iter.err
 }
 }
 
 
+// checkErrAndNotFound handle error and NotFound in one method.
+func (iter *Iter) checkErrAndNotFound() error {
+	if iter.err != nil {
+		return iter.err
+	} else if len(iter.rows) == 0 {
+		return ErrNotFound
+	}
+	return nil
+}
+
 type nextIter struct {
 type nextIter struct {
 	qry  Query
 	qry  Query
 	pos  int
 	pos  int