Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0762ddbbc2 |
@@ -61,4 +61,9 @@ type Options struct {
|
||||
// connection before the background reaper disposes it.
|
||||
// Default: 30s. Negative value disables the reaper.
|
||||
ContextTTL time.Duration
|
||||
|
||||
// ActionRateLimit configures the default token-bucket rate limiter for
|
||||
// action endpoints. Zero values use built-in defaults (10 req/s, burst 20).
|
||||
// Set Rate to -1 to disable rate limiting entirely.
|
||||
ActionRateLimit RateLimitConfig
|
||||
}
|
||||
|
||||
26
context.go
26
context.go
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Context is the living bridge between Go and the browser.
|
||||
@@ -27,7 +28,8 @@ type Context struct {
|
||||
componentRegistry map[string]*Context
|
||||
parentPageCtx *Context
|
||||
patchChan chan patch
|
||||
actionRegistry map[string]func()
|
||||
actionLimiter *rate.Limiter
|
||||
actionRegistry map[string]actionEntry
|
||||
signals *sync.Map
|
||||
mu sync.RWMutex
|
||||
ctxDisposedChan chan struct{}
|
||||
@@ -104,26 +106,31 @@ func (c *Context) isComponent() bool {
|
||||
// h.Button(h.Text("Increment n"), increment.OnClick()),
|
||||
// )
|
||||
// })
|
||||
func (c *Context) Action(f func()) *actionTrigger {
|
||||
func (c *Context) Action(f func(), opts ...ActionOption) *actionTrigger {
|
||||
id := genRandID()
|
||||
if f == nil {
|
||||
c.app.logErr(c, "failed to bind action '%s' to context: nil func", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
entry := actionEntry{fn: f}
|
||||
for _, opt := range opts {
|
||||
opt(&entry)
|
||||
}
|
||||
|
||||
if c.isComponent() {
|
||||
c.parentPageCtx.actionRegistry[id] = f
|
||||
c.parentPageCtx.actionRegistry[id] = entry
|
||||
} else {
|
||||
c.actionRegistry[id] = f
|
||||
c.actionRegistry[id] = entry
|
||||
}
|
||||
return &actionTrigger{id}
|
||||
}
|
||||
|
||||
func (c *Context) getActionFn(id string) (func(), error) {
|
||||
if f, ok := c.actionRegistry[id]; ok {
|
||||
return f, nil
|
||||
func (c *Context) getAction(id string) (actionEntry, error) {
|
||||
if e, ok := c.actionRegistry[id]; ok {
|
||||
return e, nil
|
||||
}
|
||||
return nil, fmt.Errorf("action '%s' not found", id)
|
||||
return actionEntry{}, fmt.Errorf("action '%s' not found", id)
|
||||
}
|
||||
|
||||
// OnInterval starts a go routine that sets a time.Ticker with the given duration and executes
|
||||
@@ -482,7 +489,8 @@ func newContext(id string, route string, v *V) *Context {
|
||||
routeParams: make(map[string]string),
|
||||
app: v,
|
||||
componentRegistry: make(map[string]*Context),
|
||||
actionRegistry: make(map[string]func()),
|
||||
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
|
||||
actionRegistry: make(map[string]actionEntry),
|
||||
signals: new(sync.Map),
|
||||
patchChan: make(chan patch, 1),
|
||||
ctxDisposedChan: make(chan struct{}, 1),
|
||||
|
||||
3
go.mod
3
go.mod
@@ -14,6 +14,7 @@ require (
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/starfederation/datastar-go v1.0.3
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/time v0.14.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -37,6 +38,6 @@ require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
48
ratelimit.go
Normal file
48
ratelimit.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package via
|
||||
|
||||
import "golang.org/x/time/rate"
|
||||
|
||||
const (
|
||||
defaultActionRate float64 = 10.0
|
||||
defaultActionBurst int = 20
|
||||
)
|
||||
|
||||
// RateLimitConfig configures token-bucket rate limiting for actions.
|
||||
// Zero values fall back to defaults. Rate of -1 disables limiting entirely.
|
||||
type RateLimitConfig struct {
|
||||
Rate float64
|
||||
Burst int
|
||||
}
|
||||
|
||||
// ActionOption configures per-action behaviour when passed to Context.Action.
|
||||
type ActionOption func(*actionEntry)
|
||||
|
||||
type actionEntry struct {
|
||||
fn func()
|
||||
limiter *rate.Limiter // nil = use context default
|
||||
}
|
||||
|
||||
// WithRateLimit returns an ActionOption that gives this action its own
|
||||
// token-bucket limiter, overriding the context-level default.
|
||||
func WithRateLimit(r float64, burst int) ActionOption {
|
||||
return func(e *actionEntry) {
|
||||
e.limiter = newLimiter(RateLimitConfig{Rate: r, Burst: burst}, defaultActionRate, defaultActionBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// newLimiter creates a *rate.Limiter from cfg, substituting defaults for zero
|
||||
// values. A Rate of -1 disables limiting (returns nil).
|
||||
func newLimiter(cfg RateLimitConfig, defaultRate float64, defaultBurst int) *rate.Limiter {
|
||||
r := cfg.Rate
|
||||
b := cfg.Burst
|
||||
if r == -1 {
|
||||
return nil
|
||||
}
|
||||
if r == 0 {
|
||||
r = defaultRate
|
||||
}
|
||||
if b == 0 {
|
||||
b = defaultBurst
|
||||
}
|
||||
return rate.NewLimiter(rate.Limit(r), b)
|
||||
}
|
||||
101
ratelimit_test.go
Normal file
101
ratelimit_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewLimiter_Defaults(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{}, defaultActionRate, defaultActionBurst)
|
||||
require.NotNil(t, l)
|
||||
assert.InDelta(t, defaultActionRate, float64(l.Limit()), 0.001)
|
||||
assert.Equal(t, defaultActionBurst, l.Burst())
|
||||
}
|
||||
|
||||
func TestNewLimiter_CustomValues(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: 5, Burst: 10}, defaultActionRate, defaultActionBurst)
|
||||
require.NotNil(t, l)
|
||||
assert.InDelta(t, 5.0, float64(l.Limit()), 0.001)
|
||||
assert.Equal(t, 10, l.Burst())
|
||||
}
|
||||
|
||||
func TestNewLimiter_DisabledWithNegativeRate(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: -1}, defaultActionRate, defaultActionBurst)
|
||||
assert.Nil(t, l)
|
||||
}
|
||||
|
||||
func TestTokenBucket_AllowsBurstThenRejects(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: 1, Burst: 3}, 1, 3)
|
||||
require.NotNil(t, l)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
assert.True(t, l.Allow(), "request %d should be allowed within burst", i)
|
||||
}
|
||||
assert.False(t, l.Allow(), "request beyond burst should be rejected")
|
||||
}
|
||||
|
||||
func TestWithRateLimit_CreatesLimiter(t *testing.T) {
|
||||
entry := actionEntry{fn: func() {}}
|
||||
opt := WithRateLimit(2, 4)
|
||||
opt(&entry)
|
||||
|
||||
require.NotNil(t, entry.limiter)
|
||||
assert.InDelta(t, 2.0, float64(entry.limiter.Limit()), 0.001)
|
||||
assert.Equal(t, 4, entry.limiter.Burst())
|
||||
}
|
||||
|
||||
func TestContextAction_WithRateLimit(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-rl", "/", v)
|
||||
|
||||
called := false
|
||||
c.Action(func() { called = true }, WithRateLimit(1, 2))
|
||||
|
||||
// Verify the entry has its own limiter
|
||||
for _, entry := range c.actionRegistry {
|
||||
require.NotNil(t, entry.limiter)
|
||||
assert.InDelta(t, 1.0, float64(entry.limiter.Limit()), 0.001)
|
||||
assert.Equal(t, 2, entry.limiter.Burst())
|
||||
}
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestContextAction_DefaultNoPerActionLimiter(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-no-rl", "/", v)
|
||||
|
||||
c.Action(func() {})
|
||||
|
||||
for _, entry := range c.actionRegistry {
|
||||
assert.Nil(t, entry.limiter, "entry without WithRateLimit should have nil limiter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextLimiter_DefaultsApplied(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-ctx-limiter", "/", v)
|
||||
|
||||
require.NotNil(t, c.actionLimiter)
|
||||
assert.InDelta(t, defaultActionRate, float64(c.actionLimiter.Limit()), 0.001)
|
||||
assert.Equal(t, defaultActionBurst, c.actionLimiter.Burst())
|
||||
}
|
||||
|
||||
func TestContextLimiter_DisabledViaConfig(t *testing.T) {
|
||||
v := New()
|
||||
v.actionRateLimit = RateLimitConfig{Rate: -1}
|
||||
c := newContext("test-disabled", "/", v)
|
||||
|
||||
assert.Nil(t, c.actionLimiter)
|
||||
}
|
||||
|
||||
func TestContextLimiter_CustomConfig(t *testing.T) {
|
||||
v := New()
|
||||
v.Config(Options{ActionRateLimit: RateLimitConfig{Rate: 50, Burst: 100}})
|
||||
c := newContext("test-custom", "/", v)
|
||||
|
||||
require.NotNil(t, c.actionLimiter)
|
||||
assert.InDelta(t, 50.0, float64(c.actionLimiter.Limit()), 0.001)
|
||||
assert.Equal(t, 100, c.actionLimiter.Burst())
|
||||
}
|
||||
20
via.go
20
via.go
@@ -48,6 +48,7 @@ type V struct {
|
||||
devModePageInitFnMap map[string]func(*Context)
|
||||
sessionManager *scs.SessionManager
|
||||
pubsub PubSub
|
||||
actionRateLimit RateLimitConfig
|
||||
datastarPath string
|
||||
datastarContent []byte
|
||||
datastarOnce sync.Once
|
||||
@@ -132,6 +133,9 @@ func (v *V) Config(cfg Options) {
|
||||
if cfg.ContextTTL != 0 {
|
||||
v.cfg.ContextTTL = cfg.ContextTTL
|
||||
}
|
||||
if cfg.ActionRateLimit.Rate != 0 || cfg.ActionRateLimit.Burst != 0 {
|
||||
v.actionRateLimit = cfg.ActionRateLimit
|
||||
}
|
||||
}
|
||||
|
||||
// AppendToHead appends the given h.H nodes to the head of the base HTML document.
|
||||
@@ -639,13 +643,23 @@ func New() *V {
|
||||
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if c.actionLimiter != nil && !c.actionLimiter.Allow() {
|
||||
v.logWarn(c, "action '%s' rate limited", actionID)
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
c.reqCtx = r.Context()
|
||||
actionFn, err := c.getActionFn(actionID)
|
||||
entry, err := c.getAction(actionID)
|
||||
if err != nil {
|
||||
v.logDebug(c, "action '%s' failed: %v", actionID, err)
|
||||
return
|
||||
}
|
||||
// log err if actionFn panics
|
||||
if entry.limiter != nil && !entry.limiter.Allow() {
|
||||
v.logWarn(c, "action '%s' rate limited (per-action)", actionID)
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
// log err if action panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
v.logErr(c, "action '%s' failed: %v", actionID, r)
|
||||
@@ -653,7 +667,7 @@ func New() *V {
|
||||
}()
|
||||
|
||||
c.injectSignals(sigs)
|
||||
actionFn()
|
||||
entry.fn()
|
||||
})
|
||||
|
||||
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
Reference in New Issue
Block a user