10 Commits

Author SHA1 Message Date
Ryan Hamamura
10b4838f8d feat: auto-track fields on context for zero-arg ValidateAll/ResetFields
Fields created via Context.Field are now tracked on the page context,
so ValidateAll() and ResetFields() with no arguments operate on all
fields by default. Explicit field args still work for selective use.

Also switches MinLen/MaxLen to utf8.RuneCountInString for correct
unicode handling and replaces fmt.Errorf with errors.New where
format strings are unnecessary.
2026-02-11 19:57:13 -10:00
Ryan Hamamura
5362614c3e feat: add field validation API with signup form example
Introduces Field, Rule, ValidateAll, ResetFields, and AddError for
declarative input validation. Includes built-in rules (Required,
MinLen, MaxLen, Min, Max, Email, Pattern, Custom) and a signup
example exercising the full API surface.
2026-02-11 14:42:44 -10:00
ryanhamamura
e636970f7b feat: add middleware, route groups, and codebase cleanup
* feat: add middleware example demonstrating route groups

Self-contained example covering v.Use(), v.Group(), nested groups,
Group.Use(), and middleware chaining with role-based access control.

* feat: add per-action middleware via WithMiddleware ActionOption

Reuses the existing Middleware type so the same auth/logging functions
work at both page and action level. Middleware runs after CSRF and
rate-limit checks, with full access to session and signals.

* feat: add RedirectView helper and refactor session example to use middleware

RedirectView lets middleware abort and redirect in one step. The session
example now uses an authRequired middleware on a route group instead of
an inline check inside the view.

* fix: remove dead code, fix double Load and extractParams mismatch

- Remove componentRegistry (written, never read)
- Remove unused signal methods: Bytes, Int64, Float
- Remove unreachable nil check in registerCtx
- Simplify injectRouteParams (extractParams already returns fresh map)
- Fix double sync.Map.Load in injectSignals
- Merge Shutdown/shutdown into single method
- Inline currSessionNum
- Fix extractParams: mismatched literal segment now returns nil
- Minor: new(bytes.Buffer), go c.Sync(), genRandID reads 4 bytes
2026-02-11 13:50:02 -10:00
Ryan Hamamura
f5158b866c feat: add Static and StaticFS helpers for serving static files
One-liner static file serving: v.Static("/assets/", "./public") for
filesystem directories and v.StaticFS("/assets/", fsys) for embed.FS.
Both auto-normalize the URL prefix and disable directory listings.
2026-02-06 13:22:00 -10:00
Ryan Hamamura
2f6c5916ce docs: rewrite README with correct import paths and current feature set 2026-02-06 12:56:31 -10:00
Ryan Hamamura
0762ddbbc2 feat: add token-bucket rate limiting for action endpoints
Add per-context and per-action rate limiting using golang.org/x/time/rate.
Configure globally via Options.ActionRateLimit or per-action with
WithRateLimit(). Defaults to 10 req/s with burst of 20.
2026-02-06 11:52:07 -10:00
Ryan Hamamura
b7acfa6302 feat: add automatic CSRF protection for action calls
Generate a per-context CSRF token (128-bit, crypto/rand) and inject it
as a Datastar signal (via-csrf) alongside via-ctx. Validate with
constant-time comparison on /_action/{id} before executing, returning
403 on mismatch. Transparent to users since Datastar sends all signals
automatically.

Closes #9
2026-02-06 11:17:41 -10:00
Ryan Hamamura
8aa91c577c feat: add event types OnSubmit, OnInput, OnFocus, OnBlur, OnMouseEnter, OnMouseLeave, OnScroll, OnDblClick 2026-02-06 10:54:27 -10:00
Ryan Hamamura
6dcd54c88b fix: clean up leaked contexts on SSE disconnect and add orphan reaper
When clients disconnect without beforeunload firing (network drops,
mobile kills, crashes), contexts leaked in the registry permanently.

- Extract cleanupCtx helper for dispose/unregister sequence
- Call cleanupCtx on SSE disconnect (sse.Context().Done())
- Add background reaper for contexts where SSE never connected
- Add ContextTTL config option (default 30s, negative disables)
- Fix inverted condition in devModeRemovePersisted
2026-02-06 10:34:28 -10:00
Ryan Hamamura
2c44671d0e feat: add generic pub/sub helpers and pubsub-crud example
Add typed Publish[T] and Subscribe[T] generic helpers that handle
JSON marshaling, along with vianats.EnsureStream and ReplayHistory
helpers. Refactor nats-chatroom to use the new APIs.

Add pubsub-crud example demonstrating CRUD operations with DaisyUI
toast notifications broadcast to all connected clients via NATS.
2026-02-06 09:47:39 -10:00
25 changed files with 2427 additions and 168 deletions

View File

