Explorar o código

Cron Entry creation + tests

Rob Figueiredo %!s(int64=13) %!d(string=hai) anos
pai
achega
4152b5bf2f
Modificáronse 4 ficheiros con 319 adicións e 1 borrados
  1. 2 1
      README.md
  2. 5 0
      cron.go
  3. 200 0
      entry.go
  4. 112 0
      entry_test.go

+ 2 - 1
README.md

@@ -1,4 +1,5 @@
 cron
 ====
 
-a cron library for go
+A cron library for Go.
+

+ 5 - 0
cron.go

@@ -0,0 +1,5 @@
+package cron
+
+type Cron struct {
+	Entries []*Entry
+}

+ 200 - 0
entry.go

@@ -0,0 +1,200 @@
+package cron
+
+import (
+	"log"
+	"math"
+	"strconv"
+	"strings"
+)
+
+type Entry struct {
+	Minute, Hour, Dom, Month, Dow uint64
+	Func                          func()
+}
+
+type Range struct{ min, max uint }
+
+var (
+	minutes = Range{0, 59}
+	hours   = Range{0, 23}
+	dom     = Range{1, 31}
+	months  = Range{1, 12}
+	dow     = Range{0, 7}
+)
+
+// Returns a new crontab entry representing the given spec.
+// Panics with a descriptive error if the spec is not valid.
+func NewEntry(spec string, cmd func()) *Entry {
+	if spec[0] == '@' {
+		entry := parseDescriptor(spec)
+		entry.Func = cmd
+		return entry
+	}
+
+	// Split on whitespace.  We require 4 or 5 fields.
+	// (minute) (hour) (day of month) (month) (day of week, optional)
+	fields := strings.Fields(spec)
+	if len(fields) != 4 && len(fields) != 5 {
+		log.Panicf("Expected 4 or 5 fields, found %d: %s", len(fields), spec)
+	}
+
+	entry := &Entry{
+		Minute: getField(fields[0], minutes),
+		Hour:   getField(fields[1], hours),
+		Dom:    getField(fields[2], dom),
+		Month:  getField(fields[3], months),
+		Func:   cmd,
+	}
+	if len(fields) == 5 {
+		entry.Dow = getField(fields[4], dow)
+
+		// If either bit 0 or 7 are set, set both.  (both accepted as Sunday)
+		if entry.Dow&1|entry.Dow&1<<7 > 0 {
+			entry.Dow = entry.Dow | 1 | 1<<7
+		}
+	}
+
+	return entry
+}
+
+// Return an Int with the bits set representing all of the times that the field represents.
+// A "field" is a comma-separated list of "ranges".
+func getField(field string, r Range) uint64 {
+	// list = range {"," range}
+	var bits uint64
+	ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' })
+	for _, expr := range ranges {
+		bits |= getRange(expr, r)
+	}
+	return bits
+}
+
+func getRange(expr string, r Range) uint64 {
+	// number | number "-" number [ "/" number ]
+	var start, end, step uint
+	rangeAndStep := strings.Split(expr, "/")
+	lowAndHigh := strings.Split(rangeAndStep[0], "-")
+
+	if lowAndHigh[0] == "*" {
+		start = r.min
+		end = r.max
+	} else {
+		start = mustParseInt(lowAndHigh[0])
+		switch len(lowAndHigh) {
+		case 1:
+			end = start
+		case 2:
+			end = mustParseInt(lowAndHigh[1])
+		default:
+			log.Panicf("Too many commas: %s", expr)
+		}
+	}
+
+	switch len(rangeAndStep) {
+	case 1:
+		step = 1
+	case 2:
+		step = mustParseInt(rangeAndStep[1])
+	default:
+		log.Panicf("Too many slashes: %s", expr)
+	}
+
+	if start < r.min {
+		log.Panicf("Beginning of range (%d) below minimum (%d): %s", start, r.min, expr)
+	}
+	if end > r.max {
+		log.Panicf("End of range (%d) above maximum (%d): %s", end, r.max, expr)
+	}
+	if start > end {
+		log.Panicf("Beginning of range (%d) beyond end of range (%d): %s", start, end, expr)
+	}
+
+	return getBits(start, end, step)
+}
+
+func mustParseInt(expr string) uint {
+	num, err := strconv.Atoi(expr)
+	if err != nil {
+		log.Panicf("Failed to parse int from %s: %s", expr, err)
+	}
+	if num < 0 {
+		log.Panicf("Negative number (%d) not allowed: %s", num, expr)
+	}
+
+	return uint(num)
+}
+
+func getBits(min, max, step uint) uint64 {
+	var bits uint64
+
+	// If step is 1, use shifts.
+	if step == 1 {
+		return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
+	}
+
+	// Else, use a simple loop.
+	for i := min; i <= max; i += step {
+		bits |= 1 << i
+	}
+	return bits
+}
+
+func all(r Range) uint64 {
+	return getBits(r.min, r.max, 1)
+}
+
+func first(r Range) uint64 {
+	return getBits(r.min, r.min, 1)
+}
+
+func parseDescriptor(spec string) *Entry {
+	switch spec {
+	case "@yearly", "@annually":
+		return &Entry{
+			Minute: 1 << minutes.min,
+			Hour:   1 << hours.min,
+			Dom:    1 << dom.min,
+			Month:  1 << months.min,
+			Dow:    all(dow),
+		}
+
+	case "@monthly":
+		return &Entry{
+			Minute: 1 << minutes.min,
+			Hour:   1 << hours.min,
+			Dom:    1 << dom.min,
+			Month:  all(months),
+			Dow:    all(dow),
+		}
+
+	case "@weekly":
+		return &Entry{
+			Minute: 1 << minutes.min,
+			Hour:   1 << hours.min,
+			Dom:    all(dom),
+			Month:  all(months),
+			Dow:    1 << dow.min,
+		}
+
+	case "@daily", "@midnight":
+		return &Entry{
+			Minute: 1 << minutes.min,
+			Hour:   1 << hours.min,
+			Dom:    all(dom),
+			Month:  all(months),
+			Dow:    all(dow),
+		}
+
+	case "@hourly":
+		return &Entry{
+			Minute: 1 << minutes.min,
+			Hour:   all(hours),
+			Dom:    all(dom),
+			Month:  all(months),
+			Dow:    all(dow),
+		}
+	}
+
+	log.Panicf("Unrecognized descriptor: %s", spec)
+	return nil
+}

