Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0762ddbbc2 | ||
|
|
b7acfa6302 | ||
|
|
8aa91c577c | ||
|
|
6dcd54c88b |
@@ -107,6 +107,54 @@ func (a *actionTrigger) OnChange(options ...ActionTriggerOption) h.H {
|
||||
return h.Data("on:change__debounce.200ms", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnSubmit returns a via.h DOM attribute that triggers on form submit.
|
||||
func (a *actionTrigger) OnSubmit(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:submit", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnInput returns a via.h DOM attribute that triggers on input (without debounce).
|
||||
func (a *actionTrigger) OnInput(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:input", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnFocus returns a via.h DOM attribute that triggers when the element gains focus.
|
||||
func (a *actionTrigger) OnFocus(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:focus", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnBlur returns a via.h DOM attribute that triggers when the element loses focus.
|
||||
func (a *actionTrigger) OnBlur(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:blur", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnMouseEnter returns a via.h DOM attribute that triggers when the mouse enters the element.
|
||||
func (a *actionTrigger) OnMouseEnter(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:mouseenter", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnMouseLeave returns a via.h DOM attribute that triggers when the mouse leaves the element.
|
||||
func (a *actionTrigger) OnMouseLeave(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:mouseleave", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnScroll returns a via.h DOM attribute that triggers on scroll.
|
||||
func (a *actionTrigger) OnScroll(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:scroll", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnDblClick returns a via.h DOM attribute that triggers on double click.
|
||||
func (a *actionTrigger) OnDblClick(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:dblclick", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnKeyDown returns a via.h DOM attribute that triggers when a key is pressed.
|
||||
// key: optional, see https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key
|
||||
// Example: OnKeyDown("Enter")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/alexedwards/scs/v2"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
@@ -54,4 +56,14 @@ type Options struct {
|
||||
// PubSub enables publish/subscribe messaging. Use vianats.New() for an
|
||||
// embedded NATS backend, or supply any PubSub implementation.
|
||||
PubSub PubSub
|
||||
|
||||
// ContextTTL is the maximum time a context may exist without an SSE
|
||||
// connection before the background reaper disposes it.
|
||||
// Default: 30s. Negative value disables the reaper.
|
||||
ContextTTL time.Duration
|
||||
|
||||
// ActionRateLimit configures the default token-bucket rate limiter for
|
||||
// action endpoints. Zero values use built-in defaults (10 req/s, burst 20).
|
||||
// Set Rate to -1 to disable rate limiting entirely.
|
||||
ActionRateLimit RateLimitConfig
|
||||
}
|
||||
|
||||
32
context.go
32
context.go
@@ -8,9 +8,11 @@ import (
|
||||
"maps"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Context is the living bridge between Go and the browser.
|
||||
@@ -19,13 +21,15 @@ import (
|
||||
type Context struct {
|
||||
id string
|
||||
route string
|
||||
csrfToken string
|
||||
app *V
|
||||
view func() h.H
|
||||
routeParams map[string]string
|
||||
componentRegistry map[string]*Context
|
||||
parentPageCtx *Context
|
||||
patchChan chan patch
|
||||
actionRegistry map[string]func()
|
||||
actionLimiter *rate.Limiter
|
||||
actionRegistry map[string]actionEntry
|
||||
signals *sync.Map
|
||||
mu sync.RWMutex
|
||||
ctxDisposedChan chan struct{}
|
||||
@@ -33,6 +37,8 @@ type Context struct {
|
||||
subscriptions []Subscription
|
||||
subsMu sync.Mutex
|
||||
disposeOnce sync.Once
|
||||
createdAt time.Time
|
||||
sseConnected atomic.Bool
|
||||
}
|
||||
|
||||
// View defines the UI rendered by this context.
|
||||
@@ -100,26 +106,31 @@ func (c *Context) isComponent() bool {
|
||||
// h.Button(h.Text("Increment n"), increment.OnClick()),
|
||||
// )
|
||||
// })
|
||||
func (c *Context) Action(f func()) *actionTrigger {
|
||||
func (c *Context) Action(f func(), opts ...ActionOption) *actionTrigger {
|
||||
id := genRandID()
|
||||
if f == nil {
|
||||
c.app.logErr(c, "failed to bind action '%s' to context: nil func", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
entry := actionEntry{fn: f}
|
||||
for _, opt := range opts {
|
||||
opt(&entry)
|
||||
}
|
||||
|
||||
if c.isComponent() {
|
||||
c.parentPageCtx.actionRegistry[id] = f
|
||||
c.parentPageCtx.actionRegistry[id] = entry
|
||||
} else {
|
||||
c.actionRegistry[id] = f
|
||||
c.actionRegistry[id] = entry
|
||||
}
|
||||
return &actionTrigger{id}
|
||||
}
|
||||
|
||||
func (c *Context) getActionFn(id string) (func(), error) {
|
||||
if f, ok := c.actionRegistry[id]; ok {
|
||||
return f, nil
|
||||
func (c *Context) getAction(id string) (actionEntry, error) {
|
||||
if e, ok := c.actionRegistry[id]; ok {
|
||||
return e, nil
|
||||
}
|
||||
return nil, fmt.Errorf("action '%s' not found", id)
|
||||
return actionEntry{}, fmt.Errorf("action '%s' not found", id)
|
||||
}
|
||||
|
||||
// OnInterval starts a go routine that sets a time.Ticker with the given duration and executes
|
||||
@@ -474,12 +485,15 @@ func newContext(id string, route string, v *V) *Context {
|
||||
return &Context{
|
||||
id: id,
|
||||
route: route,
|
||||
csrfToken: genCSRFToken(),
|
||||
routeParams: make(map[string]string),
|
||||
app: v,
|
||||
componentRegistry: make(map[string]*Context),
|
||||
actionRegistry: make(map[string]func()),
|
||||
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
|
||||
actionRegistry: make(map[string]actionEntry),
|
||||
signals: new(sync.Map),
|
||||
patchChan: make(chan patch, 1),
|
||||
ctxDisposedChan: make(chan struct{}, 1),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
3
go.mod
3
go.mod
@@ -14,6 +14,7 @@ require (
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/starfederation/datastar-go v1.0.3
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/time v0.14.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -37,6 +38,6 @@ require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
48
ratelimit.go
Normal file
48
ratelimit.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package via
|
||||
|
||||
import "golang.org/x/time/rate"
|
||||
|
||||
const (
|
||||
defaultActionRate float64 = 10.0
|
||||
defaultActionBurst int = 20
|
||||
)
|
||||
|
||||
// RateLimitConfig configures token-bucket rate limiting for actions.
|
||||
// Zero values fall back to defaults. Rate of -1 disables limiting entirely.
|
||||
type RateLimitConfig struct {
|
||||
Rate float64
|
||||
Burst int
|
||||
}
|
||||
|
||||
// ActionOption configures per-action behaviour when passed to Context.Action.
|
||||
type ActionOption func(*actionEntry)
|
||||
|
||||
type actionEntry struct {
|
||||
fn func()
|
||||
limiter *rate.Limiter // nil = use context default
|
||||
}
|
||||
|
||||
// WithRateLimit returns an ActionOption that gives this action its own
|
||||
// token-bucket limiter, overriding the context-level default.
|
||||
func WithRateLimit(r float64, burst int) ActionOption {
|
||||
return func(e *actionEntry) {
|
||||
e.limiter = newLimiter(RateLimitConfig{Rate: r, Burst: burst}, defaultActionRate, defaultActionBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// newLimiter creates a *rate.Limiter from cfg, substituting defaults for zero
|
||||
// values. A Rate of -1 disables limiting (returns nil).
|
||||
func newLimiter(cfg RateLimitConfig, defaultRate float64, defaultBurst int) *rate.Limiter {
|
||||
r := cfg.Rate
|
||||
b := cfg.Burst
|
||||
if r == -1 {
|
||||
return nil
|
||||
}
|
||||
if r == 0 {
|
||||
r = defaultRate
|
||||
}
|
||||
if b == 0 {
|
||||
b = defaultBurst
|
||||
}
|
||||
return rate.NewLimiter(rate.Limit(r), b)
|
||||
}
|
||||
101
ratelimit_test.go
Normal file
101
ratelimit_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewLimiter_Defaults(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{}, defaultActionRate, defaultActionBurst)
|
||||
require.NotNil(t, l)
|
||||
assert.InDelta(t, defaultActionRate, float64(l.Limit()), 0.001)
|
||||
assert.Equal(t, defaultActionBurst, l.Burst())
|
||||
}
|
||||
|
||||
func TestNewLimiter_CustomValues(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: 5, Burst: 10}, defaultActionRate, defaultActionBurst)
|
||||
require.NotNil(t, l)
|
||||
assert.InDelta(t, 5.0, float64(l.Limit()), 0.001)
|
||||
assert.Equal(t, 10, l.Burst())
|
||||
}
|
||||
|
||||
func TestNewLimiter_DisabledWithNegativeRate(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: -1}, defaultActionRate, defaultActionBurst)
|
||||
assert.Nil(t, l)
|
||||
}
|
||||
|
||||
func TestTokenBucket_AllowsBurstThenRejects(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: 1, Burst: 3}, 1, 3)
|
||||
require.NotNil(t, l)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
assert.True(t, l.Allow(), "request %d should be allowed within burst", i)
|
||||
}
|
||||
assert.False(t, l.Allow(), "request beyond burst should be rejected")
|
||||
}
|
||||
|
||||
func TestWithRateLimit_CreatesLimiter(t *testing.T) {
|
||||
entry := actionEntry{fn: func() {}}
|
||||
opt := WithRateLimit(2, 4)
|
||||
opt(&entry)
|
||||
|
||||
require.NotNil(t, entry.limiter)
|
||||
assert.InDelta(t, 2.0, float64(entry.limiter.Limit()), 0.001)
|
||||
assert.Equal(t, 4, entry.limiter.Burst())
|
||||
}
|
||||
|
||||
func TestContextAction_WithRateLimit(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-rl", "/", v)
|
||||
|
||||
called := false
|
||||
c.Action(func() { called = true }, WithRateLimit(1, 2))
|
||||
|
||||
// Verify the entry has its own limiter
|
||||
for _, entry := range c.actionRegistry {
|
||||
require.NotNil(t, entry.limiter)
|
||||
assert.InDelta(t, 1.0, float64(entry.limiter.Limit()), 0.001)
|
||||
assert.Equal(t, 2, entry.limiter.Burst())
|
||||
}
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestContextAction_DefaultNoPerActionLimiter(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-no-rl", "/", v)
|
||||
|
||||
c.Action(func() {})
|
||||
|
||||
for _, entry := range c.actionRegistry {
|
||||
assert.Nil(t, entry.limiter, "entry without WithRateLimit should have nil limiter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextLimiter_DefaultsApplied(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-ctx-limiter", "/", v)
|
||||
|
||||
require.NotNil(t, c.actionLimiter)
|
||||
assert.InDelta(t, defaultActionRate, float64(c.actionLimiter.Limit()), 0.001)
|
||||
assert.Equal(t, defaultActionBurst, c.actionLimiter.Burst())
|
||||
}
|
||||
|
||||
func TestContextLimiter_DisabledViaConfig(t *testing.T) {
|
||||
v := New()
|
||||
v.actionRateLimit = RateLimitConfig{Rate: -1}
|
||||
c := newContext("test-disabled", "/", v)
|
||||
|
||||
assert.Nil(t, c.actionLimiter)
|
||||
}
|
||||
|
||||
func TestContextLimiter_CustomConfig(t *testing.T) {
|
||||
v := New()
|
||||
v.Config(Options{ActionRateLimit: RateLimitConfig{Rate: 50, Burst: 100}})
|
||||
c := newContext("test-custom", "/", v)
|
||||
|
||||
require.NotNil(t, c.actionLimiter)
|
||||
assert.InDelta(t, 50.0, float64(c.actionLimiter.Limit()), 0.001)
|
||||
assert.Equal(t, 100, c.actionLimiter.Burst())
|
||||
}
|
||||
107
via.go
107
via.go
@@ -10,6 +10,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
_ "embed"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -47,9 +48,11 @@ type V struct {
|
||||
devModePageInitFnMap map[string]func(*Context)
|
||||
sessionManager *scs.SessionManager
|
||||
pubsub PubSub
|
||||
actionRateLimit RateLimitConfig
|
||||
datastarPath string
|
||||
datastarContent []byte
|
||||
datastarOnce sync.Once
|
||||
reaperStop chan struct{}
|
||||
}
|
||||
|
||||
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
|
||||
@@ -127,6 +130,12 @@ func (v *V) Config(cfg Options) {
|
||||
if cfg.PubSub != nil {
|
||||
v.pubsub = cfg.PubSub
|
||||
}
|
||||
if cfg.ContextTTL != 0 {
|
||||
v.cfg.ContextTTL = cfg.ContextTTL
|
||||
}
|
||||
if cfg.ActionRateLimit.Rate != 0 || cfg.ActionRateLimit.Burst != 0 {
|
||||
v.actionRateLimit = cfg.ActionRateLimit
|
||||
}
|
||||
}
|
||||
|
||||
// AppendToHead appends the given h.H nodes to the head of the base HTML document.
|
||||
@@ -199,7 +208,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
headElements := []h.H{h.Script(h.Type("module"), h.Src(v.datastarPath))}
|
||||
headElements = append(headElements, v.documentHeadIncludes...)
|
||||
headElements = append(headElements,
|
||||
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s'}", id))),
|
||||
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s','via-csrf':'%s'}", id, c.csrfToken))),
|
||||
h.Meta(h.Data("init", "@get('/_sse')")),
|
||||
h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => {
|
||||
navigator.sendBeacon('/_session/close', '%s');});`, c.id))),
|
||||
@@ -238,6 +247,14 @@ func (v *V) currSessionNum() int {
|
||||
return len(v.contextRegistry)
|
||||
}
|
||||
|
||||
func (v *V) cleanupCtx(c *Context) {
|
||||
c.dispose()
|
||||
if v.cfg.DevMode {
|
||||
v.devModeRemovePersisted(c)
|
||||
}
|
||||
v.unregisterCtx(c)
|
||||
}
|
||||
|
||||
func (v *V) unregisterCtx(c *Context) {
|
||||
if c.id == "" {
|
||||
v.logErr(c, "unregister ctx failed: ctx contains empty id")
|
||||
@@ -259,6 +276,50 @@ func (v *V) getCtx(id string) (*Context, error) {
|
||||
return nil, fmt.Errorf("ctx '%s' not found", id)
|
||||
}
|
||||
|
||||
func (v *V) startReaper() {
|
||||
ttl := v.cfg.ContextTTL
|
||||
if ttl < 0 {
|
||||
return
|
||||
}
|
||||
if ttl == 0 {
|
||||
ttl = 30 * time.Second
|
||||
}
|
||||
interval := ttl / 3
|
||||
if interval < 5*time.Second {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
v.reaperStop = make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-v.reaperStop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
v.reapOrphanedContexts(ttl)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (v *V) reapOrphanedContexts(ttl time.Duration) {
|
||||
now := time.Now()
|
||||
v.contextRegistryMutex.RLock()
|
||||
var orphans []*Context
|
||||
for _, c := range v.contextRegistry {
|
||||
if !c.sseConnected.Load() && now.Sub(c.createdAt) > ttl {
|
||||
orphans = append(orphans, c)
|
||||
}
|
||||
}
|
||||
v.contextRegistryMutex.RUnlock()
|
||||
|
||||
for _, c := range orphans {
|
||||
v.logInfo(c, "reaping orphaned context (no SSE connection after %s)", ttl)
|
||||
v.cleanupCtx(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the Via HTTP server and blocks until a SIGINT or SIGTERM
|
||||
// signal is received, then performs a graceful shutdown.
|
||||
func (v *V) Start() {
|
||||
@@ -271,6 +332,8 @@ func (v *V) Start() {
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
v.startReaper()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- v.server.ListenAndServe()
|
||||
@@ -301,6 +364,9 @@ func (v *V) Shutdown() {
|
||||
}
|
||||
|
||||
func (v *V) shutdown() {
|
||||
if v.reaperStop != nil {
|
||||
close(v.reaperStop)
|
||||
}
|
||||
v.logInfo(nil, "draining all contexts")
|
||||
v.drainAllContexts()
|
||||
|
||||
@@ -400,10 +466,7 @@ func (v *V) devModeRemovePersisted(c *Context) {
|
||||
}
|
||||
file.Close()
|
||||
|
||||
// remove ctx to persisted list
|
||||
if _, ok := ctxRegMap[c.id]; !ok {
|
||||
delete(ctxRegMap, c.id)
|
||||
}
|
||||
|
||||
// write persisted list to file
|
||||
file, err = os.Create(p)
|
||||
@@ -507,6 +570,7 @@ func New() *V {
|
||||
// use last-event-id to tell if request is a sse reconnect
|
||||
sse.Send(datastar.EventTypePatchElements, []string{}, datastar.WithSSEEventId("via"))
|
||||
|
||||
c.sseConnected.Store(true)
|
||||
v.logDebug(c, "SSE connection established")
|
||||
|
||||
go func() {
|
||||
@@ -517,6 +581,7 @@ func New() *V {
|
||||
select {
|
||||
case <-sse.Context().Done():
|
||||
v.logDebug(c, "SSE connection ended")
|
||||
v.cleanupCtx(c)
|
||||
return
|
||||
case <-c.ctxDisposedChan:
|
||||
v.logDebug(c, "context disposed, closing SSE")
|
||||
@@ -572,13 +637,29 @@ func New() *V {
|
||||
v.logErr(nil, "action '%s' failed: %v", actionID, err)
|
||||
return
|
||||
}
|
||||
csrfToken, _ := sigs["via-csrf"].(string)
|
||||
if subtle.ConstantTimeCompare([]byte(csrfToken), []byte(c.csrfToken)) != 1 {
|
||||
v.logWarn(c, "action '%s' rejected: invalid CSRF token", actionID)
|
||||
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if c.actionLimiter != nil && !c.actionLimiter.Allow() {
|
||||
v.logWarn(c, "action '%s' rate limited", actionID)
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
c.reqCtx = r.Context()
|
||||
actionFn, err := c.getActionFn(actionID)
|
||||
entry, err := c.getAction(actionID)
|
||||
if err != nil {
|
||||
v.logDebug(c, "action '%s' failed: %v", actionID, err)
|
||||
return
|
||||
}
|
||||
// log err if actionFn panics
|
||||
if entry.limiter != nil && !entry.limiter.Allow() {
|
||||
v.logWarn(c, "action '%s' rate limited (per-action)", actionID)
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
// log err if action panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
v.logErr(c, "action '%s' failed: %v", actionID, r)
|
||||
@@ -586,7 +667,7 @@ func New() *V {
|
||||
}()
|
||||
|
||||
c.injectSignals(sigs)
|
||||
actionFn()
|
||||
entry.fn()
|
||||
})
|
||||
|
||||
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -603,12 +684,8 @@ func New() *V {
|
||||
v.logErr(c, "failed to handle session close: %v", err)
|
||||
return
|
||||
}
|
||||
c.dispose()
|
||||
v.logDebug(c, "session close event triggered")
|
||||
if v.cfg.DevMode {
|
||||
v.devModeRemovePersisted(c)
|
||||
}
|
||||
v.unregisterCtx(c)
|
||||
v.cleanupCtx(c)
|
||||
})
|
||||
return v
|
||||
}
|
||||
@@ -619,6 +696,12 @@ func genRandID() string {
|
||||
return hex.EncodeToString(b)[:8]
|
||||
}
|
||||
|
||||
func genCSRFToken() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func extractParams(pattern, path string) map[string]string {
|
||||
p := strings.Split(strings.Trim(pattern, "/"), "/")
|
||||
u := strings.Split(strings.Trim(path, "/"), "/")
|
||||
|
||||
148
via_test.go
148
via_test.go
@@ -1,9 +1,13 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -128,6 +132,60 @@ func TestAction(t *testing.T) {
|
||||
assert.Contains(t, body, "/_action/")
|
||||
}
|
||||
|
||||
func TestEventTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attr string
|
||||
buildEl func(trigger *actionTrigger) h.H
|
||||
}{
|
||||
{"OnSubmit", "data-on:submit", func(tr *actionTrigger) h.H { return h.Form(tr.OnSubmit()) }},
|
||||
{"OnInput", "data-on:input", func(tr *actionTrigger) h.H { return h.Input(tr.OnInput()) }},
|
||||
{"OnFocus", "data-on:focus", func(tr *actionTrigger) h.H { return h.Input(tr.OnFocus()) }},
|
||||
{"OnBlur", "data-on:blur", func(tr *actionTrigger) h.H { return h.Input(tr.OnBlur()) }},
|
||||
{"OnMouseEnter", "data-on:mouseenter", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseEnter()) }},
|
||||
{"OnMouseLeave", "data-on:mouseleave", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseLeave()) }},
|
||||
{"OnScroll", "data-on:scroll", func(tr *actionTrigger) h.H { return h.Div(tr.OnScroll()) }},
|
||||
{"OnDblClick", "data-on:dblclick", func(tr *actionTrigger) h.H { return h.Div(tr.OnDblClick()) }},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var trigger *actionTrigger
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
trigger = c.Action(func() {})
|
||||
c.View(func() h.H { return tt.buildEl(trigger) })
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, req)
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, tt.attr)
|
||||
assert.Contains(t, body, "/_action/"+trigger.id)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("WithSignal", func(t *testing.T) {
|
||||
var trigger *actionTrigger
|
||||
var sig *signal
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
trigger = c.Action(func() {})
|
||||
sig = c.Signal("val")
|
||||
c.View(func() h.H {
|
||||
return h.Div(trigger.OnDblClick(WithSignal(sig, "x")))
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, req)
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, "data-on:dblclick")
|
||||
assert.Contains(t, body, "$"+sig.ID()+"='x'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnKeyDownWithWindow(t *testing.T) {
|
||||
var trigger *actionTrigger
|
||||
v := New()
|
||||
@@ -235,3 +293,93 @@ func TestPage_PanicsOnNoView(t *testing.T) {
|
||||
v.Page("/", func(c *Context) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestReaperCleansOrphanedContexts(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("orphan-1", "/", v)
|
||||
c.createdAt = time.Now().Add(-time.Minute) // created 1 min ago
|
||||
v.registerCtx(c)
|
||||
|
||||
_, err := v.getCtx("orphan-1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
v.reapOrphanedContexts(10 * time.Second)
|
||||
|
||||
_, err = v.getCtx("orphan-1")
|
||||
assert.Error(t, err, "orphaned context should have been reaped")
|
||||
}
|
||||
|
||||
func TestReaperIgnoresConnectedContexts(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("connected-1", "/", v)
|
||||
c.createdAt = time.Now().Add(-time.Minute)
|
||||
c.sseConnected.Store(true)
|
||||
v.registerCtx(c)
|
||||
|
||||
v.reapOrphanedContexts(10 * time.Second)
|
||||
|
||||
_, err := v.getCtx("connected-1")
|
||||
assert.NoError(t, err, "connected context should survive reaping")
|
||||
}
|
||||
|
||||
func TestReaperDisabledWithNegativeTTL(t *testing.T) {
|
||||
v := New()
|
||||
v.cfg.ContextTTL = -1
|
||||
v.startReaper()
|
||||
assert.Nil(t, v.reaperStop, "reaper should not start with negative TTL")
|
||||
}
|
||||
|
||||
func TestCleanupCtxIdempotent(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("idempotent-1", "/", v)
|
||||
v.registerCtx(c)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
v.cleanupCtx(c)
|
||||
v.cleanupCtx(c)
|
||||
})
|
||||
|
||||
_, err := v.getCtx("idempotent-1")
|
||||
assert.Error(t, err, "context should be removed after cleanup")
|
||||
}
|
||||
|
||||
func TestDevModeRemovePersistedFix(t *testing.T) {
|
||||
v := New()
|
||||
v.cfg.DevMode = true
|
||||
|
||||
dir := filepath.Join(t.TempDir(), ".via", "devmode")
|
||||
p := filepath.Join(dir, "ctx.json")
|
||||
assert.NoError(t, os.MkdirAll(dir, 0755))
|
||||
|
||||
// Write a persisted context
|
||||
ctxRegMap := map[string]string{"test-ctx-1": "/"}
|
||||
f, err := os.Create(p)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, json.NewEncoder(f).Encode(ctxRegMap))
|
||||
f.Close()
|
||||
|
||||
// Patch devModeRemovePersisted to use our temp path by calling it
|
||||
// directly — we need to override the path. Instead, test via the
|
||||
// actual function by temporarily changing the working dir.
|
||||
origDir, _ := os.Getwd()
|
||||
assert.NoError(t, os.Chdir(t.TempDir()))
|
||||
defer os.Chdir(origDir)
|
||||
|
||||
// Re-create the structure in the temp dir
|
||||
assert.NoError(t, os.MkdirAll(filepath.Join(".via", "devmode"), 0755))
|
||||
p2 := filepath.Join(".via", "devmode", "ctx.json")
|
||||
f2, _ := os.Create(p2)
|
||||
json.NewEncoder(f2).Encode(map[string]string{"test-ctx-1": "/"})
|
||||
f2.Close()
|
||||
|
||||
c := newContext("test-ctx-1", "/", v)
|
||||
v.devModeRemovePersisted(c)
|
||||
|
||||
// Read back and verify
|
||||
f3, err := os.Open(p2)
|
||||
assert.NoError(t, err)
|
||||
defer f3.Close()
|
||||
var result map[string]string
|
||||
assert.NoError(t, json.NewDecoder(f3).Decode(&result))
|
||||
assert.Empty(t, result, "persisted context should be removed")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user