@@ -1,30 +1,33 @@
# Via
# Via
Real-time engine for building reactive web applications in pure Go.
## Why Via?
Somewhere along the way, the web became tangled in layers of JavaScript, build chains, and frameworks stacked on frameworks.
Via takes a radical stance:
The web became tangled in layers of JavaScript, build chains, and frameworks stacked on frameworks. Via takes a different path.
- No templates.
- No JavaScript.
- No transpilation.
- No hydration.
- No front-end fatigue.
- Single SSE stream.
- Full reactivity.
- Built-in Brotli compression.
- Pure Go.
**Philosophy**
- No templates. No JavaScript. No transpilation. No hydration.
- Views are pure Go functions. HTML is composed with a type-safe DSL.
- A single SSE stream carries all reactivity — no WebSocket juggling, no polling.
**Batteries included**
- Automatic CSRF protection on every action call
- Token-bucket rate limiting (global defaults + per-action overrides)
- Cookie-based sessions backed by SQLite
- Pub/sub messaging with an embedded NATS backend
- Structured logging via zerolog
- Graceful shutdown with context draining
- Brotli compression out of the box
## Example
```go
package main
import (
"github.com/go-via/via"
"github.com/go-via/via/h"
"github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h"
)
type Counter struct{ Count int }
@@ -57,25 +60,43 @@ func main() {
}
```
## What's built in
## 🚧 Experimental
<s>Via is still a newborn.</s> Via is taking its first steps!
- Version `0.1.0` released.
- Expect a little less chaos.
- **Reactive views + signals** — bind state to the DOM; changes push over SSE automatically
- **Components** — self-contained subcontexts with their own data, actions, and signals
- **Sessions** — cookie-based, backed by SQLite via `scs`
- **Pub/sub** — embedded NATS server with JetStream; generic `Publish[T]` / `Subscribe[T]` helpers
- **CSRF protection** — automatic token generation and validation on every action
- **Rate limiting** — token-bucket algorithm, configurable globally and per-action
- **Event handling** — `OnClick`, `OnChange`, `OnSubmit`, `OnInput`, `OnFocus`, `OnBlur`, `OnMouseEnter`, `OnMouseLeave`, `OnScroll`, `OnDblClick`, `OnKeyDown`, and `OnKeyDownMap` for multi-key bindings
- **Timed routines** — `OnInterval` with start/stop/update controls, tied to context lifecycle
- **Redirects** — `Redirect`, `ReplaceURL`, and format-string variants
- **Plugin system** — `func(v *V)` hooks for integrating CSS/JS libraries
- **Structured logging** — zerolog with configurable levels; console output in dev, JSON in production
- **Graceful shutdown** — listens for SIGINT/SIGTERM, drains contexts, closes pub/sub
- **Context lifecycle** — background reaper cleans up disconnected contexts; configurable TTL
- **HTML DSL** — the `h` package provides type-safe Go-native HTML composition
## Examples
The `internal/examples/` directory contains 14 runnable examples:
`chatroom` · `counter` · `countercomp` · `greeter` · `keyboard` · `livereload` · `nats-chatroom` · `pathparams` · `picocss` · `plugins` · `pubsub-crud` · `realtimechart` · `session` · `shakespeare`
## Experimental
Via is maturing — sessions, CSRF, rate limiting, pub/sub, and graceful shutdown are in place — but the API is still evolving. Expect breaking changes before `v1`.
## Contributing
- Via is intentionally minimal and opinionated — and so is contributing.
- If you love Go, simplicity, and meaningful abstractions — Come along for the ride!
- Fork, branch, build, tinker with things, submit a pull request.
- Fork, branch, build, tinker, submit a pull request.
- Keep every line purposeful.
- Share feedback: open an issue or start a discussion.
## Credits
Via builds upon the work of these amazing projects:
Via builds upon the work of these projects:
- 🚀 [Datastar](https://data-star.dev) - The hypermedia powerhouse at the core of Via. It powers browser reactivity through Signals and enables real-time HTML/Signal patches over an always-on SSE event stream.
- 🧩 [Gomponents](https://maragu.dev/gomponents) - The awesome project that gifts Via with Go-native HTML composition superpowers through the `via/h` package.
> Thank you for building something that doesnt just function — it inspires. 🫶
- [Datastar](https://data-star.dev) — the hypermedia framework powering browser reactivity through signals and real-time HTML patches over SSE.
- [Gomponents](https://maragu.dev/gomponents) Go-native HTML composition that powers the `via/h` package.

View File

@@ -107,6 +107,54 @@ func (a *actionTrigger) OnChange(options ...ActionTriggerOption) h.H {
return h.Data("on:change__debounce.200ms", buildOnExpr(actionURL(a.id), &opts))
}
// OnSubmit returns a via.h DOM attribute that triggers on form submit.
func (a *actionTrigger) OnSubmit(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:submit", buildOnExpr(actionURL(a.id), &opts))
}
// OnInput returns a via.h DOM attribute that triggers on input (without debounce).
func (a *actionTrigger) OnInput(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:input", buildOnExpr(actionURL(a.id), &opts))
}
// OnFocus returns a via.h DOM attribute that triggers when the element gains focus.
func (a *actionTrigger) OnFocus(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:focus", buildOnExpr(actionURL(a.id), &opts))
}
// OnBlur returns a via.h DOM attribute that triggers when the element loses focus.
func (a *actionTrigger) OnBlur(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:blur", buildOnExpr(actionURL(a.id), &opts))
}
// OnMouseEnter returns a via.h DOM attribute that triggers when the mouse enters the element.
func (a *actionTrigger) OnMouseEnter(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:mouseenter", buildOnExpr(actionURL(a.id), &opts))
}
// OnMouseLeave returns a via.h DOM attribute that triggers when the mouse leaves the element.
func (a *actionTrigger) OnMouseLeave(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:mouseleave", buildOnExpr(actionURL(a.id), &opts))
}
// OnScroll returns a via.h DOM attribute that triggers on scroll.
func (a *actionTrigger) OnScroll(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:scroll", buildOnExpr(actionURL(a.id), &opts))
}
// OnDblClick returns a via.h DOM attribute that triggers on double click.
func (a *actionTrigger) OnDblClick(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:dblclick", buildOnExpr(actionURL(a.id), &opts))
}
// OnKeyDown returns a via.h DOM attribute that triggers when a key is pressed.
// key: optional, see https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key
// Example: OnKeyDown("Enter")

View File

@@ -1,6 +1,8 @@
package via
import (
"time"
"github.com/alexedwards/scs/v2"
"github.com/rs/zerolog"
)
@@ -54,4 +56,14 @@ type Options struct {
// PubSub enables publish/subscribe messaging. Use vianats.New() for an
// embedded NATS backend, or supply any PubSub implementation.
PubSub PubSub
// ContextTTL is the maximum time a context may exist without an SSE
// connection before the background reaper disposes it.
// Default: 30s. Negative value disables the reaper.
ContextTTL time.Duration
// ActionRateLimit configures the default token-bucket rate limiter for
// action endpoints. Zero values use built-in defaults (10 req/s, burst 20).
// Set Rate to -1 to disable rate limiting entirely.
ActionRateLimit RateLimitConfig
}

View File

@@ -5,12 +5,13 @@ import (
"context"
"encoding/json"
"fmt"
"maps"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/ryanhamamura/via/h"
"golang.org/x/time/rate"
)
// Context is the living bridge between Go and the browser.
@@ -19,20 +20,24 @@ import (
type Context struct {
id string
route string
csrfToken string
app *V
view func() h.H
routeParams map[string]string
componentRegistry map[string]*Context
parentPageCtx *Context
patchChan chan patch
actionRegistry map[string]func()
actionLimiter *rate.Limiter
actionRegistry map[string]actionEntry
signals *sync.Map
mu sync.RWMutex
ctxDisposedChan chan struct{}
reqCtx context.Context
fields []*Field
subscriptions []Subscription
subsMu sync.Mutex
disposeOnce sync.Once
createdAt time.Time
sseConnected atomic.Bool
}
// View defines the UI rendered by this context.
@@ -75,7 +80,6 @@ func (c *Context) Component(initCtx func(c *Context)) func() h.H {
compCtx.parentPageCtx = c
}
initCtx(compCtx)
c.componentRegistry[id] = compCtx
return compCtx.view
}
@@ -100,26 +104,31 @@ func (c *Context) isComponent() bool {
// h.Button(h.Text("Increment n"), increment.OnClick()),
// )
// })
func (c *Context) Action(f func()) *actionTrigger {
func (c *Context) Action(f func(), opts ...ActionOption) *actionTrigger {
id := genRandID()
if f == nil {
c.app.logErr(c, "failed to bind action '%s' to context: nil func", id)
return nil
}
entry := actionEntry{fn: f}
for _, opt := range opts {
opt(&entry)
}
if c.isComponent() {
c.parentPageCtx.actionRegistry[id] = f
c.parentPageCtx.actionRegistry[id] = entry
} else {
c.actionRegistry[id] = f
c.actionRegistry[id] = entry
}
return &actionTrigger{id}
}
func (c *Context) getActionFn(id string) (func(), error) {
if f, ok := c.actionRegistry[id]; ok {
return f, nil
func (c *Context) getAction(id string) (actionEntry, error) {
if e, ok := c.actionRegistry[id]; ok {
return e, nil
}
return nil, fmt.Errorf("action '%s' not found", id)
return actionEntry{}, fmt.Errorf("action '%s' not found", id)
}
// OnInterval starts a go routine that sets a time.Ticker with the given duration and executes
@@ -197,14 +206,14 @@ func (c *Context) injectSignals(sigs map[string]any) {
defer c.mu.Unlock()
for sigID, val := range sigs {
if _, ok := c.signals.Load(sigID); !ok {
item, ok := c.signals.Load(sigID)
if !ok {
c.signals.Store(sigID, &signal{
id: sigID,
val: val,
})
continue
}
item, _ := c.signals.Load(sigID)
if sig, ok := item.(*signal); ok {
sig.val = val
sig.changed = false
@@ -255,7 +264,7 @@ func (c *Context) sendPatch(p patch) {
// Sync pushes the current view state and signal changes to the browser immediately
// over the live SSE event stream.
func (c *Context) Sync() {
elemsPatch := bytes.NewBuffer(make([]byte, 0))
elemsPatch := new(bytes.Buffer)
if err := c.view().Render(elemsPatch); err != nil {
c.app.logErr(c, "sync view failed: %v", err)
return
@@ -320,6 +329,15 @@ func (c *Context) ExecScript(s string) {
c.sendPatch(patch{patchTypeScript, s})
}
// RedirectView sets a view that redirects the browser to the given URL.
// Use this in middleware to abort the chain and redirect in one step.
func (c *Context) RedirectView(url string) {
c.View(func() h.H {
c.Redirect(url)
return h.Div()
})
}
// Redirect navigates the browser to the given URL.
// This triggers a full page navigation - the current context will be disposed
// and a new context created at the destination URL.
@@ -375,12 +393,9 @@ func (c *Context) injectRouteParams(params map[string]string) {
if params == nil {
return
}
m := make(map[string]string)
c.mu.Lock()
defer c.mu.Unlock()
maps.Copy(m, params)
c.routeParams = m
c.routeParams = params
}
// GetPathParam retrieves the value from the page request URL for the given parameter name
@@ -466,6 +481,50 @@ func (c *Context) unsubscribeAll() {
}
}
// Field creates a signal with validation rules attached.
// The initial value seeds both the signal and the reset target.
// The field is tracked on the context so ValidateAll/ResetFields
// can operate on all fields by default.
func (c *Context) Field(initial any, rules ...Rule) *Field {
f := &Field{
signal: c.Signal(initial),
rules: rules,
initialVal: initial,
}
target := c
if c.isComponent() {
target = c.parentPageCtx
}
target.fields = append(target.fields, f)
return f
}
// ValidateAll runs Validate on each field, returning true only if all pass.
// With no arguments it validates every field tracked on this context.
func (c *Context) ValidateAll(fields ...*Field) bool {
if len(fields) == 0 {
fields = c.fields
}
ok := true
for _, f := range fields {
if !f.Validate() {
ok = false
}
}
return ok
}
// ResetFields resets each field to its initial value and clears errors.
// With no arguments it resets every field tracked on this context.
func (c *Context) ResetFields(fields ...*Field) {
if len(fields) == 0 {
fields = c.fields
}
for _, f := range fields {
f.Reset()
}
}
func newContext(id string, route string, v *V) *Context {
if v == nil {
panic("create context failed: app pointer is nil")
@@ -474,12 +533,14 @@ func newContext(id string, route string, v *V) *Context {
return &Context{
id: id,
route: route,
csrfToken: genCSRFToken(),
routeParams: make(map[string]string),
app: v,
componentRegistry: make(map[string]*Context),
actionRegistry: make(map[string]func()),
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
actionRegistry: make(map[string]actionEntry),
signals: new(sync.Map),
patchChan: make(chan patch, 1),
ctxDisposedChan: make(chan struct{}, 1),
createdAt: time.Now(),
}
}

58
field.go Normal file
View File

@@ -0,0 +1,58 @@
package via
// Field is a signal with built-in validation rules and error state.
// It embeds *signal, so all signal methods (Bind, String, Int, Bool, SetValue, Text, ID)
// work transparently.
type Field struct {
*signal
rules []Rule
errors []string
initialVal any
}
// Validate runs all rules against the current value.
// Clears previous errors, populates new ones, returns true if all rules pass.
func (f *Field) Validate() bool {
f.errors = nil
val := f.String()
for _, r := range f.rules {
if err := r.validate(val); err != nil {
f.errors = append(f.errors, err.Error())
}
}
return len(f.errors) == 0
}
// HasError returns true if this field has any validation errors.
func (f *Field) HasError() bool {
return len(f.errors) > 0
}
// FirstError returns the first validation error message, or "" if valid.
func (f *Field) FirstError() string {
if len(f.errors) > 0 {
return f.errors[0]
}
return ""
}
// Errors returns all current validation error messages.
func (f *Field) Errors() []string {
return f.errors
}
// AddError manually adds an error message (useful for server-side or cross-field validation).
func (f *Field) AddError(msg string) {
f.errors = append(f.errors, msg)
}
// ClearErrors removes all validation errors from this field.
func (f *Field) ClearErrors() {
f.errors = nil
}
// Reset restores the field value to its initial value and clears all errors.
func (f *Field) Reset() {
f.SetValue(f.initialVal)
f.errors = nil
}

206
field_test.go Normal file
View File

@@ -0,0 +1,206 @@
package via
import (
"fmt"
"testing"
"github.com/ryanhamamura/via/h"
"github.com/stretchr/testify/assert"
)
func newTestField(initial any, rules ...Rule) *Field {
v := New()
var f *Field
v.Page("/", func(c *Context) {
f = c.Field(initial, rules...)
c.View(func() h.H { return h.Div() })
})
return f
}
func TestFieldCreation(t *testing.T) {
f := newTestField("hello", Required())
assert.Equal(t, "hello", f.String())
assert.NotEmpty(t, f.ID())
}
func TestFieldSignalDelegation(t *testing.T) {
f := newTestField(42)
assert.Equal(t, "42", f.String())
assert.Equal(t, 42, f.Int())
f.SetValue("new")
assert.Equal(t, "new", f.String())
// Bind returns an h.H element
assert.NotNil(t, f.Bind())
}
func TestFieldValidateSingleRule(t *testing.T) {
f := newTestField("", Required())
assert.False(t, f.Validate())
assert.True(t, f.HasError())
assert.Equal(t, "This field is required", f.FirstError())
f.SetValue("ok")
assert.True(t, f.Validate())
assert.False(t, f.HasError())
assert.Equal(t, "", f.FirstError())
}
func TestFieldValidateMultipleRules(t *testing.T) {
f := newTestField("ab", Required(), MinLen(3))
assert.False(t, f.Validate())
errs := f.Errors()
assert.Len(t, errs, 1)
assert.Equal(t, "Must be at least 3 characters", errs[0])
f.SetValue("")
assert.False(t, f.Validate())
errs = f.Errors()
assert.Len(t, errs, 2)
}
func TestFieldErrors(t *testing.T) {
f := newTestField("")
assert.Nil(t, f.Errors())
assert.False(t, f.HasError())
assert.Equal(t, "", f.FirstError())
}
func TestFieldAddError(t *testing.T) {
f := newTestField("ok")
f.AddError("username taken")
assert.True(t, f.HasError())
assert.Equal(t, "username taken", f.FirstError())
assert.Len(t, f.Errors(), 1)
}
func TestFieldClearErrors(t *testing.T) {
f := newTestField("", Required())
f.Validate()
assert.True(t, f.HasError())
f.ClearErrors()
assert.False(t, f.HasError())
}
func TestFieldReset(t *testing.T) {
f := newTestField("initial", Required(), MinLen(3))
f.SetValue("changed")
f.AddError("some error")
f.Reset()
assert.Equal(t, "initial", f.String())
assert.False(t, f.HasError())
}
func TestValidateAll(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
c.Field("", Required(), MinLen(3))
c.Field("", Required(), Email())
c.View(func() h.H { return h.Div() })
// both empty → both fail
assert.False(t, c.ValidateAll())
})
v2 := New()
v2.Page("/", func(c *Context) {
c.Field("joe", Required(), MinLen(3))
c.Field("joe@x.com", Required(), Email())
c.View(func() h.H { return h.Div() })
assert.True(t, c.ValidateAll())
})
}
func TestValidateAllPartialFailure(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
good := c.Field("valid", Required())
bad := c.Field("", Required())
c.View(func() h.H { return h.Div() })
ok := c.ValidateAll()
assert.False(t, ok)
assert.False(t, good.HasError())
assert.True(t, bad.HasError())
})
}
func TestValidateAllSelectiveArgs(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
a := c.Field("", Required())
b := c.Field("ok", Required())
cField := c.Field("", Required())
c.View(func() h.H { return h.Div() })
// only validate a and b — cField should be untouched
ok := c.ValidateAll(a, b)
assert.False(t, ok)
assert.True(t, a.HasError())
assert.False(t, b.HasError())
assert.False(t, cField.HasError(), "unselected field should not be validated")
})
}
func TestResetFields(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
a := c.Field("a", Required())
b := c.Field("b", Required())
c.View(func() h.H { return h.Div() })
a.SetValue("changed-a")
b.SetValue("changed-b")
a.AddError("err")
c.ResetFields()
assert.Equal(t, "a", a.String())
assert.Equal(t, "b", b.String())
assert.False(t, a.HasError())
})
}
func TestResetFieldsSelectiveArgs(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
a := c.Field("a")
b := c.Field("b")
c.View(func() h.H { return h.Div() })
a.SetValue("changed-a")
b.SetValue("changed-b")
// only reset a
c.ResetFields(a)
assert.Equal(t, "a", a.String())
assert.Equal(t, "changed-b", b.String(), "unselected field should not be reset")
})
}
func TestFieldValidateClearsPreviousErrors(t *testing.T) {
f := newTestField("", Required())
f.Validate()
assert.True(t, f.HasError())
f.SetValue("ok")
f.Validate()
assert.False(t, f.HasError())
}
func TestFieldCustomValidator(t *testing.T) {
f := newTestField("bad", Custom(func(val string) error {
if val == "bad" {
return fmt.Errorf("no bad words")
}
return nil
}))
assert.False(t, f.Validate())
assert.Equal(t, "no bad words", f.FirstError())
f.SetValue("good")
assert.True(t, f.Validate())
}

3
go.mod
View File

@@ -14,6 +14,7 @@ require (
github.com/rs/zerolog v1.34.0
github.com/starfederation/datastar-go v1.0.3
github.com/stretchr/testify v1.11.1
golang.org/x/time v0.14.0
)
require (
@@ -37,6 +38,6 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,151 @@
package main
import (
"fmt"
"time"
"github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h"
)
func main() {
v := via.New()
v.Config(via.Options{
ServerAddress: ":8080",
DocumentTitle: "Middleware Example",
})
// --- Middleware definitions ---
// requestLogger logs every page request to stdout.
requestLogger := func(c *via.Context, next func()) {
fmt.Printf("[%s] request\n", time.Now().Format("15:04:05"))
next()
}
// authRequired redirects unauthenticated users to /login.
authRequired := func(c *via.Context, next func()) {
if c.Session().GetString("role") == "" {
c.RedirectView("/login")
return
}
next()
}
// auditLog prints the authenticated username to stdout.
auditLog := func(c *via.Context, next func()) {
fmt.Printf("[audit] user=%s\n", c.Session().GetString("username"))
next()
}
// superAdminOnly rejects non-superadmin users with a forbidden view.
superAdminOnly := func(c *via.Context, next func()) {
if c.Session().GetString("role") != "superadmin" {
c.View(func() h.H {
return h.Div(
h.H1(h.Text("Forbidden")),
h.P(h.Text("Super-admin access required.")),
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
)
})
return
}
next()
}
// --- Route registration ---
v.Use(requestLogger) // global middleware
admin := v.Group("/admin", authRequired) // prefixed group
admin.Use(auditLog) // Group.Use()
superAdmin := admin.Group("/super", superAdminOnly) // nested group
// Public: redirect root to login
v.Page("/", func(c *via.Context) {
c.View(func() h.H {
c.Redirect("/login")
return h.Div()
})
})
// Public: login page with role-selection buttons
v.Page("/login", func(c *via.Context) {
loginAdmin := c.Action(func() {
c.Session().Set("role", "admin")
c.Session().Set("username", "alice")
c.Session().RenewToken()
c.Redirect("/admin/dashboard")
})
loginSuper := c.Action(func() {
c.Session().Set("role", "superadmin")
c.Session().Set("username", "bob")
c.Session().RenewToken()
c.Redirect("/admin/dashboard")
})
c.View(func() h.H {
return h.Div(
h.H1(h.Text("Login")),
h.P(h.Text("Choose a role:")),
h.Button(h.Text("Login as Admin"), loginAdmin.OnClick()),
h.Raw(" "),
h.Button(h.Text("Login as Super Admin"), loginSuper.OnClick()),
)
})
})
// Per-action middleware: only superadmins can invoke this action.
requireSuperAdmin := func(c *via.Context, next func()) {
if c.Session().GetString("role") != "superadmin" {
return
}
next()
}
// Admin: dashboard (requires authRequired + auditLog)
admin.Page("/dashboard", func(c *via.Context) {
logout := c.Action(func() {
c.Session().Delete("role")
c.Session().Delete("username")
c.Redirect("/login")
})
dangerAction := c.Action(func() {
fmt.Printf("[danger] executed by %s\n", c.Session().GetString("username"))
c.Sync()
}, via.WithMiddleware(requireSuperAdmin))
c.View(func() h.H {
username := c.Session().GetString("username")
role := c.Session().GetString("role")
return h.Div(
h.H1(h.Textf("Dashboard — %s (%s)", username, role)),
h.Ul(
h.Li(h.A(h.Href("/admin/super/settings"), h.Text("Super Admin Settings"))),
),
h.H2(h.Text("Danger Zone")),
h.P(h.Text("This action is protected by per-action middleware (superadmin only):")),
h.Button(h.Text("Delete Everything"), dangerAction.OnClick()),
h.Br(),
h.Br(),
h.Button(h.Text("Logout"), logout.OnClick()),
)
})
})
// Super-admin: settings (requires authRequired + auditLog + superAdminOnly)
superAdmin.Page("/settings", func(c *via.Context) {
c.View(func() h.H {
username := c.Session().GetString("username")
return h.Div(
h.H1(h.Textf("Super Admin Settings — %s", username)),
h.P(h.Text("Only super-admins can see this page.")),
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
)
})
})
v.Start()
}

View File

@@ -2,13 +2,11 @@ package main
import (
"context"
"encoding/json"
"log"
"math/rand"
"sync"
"time"
"github.com/nats-io/nats.go"
"github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h"
"github.com/ryanhamamura/via/vianats"
@@ -46,15 +44,15 @@ func main() {
}
defer ps.Close()
// Create JetStream stream for message durability
js := ps.JetStream()
js.AddStream(&nats.StreamConfig{
err = vianats.EnsureStream(ps, vianats.StreamConfig{
Name: "CHAT",
Subjects: []string{"chat.>"},
Retention: nats.LimitsPolicy,
MaxMsgs: 1000,
MaxAge: 24 * time.Hour,
})
if err != nil {
log.Fatalf("Failed to ensure stream: %v", err)
}
v := via.New()
v.Config(via.Options{
@@ -147,30 +145,14 @@ func main() {
currentSub.Unsubscribe()
}
// Replay history from JetStream before subscribing for real-time
subject := "chat.room." + room
if hist, err := js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer()); err == nil {
for {
msg, err := hist.NextMsg(200 * time.Millisecond)
if err != nil {
break
}
var chatMsg ChatMessage
if json.Unmarshal(msg.Data, &chatMsg) == nil {
messages = append(messages, chatMsg)
}
}
hist.Unsubscribe()
if len(messages) > 50 {
messages = messages[len(messages)-50:]
}
// Replay history from JetStream
if hist, err := vianats.ReplayHistory[ChatMessage](ps, subject, 50); err == nil {
messages = hist
}
sub, _ := c.Subscribe(subject, func(data []byte) {
var msg ChatMessage
if err := json.Unmarshal(data, &msg); err != nil {
return
}
sub, _ := via.Subscribe(c, subject, func(msg ChatMessage) {
messagesMu.Lock()
messages = append(messages, msg)
if len(messages) > 50 {
@@ -203,12 +185,11 @@ func main() {
}
statement.SetValue("")
data, _ := json.Marshal(ChatMessage{
via.Publish(c, "chat.room."+currentRoom, ChatMessage{
User: currentUser,
Message: msg,
Time: time.Now().UnixMilli(),
})
c.Publish("chat.room."+currentRoom, data)
})
c.View(func() h.H {

View File

@@ -0,0 +1,284 @@
package main
import (
"context"
"crypto/rand"
"fmt"
"html"
"log"
"sync"
"time"
"github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h"
"github.com/ryanhamamura/via/vianats"
)
var WithSignal = via.WithSignal
type Bookmark struct {
ID string
Title string
URL string
}
type CRUDEvent struct {
Action string `json:"action"`
Title string `json:"title"`
UserID string `json:"user_id"`
}
var (
bookmarks []Bookmark
bookmarksMu sync.RWMutex
)
func randomHex(n int) string {
b := make([]byte, n)
rand.Read(b)
return fmt.Sprintf("%x", b)
}
func findBookmark(id string) (Bookmark, int) {
for i, bm := range bookmarks {
if bm.ID == id {
return bm, i
}
}
return Bookmark{}, -1
}
func main() {
ctx := context.Background()
ps, err := vianats.New(ctx, "./data/nats")
if err != nil {
log.Fatalf("Failed to start embedded NATS: %v", err)
}
defer ps.Close()
err = vianats.EnsureStream(ps, vianats.StreamConfig{
Name: "BOOKMARKS",
Subjects: []string{"bookmarks.>"},
MaxMsgs: 1000,
MaxAge: 24 * time.Hour,
})
if err != nil {
log.Fatalf("Failed to ensure stream: %v", err)
}
v := via.New()
v.Config(via.Options{
DevMode: true,
DocumentTitle: "Bookmarks",
LogLevel: via.LogLevelInfo,
ServerAddress: ":7331",
PubSub: ps,
})
v.AppendToHead(
h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/daisyui@4/dist/full.min.css")),
h.Script(h.Src("https://cdn.tailwindcss.com")),
)
v.Page("/", func(c *via.Context) {
userID := randomHex(8)
titleSignal := c.Signal("")
urlSignal := c.Signal("")
targetIDSignal := c.Signal("")
via.Subscribe(c, "bookmarks.events", func(evt CRUDEvent) {
if evt.UserID == userID {
return
}
safeTitle := html.EscapeString(evt.Title)
var alertClass string
switch evt.Action {
case "created":
alertClass = "alert-success"
case "updated":
alertClass = "alert-info"
case "deleted":
alertClass = "alert-error"
}
c.ExecScript(fmt.Sprintf(`(function(){
var tc = document.getElementById('toast-container');
if (!tc) return;
var d = document.createElement('div');
d.className = 'alert %s';
d.innerHTML = '<span>Bookmark "%s" %s</span>';
tc.appendChild(d);
setTimeout(function(){ d.remove(); }, 3000);
})()`, alertClass, safeTitle, evt.Action))
c.Sync()
})
save := c.Action(func() {
title := titleSignal.String()
url := urlSignal.String()
if title == "" || url == "" {
return
}
targetID := targetIDSignal.String()
action := "created"
bookmarksMu.Lock()
if targetID != "" {
if _, idx := findBookmark(targetID); idx >= 0 {
bookmarks[idx].Title = title
bookmarks[idx].URL = url
action = "updated"
}
} else {
bookmarks = append(bookmarks, Bookmark{
ID: randomHex(8),
Title: title,
URL: url,
})
}
bookmarksMu.Unlock()
titleSignal.SetValue("")
urlSignal.SetValue("")
targetIDSignal.SetValue("")
via.Publish(c, "bookmarks.events", CRUDEvent{
Action: action,
Title: title,
UserID: userID,
})
c.Sync()
})
edit := c.Action(func() {
id := targetIDSignal.String()
bookmarksMu.RLock()
bm, idx := findBookmark(id)
bookmarksMu.RUnlock()
if idx < 0 {
return
}
titleSignal.SetValue(bm.Title)
urlSignal.SetValue(bm.URL)
})
del := c.Action(func() {
id := targetIDSignal.String()
bookmarksMu.Lock()
bm, idx := findBookmark(id)
if idx >= 0 {
bookmarks = append(bookmarks[:idx], bookmarks[idx+1:]...)
}
bookmarksMu.Unlock()
if idx < 0 {
return
}
targetIDSignal.SetValue("")
via.Publish(c, "bookmarks.events", CRUDEvent{
Action: "deleted",
Title: bm.Title,
UserID: userID,
})
c.Sync()
})
cancelEdit := c.Action(func() {
titleSignal.SetValue("")
urlSignal.SetValue("")
targetIDSignal.SetValue("")
})
c.View(func() h.H {
isEditing := targetIDSignal.String() != ""
// Build table rows
bookmarksMu.RLock()
var rows []h.H
for _, bm := range bookmarks {
rows = append(rows, h.Tr(
h.Td(h.Text(bm.Title)),
h.Td(h.A(h.Href(bm.URL), h.Attr("target", "_blank"), h.Class("link link-primary"), h.Text(bm.URL))),
h.Td(
h.Div(h.Class("flex gap-1"),
h.Button(h.Class("btn btn-xs btn-ghost"), h.Text("Edit"),
edit.OnClick(WithSignal(targetIDSignal, bm.ID)),
),
h.Button(h.Class("btn btn-xs btn-ghost text-error"), h.Text("Delete"),
del.OnClick(WithSignal(targetIDSignal, bm.ID)),
),
),
),
))
}
bookmarksMu.RUnlock()
saveLabel := "Add Bookmark"
if isEditing {
saveLabel = "Update Bookmark"
}
return h.Div(h.Class("min-h-screen bg-base-200"),
// Navbar
h.Div(h.Class("navbar bg-base-100 shadow-sm"),
h.Div(h.Class("flex-1"),
h.A(h.Class("btn btn-ghost text-xl"), h.Text("Bookmarks")),
),
h.Div(h.Class("flex-none"),
h.Div(h.Class("badge badge-outline"), h.Text(userID[:8])),
),
),
h.Div(h.Class("container mx-auto p-4 max-w-3xl flex flex-col gap-4"),
// Form card
h.Div(h.Class("card bg-base-100 shadow"),
h.Div(h.Class("card-body"),
h.H2(h.Class("card-title"), h.Text(saveLabel)),
h.Div(h.Class("flex flex-col gap-2"),
h.Input(h.Class("input input-bordered w-full"), h.Type("text"), h.Placeholder("Title"), titleSignal.Bind()),
h.Input(h.Class("input input-bordered w-full"), h.Type("text"), h.Placeholder("https://example.com"), urlSignal.Bind()),
h.Div(h.Class("card-actions justify-end"),
h.If(isEditing,
h.Button(h.Class("btn btn-ghost"), h.Text("Cancel"), cancelEdit.OnClick()),
),
h.Button(h.Class("btn btn-primary"), h.Text(saveLabel), save.OnClick()),
),
),
),
),
// Table card
h.Div(h.Class("card bg-base-100 shadow"),
h.Div(h.Class("card-body"),
h.H2(h.Class("card-title"), h.Text("All Bookmarks")),
h.If(len(rows) == 0,
h.P(h.Class("text-base-content/60"), h.Text("No bookmarks yet. Add one above!")),
),
h.If(len(rows) > 0,
h.Div(h.Class("overflow-x-auto"),
h.Table(h.Class("table"),
h.THead(h.Tr(
h.Th(h.Text("Title")),
h.Th(h.Text("URL")),
h.Th(h.Text("Actions")),
)),
h.TBody(rows...),
),
),
),
),
),
),
// Toast container — ignored by morph so Sync() doesn't wipe active toasts
h.Div(h.ID("toast-container"), h.Class("toast toast-end toast-top"), h.DataIgnoreMorph()),
)
})
})
log.Println("Starting pubsub-crud example on :7331")
v.Start()
}

View File

@@ -29,7 +29,17 @@ func main() {
SessionManager: sm,
})
// Login page
// Auth middleware — redirects unauthenticated users to /login
authRequired := func(c *via.Context, next func()) {
if c.Session().GetString("username") == "" {
c.Session().Set("flash", "Please log in first")
c.RedirectView("/login")
return
}
next()
}
// Login page (public)
v.Page("/login", func(c *via.Context) {
flash := c.Session().PopString("flash")
usernameInput := c.Signal("")
@@ -64,8 +74,10 @@ func main() {
})
})
// Dashboard page (protected)
v.Page("/dashboard", func(c *via.Context) {
// Protected pages
protected := v.Group("", authRequired)
protected.Page("/dashboard", func(c *via.Context) {
logout := c.Action(func() {
c.Session().Set("flash", "Goodbye!")
c.Session().Delete("username")
@@ -74,14 +86,6 @@ func main() {
c.View(func() h.H {
username := c.Session().GetString("username")
// Not logged in? Redirect to login
if username == "" {
c.Session().Set("flash", "Please log in first")
c.Redirect("/login")
return h.Div()
}
flash := c.Session().PopString("flash")
var flashMsg h.H
if flash != "" {

View File

@@ -0,0 +1,87 @@
package main
import (
"github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h"
)
func main() {
v := via.New()
v.Config(via.Options{
DocumentTitle: "Signup",
ServerAddress: ":8080",
})
v.AppendToHead(h.StyleEl(h.Raw(`
body { font-family: system-ui, sans-serif; max-width: 420px; margin: 2rem auto; padding: 0 1rem; }
label { display: block; font-weight: 600; margin-top: 1rem; }
input { display: block; width: 100%; padding: 0.4rem; margin-top: 0.25rem; box-sizing: border-box; }
.error { color: #c00; font-size: 0.85rem; margin-top: 0.2rem; }
.success { color: #080; margin-top: 1rem; }
.actions { margin-top: 1.5rem; display: flex; gap: 0.5rem; }
`)))
v.Page("/", func(c *via.Context) {
username := c.Field("", via.Required(), via.MinLen(3), via.MaxLen(20))
email := c.Field("", via.Required(), via.Email())
age := c.Field("", via.Required(), via.Min(13), via.Max(120))
// Optional field — only validated when non-empty
website := c.Field("", via.Pattern(`^$|^https?://\S+$`, "Must be a valid URL"))
var success string
signup := c.Action(func() {
success = ""
if !c.ValidateAll() {
c.Sync()
return
}
// Server-side check
if username.String() == "admin" {
username.AddError("Username is already taken")
c.Sync()
return
}
success = "Account created for " + username.String() + "!"
c.ResetFields()
c.Sync()
})
reset := c.Action(func() {
success = ""
c.ResetFields()
c.Sync()
})
c.View(func() h.H {
return h.Div(
h.H1(h.Text("Sign Up")),
h.Label(h.Text("Username")),
h.Input(h.Type("text"), h.Placeholder("pick a username"), username.Bind()),
h.If(username.HasError(), h.Div(h.Class("error"), h.Text(username.FirstError()))),
h.Label(h.Text("Email")),
h.Input(h.Type("email"), h.Placeholder("you@example.com"), email.Bind()),
h.If(email.HasError(), h.Div(h.Class("error"), h.Text(email.FirstError()))),
h.Label(h.Text("Age")),
h.Input(h.Type("number"), h.Placeholder("your age"), age.Bind()),
h.If(age.HasError(), h.Div(h.Class("error"), h.Text(age.FirstError()))),
h.Label(h.Text("Website (optional)")),
h.Input(h.Type("url"), h.Placeholder("https://example.com"), website.Bind()),
h.If(website.HasError(), h.Div(h.Class("error"), h.Text(website.FirstError()))),
h.Div(h.Class("actions"),
h.Button(h.Text("Sign Up"), signup.OnClick()),
h.Button(h.Text("Reset"), reset.OnClick()),
),
h.If(success != "", h.P(h.Class("success"), h.Text(success))),
)
})
})
v.Start()
}

82
middleware.go Normal file
View File

@@ -0,0 +1,82 @@
package via
// Middleware wraps a page init function. Call next to continue the chain;
// return without calling next to abort (set a view first, e.g. RedirectView).
type Middleware func(c *Context, next func())
// Group is a route group with a shared prefix and middleware stack.
type Group struct {
v *V
prefix string
middleware []Middleware
}
// Use appends middleware to the global stack.
// Global middleware runs before every page handler.
func (v *V) Use(mw ...Middleware) {
v.middleware = append(v.middleware, mw...)
}
// Group creates a route group with the given path prefix and middleware.
// Routes registered on the group are prefixed and run the group's middleware
// after any global middleware.
func (v *V) Group(prefix string, mw ...Middleware) *Group {
return &Group{
v: v,
prefix: prefix,
middleware: mw,
}
}
// Page registers a route on this group. The full route is the group prefix
// concatenated with route.
func (g *Group) Page(route string, initContextFn func(c *Context)) {
fullRoute := g.prefix + route
allMw := make([]Middleware, 0, len(g.v.middleware)+len(g.middleware))
allMw = append(allMw, g.v.middleware...)
allMw = append(allMw, g.middleware...)
wrapped := chainMiddleware(allMw, initContextFn)
g.v.page(fullRoute, initContextFn, wrapped)
}
// Group creates a nested sub-group that inherits this group's prefix and
// middleware, then adds its own.
func (g *Group) Group(prefix string, mw ...Middleware) *Group {
combined := make([]Middleware, len(g.middleware), len(g.middleware)+len(mw))
copy(combined, g.middleware)
combined = append(combined, mw...)
return &Group{
v: g.v,
prefix: g.prefix + prefix,
middleware: combined,
}
}
// Use appends middleware to this group's stack.
func (g *Group) Use(mw ...Middleware) {
g.middleware = append(g.middleware, mw...)
}
// WithMiddleware returns an ActionOption that attaches middleware to an action.
// Action middleware runs after CSRF/rate-limit checks and signal injection.
func WithMiddleware(mw ...Middleware) ActionOption {
return func(e *actionEntry) {
e.middleware = append(e.middleware, mw...)
}
}
// chainMiddleware wraps handler with the given middleware, outer-first.
func chainMiddleware(mws []Middleware, handler func(*Context)) func(*Context) {
if len(mws) == 0 {
return handler
}
chained := handler
for i := len(mws) - 1; i >= 0; i-- {
mw := mws[i]
next := chained
chained = func(c *Context) {
mw(c, func() { next(c) })
}
}
return chained
}

340
middleware_test.go Normal file
View File

@@ -0,0 +1,340 @@
package via
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/ryanhamamura/via/h"
"github.com/stretchr/testify/assert"
)
func TestMiddlewareRunsBeforeHandler(t *testing.T) {
var order []string
v := New()
v.Use(func(c *Context, next func()) {
order = append(order, "mw")
next()
})
v.Page("/", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div() })
})
// Reset after registration (panic-check runs the raw handler)
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, []string{"mw", "handler"}, order)
}
func TestMiddlewareAbortSkipsHandler(t *testing.T) {
handlerCalled := false
v := New()
v.Use(func(c *Context, next func()) {
c.RedirectView("/other")
})
v.Page("/", func(c *Context) {
handlerCalled = true
c.View(func() h.H { return h.Div() })
})
handlerCalled = false
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.False(t, handlerCalled)
}
func TestMiddlewareChainOrder(t *testing.T) {
var order []string
v := New()
for _, label := range []string{"A", "B", "C"} {
l := label
v.Use(func(c *Context, next func()) {
order = append(order, l)
next()
})
}
v.Page("/", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div() })
})
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
assert.Equal(t, []string{"A", "B", "C", "handler"}, order)
}
func TestGroupPrefixRouting(t *testing.T) {
v := New()
g := v.Group("/admin")
g.Page("/dashboard", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("admin dashboard")) })
})
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dashboard", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "admin dashboard")
}
func TestGroupMiddlewareAppliesToGroupOnly(t *testing.T) {
var groupMwCalled bool
v := New()
g := v.Group("/admin", func(c *Context, next func()) {
groupMwCalled = true
next()
})
g.Page("/panel", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("panel")) })
})
v.Page("/public", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("public")) })
})
// Hit public page — group middleware should NOT run
groupMwCalled = false
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/public", nil))
assert.False(t, groupMwCalled)
assert.Contains(t, w.Body.String(), "public")
// Hit group page — group middleware should run
groupMwCalled = false
w = httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/panel", nil))
assert.True(t, groupMwCalled)
assert.Contains(t, w.Body.String(), "panel")
}
func TestGlobalMiddlewareAppliesToGroupPages(t *testing.T) {
var globalCalled bool
v := New()
v.Use(func(c *Context, next func()) {
globalCalled = true
next()
})
g := v.Group("/admin")
g.Page("/dash", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("dash")) })
})
globalCalled = false
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dash", nil))
assert.True(t, globalCalled)
assert.Contains(t, w.Body.String(), "dash")
}
func TestNestedGroupInheritsPrefixAndMiddleware(t *testing.T) {
var order []string
v := New()
admin := v.Group("/admin", func(c *Context, next func()) {
order = append(order, "admin")
next()
})
superAdmin := admin.Group("/super", func(c *Context, next func()) {
order = append(order, "super")
next()
})
superAdmin.Page("/secret", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div(h.Text("secret")) })
})
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/super/secret", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, []string{"admin", "super", "handler"}, order)
assert.Contains(t, w.Body.String(), "secret")
}
func TestGroupUse(t *testing.T) {
var order []string
v := New()
g := v.Group("/api")
g.Use(func(c *Context, next func()) {
order = append(order, "added-later")
next()
})
g.Page("/items", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div() })
})
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/api/items", nil))
assert.Equal(t, []string{"added-later", "handler"}, order)
}
func TestRedirectViewSetsValidView(t *testing.T) {
v := New()
v.Page("/test", func(c *Context) {
c.RedirectView("/somewhere")
})
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/test", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "<!doctype html>")
}
func TestGlobalAndGroupMiddlewareOrder(t *testing.T) {
var order []string
v := New()
v.Use(func(c *Context, next func()) {
order = append(order, "global")
next()
})
g := v.Group("/g", func(c *Context, next func()) {
order = append(order, "group")
next()
})
g.Page("/page", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div() })
})
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/g/page", nil))
assert.Equal(t, []string{"global", "group", "handler"}, order)
}
// --- Action middleware tests ---
func TestActionMiddlewareRunsBeforeAction(t *testing.T) {
var order []string
v := New()
c := newContext("test", "/", v)
mw := func(_ *Context, next func()) {
order = append(order, "mw")
next()
}
trigger := c.Action(func() {
order = append(order, "action")
}, WithMiddleware(mw))
entry, err := c.getAction(trigger.id)
assert.NoError(t, err)
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
assert.Equal(t, []string{"mw", "action"}, order)
}
func TestActionMiddlewareAbortSkipsAction(t *testing.T) {
actionCalled := false
v := New()
c := newContext("test", "/", v)
mw := func(_ *Context, next func()) {
// don't call next — action should not run
}
trigger := c.Action(func() {
actionCalled = true
}, WithMiddleware(mw))
entry, err := c.getAction(trigger.id)
assert.NoError(t, err)
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
assert.False(t, actionCalled)
}
func TestActionMiddlewareChainOrder(t *testing.T) {
var order []string
v := New()
c := newContext("test", "/", v)
var mws []Middleware
for _, label := range []string{"A", "B", "C"} {
l := label
mws = append(mws, func(_ *Context, next func()) {
order = append(order, l)
next()
})
}
trigger := c.Action(func() {
order = append(order, "action")
}, WithMiddleware(mws...))
entry, err := c.getAction(trigger.id)
assert.NoError(t, err)
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
assert.Equal(t, []string{"A", "B", "C", "action"}, order)
}
func TestActionMiddlewareCombinedWithRateLimit(t *testing.T) {
v := New()
c := newContext("test", "/", v)
mw := func(_ *Context, next func()) { next() }
trigger := c.Action(func() {}, WithRateLimit(5, 10), WithMiddleware(mw))
entry, err := c.getAction(trigger.id)
assert.NoError(t, err)
assert.NotNil(t, entry.limiter)
assert.Len(t, entry.middleware, 1)
}
func TestGroupWithEmptyPrefix(t *testing.T) {
var mwCalled bool
v := New()
g := v.Group("", func(c *Context, next func()) {
mwCalled = true
next()
})
g.Page("/dashboard", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("dash")) })
})
mwCalled = false
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/dashboard", nil))
assert.True(t, mwCalled)
assert.Contains(t, w.Body.String(), "dash")
}

