Browse Source

add more tests for mongoc (#443)

Kevin Wan 3 years ago
parent
commit
316195e912
1 changed files with 82 additions and 10 deletions
  1. 82 10
      core/stores/mongoc/cachedcollection_test.go

+ 82 - 10
core/stores/mongoc/cachedcollection_test.go

@@ -1,6 +1,7 @@
 package mongoc
 
 import (
+	"encoding/json"
 	"errors"
 	"io/ioutil"
 	"log"
@@ -21,10 +22,76 @@ import (
 	"github.com/tal-tech/go-zero/core/stores/redis/redistest"
 )
 
+const dummyCount = 10
+
 func init() {
 	stat.SetReporter(nil)
 }
 
+func TestCollection_Count(t *testing.T) {
+	resetStats()
+	r, clean, err := redistest.CreateRedis()
+	assert.Nil(t, err)
+	defer clean()
+
+	cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
+	c := newCollection(dummyConn{}, cach)
+	val, err := c.Count("any")
+	assert.Nil(t, err)
+	assert.Equal(t, dummyCount, val)
+
+	var value string
+	assert.Nil(t, r.Set("any", `"foo"`))
+	assert.Nil(t, c.GetCache("any", &value))
+	assert.Equal(t, "foo", value)
+	assert.Nil(t, c.DelCache("any"))
+
+	assert.Nil(t, c.SetCache("any", "bar"))
+	assert.Nil(t, c.FindAllNoCache(&value, "any", func(query mongo.Query) mongo.Query {
+		return query
+	}))
+	assert.Nil(t, c.FindOne(&value, "any", "foo"))
+	assert.Equal(t, "bar", value)
+	assert.Nil(t, c.DelCache("any"))
+	c = newCollection(dummyConn{val: `"bar"`}, cach)
+	assert.Nil(t, c.FindOne(&value, "any", "foo"))
+	assert.Equal(t, "bar", value)
+	assert.Nil(t, c.FindOneNoCache(&value, "foo"))
+	assert.Equal(t, "bar", value)
+	assert.Nil(t, c.FindOneId(&value, "anyone", "foo"))
+	assert.Equal(t, "bar", value)
+	assert.Nil(t, c.FindOneIdNoCache(&value, "foo"))
+	assert.Equal(t, "bar", value)
+	assert.Nil(t, c.Insert("foo"))
+	assert.Nil(t, c.Pipe("foo"))
+	assert.Nil(t, c.Remove("any"))
+	assert.Nil(t, c.RemoveId("any"))
+	_, err = c.RemoveAll("any")
+	assert.Nil(t, err)
+	assert.Nil(t, c.Update("foo", "bar"))
+	assert.Nil(t, c.UpdateId("foo", "bar"))
+	_, err = c.Upsert("foo", "bar")
+	assert.Nil(t, err)
+
+	c = newCollection(dummyConn{
+		val:       `"bar"`,
+		removeErr: errors.New("any"),
+	}, cach)
+	assert.NotNil(t, c.Remove("any"))
+	_, err = c.RemoveAll("any", "bar")
+	assert.NotNil(t, err)
+	assert.NotNil(t, c.RemoveId("any"))
+
+	c = newCollection(dummyConn{
+		val:       `"bar"`,
+		updateErr: errors.New("any"),
+	}, cach)
+	assert.NotNil(t, c.Update("foo", "bar"))
+	assert.NotNil(t, c.UpdateId("foo", "bar"))
+	_, err = c.Upsert("foo", "bar")
+	assert.NotNil(t, err)
+}
+
 func TestStat(t *testing.T) {
 	resetStats()
 	r, clean, err := redistest.CreateRedis()
@@ -156,14 +223,17 @@ func resetStats() {
 }
 
 type dummyConn struct {
+	val       string
+	removeErr error
+	updateErr error
 }
 
 func (c dummyConn) Find(query interface{}) mongo.Query {
-	return dummyQuery{}
+	return dummyQuery{val: c.val}
 }
 
 func (c dummyConn) FindId(id interface{}) mongo.Query {
-	return dummyQuery{}
+	return dummyQuery{val: c.val}
 }
 
 func (c dummyConn) Insert(docs ...interface{}) error {
@@ -171,7 +241,7 @@ func (c dummyConn) Insert(docs ...interface{}) error {
 }
 
 func (c dummyConn) Remove(selector interface{}) error {
-	return nil
+	return c.removeErr
 }
 
 func (dummyConn) Pipe(pipeline interface{}) mongo.Pipe {
@@ -179,25 +249,27 @@ func (dummyConn) Pipe(pipeline interface{}) mongo.Pipe {
 }
 
 func (c dummyConn) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
-	return nil, nil
+	return nil, c.removeErr
 }
 
 func (c dummyConn) RemoveId(id interface{}) error {
-	return nil
+	return c.removeErr
 }
 
 func (c dummyConn) Update(selector, update interface{}) error {
-	return nil
+	return c.updateErr
 }
 
 func (c dummyConn) UpdateId(id, update interface{}) error {
-	return nil
+	return c.updateErr
 }
+
 func (c dummyConn) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
-	return nil, nil
+	return nil, c.updateErr
 }
 
 type dummyQuery struct {
+	val string
 }
 
 func (d dummyQuery) All(result interface{}) error {
@@ -209,7 +281,7 @@ func (d dummyQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInf
 }
 
 func (d dummyQuery) Count() (int, error) {
-	return 0, nil
+	return dummyCount, nil
 }
 
 func (d dummyQuery) Distinct(key string, result interface{}) error {
@@ -229,7 +301,7 @@ func (d dummyQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapR
 }
 
 func (d dummyQuery) One(result interface{}) error {
-	return nil
+	return json.Unmarshal([]byte(d.val), result)
 }
 
 func (d dummyQuery) Batch(n int) mongo.Query {