Browse Source

feat(context): add cast helpers to c.Keys (#856)

* feat(context): add cast helpers to c.Keys

* Add tests for cast helpers to c.Keys
Javier Provecho Fernandez 8 năm trước cách đây
mục cha
commit
5eea51b6c9
2 tập tin đã thay đổi với 167 bổ sung0 xóa
  1. 88 0
      context.go
  2. 79 0
      context_test.go

+ 88 - 0
context.go

@@ -187,6 +187,94 @@ func (c *Context) MustGet(key string) interface{} {
 	panic("Key \"" + key + "\" does not exist")
 }
 
+// GetString returns the value associated with the key as a string.
+func (c *Context) GetString(key string) (s string) {
+	if val, ok := c.Get(key); ok && val != nil {
+		s, _ = val.(string)
+	}
+	return
+}
+
+// GetBool returns the value associated with the key as a boolean.
+func (c *Context) GetBool(key string) (b bool) {
+	if val, ok := c.Get(key); ok && val != nil {
+		b, _ = val.(bool)
+	}
+	return
+}
+
+// GetInt returns the value associated with the key as an integer.
+func (c *Context) GetInt(key string) (i int) {
+	if val, ok := c.Get(key); ok && val != nil {
+		i, _ = val.(int)
+	}
+	return
+}
+
+// GetInt64 returns the value associated with the key as an integer.
+func (c *Context) GetInt64(key string) (i64 int64) {
+	if val, ok := c.Get(key); ok && val != nil {
+		i64, _ = val.(int64)
+	}
+	return
+}
+
+// GetFloat64 returns the value associated with the key as a float64.
+func (c *Context) GetFloat64(key string) (f64 float64) {
+	if val, ok := c.Get(key); ok && val != nil {
+		f64, _ = val.(float64)
+	}
+	return
+}
+
+// GetTime returns the value associated with the key as time.
+func (c *Context) GetTime(key string) (t time.Time) {
+	if val, ok := c.Get(key); ok && val != nil {
+		t, _ = val.(time.Time)
+	}
+	return
+}
+
+// GetDuration returns the value associated with the key as a duration.
+func (c *Context) GetDuration(key string) (d time.Duration) {
+	if val, ok := c.Get(key); ok && val != nil {
+		d, _ = val.(time.Duration)
+	}
+	return
+}
+
+// GetStringSlice returns the value associated with the key as a slice of strings.
+func (c *Context) GetStringSlice(key string) (ss []string) {
+	if val, ok := c.Get(key); ok && val != nil {
+		ss, _ = val.([]string)
+	}
+	return
+}
+
+// GetStringMap returns the value associated with the key as a map of interfaces.
+func (c *Context) GetStringMap(key string) (sm map[string]interface{}) {
+	if val, ok := c.Get(key); ok && val != nil {
+		sm, _ = val.(map[string]interface{})
+	}
+	return
+}
+
+// GetStringMapString returns the value associated with the key as a map of strings.
+func (c *Context) GetStringMapString(key string) (sms map[string]string) {
+	if val, ok := c.Get(key); ok && val != nil {
+		sms, _ = val.(map[string]string)
+	}
+	return
+}
+
+// GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings.
+func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) {
+	if val, ok := c.Get(key); ok && val != nil {
+		smss, _ = val.(map[string][]string)
+	}
+	return
+}
+
 /************************************/
 /************ INPUT DATA ************/
 /************************************/

+ 79 - 0
context_test.go

@@ -168,6 +168,85 @@ func TestContextSetGetValues(t *testing.T) {
 
 }
 
+func TestContextGetString(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	c.Set("string", "this is a string")
+	assert.Equal(t, "this is a string", c.GetString("string"))
+}
+
+func TestContextSetGetBool(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	c.Set("bool", true)
+	assert.Equal(t, true, c.GetBool("bool"))
+}
+
+func TestContextGetInt(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	c.Set("int", 1)
+	assert.Equal(t, 1, c.GetInt("int"))
+}
+
+func TestContextGetInt64(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	c.Set("int64", int64(42424242424242))
+	assert.Equal(t, int64(42424242424242), c.GetInt64("int64"))
+}
+
+func TestContextGetFloat64(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	c.Set("float64", 4.2)
+	assert.Equal(t, 4.2, c.GetFloat64("float64"))
+}
+
+func TestContextGetTime(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	t1, _ := time.Parse("1/2/2006 15:04:05", "01/01/2017 12:00:00")
+	c.Set("time", t1)
+	assert.Equal(t, t1, c.GetTime("time"))
+}
+
+func TestContextGetDuration(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	c.Set("duration", time.Second)
+	assert.Equal(t, time.Second, c.GetDuration("duration"))
+}
+
+func TestContextGetStringSlice(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	c.Set("slice", []string{"foo"})
+	assert.Equal(t, []string{"foo"}, c.GetStringSlice("slice"))
+}
+
+func TestContextGetStringMap(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	var m = make(map[string]interface{})
+	m["foo"] = 1
+	c.Set("map", m)
+
+	assert.Equal(t, m, c.GetStringMap("map"))
+	assert.Equal(t, 1, c.GetStringMap("map")["foo"])
+}
+
+func TestContextGetStringMapString(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	var m = make(map[string]string)
+	m["foo"] = "bar"
+	c.Set("map", m)
+
+	assert.Equal(t, m, c.GetStringMapString("map"))
+	assert.Equal(t, "bar", c.GetStringMapString("map")["foo"])
+}
+
+func TestContextGetStringMapStringSlice(t *testing.T) {
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	var m = make(map[string][]string)
+	m["foo"] = []string{"foo"}
+	c.Set("map", m)
+
+	assert.Equal(t, m, c.GetStringMapStringSlice("map"))
+	assert.Equal(t, []string{"foo"}, c.GetStringMapStringSlice("map")["foo"])
+}
+
 func TestContextCopy(t *testing.T) {
 	c, _ := CreateTestContext(httptest.NewRecorder())
 	c.index = 2