23
pubsub_helpers.go Normal file
View File

@@ -0,0 +1,23 @@
package via
import "encoding/json"
// Publish JSON-marshals msg and publishes to subject.
func Publish[T any](c *Context, subject string, msg T) error {
data, err := json.Marshal(msg)
if err != nil {
return err
}
return c.Publish(subject, data)
}
// Subscribe JSON-unmarshals each message as T and calls handler.
func Subscribe[T any](c *Context, subject string, handler func(T)) (Subscription, error) {
return c.Subscribe(subject, func(data []byte) {
var msg T
if err := json.Unmarshal(data, &msg); err != nil {
return
}
handler(msg)
})
}

66
pubsub_helpers_test.go Normal file
View File

@@ -0,0 +1,66 @@
package via
import (
"sync"
"testing"
"github.com/ryanhamamura/via/h"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPublishSubscribe_RoundTrip(t *testing.T) {
ps := newMockPubSub()
v := New()
v.Config(Options{PubSub: ps})
type event struct {
Name string `json:"name"`
Count int `json:"count"`
}
var got event
var wg sync.WaitGroup
wg.Add(1)
c := newContext("typed-ctx", "/", v)
c.View(func() h.H { return h.Div() })
_, err := Subscribe(c, "events", func(e event) {
got = e
wg.Done()
})
require.NoError(t, err)
err = Publish(c, "events", event{Name: "click", Count: 42})
require.NoError(t, err)
wg.Wait()
assert.Equal(t, "click", got.Name)
assert.Equal(t, 42, got.Count)
}
func TestSubscribe_SkipsBadJSON(t *testing.T) {
ps := newMockPubSub()
v := New()
v.Config(Options{PubSub: ps})
type msg struct {
Text string `json:"text"`
}
called := false
c := newContext("bad-json-ctx", "/", v)
c.View(func() h.H { return h.Div() })
_, err := Subscribe(c, "topic", func(m msg) {
called = true
})
require.NoError(t, err)
// Publish raw invalid JSON — handler should silently skip
err = c.Publish("topic", []byte("not json"))
require.NoError(t, err)
assert.False(t, called)
}

49
ratelimit.go Normal file
View File

@@ -0,0 +1,49 @@
package via
import "golang.org/x/time/rate"
const (
defaultActionRate float64 = 10.0
defaultActionBurst int = 20
)
// RateLimitConfig configures token-bucket rate limiting for actions.
// Zero values fall back to defaults. Rate of -1 disables limiting entirely.
type RateLimitConfig struct {
Rate float64
Burst int
}
// ActionOption configures per-action behaviour when passed to Context.Action.
type ActionOption func(*actionEntry)
type actionEntry struct {
fn func()
limiter *rate.Limiter // nil = use context default
middleware []Middleware
}
// WithRateLimit returns an ActionOption that gives this action its own
// token-bucket limiter, overriding the context-level default.
func WithRateLimit(r float64, burst int) ActionOption {
return func(e *actionEntry) {
e.limiter = newLimiter(RateLimitConfig{Rate: r, Burst: burst}, defaultActionRate, defaultActionBurst)
}
}
// newLimiter creates a *rate.Limiter from cfg, substituting defaults for zero
// values. A Rate of -1 disables limiting (returns nil).
func newLimiter(cfg RateLimitConfig, defaultRate float64, defaultBurst int) *rate.Limiter {
r := cfg.Rate
b := cfg.Burst
if r == -1 {
return nil
}
if r == 0 {
r = defaultRate
}
if b == 0 {
b = defaultBurst
}
return rate.NewLimiter(rate.Limit(r), b)
}

101
ratelimit_test.go Normal file
View File

@@ -0,0 +1,101 @@
package via
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewLimiter_Defaults(t *testing.T) {
l := newLimiter(RateLimitConfig{}, defaultActionRate, defaultActionBurst)
require.NotNil(t, l)
assert.InDelta(t, defaultActionRate, float64(l.Limit()), 0.001)
assert.Equal(t, defaultActionBurst, l.Burst())
}
func TestNewLimiter_CustomValues(t *testing.T) {
l := newLimiter(RateLimitConfig{Rate: 5, Burst: 10}, defaultActionRate, defaultActionBurst)
require.NotNil(t, l)
assert.InDelta(t, 5.0, float64(l.Limit()), 0.001)
assert.Equal(t, 10, l.Burst())
}
func TestNewLimiter_DisabledWithNegativeRate(t *testing.T) {
l := newLimiter(RateLimitConfig{Rate: -1}, defaultActionRate, defaultActionBurst)
assert.Nil(t, l)
}
func TestTokenBucket_AllowsBurstThenRejects(t *testing.T) {
l := newLimiter(RateLimitConfig{Rate: 1, Burst: 3}, 1, 3)
require.NotNil(t, l)
for i := 0; i < 3; i++ {
assert.True(t, l.Allow(), "request %d should be allowed within burst", i)
}
assert.False(t, l.Allow(), "request beyond burst should be rejected")
}
func TestWithRateLimit_CreatesLimiter(t *testing.T) {
entry := actionEntry{fn: func() {}}
opt := WithRateLimit(2, 4)
opt(&entry)
require.NotNil(t, entry.limiter)
assert.InDelta(t, 2.0, float64(entry.limiter.Limit()), 0.001)
assert.Equal(t, 4, entry.limiter.Burst())
}
func TestContextAction_WithRateLimit(t *testing.T) {
v := New()
c := newContext("test-rl", "/", v)
called := false
c.Action(func() { called = true }, WithRateLimit(1, 2))
// Verify the entry has its own limiter
for _, entry := range c.actionRegistry {
require.NotNil(t, entry.limiter)
assert.InDelta(t, 1.0, float64(entry.limiter.Limit()), 0.001)
assert.Equal(t, 2, entry.limiter.Burst())
}
assert.False(t, called)
}
func TestContextAction_DefaultNoPerActionLimiter(t *testing.T) {
v := New()
c := newContext("test-no-rl", "/", v)
c.Action(func() {})
for _, entry := range c.actionRegistry {
assert.Nil(t, entry.limiter, "entry without WithRateLimit should have nil limiter")
}
}
func TestContextLimiter_DefaultsApplied(t *testing.T) {
v := New()
c := newContext("test-ctx-limiter", "/", v)
require.NotNil(t, c.actionLimiter)
assert.InDelta(t, defaultActionRate, float64(c.actionLimiter.Limit()), 0.001)
assert.Equal(t, defaultActionBurst, c.actionLimiter.Burst())
}
func TestContextLimiter_DisabledViaConfig(t *testing.T) {
v := New()
v.actionRateLimit = RateLimitConfig{Rate: -1}
c := newContext("test-disabled", "/", v)
assert.Nil(t, c.actionLimiter)
}
func TestContextLimiter_CustomConfig(t *testing.T) {
v := New()
v.Config(Options{ActionRateLimit: RateLimitConfig{Rate: 50, Burst: 100}})
c := newContext("test-custom", "/", v)
require.NotNil(t, c.actionLimiter)
assert.InDelta(t, 50.0, float64(c.actionLimiter.Limit()), 0.001)
assert.Equal(t, 100, c.actionLimiter.Burst())
}

130
rule.go Normal file
View File

@@ -0,0 +1,130 @@
package via
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"unicode/utf8"
)
// Rule defines a single validation check for a Field.
type Rule struct {
validate func(val string) error
}
// Required rejects empty or whitespace-only values.
func Required(msg ...string) Rule {
m := "This field is required"
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
if strings.TrimSpace(val) == "" {
return errors.New(m)
}
return nil
}}
}
// MinLen rejects values shorter than n characters.
func MinLen(n int, msg ...string) Rule {
m := fmt.Sprintf("Must be at least %d characters", n)
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
if utf8.RuneCountInString(val) < n {
return errors.New(m)
}
return nil
}}
}
// MaxLen rejects values longer than n characters.
func MaxLen(n int, msg ...string) Rule {
m := fmt.Sprintf("Must be at most %d characters", n)
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
if utf8.RuneCountInString(val) > n {
return errors.New(m)
}
return nil
}}
}
// Min parses the value as an integer and rejects values less than n.
func Min(n int, msg ...string) Rule {
m := fmt.Sprintf("Must be at least %d", n)
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
v, err := strconv.Atoi(val)
if err != nil {
return errors.New("Must be a valid number")
}
if v < n {
return errors.New(m)
}
return nil
}}
}
// Max parses the value as an integer and rejects values greater than n.
func Max(n int, msg ...string) Rule {
m := fmt.Sprintf("Must be at most %d", n)
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
v, err := strconv.Atoi(val)
if err != nil {
return errors.New("Must be a valid number")
}
if v > n {
return errors.New(m)
}
return nil
}}
}
// Pattern rejects values that don't match the regular expression re.
func Pattern(re string, msg ...string) Rule {
m := "Invalid format"
if len(msg) > 0 {
m = msg[0]
}
compiled := regexp.MustCompile(re)
return Rule{func(val string) error {
if !compiled.MatchString(val) {
return errors.New(m)
}
return nil
}}
}
var emailRegexp = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
// Email rejects values that don't look like an email address.
func Email(msg ...string) Rule {
m := "Invalid email address"
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
if !emailRegexp.MatchString(val) {
return errors.New(m)
}
return nil
}}
}
// Custom creates a rule from a user-provided validation function.
// The function should return nil for valid input and an error for invalid input.
func Custom(fn func(string) error) Rule {
return Rule{validate: fn}
}