+ 112 - 0
entry_test.go

@@ -0,0 +1,112 @@
+package cron
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestRange(t *testing.T) {
+	ranges := []struct {
+		expr     string
+		min, max uint
+		expected uint64
+	}{
+		{"5", 0, 7, 1 << 5},
+		{"0", 0, 7, 1 << 0},
+		{"7", 0, 7, 1 << 7},
+
+		{"5-5", 0, 7, 1 << 5},
+		{"5-6", 0, 7, 1<<5 | 1<<6},
+		{"5-7", 0, 7, 1<<5 | 1<<6 | 1<<7},
+
+		{"5-6/2", 0, 7, 1 << 5},
+		{"5-7/2", 0, 7, 1<<5 | 1<<7},
+		{"5-7/1", 0, 7, 1<<5 | 1<<6 | 1<<7},
+
+		{"*", 1, 3, 1<<1 | 1<<2 | 1<<3},
+		{"*/2", 1, 3, 1<<1 | 1<<3},
+	}
+
+	for _, c := range ranges {
+		actual := getRange(c.expr, Range{c.min, c.max})
+		if actual != c.expected {
+			t.Errorf("%s => (expected) %d != %d (actual)", c.expr, c.expected, actual)
+		}
+	}
+}
+
+func TestField(t *testing.T) {
+	fields := []struct {
+		expr     string
+		min, max uint
+		expected uint64
+	}{
+		{"5", 1, 7, 1 << 5},
+		{"5,6", 1, 7, 1<<5 | 1<<6},
+		{"5,6,7", 1, 7, 1<<5 | 1<<6 | 1<<7},
+		{"1,5-7/2,3", 1, 7, 1<<1 | 1<<5 | 1<<7 | 1<<3},
+	}
+
+	for _, c := range fields {
+		actual := getField(c.expr, Range{c.min, c.max})
+		if actual != c.expected {
+			t.Errorf("%s => (expected) %d != %d (actual)", c.expr, c.expected, actual)
+		}
+	}
+}
+
+func TestBits(t *testing.T) {
+	allBits := []struct {
+		r        Range
+		expected uint64
+	}{
+		{minutes, 0xfffffffffffffff}, // 0-59: 60 ones
+		{hours, 0xffffff},            // 0-23: 24 ones
+		{dom, 0xfffffffe},            // 1-31: 31 ones, 1 zero
+		{months, 0x1ffe},             // 1-12: 12 ones, 1 zero
+		{dow, 0xff},                  // 0-7: 8 ones
+	}
+
+	for _, c := range allBits {
+		actual := all(c.r)
+		if c.expected != actual {
+			t.Errorf("%d-%d/%d => (expected) %b != %b (actual)",
+				c.r.min, c.r.max, 1, c.expected, actual)
+		}
+	}
+
+	bits := []struct {
+		min, max, step uint
+		expected       uint64
+	}{
+
+		{0, 0, 1, 0x1},
+		{1, 1, 1, 0x2},
+		{1, 5, 2, 0x2a}, // 101010
+		{1, 4, 2, 0xa},  // 1010
+	}
+
+	for _, c := range bits {
+		actual := getBits(c.min, c.max, c.step)
+		if c.expected != actual {
+			t.Errorf("%d-%d/%d => (expected) %b != %b (actual)",
+				c.min, c.max, c.step, c.expected, actual)
+		}
+	}
+}
+
+func TestEntry(t *testing.T) {
+	entries := []struct {
+		expr     string
+		expected Entry
+	}{
+		{"5 * * * *", Entry{1 << 5, all(hours), all(dom), all(months), all(dow), nil}},
+	}
+
+	for _, c := range entries {
+		actual := *NewEntry(c.expr, nil)
+		if !reflect.DeepEqual(actual, c.expected) {
+			t.Errorf("%s => (expected) %b != %b (actual)", c.expr, c.expected, actual)
+		}
+	}
+}