Browse Source

xsrftoken: update token implementation

This changes the behaviour of the XSRF tokens slightly:
  - timestamps are rounded up to the nearest ms (not ns)
  - the timestamp is appended to the taken after base64 encoding

Change-Id: Iaa32bd055bd76db77ef47f07b975e75c5ecaadf0
Reviewed-on: https://go-review.googlesource.com/17664
Reviewed-by: David Symonds <dsymonds@golang.org>
Dave Day 10 years ago
parent
commit
da05ccad07
2 changed files with 19 additions and 17 deletions
  1. 16 15
      xsrftoken/xsrf.go
  2. 3 2
      xsrftoken/xsrf_test.go

+ 16 - 15
xsrftoken/xsrf.go

@@ -6,7 +6,6 @@
 package xsrftoken
 
 import (
-	"bytes"
 	"crypto/hmac"
 	"crypto/sha1"
 	"crypto/subtle"
@@ -37,35 +36,37 @@ func Generate(key, userID, actionID string) string {
 
 // generateTokenAtTime is like Generate, but returns a token that expires 24 hours from now.
 func generateTokenAtTime(key, userID, actionID string, now time.Time) string {
+	// Round time up and convert to milliseconds.
+	milliTime := (now.UnixNano() + 1e6 - 1) / 1e6
+
 	h := hmac.New(sha1.New, []byte(key))
-	fmt.Fprintf(h, "%s:%s:%d", clean(userID), clean(actionID), now.UnixNano())
-	tok := fmt.Sprintf("%s:%d", h.Sum(nil), now.UnixNano())
-	return base64.URLEncoding.EncodeToString([]byte(tok))
+	fmt.Fprintf(h, "%s:%s:%d", clean(userID), clean(actionID), milliTime)
+
+	// Get the padded base64 string then removing the padding.
+	tok := string(h.Sum(nil))
+	tok = base64.URLEncoding.EncodeToString([]byte(tok))
+	tok = strings.TrimRight(tok, "=")
+
+	return fmt.Sprintf("%s:%d", tok, milliTime)
 }
 
-// Valid returns true if token is a valid, unexpired token returned by Generate.
+// Valid reports whether a token is a valid, unexpired token returned by Generate.
 func Valid(token, key, userID, actionID string) bool {
 	return validTokenAtTime(token, key, userID, actionID, time.Now())
 }
 
-// validTokenAtTime is like Valid, but it uses now to check if the token is expired.
+// validTokenAtTime reports whether a token is valid at the given time.
 func validTokenAtTime(token, key, userID, actionID string, now time.Time) bool {
-	// Decode the token.
-	data, err := base64.URLEncoding.DecodeString(token)
-	if err != nil {
-		return false
-	}
-
 	// Extract the issue time of the token.
-	sep := bytes.LastIndex(data, []byte{':'})
+	sep := strings.LastIndex(token, ":")
 	if sep < 0 {
 		return false
 	}
-	nanos, err := strconv.ParseInt(string(data[sep+1:]), 10, 64)
+	millis, err := strconv.ParseInt(token[sep+1:], 10, 64)
 	if err != nil {
 		return false
 	}
-	issueTime := time.Unix(0, nanos)
+	issueTime := time.Unix(0, millis*1e6)
 
 	// Check that the token is not expired.
 	if now.Sub(issueTime) >= Timeout {

+ 3 - 2
xsrftoken/xsrf_test.go

@@ -29,7 +29,7 @@ func TestValidToken(t *testing.T) {
 	if !validTokenAtTime(tok, key, userID, actionID, now.Add(Timeout-1*time.Nanosecond)) {
 		t.Error("Just before timeout: Expected token to be valid")
 	}
-	if !validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute)) {
+	if !validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute+1*time.Millisecond)) {
 		t.Error("One minute in the past: Expected token to be valid")
 	}
 }
@@ -51,7 +51,7 @@ func TestInvalidToken(t *testing.T) {
 		{"Bad key", "foobar", userID, actionID, oneMinuteFromNow},
 		{"Bad userID", key, "foobar", actionID, oneMinuteFromNow},
 		{"Bad actionID", key, userID, "foobar", oneMinuteFromNow},
-		{"Expired", key, userID, actionID, now.Add(Timeout)},
+		{"Expired", key, userID, actionID, now.Add(Timeout + 1*time.Millisecond)},
 		{"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute)},
 	}
 
@@ -72,6 +72,7 @@ func TestValidateBadData(t *testing.T) {
 		{"Invalid Base64", "ASDab24(@)$*=="},
 		{"No delimiter", base64.URLEncoding.EncodeToString([]byte("foobar12345678"))},
 		{"Invalid time", base64.URLEncoding.EncodeToString([]byte("foobar:foobar"))},
+		{"Wrong length", "1234" + generateTokenAtTime(key, userID, actionID, now)},
 	}
 
 	for _, bdt := range badDataTests {