116
rule_test.go Normal file
View File

@@ -0,0 +1,116 @@
package via
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRequired(t *testing.T) {
r := Required()
assert.NoError(t, r.validate("hello"))
assert.Error(t, r.validate(""))
assert.Error(t, r.validate(" "))
}
func TestRequiredCustomMessage(t *testing.T) {
r := Required("name needed")
err := r.validate("")
assert.EqualError(t, err, "name needed")
}
func TestMinLen(t *testing.T) {
r := MinLen(3)
assert.NoError(t, r.validate("abc"))
assert.NoError(t, r.validate("abcd"))
assert.Error(t, r.validate("ab"))
assert.Error(t, r.validate(""))
}
func TestMinLenCustomMessage(t *testing.T) {
r := MinLen(5, "too short")
err := r.validate("ab")
assert.EqualError(t, err, "too short")
}
func TestMaxLen(t *testing.T) {
r := MaxLen(5)
assert.NoError(t, r.validate("abc"))
assert.NoError(t, r.validate("abcde"))
assert.Error(t, r.validate("abcdef"))
}
func TestMaxLenCustomMessage(t *testing.T) {
r := MaxLen(2, "too long")
err := r.validate("abc")
assert.EqualError(t, err, "too long")
}
func TestMin(t *testing.T) {
r := Min(5)
assert.NoError(t, r.validate("5"))
assert.NoError(t, r.validate("10"))
assert.Error(t, r.validate("4"))
assert.Error(t, r.validate("abc"))
}
func TestMinCustomMessage(t *testing.T) {
r := Min(10, "need 10+")
err := r.validate("3")
assert.EqualError(t, err, "need 10+")
}
func TestMax(t *testing.T) {
r := Max(10)
assert.NoError(t, r.validate("10"))
assert.NoError(t, r.validate("5"))
assert.Error(t, r.validate("11"))
assert.Error(t, r.validate("abc"))
}
func TestMaxCustomMessage(t *testing.T) {
r := Max(5, "too big")
err := r.validate("6")
assert.EqualError(t, err, "too big")
}
func TestPattern(t *testing.T) {
r := Pattern(`^\d{3}$`)
assert.NoError(t, r.validate("123"))
assert.Error(t, r.validate("12"))
assert.Error(t, r.validate("abcd"))
}
func TestPatternCustomMessage(t *testing.T) {
r := Pattern(`^\d+$`, "digits only")
err := r.validate("abc")
assert.EqualError(t, err, "digits only")
}
func TestEmail(t *testing.T) {
r := Email()
assert.NoError(t, r.validate("user@example.com"))
assert.NoError(t, r.validate("a.b+c@foo.co"))
assert.Error(t, r.validate("notanemail"))
assert.Error(t, r.validate("@example.com"))
assert.Error(t, r.validate("user@"))
assert.Error(t, r.validate(""))
}
func TestEmailCustomMessage(t *testing.T) {
r := Email("bad email")
err := r.validate("nope")
assert.EqualError(t, err, "bad email")
}
func TestCustom(t *testing.T) {
r := Custom(func(val string) error {
if val != "magic" {
return fmt.Errorf("must be magic")
}
return nil
})
assert.NoError(t, r.validate("magic"))
assert.EqualError(t, r.validate("other"), "must be magic")
}

View File

@@ -81,26 +81,3 @@ func (s *signal) Int() int {
return 0
}
// Int64 tries to read the signal value as an int64.
// Returns the value or 0 on failure.
func (s *signal) Int64() int64 {
if n, err := strconv.ParseInt(s.String(), 10, 64); err == nil {
return n
}
return 0
}
// Float64 tries to read the signal value as a float64.
// Returns the value or 0.0 on failure.
func (s *signal) Float() float64 {
if n, err := strconv.ParseFloat(s.String(), 64); err == nil {
return n
}
return 0.0
}
// Bytes tries to read the signal value as a []byte
// Returns the value or an empty []byte on failure.
func (s *signal) Bytes() []byte {
return []byte(s.String())
}

143
static_test.go Normal file
View File

@@ -0,0 +1,143 @@
package via
import (
"io/fs"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"testing/fstest"
"github.com/stretchr/testify/assert"
)
func TestStatic(t *testing.T) {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "sub"), 0755)
os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello world"), 0644)
os.WriteFile(filepath.Join(dir, "sub", "nested.txt"), []byte("nested"), 0644)
v := New()
v.Static("/assets/", dir)
t.Run("serves file", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/hello.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "hello world", w.Body.String())
})
t.Run("serves nested file", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/sub/nested.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "nested", w.Body.String())
})
t.Run("directory listing returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("subdirectory listing returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/sub/", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("missing file returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/nope.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
}
func TestStaticAutoSlash(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "ok.txt"), []byte("ok"), 0644)
v := New()
v.Static("/files", dir) // no trailing slash
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/files/ok.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "ok", w.Body.String())
}
func TestStaticFS(t *testing.T) {
fsys := fstest.MapFS{
"style.css": {Data: []byte("body{}")},
"js/app.js": {Data: []byte("console.log('hi')")},
}
v := New()
v.StaticFS("/static/", fsys)
t.Run("serves file", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/static/style.css", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "body{}", w.Body.String())
})
t.Run("serves nested file", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/static/js/app.js", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "console.log('hi')", w.Body.String())
})
t.Run("directory listing returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/static/", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("missing file returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/static/nope.css", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
}
func TestStaticFSAutoSlash(t *testing.T) {
fsys := fstest.MapFS{
"ok.txt": {Data: []byte("ok")},
}
v := New()
v.StaticFS("/embed", fsys) // no trailing slash
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/embed/ok.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "ok", w.Body.String())
}
// Verify StaticFS accepts the fs.FS interface (compile-time check).
var _ fs.FS = fstest.MapFS{}

191
via.go
View File

@@ -9,11 +9,13 @@ package via
import (
"context"
"crypto/rand"
"crypto/subtle"
_ "embed"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"io/fs"
"net/http"
"net/url"
"os"
@@ -47,9 +49,12 @@ type V struct {
devModePageInitFnMap map[string]func(*Context)
sessionManager *scs.SessionManager
pubsub PubSub
actionRateLimit RateLimitConfig
datastarPath string
datastarContent []byte
datastarOnce sync.Once
reaperStop chan struct{}
middleware []Middleware
}
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
@@ -127,6 +132,12 @@ func (v *V) Config(cfg Options) {
if cfg.PubSub != nil {
v.pubsub = cfg.PubSub
}
if cfg.ContextTTL != 0 {
v.cfg.ContextTTL = cfg.ContextTTL
}
if cfg.ActionRateLimit.Rate != 0 || cfg.ActionRateLimit.Burst != 0 {
v.actionRateLimit = cfg.ActionRateLimit
}
}
// AppendToHead appends the given h.H nodes to the head of the base HTML document.
@@ -160,8 +171,16 @@ func (v *V) AppendToFoot(elements ...h.H) {
// })
// })
func (v *V) Page(route string, initContextFn func(c *Context)) {
wrapped := chainMiddleware(v.middleware, initContextFn)
v.page(route, initContextFn, wrapped)
}
// page registers a route with separate raw and wrapped init functions.
// raw is used for the panic-check at registration time; wrapped includes
// any middleware and is used as the live handler.
func (v *V) page(route string, raw, wrapped func(*Context)) {
v.ensureDatastarHandler()
// check for panics
// check for panics using the raw handler (no middleware)
func() {
defer func() {
if err := recover(); err != nil {
@@ -170,14 +189,13 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
}
}()
c := newContext("", "", v)
initContextFn(c)
raw(c)
c.view()
c.stopAllRoutines()
}()
// save page init function allows devmode to restore persisted ctx later
if v.cfg.DevMode {
v.devModePageInitFnMap[route] = initContextFn
v.devModePageInitFnMap[route] = wrapped
}
v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v.logDebug(nil, "GET %s", r.URL.String())
@@ -191,7 +209,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
c.reqCtx = r.Context()
routeParams := extractParams(route, r.URL.Path)
c.injectRouteParams(routeParams)
initContextFn(c)
wrapped(c)
v.registerCtx(c)
if v.cfg.DevMode {
v.devModePersist(c)
@@ -199,7 +217,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
headElements := []h.H{h.Script(h.Type("module"), h.Src(v.datastarPath))}
headElements = append(headElements, v.documentHeadIncludes...)
headElements = append(headElements,
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s'}", id))),
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s','via-csrf':'%s'}", id, c.csrfToken))),
h.Meta(h.Data("init", "@get('/_sse')")),
h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => {
navigator.sendBeacon('/_session/close', '%s');});`, c.id))),
@@ -216,7 +234,6 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
Title: v.cfg.DocumentTitle,
Head: headElements,
Body: bodyElements,
HTMLAttrs: []h.H{},
})
_ = view.Render(w)
}))
@@ -225,17 +242,17 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
func (v *V) registerCtx(c *Context) {
v.contextRegistryMutex.Lock()
defer v.contextRegistryMutex.Unlock()
if c == nil {
v.logErr(c, "failed to add nil context to registry")
return
}
v.contextRegistry[c.id] = c
v.logDebug(c, "new context added to registry")
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
}
func (v *V) currSessionNum() int {
return len(v.contextRegistry)
func (v *V) cleanupCtx(c *Context) {
c.dispose()
if v.cfg.DevMode {
v.devModeRemovePersisted(c)
}
v.unregisterCtx(c)
}
func (v *V) unregisterCtx(c *Context) {
@@ -247,7 +264,7 @@ func (v *V) unregisterCtx(c *Context) {
defer v.contextRegistryMutex.Unlock()
v.logDebug(c, "ctx removed from registry")
delete(v.contextRegistry, c.id)
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
}
func (v *V) getCtx(id string) (*Context, error) {
@@ -259,6 +276,50 @@ func (v *V) getCtx(id string) (*Context, error) {
return nil, fmt.Errorf("ctx '%s' not found", id)
}
func (v *V) startReaper() {
ttl := v.cfg.ContextTTL
if ttl < 0 {
return
}
if ttl == 0 {
ttl = 30 * time.Second
}
interval := ttl / 3
if interval < 5*time.Second {
interval = 5 * time.Second
}
v.reaperStop = make(chan struct{})
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-v.reaperStop:
return
case <-ticker.C:
v.reapOrphanedContexts(ttl)
}
}
}()
}
func (v *V) reapOrphanedContexts(ttl time.Duration) {
now := time.Now()
v.contextRegistryMutex.RLock()
var orphans []*Context
for _, c := range v.contextRegistry {
if !c.sseConnected.Load() && now.Sub(c.createdAt) > ttl {
orphans = append(orphans, c)
}
}
v.contextRegistryMutex.RUnlock()
for _, c := range orphans {
v.logInfo(c, "reaping orphaned context (no SSE connection after %s)", ttl)
v.cleanupCtx(c)
}
}
// Start starts the Via HTTP server and blocks until a SIGINT or SIGTERM
// signal is received, then performs a graceful shutdown.
func (v *V) Start() {
@@ -271,6 +332,8 @@ func (v *V) Start() {
Handler: handler,
}
v.startReaper()
errCh := make(chan error, 1)
go func() {
errCh <- v.server.ListenAndServe()
@@ -291,16 +354,15 @@ func (v *V) Start() {
return
}
v.shutdown()
v.Shutdown()
}
// Shutdown gracefully shuts down the server and all contexts.
// Safe for programmatic or test use.
func (v *V) Shutdown() {
v.shutdown()
if v.reaperStop != nil {
close(v.reaperStop)
}
func (v *V) shutdown() {
v.logInfo(nil, "draining all contexts")
v.drainAllContexts()
@@ -346,6 +408,46 @@ func (v *V) HTTPServeMux() *http.ServeMux {
return v.mux
}
// Static serves files from a filesystem directory at the given URL prefix.
//
// Example:
//
// v.Static("/assets/", "./public")
func (v *V) Static(urlPrefix, dir string) {
if !strings.HasSuffix(urlPrefix, "/") {
urlPrefix += "/"
}
fileServer := http.StripPrefix(urlPrefix, http.FileServer(http.Dir(dir)))
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
}
// StaticFS serves files from an [fs.FS] at the given URL prefix.
// This is useful with //go:embed filesystems.
//
// Example:
//
// //go:embed static
// var staticFiles embed.FS
// v.StaticFS("/assets/", staticFiles)
func (v *V) StaticFS(urlPrefix string, fsys fs.FS) {
if !strings.HasSuffix(urlPrefix, "/") {
urlPrefix += "/"
}
fileServer := http.StripPrefix(urlPrefix, http.FileServerFS(fsys))
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
}
// noDirListing wraps a file server handler to return 404 for directory requests.
func noDirListing(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/") {
http.NotFound(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (v *V) ensureDatastarHandler() {
v.datastarOnce.Do(func() {
v.mux.HandleFunc("GET "+v.datastarPath, func(w http.ResponseWriter, r *http.Request) {
@@ -400,10 +502,7 @@ func (v *V) devModeRemovePersisted(c *Context) {
}
file.Close()
// remove ctx to persisted list
if _, ok := ctxRegMap[c.id]; !ok {
delete(ctxRegMap, c.id)
}
// write persisted list to file
file, err = os.Create(p)
@@ -507,16 +606,16 @@ func New() *V {
// use last-event-id to tell if request is a sse reconnect
sse.Send(datastar.EventTypePatchElements, []string{}, datastar.WithSSEEventId("via"))
c.sseConnected.Store(true)
v.logDebug(c, "SSE connection established")
go func() {
c.Sync()
}()
go c.Sync()
for {
select {
case <-sse.Context().Done():
v.logDebug(c, "SSE connection ended")
v.cleanupCtx(c)
return
case <-c.ctxDisposedChan:
v.logDebug(c, "context disposed, closing SSE")
@@ -572,13 +671,29 @@ func New() *V {
v.logErr(nil, "action '%s' failed: %v", actionID, err)
return
}
csrfToken, _ := sigs["via-csrf"].(string)
if subtle.ConstantTimeCompare([]byte(csrfToken), []byte(c.csrfToken)) != 1 {
v.logWarn(c, "action '%s' rejected: invalid CSRF token", actionID)
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
if c.actionLimiter != nil && !c.actionLimiter.Allow() {
v.logWarn(c, "action '%s' rate limited", actionID)
http.Error(w, "rate limited", http.StatusTooManyRequests)
return
}
c.reqCtx = r.Context()
actionFn, err := c.getActionFn(actionID)
entry, err := c.getAction(actionID)
if err != nil {
v.logDebug(c, "action '%s' failed: %v", actionID, err)
return
}
// log err if actionFn panics
if entry.limiter != nil && !entry.limiter.Allow() {
v.logWarn(c, "action '%s' rate limited (per-action)", actionID)
http.Error(w, "rate limited", http.StatusTooManyRequests)
return
}
// log err if action panics
defer func() {
if r := recover(); r != nil {
v.logErr(c, "action '%s' failed: %v", actionID, r)
@@ -586,7 +701,11 @@ func New() *V {
}()
c.injectSignals(sigs)
actionFn()
if len(entry.middleware) > 0 {
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
} else {
entry.fn()
}
})
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
@@ -603,20 +722,22 @@ func New() *V {
v.logErr(c, "failed to handle session close: %v", err)
return
}
c.dispose()
v.logDebug(c, "session close event triggered")
if v.cfg.DevMode {
v.devModeRemovePersisted(c)
}
v.unregisterCtx(c)
v.cleanupCtx(c)
})
return v
}
func genRandID() string {
b := make([]byte, 4)
rand.Read(b)
return hex.EncodeToString(b)
}
func genCSRFToken() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)[:8]
return hex.EncodeToString(b)
}
func extractParams(pattern, path string) map[string]string {
@@ -631,7 +752,7 @@ func extractParams(pattern, path string) map[string]string {
key := p[i][1 : len(p[i])-1] // remove {}
params[key] = u[i]
} else if p[i] != u[i] {
continue
return nil
}
}
return params

View File

@@ -1,9 +1,13 @@
package via
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/ryanhamamura/via/h"
"github.com/stretchr/testify/assert"
@@ -128,6 +132,60 @@ func TestAction(t *testing.T) {
assert.Contains(t, body, "/_action/")
}
func TestEventTypes(t *testing.T) {
tests := []struct {
name string
attr string
buildEl func(trigger *actionTrigger) h.H
}{
{"OnSubmit", "data-on:submit", func(tr *actionTrigger) h.H { return h.Form(tr.OnSubmit()) }},
{"OnInput", "data-on:input", func(tr *actionTrigger) h.H { return h.Input(tr.OnInput()) }},
{"OnFocus", "data-on:focus", func(tr *actionTrigger) h.H { return h.Input(tr.OnFocus()) }},
{"OnBlur", "data-on:blur", func(tr *actionTrigger) h.H { return h.Input(tr.OnBlur()) }},
{"OnMouseEnter", "data-on:mouseenter", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseEnter()) }},
{"OnMouseLeave", "data-on:mouseleave", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseLeave()) }},
{"OnScroll", "data-on:scroll", func(tr *actionTrigger) h.H { return h.Div(tr.OnScroll()) }},
{"OnDblClick", "data-on:dblclick", func(tr *actionTrigger) h.H { return h.Div(tr.OnDblClick()) }},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var trigger *actionTrigger
v := New()
v.Page("/", func(c *Context) {
trigger = c.Action(func() {})
c.View(func() h.H { return tt.buildEl(trigger) })
})
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, req)
body := w.Body.String()
assert.Contains(t, body, tt.attr)
assert.Contains(t, body, "/_action/"+trigger.id)
})
}
t.Run("WithSignal", func(t *testing.T) {
var trigger *actionTrigger
var sig *signal
v := New()
v.Page("/", func(c *Context) {
trigger = c.Action(func() {})
sig = c.Signal("val")
c.View(func() h.H {
return h.Div(trigger.OnDblClick(WithSignal(sig, "x")))
})
})
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, req)
body := w.Body.String()
assert.Contains(t, body, "data-on:dblclick")
assert.Contains(t, body, "$"+sig.ID()+"=&#39;x&#39;")
})
}
func TestOnKeyDownWithWindow(t *testing.T) {
var trigger *actionTrigger
v := New()
@@ -235,3 +293,93 @@ func TestPage_PanicsOnNoView(t *testing.T) {
v.Page("/", func(c *Context) {})
})
}
func TestReaperCleansOrphanedContexts(t *testing.T) {
v := New()
c := newContext("orphan-1", "/", v)
c.createdAt = time.Now().Add(-time.Minute) // created 1 min ago
v.registerCtx(c)
_, err := v.getCtx("orphan-1")
assert.NoError(t, err)
v.reapOrphanedContexts(10 * time.Second)
_, err = v.getCtx("orphan-1")
assert.Error(t, err, "orphaned context should have been reaped")
}
func TestReaperIgnoresConnectedContexts(t *testing.T) {
v := New()
c := newContext("connected-1", "/", v)
c.createdAt = time.Now().Add(-time.Minute)
c.sseConnected.Store(true)
v.registerCtx(c)
v.reapOrphanedContexts(10 * time.Second)
_, err := v.getCtx("connected-1")
assert.NoError(t, err, "connected context should survive reaping")
}
func TestReaperDisabledWithNegativeTTL(t *testing.T) {
v := New()
v.cfg.ContextTTL = -1
v.startReaper()
assert.Nil(t, v.reaperStop, "reaper should not start with negative TTL")
}
func TestCleanupCtxIdempotent(t *testing.T) {
v := New()
c := newContext("idempotent-1", "/", v)
v.registerCtx(c)
assert.NotPanics(t, func() {
v.cleanupCtx(c)
v.cleanupCtx(c)
})
_, err := v.getCtx("idempotent-1")
assert.Error(t, err, "context should be removed after cleanup")
}
func TestDevModeRemovePersistedFix(t *testing.T) {
v := New()
v.cfg.DevMode = true
dir := filepath.Join(t.TempDir(), ".via", "devmode")
p := filepath.Join(dir, "ctx.json")
assert.NoError(t, os.MkdirAll(dir, 0755))
// Write a persisted context
ctxRegMap := map[string]string{"test-ctx-1": "/"}
f, err := os.Create(p)
assert.NoError(t, err)
assert.NoError(t, json.NewEncoder(f).Encode(ctxRegMap))
f.Close()
// Patch devModeRemovePersisted to use our temp path by calling it
// directly — we need to override the path. Instead, test via the
// actual function by temporarily changing the working dir.
origDir, _ := os.Getwd()
assert.NoError(t, os.Chdir(t.TempDir()))
defer os.Chdir(origDir)
// Re-create the structure in the temp dir
assert.NoError(t, os.MkdirAll(filepath.Join(".via", "devmode"), 0755))
p2 := filepath.Join(".via", "devmode", "ctx.json")
f2, _ := os.Create(p2)
json.NewEncoder(f2).Encode(map[string]string{"test-ctx-1": "/"})
f2.Close()
c := newContext("test-ctx-1", "/", v)
v.devModeRemovePersisted(c)
// Read back and verify
f3, err := os.Open(p2)
assert.NoError(t, err)
defer f3.Close()
var result map[string]string
assert.NoError(t, json.NewDecoder(f3).Decode(&result))
assert.Empty(t, result, "persisted context should be removed")
}

View File

@@ -4,7 +4,9 @@ package vianats
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/delaneyj/toolbelt/embeddednats"
"github.com/nats-io/nats.go"
@@ -76,3 +78,50 @@ func (n *NATS) Conn() *nats.Conn {
func (n *NATS) JetStream() nats.JetStreamContext {
return n.js
}
// StreamConfig holds the parameters for creating or updating a JetStream stream.
type StreamConfig struct {
Name string
Subjects []string
MaxMsgs int64
MaxAge time.Duration
}
// EnsureStream creates or updates a JetStream stream matching cfg.
func EnsureStream(n *NATS, cfg StreamConfig) error {
_, err := n.js.AddStream(&nats.StreamConfig{
Name: cfg.Name,
Subjects: cfg.Subjects,
Retention: nats.LimitsPolicy,
MaxMsgs: cfg.MaxMsgs,
MaxAge: cfg.MaxAge,
})
return err
}
// ReplayHistory fetches the last limit messages from subject,
// deserializing each as T. Returns an empty slice if nothing is available.
func ReplayHistory[T any](n *NATS, subject string, limit int) ([]T, error) {
sub, err := n.js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer())
if err != nil {
return nil, err
}
defer sub.Unsubscribe()
var msgs []T
for {
raw, err := sub.NextMsg(200 * time.Millisecond)
if err != nil {
break
}
var msg T
if json.Unmarshal(raw.Data, &msg) == nil {
msgs = append(msgs, msg)
}
}
if limit > 0 && len(msgs) > limit {
msgs = msgs[len(msgs)-limit:]
}
return msgs, nil
}