From 0762ddbbc2f1c099af68fa5a2e9a9af1dca4d7a8 Mon Sep 17 00:00:00 2001 From: Ryan Hamamura <58859899+ryanhamamura@users.noreply.github.com> Date: Fri, 6 Feb 2026 11:52:07 -1000 Subject: [PATCH] feat: add token-bucket rate limiting for action endpoints Add per-context and per-action rate limiting using golang.org/x/time/rate. Configure globally via Options.ActionRateLimit or per-action with WithRateLimit(). Defaults to 10 req/s with burst of 20. --- configuration.go | 5 +++ context.go | 26 +++++++----- go.mod | 3 +- ratelimit.go | 48 ++++++++++++++++++++++ ratelimit_test.go | 101 ++++++++++++++++++++++++++++++++++++++++++++++ via.go | 20 +++++++-- 6 files changed, 190 insertions(+), 13 deletions(-) create mode 100644 ratelimit.go create mode 100644 ratelimit_test.go diff --git a/configuration.go b/configuration.go index eac0bcc..d2ec955 100644 --- a/configuration.go +++ b/configuration.go @@ -61,4 +61,9 @@ type Options struct { // connection before the background reaper disposes it. // Default: 30s. Negative value disables the reaper. ContextTTL time.Duration + + // ActionRateLimit configures the default token-bucket rate limiter for + // action endpoints. Zero values use built-in defaults (10 req/s, burst 20). + // Set Rate to -1 to disable rate limiting entirely. + ActionRateLimit RateLimitConfig } diff --git a/context.go b/context.go index 44da59c..c396d07 100644 --- a/context.go +++ b/context.go @@ -12,6 +12,7 @@ import ( "time" "github.com/ryanhamamura/via/h" + "golang.org/x/time/rate" ) // Context is the living bridge between Go and the browser. @@ -27,7 +28,8 @@ type Context struct { componentRegistry map[string]*Context parentPageCtx *Context patchChan chan patch - actionRegistry map[string]func() + actionLimiter *rate.Limiter + actionRegistry map[string]actionEntry signals *sync.Map mu sync.RWMutex ctxDisposedChan chan struct{} @@ -104,26 +106,31 @@ func (c *Context) isComponent() bool { // h.Button(h.Text("Increment n"), increment.OnClick()), // ) // }) -func (c *Context) Action(f func()) *actionTrigger { +func (c *Context) Action(f func(), opts ...ActionOption) *actionTrigger { id := genRandID() if f == nil { c.app.logErr(c, "failed to bind action '%s' to context: nil func", id) return nil } + entry := actionEntry{fn: f} + for _, opt := range opts { + opt(&entry) + } + if c.isComponent() { - c.parentPageCtx.actionRegistry[id] = f + c.parentPageCtx.actionRegistry[id] = entry } else { - c.actionRegistry[id] = f + c.actionRegistry[id] = entry } return &actionTrigger{id} } -func (c *Context) getActionFn(id string) (func(), error) { - if f, ok := c.actionRegistry[id]; ok { - return f, nil +func (c *Context) getAction(id string) (actionEntry, error) { + if e, ok := c.actionRegistry[id]; ok { + return e, nil } - return nil, fmt.Errorf("action '%s' not found", id) + return actionEntry{}, fmt.Errorf("action '%s' not found", id) } // OnInterval starts a go routine that sets a time.Ticker with the given duration and executes @@ -482,7 +489,8 @@ func newContext(id string, route string, v *V) *Context { routeParams: make(map[string]string), app: v, componentRegistry: make(map[string]*Context), - actionRegistry: make(map[string]func()), + actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst), + actionRegistry: make(map[string]actionEntry), signals: new(sync.Map), patchChan: make(chan patch, 1), ctxDisposedChan: make(chan struct{}, 1), diff --git a/go.mod b/go.mod index 83072c8..fe3c946 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/rs/zerolog v1.34.0 github.com/starfederation/datastar-go v1.0.3 github.com/stretchr/testify v1.11.1 + golang.org/x/time v0.14.0 ) require ( @@ -37,6 +38,6 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/sys v0.38.0 // indirect - golang.org/x/time v0.14.0 // indirect + golang.org/x/time v0.14.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..57dc20a --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,48 @@ +package via + +import "golang.org/x/time/rate" + +const ( + defaultActionRate float64 = 10.0 + defaultActionBurst int = 20 +) + +// RateLimitConfig configures token-bucket rate limiting for actions. +// Zero values fall back to defaults. Rate of -1 disables limiting entirely. +type RateLimitConfig struct { + Rate float64 + Burst int +} + +// ActionOption configures per-action behaviour when passed to Context.Action. +type ActionOption func(*actionEntry) + +type actionEntry struct { + fn func() + limiter *rate.Limiter // nil = use context default +} + +// WithRateLimit returns an ActionOption that gives this action its own +// token-bucket limiter, overriding the context-level default. +func WithRateLimit(r float64, burst int) ActionOption { + return func(e *actionEntry) { + e.limiter = newLimiter(RateLimitConfig{Rate: r, Burst: burst}, defaultActionRate, defaultActionBurst) + } +} + +// newLimiter creates a *rate.Limiter from cfg, substituting defaults for zero +// values. A Rate of -1 disables limiting (returns nil). +func newLimiter(cfg RateLimitConfig, defaultRate float64, defaultBurst int) *rate.Limiter { + r := cfg.Rate + b := cfg.Burst + if r == -1 { + return nil + } + if r == 0 { + r = defaultRate + } + if b == 0 { + b = defaultBurst + } + return rate.NewLimiter(rate.Limit(r), b) +} diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..d4ed9cf --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,101 @@ +package via + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewLimiter_Defaults(t *testing.T) { + l := newLimiter(RateLimitConfig{}, defaultActionRate, defaultActionBurst) + require.NotNil(t, l) + assert.InDelta(t, defaultActionRate, float64(l.Limit()), 0.001) + assert.Equal(t, defaultActionBurst, l.Burst()) +} + +func TestNewLimiter_CustomValues(t *testing.T) { + l := newLimiter(RateLimitConfig{Rate: 5, Burst: 10}, defaultActionRate, defaultActionBurst) + require.NotNil(t, l) + assert.InDelta(t, 5.0, float64(l.Limit()), 0.001) + assert.Equal(t, 10, l.Burst()) +} + +func TestNewLimiter_DisabledWithNegativeRate(t *testing.T) { + l := newLimiter(RateLimitConfig{Rate: -1}, defaultActionRate, defaultActionBurst) + assert.Nil(t, l) +} + +func TestTokenBucket_AllowsBurstThenRejects(t *testing.T) { + l := newLimiter(RateLimitConfig{Rate: 1, Burst: 3}, 1, 3) + require.NotNil(t, l) + + for i := 0; i < 3; i++ { + assert.True(t, l.Allow(), "request %d should be allowed within burst", i) + } + assert.False(t, l.Allow(), "request beyond burst should be rejected") +} + +func TestWithRateLimit_CreatesLimiter(t *testing.T) { + entry := actionEntry{fn: func() {}} + opt := WithRateLimit(2, 4) + opt(&entry) + + require.NotNil(t, entry.limiter) + assert.InDelta(t, 2.0, float64(entry.limiter.Limit()), 0.001) + assert.Equal(t, 4, entry.limiter.Burst()) +} + +func TestContextAction_WithRateLimit(t *testing.T) { + v := New() + c := newContext("test-rl", "/", v) + + called := false + c.Action(func() { called = true }, WithRateLimit(1, 2)) + + // Verify the entry has its own limiter + for _, entry := range c.actionRegistry { + require.NotNil(t, entry.limiter) + assert.InDelta(t, 1.0, float64(entry.limiter.Limit()), 0.001) + assert.Equal(t, 2, entry.limiter.Burst()) + } + assert.False(t, called) +} + +func TestContextAction_DefaultNoPerActionLimiter(t *testing.T) { + v := New() + c := newContext("test-no-rl", "/", v) + + c.Action(func() {}) + + for _, entry := range c.actionRegistry { + assert.Nil(t, entry.limiter, "entry without WithRateLimit should have nil limiter") + } +} + +func TestContextLimiter_DefaultsApplied(t *testing.T) { + v := New() + c := newContext("test-ctx-limiter", "/", v) + + require.NotNil(t, c.actionLimiter) + assert.InDelta(t, defaultActionRate, float64(c.actionLimiter.Limit()), 0.001) + assert.Equal(t, defaultActionBurst, c.actionLimiter.Burst()) +} + +func TestContextLimiter_DisabledViaConfig(t *testing.T) { + v := New() + v.actionRateLimit = RateLimitConfig{Rate: -1} + c := newContext("test-disabled", "/", v) + + assert.Nil(t, c.actionLimiter) +} + +func TestContextLimiter_CustomConfig(t *testing.T) { + v := New() + v.Config(Options{ActionRateLimit: RateLimitConfig{Rate: 50, Burst: 100}}) + c := newContext("test-custom", "/", v) + + require.NotNil(t, c.actionLimiter) + assert.InDelta(t, 50.0, float64(c.actionLimiter.Limit()), 0.001) + assert.Equal(t, 100, c.actionLimiter.Burst()) +} diff --git a/via.go b/via.go index 4d928d6..c327d64 100644 --- a/via.go +++ b/via.go @@ -48,6 +48,7 @@ type V struct { devModePageInitFnMap map[string]func(*Context) sessionManager *scs.SessionManager pubsub PubSub + actionRateLimit RateLimitConfig datastarPath string datastarContent []byte datastarOnce sync.Once @@ -132,6 +133,9 @@ func (v *V) Config(cfg Options) { if cfg.ContextTTL != 0 { v.cfg.ContextTTL = cfg.ContextTTL } + if cfg.ActionRateLimit.Rate != 0 || cfg.ActionRateLimit.Burst != 0 { + v.actionRateLimit = cfg.ActionRateLimit + } } // AppendToHead appends the given h.H nodes to the head of the base HTML document. @@ -639,13 +643,23 @@ func New() *V { http.Error(w, "invalid CSRF token", http.StatusForbidden) return } + if c.actionLimiter != nil && !c.actionLimiter.Allow() { + v.logWarn(c, "action '%s' rate limited", actionID) + http.Error(w, "rate limited", http.StatusTooManyRequests) + return + } c.reqCtx = r.Context() - actionFn, err := c.getActionFn(actionID) + entry, err := c.getAction(actionID) if err != nil { v.logDebug(c, "action '%s' failed: %v", actionID, err) return } - // log err if actionFn panics + if entry.limiter != nil && !entry.limiter.Allow() { + v.logWarn(c, "action '%s' rate limited (per-action)", actionID) + http.Error(w, "rate limited", http.StatusTooManyRequests) + return + } + // log err if action panics defer func() { if r := recover(); r != nil { v.logErr(c, "action '%s' failed: %v", actionID, r) @@ -653,7 +667,7 @@ func New() *V { }() c.injectSignals(sigs) - actionFn() + entry.fn() }) v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {