Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
785f11e52d | ||
|
|
2f19874c17 | ||
|
|
27b8540b71 | ||
|
|
532651552a | ||
|
|
2310e45d35 | ||
|
|
10b4838f8d | ||
|
|
5362614c3e | ||
|
|
e636970f7b | ||
|
|
f5158b866c | ||
|
|
2f6c5916ce | ||
|
|
0762ddbbc2 | ||
|
|
b7acfa6302 | ||
|
|
8aa91c577c | ||
|
|
6dcd54c88b |
75
README.md
75
README.md
@@ -1,30 +1,33 @@
|
||||
# ⚡Via
|
||||
# Via
|
||||
|
||||
Real-time engine for building reactive web applications in pure Go.
|
||||
|
||||
|
||||
## Why Via?
|
||||
Somewhere along the way, the web became tangled in layers of JavaScript, build chains, and frameworks stacked on frameworks.
|
||||
|
||||
Via takes a radical stance:
|
||||
The web became tangled in layers of JavaScript, build chains, and frameworks stacked on frameworks. Via takes a different path.
|
||||
|
||||
- No templates.
|
||||
- No JavaScript.
|
||||
- No transpilation.
|
||||
- No hydration.
|
||||
- No front-end fatigue.
|
||||
- Single SSE stream.
|
||||
- Full reactivity.
|
||||
- Built-in Brotli compression.
|
||||
- Pure Go.
|
||||
**Philosophy**
|
||||
- No templates. No JavaScript. No transpilation. No hydration.
|
||||
- Views are pure Go functions. HTML is composed with a type-safe DSL.
|
||||
- A single SSE stream carries all reactivity — no WebSocket juggling, no polling.
|
||||
|
||||
**Batteries included**
|
||||
- Automatic CSRF protection on every action call
|
||||
- Token-bucket rate limiting (global defaults + per-action overrides)
|
||||
- Cookie-based sessions backed by SQLite
|
||||
- Pub/sub messaging with an embedded NATS backend
|
||||
- Structured logging via zerolog
|
||||
- Graceful shutdown with context draining
|
||||
- Brotli compression out of the box
|
||||
|
||||
## Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/go-via/via"
|
||||
"github.com/go-via/via/h"
|
||||
"github.com/ryanhamamura/via"
|
||||
"github.com/ryanhamamura/via/h"
|
||||
)
|
||||
|
||||
type Counter struct{ Count int }
|
||||
@@ -57,25 +60,43 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
## What's built in
|
||||
|
||||
## 🚧 Experimental
|
||||
<s>Via is still a newborn.</s> Via is taking its first steps!
|
||||
- Version `0.1.0` released.
|
||||
- Expect a little less chaos.
|
||||
- **Reactive views + signals** — bind state to the DOM; changes push over SSE automatically
|
||||
- **Components** — self-contained subcontexts with their own data, actions, and signals
|
||||
- **Sessions** — cookie-based, backed by SQLite via `scs`
|
||||
- **Pub/sub** — embedded NATS server with JetStream; generic `Publish[T]` / `Subscribe[T]` helpers
|
||||
- **CSRF protection** — automatic token generation and validation on every action
|
||||
- **Rate limiting** — token-bucket algorithm, configurable globally and per-action
|
||||
- **Event handling** — `OnClick`, `OnChange`, `OnSubmit`, `OnInput`, `OnFocus`, `OnBlur`, `OnMouseEnter`, `OnMouseLeave`, `OnScroll`, `OnDblClick`, `OnKeyDown`, and `OnKeyDownMap` for multi-key bindings
|
||||
- **Timed routines** — `OnInterval` auto-starts a ticker goroutine, returns a stop function, tied to context lifecycle
|
||||
- **Redirects** — `Redirect`, `ReplaceURL`, and format-string variants
|
||||
- **Plugin system** — `func(v *V)` hooks for integrating CSS/JS libraries
|
||||
- **Structured logging** — zerolog with configurable levels; console output in dev, JSON in production
|
||||
- **Graceful shutdown** — listens for SIGINT/SIGTERM, drains contexts, closes pub/sub
|
||||
- **Context lifecycle** — background reaper cleans up disconnected contexts; configurable TTL
|
||||
- **HTML DSL** — the `h` package provides type-safe Go-native HTML composition
|
||||
|
||||
## Examples
|
||||
|
||||
The `internal/examples/` directory contains 14 runnable examples:
|
||||
|
||||
`chatroom` · `counter` · `countercomp` · `greeter` · `keyboard` · `livereload` · `nats-chatroom` · `pathparams` · `picocss` · `plugins` · `pubsub-crud` · `realtimechart` · `session` · `shakespeare`
|
||||
|
||||
## Experimental
|
||||
|
||||
Via is maturing — sessions, CSRF, rate limiting, pub/sub, and graceful shutdown are in place — but the API is still evolving. Expect breaking changes before `v1`.
|
||||
|
||||
## Contributing
|
||||
|
||||
- Via is intentionally minimal and opinionated — and so is contributing.
|
||||
- If you love Go, simplicity, and meaningful abstractions — Come along for the ride!
|
||||
- Fork, branch, build, tinker with things, submit a pull request.
|
||||
- Fork, branch, build, tinker, submit a pull request.
|
||||
- Keep every line purposeful.
|
||||
- Share feedback: open an issue or start a discussion.
|
||||
|
||||
|
||||
## Credits
|
||||
|
||||
Via builds upon the work of these amazing projects:
|
||||
Via builds upon the work of these projects:
|
||||
|
||||
- 🚀 [Datastar](https://data-star.dev) - The hypermedia powerhouse at the core of Via. It powers browser reactivity through Signals and enables real-time HTML/Signal patches over an always-on SSE event stream.
|
||||
- 🧩 [Gomponents](https://maragu.dev/gomponents) - The awesome project that gifts Via with Go-native HTML composition superpowers through the `via/h` package.
|
||||
|
||||
> Thank you for building something that doesn’t just function — it inspires. 🫶
|
||||
- [Datastar](https://data-star.dev) — the hypermedia framework powering browser reactivity through signals and real-time HTML patches over SSE.
|
||||
- [Gomponents](https://maragu.dev/gomponents) — Go-native HTML composition that powers the `via/h` package.
|
||||
|
||||
@@ -107,6 +107,54 @@ func (a *actionTrigger) OnChange(options ...ActionTriggerOption) h.H {
|
||||
return h.Data("on:change__debounce.200ms", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnSubmit returns a via.h DOM attribute that triggers on form submit.
|
||||
func (a *actionTrigger) OnSubmit(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:submit", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnInput returns a via.h DOM attribute that triggers on input (without debounce).
|
||||
func (a *actionTrigger) OnInput(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:input", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnFocus returns a via.h DOM attribute that triggers when the element gains focus.
|
||||
func (a *actionTrigger) OnFocus(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:focus", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnBlur returns a via.h DOM attribute that triggers when the element loses focus.
|
||||
func (a *actionTrigger) OnBlur(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:blur", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnMouseEnter returns a via.h DOM attribute that triggers when the mouse enters the element.
|
||||
func (a *actionTrigger) OnMouseEnter(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:mouseenter", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnMouseLeave returns a via.h DOM attribute that triggers when the mouse leaves the element.
|
||||
func (a *actionTrigger) OnMouseLeave(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:mouseleave", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnScroll returns a via.h DOM attribute that triggers on scroll.
|
||||
func (a *actionTrigger) OnScroll(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:scroll", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnDblClick returns a via.h DOM attribute that triggers on double click.
|
||||
func (a *actionTrigger) OnDblClick(options ...ActionTriggerOption) h.H {
|
||||
opts := applyOptions(options...)
|
||||
return h.Data("on:dblclick", buildOnExpr(actionURL(a.id), &opts))
|
||||
}
|
||||
|
||||
// OnKeyDown returns a via.h DOM attribute that triggers when a key is pressed.
|
||||
// key: optional, see https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key
|
||||
// Example: OnKeyDown("Enter")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/alexedwards/scs/v2"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
@@ -54,4 +56,14 @@ type Options struct {
|
||||
// PubSub enables publish/subscribe messaging. Use vianats.New() for an
|
||||
// embedded NATS backend, or supply any PubSub implementation.
|
||||
PubSub PubSub
|
||||
|
||||
// ContextTTL is the maximum time a context may exist without an SSE
|
||||
// connection before the background reaper disposes it.
|
||||
// Default: 30s. Negative value disables the reaper.
|
||||
ContextTTL time.Duration
|
||||
|
||||
// ActionRateLimit configures the default token-bucket rate limiter for
|
||||
// action endpoints. Zero values use built-in defaults (10 req/s, burst 20).
|
||||
// Set Rate to -1 to disable rate limiting entirely.
|
||||
ActionRateLimit RateLimitConfig
|
||||
}
|
||||
|
||||
188
context.go
188
context.go
@@ -5,12 +5,14 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Context is the living bridge between Go and the browser.
|
||||
@@ -19,20 +21,26 @@ import (
|
||||
type Context struct {
|
||||
id string
|
||||
route string
|
||||
csrfToken string
|
||||
app *V
|
||||
view func() h.H
|
||||
routeParams map[string]string
|
||||
componentRegistry map[string]*Context
|
||||
parentPageCtx *Context
|
||||
parentPageCtx *Context
|
||||
patchChan chan patch
|
||||
actionRegistry map[string]func()
|
||||
actionLimiter *rate.Limiter
|
||||
actionRegistry map[string]actionEntry
|
||||
signals *sync.Map
|
||||
mu sync.RWMutex
|
||||
navMu sync.Mutex
|
||||
ctxDisposedChan chan struct{}
|
||||
pageStopChan chan struct{}
|
||||
reqCtx context.Context
|
||||
fields []*Field
|
||||
subscriptions []Subscription
|
||||
subsMu sync.Mutex
|
||||
disposeOnce sync.Once
|
||||
createdAt time.Time
|
||||
sseConnected atomic.Bool
|
||||
}
|
||||
|
||||
// View defines the UI rendered by this context.
|
||||
@@ -43,7 +51,11 @@ func (c *Context) View(f func() h.H) {
|
||||
if f == nil {
|
||||
panic("nil viewfn")
|
||||
}
|
||||
c.view = func() h.H { return h.Div(h.ID(c.id), f()) }
|
||||
if c.app.layout != nil {
|
||||
c.view = func() h.H { return h.Div(h.ID(c.id), c.app.layout(f)) }
|
||||
} else {
|
||||
c.view = func() h.H { return h.Div(h.ID(c.id), f()) }
|
||||
}
|
||||
}
|
||||
|
||||
// Component registers a subcontext that has self contained data, actions and signals.
|
||||
@@ -75,7 +87,6 @@ func (c *Context) Component(initCtx func(c *Context)) func() h.H {
|
||||
compCtx.parentPageCtx = c
|
||||
}
|
||||
initCtx(compCtx)
|
||||
c.componentRegistry[id] = compCtx
|
||||
return compCtx.view
|
||||
}
|
||||
|
||||
@@ -100,39 +111,46 @@ func (c *Context) isComponent() bool {
|
||||
// h.Button(h.Text("Increment n"), increment.OnClick()),
|
||||
// )
|
||||
// })
|
||||
func (c *Context) Action(f func()) *actionTrigger {
|
||||
func (c *Context) Action(f func(), opts ...ActionOption) *actionTrigger {
|
||||
id := genRandID()
|
||||
if f == nil {
|
||||
c.app.logErr(c, "failed to bind action '%s' to context: nil func", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
entry := actionEntry{fn: f}
|
||||
for _, opt := range opts {
|
||||
opt(&entry)
|
||||
}
|
||||
|
||||
if c.isComponent() {
|
||||
c.parentPageCtx.actionRegistry[id] = f
|
||||
c.parentPageCtx.actionRegistry[id] = entry
|
||||
} else {
|
||||
c.actionRegistry[id] = f
|
||||
c.actionRegistry[id] = entry
|
||||
}
|
||||
return &actionTrigger{id}
|
||||
}
|
||||
|
||||
func (c *Context) getActionFn(id string) (func(), error) {
|
||||
if f, ok := c.actionRegistry[id]; ok {
|
||||
return f, nil
|
||||
func (c *Context) getAction(id string) (actionEntry, error) {
|
||||
if e, ok := c.actionRegistry[id]; ok {
|
||||
return e, nil
|
||||
}
|
||||
return nil, fmt.Errorf("action '%s' not found", id)
|
||||
return actionEntry{}, fmt.Errorf("action '%s' not found", id)
|
||||
}
|
||||
|
||||
// OnInterval starts a go routine that sets a time.Ticker with the given duration and executes
|
||||
// the given handler func() on every tick. Use *Routine.UpdateInterval to update the interval.
|
||||
func (c *Context) OnInterval(duration time.Duration, handler func()) *OnIntervalRoutine {
|
||||
var cn chan struct{}
|
||||
if c.isComponent() { // components use the chan on the parent page ctx
|
||||
cn = c.parentPageCtx.ctxDisposedChan
|
||||
// OnInterval starts a goroutine that executes handler on every tick of the given duration.
|
||||
// The goroutine is tied to the context lifecycle and will stop when the context is disposed.
|
||||
// Returns a func() that stops the interval when called.
|
||||
func (c *Context) OnInterval(duration time.Duration, handler func()) func() {
|
||||
var disposeCh, pageCh chan struct{}
|
||||
if c.isComponent() {
|
||||
disposeCh = c.parentPageCtx.ctxDisposedChan
|
||||
pageCh = c.parentPageCtx.pageStopChan
|
||||
} else {
|
||||
cn = c.ctxDisposedChan
|
||||
disposeCh = c.ctxDisposedChan
|
||||
pageCh = c.pageStopChan
|
||||
}
|
||||
r := newOnIntervalRoutine(cn, duration, handler)
|
||||
return r
|
||||
return newOnInterval(disposeCh, pageCh, duration, handler)
|
||||
}
|
||||
|
||||
// Signal creates a reactive signal and initializes it with the given value.
|
||||
@@ -197,14 +215,14 @@ func (c *Context) injectSignals(sigs map[string]any) {
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for sigID, val := range sigs {
|
||||
if _, ok := c.signals.Load(sigID); !ok {
|
||||
item, ok := c.signals.Load(sigID)
|
||||
if !ok {
|
||||
c.signals.Store(sigID, &signal{
|
||||
id: sigID,
|
||||
val: val,
|
||||
})
|
||||
continue
|
||||
}
|
||||
item, _ := c.signals.Load(sigID)
|
||||
if sig, ok := item.(*signal); ok {
|
||||
sig.val = val
|
||||
sig.changed = false
|
||||
@@ -255,15 +273,22 @@ func (c *Context) sendPatch(p patch) {
|
||||
// Sync pushes the current view state and signal changes to the browser immediately
|
||||
// over the live SSE event stream.
|
||||
func (c *Context) Sync() {
|
||||
elemsPatch := bytes.NewBuffer(make([]byte, 0))
|
||||
c.syncView(false)
|
||||
}
|
||||
|
||||
func (c *Context) syncView(viewTransition bool) {
|
||||
elemsPatch := new(bytes.Buffer)
|
||||
if err := c.view().Render(elemsPatch); err != nil {
|
||||
c.app.logErr(c, "sync view failed: %v", err)
|
||||
return
|
||||
}
|
||||
c.sendPatch(patch{patchTypeElements, elemsPatch.String()})
|
||||
typ := patchType(patchTypeElements)
|
||||
if viewTransition {
|
||||
typ = patchTypeElementsWithVT
|
||||
}
|
||||
c.sendPatch(patch{typ, elemsPatch.String()})
|
||||
|
||||
updatedSigs := c.prepareSignalsForPatch()
|
||||
|
||||
if len(updatedSigs) != 0 {
|
||||
outgoingSigs, _ := json.Marshal(updatedSigs)
|
||||
c.sendPatch(patch{patchTypeSignals, string(outgoingSigs)})
|
||||
@@ -320,6 +345,15 @@ func (c *Context) ExecScript(s string) {
|
||||
c.sendPatch(patch{patchTypeScript, s})
|
||||
}
|
||||
|
||||
// RedirectView sets a view that redirects the browser to the given URL.
|
||||
// Use this in middleware to abort the chain and redirect in one step.
|
||||
func (c *Context) RedirectView(url string) {
|
||||
c.View(func() h.H {
|
||||
c.Redirect(url)
|
||||
return h.Div()
|
||||
})
|
||||
}
|
||||
|
||||
// Redirect navigates the browser to the given URL.
|
||||
// This triggers a full page navigation - the current context will be disposed
|
||||
// and a new context created at the destination URL.
|
||||
@@ -351,6 +385,46 @@ func (c *Context) ReplaceURLf(format string, a ...any) {
|
||||
c.ReplaceURL(fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// resetPageState tears down page-specific state (intervals, subscriptions,
|
||||
// actions, signals, fields) without disposing the context itself. The SSE
|
||||
// connection and context lifetime are unaffected.
|
||||
func (c *Context) resetPageState() {
|
||||
close(c.pageStopChan)
|
||||
c.unsubscribeAll()
|
||||
c.mu.Lock()
|
||||
c.actionRegistry = make(map[string]actionEntry)
|
||||
c.signals = new(sync.Map)
|
||||
c.fields = nil
|
||||
c.pageStopChan = make(chan struct{})
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Navigate performs an SPA navigation to the given path. It resets page state,
|
||||
// runs the target page's init function (with middleware), and pushes the new
|
||||
// view over the existing SSE connection with a view transition animation.
|
||||
// If popstate is true, replaceState is used instead of pushState.
|
||||
func (c *Context) Navigate(path string, popstate bool) {
|
||||
c.navMu.Lock()
|
||||
defer c.navMu.Unlock()
|
||||
|
||||
route, initFn, params := c.app.matchRoute(path)
|
||||
if initFn == nil {
|
||||
c.Redirect(path)
|
||||
return
|
||||
}
|
||||
c.resetPageState()
|
||||
c.route = route
|
||||
c.injectRouteParams(params)
|
||||
initFn(c)
|
||||
c.syncView(true)
|
||||
safe := strings.NewReplacer(`\`, `\\`, `'`, `\'`).Replace(path)
|
||||
if popstate {
|
||||
c.ExecScript(fmt.Sprintf("history.replaceState({},'','%s')", safe))
|
||||
} else {
|
||||
c.ExecScript(fmt.Sprintf("history.pushState({},'','%s')", safe))
|
||||
}
|
||||
}
|
||||
|
||||
// dispose idempotently tears down this context: unsubscribes all pubsub
|
||||
// subscriptions and closes ctxDisposedChan to stop routines and exit the SSE loop.
|
||||
func (c *Context) dispose() {
|
||||
@@ -361,7 +435,7 @@ func (c *Context) dispose() {
|
||||
}
|
||||
|
||||
// stopAllRoutines closes ctxDisposedChan, broadcasting to all listening
|
||||
// goroutines (OnIntervalRoutine, SSE loop) that this context is done.
|
||||
// goroutines (OnInterval, SSE loop) that this context is done.
|
||||
func (c *Context) stopAllRoutines() {
|
||||
select {
|
||||
case <-c.ctxDisposedChan:
|
||||
@@ -375,12 +449,9 @@ func (c *Context) injectRouteParams(params map[string]string) {
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
m := make(map[string]string)
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
maps.Copy(m, params)
|
||||
c.routeParams = m
|
||||
|
||||
c.routeParams = params
|
||||
}
|
||||
|
||||
// GetPathParam retrieves the value from the page request URL for the given parameter name
|
||||
@@ -466,6 +537,50 @@ func (c *Context) unsubscribeAll() {
|
||||
}
|
||||
}
|
||||
|
||||
// Field creates a signal with validation rules attached.
|
||||
// The initial value seeds both the signal and the reset target.
|
||||
// The field is tracked on the context so ValidateAll/ResetFields
|
||||
// can operate on all fields by default.
|
||||
func (c *Context) Field(initial any, rules ...Rule) *Field {
|
||||
f := &Field{
|
||||
signal: c.Signal(initial),
|
||||
rules: rules,
|
||||
initialVal: initial,
|
||||
}
|
||||
target := c
|
||||
if c.isComponent() {
|
||||
target = c.parentPageCtx
|
||||
}
|
||||
target.fields = append(target.fields, f)
|
||||
return f
|
||||
}
|
||||
|
||||
// ValidateAll runs Validate on each field, returning true only if all pass.
|
||||
// With no arguments it validates every field tracked on this context.
|
||||
func (c *Context) ValidateAll(fields ...*Field) bool {
|
||||
if len(fields) == 0 {
|
||||
fields = c.fields
|
||||
}
|
||||
ok := true
|
||||
for _, f := range fields {
|
||||
if !f.Validate() {
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetFields resets each field to its initial value and clears errors.
|
||||
// With no arguments it resets every field tracked on this context.
|
||||
func (c *Context) ResetFields(fields ...*Field) {
|
||||
if len(fields) == 0 {
|
||||
fields = c.fields
|
||||
}
|
||||
for _, f := range fields {
|
||||
f.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func newContext(id string, route string, v *V) *Context {
|
||||
if v == nil {
|
||||
panic("create context failed: app pointer is nil")
|
||||
@@ -474,12 +589,15 @@ func newContext(id string, route string, v *V) *Context {
|
||||
return &Context{
|
||||
id: id,
|
||||
route: route,
|
||||
csrfToken: genCSRFToken(),
|
||||
routeParams: make(map[string]string),
|
||||
app: v,
|
||||
componentRegistry: make(map[string]*Context),
|
||||
actionRegistry: make(map[string]func()),
|
||||
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
|
||||
actionRegistry: make(map[string]actionEntry),
|
||||
signals: new(sync.Map),
|
||||
patchChan: make(chan patch, 1),
|
||||
patchChan: make(chan patch, 8),
|
||||
ctxDisposedChan: make(chan struct{}, 1),
|
||||
pageStopChan: make(chan struct{}),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
58
field.go
Normal file
58
field.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package via
|
||||
|
||||
// Field is a signal with built-in validation rules and error state.
|
||||
// It embeds *signal, so all signal methods (Bind, String, Int, Bool, SetValue, Text, ID)
|
||||
// work transparently.
|
||||
type Field struct {
|
||||
*signal
|
||||
rules []Rule
|
||||
errors []string
|
||||
initialVal any
|
||||
}
|
||||
|
||||
// Validate runs all rules against the current value.
|
||||
// Clears previous errors, populates new ones, returns true if all rules pass.
|
||||
func (f *Field) Validate() bool {
|
||||
f.errors = nil
|
||||
val := f.String()
|
||||
for _, r := range f.rules {
|
||||
if err := r.validate(val); err != nil {
|
||||
f.errors = append(f.errors, err.Error())
|
||||
}
|
||||
}
|
||||
return len(f.errors) == 0
|
||||
}
|
||||
|
||||
// HasError returns true if this field has any validation errors.
|
||||
func (f *Field) HasError() bool {
|
||||
return len(f.errors) > 0
|
||||
}
|
||||
|
||||
// FirstError returns the first validation error message, or "" if valid.
|
||||
func (f *Field) FirstError() string {
|
||||
if len(f.errors) > 0 {
|
||||
return f.errors[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Errors returns all current validation error messages.
|
||||
func (f *Field) Errors() []string {
|
||||
return f.errors
|
||||
}
|
||||
|
||||
// AddError manually adds an error message (useful for server-side or cross-field validation).
|
||||
func (f *Field) AddError(msg string) {
|
||||
f.errors = append(f.errors, msg)
|
||||
}
|
||||
|
||||
// ClearErrors removes all validation errors from this field.
|
||||
func (f *Field) ClearErrors() {
|
||||
f.errors = nil
|
||||
}
|
||||
|
||||
// Reset restores the field value to its initial value and clears all errors.
|
||||
func (f *Field) Reset() {
|
||||
f.SetValue(f.initialVal)
|
||||
f.errors = nil
|
||||
}
|
||||
206
field_test.go
Normal file
206
field_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func newTestField(initial any, rules ...Rule) *Field {
|
||||
v := New()
|
||||
var f *Field
|
||||
v.Page("/", func(c *Context) {
|
||||
f = c.Field(initial, rules...)
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
return f
|
||||
}
|
||||
|
||||
func TestFieldCreation(t *testing.T) {
|
||||
f := newTestField("hello", Required())
|
||||
assert.Equal(t, "hello", f.String())
|
||||
assert.NotEmpty(t, f.ID())
|
||||
}
|
||||
|
||||
func TestFieldSignalDelegation(t *testing.T) {
|
||||
f := newTestField(42)
|
||||
assert.Equal(t, "42", f.String())
|
||||
assert.Equal(t, 42, f.Int())
|
||||
|
||||
f.SetValue("new")
|
||||
assert.Equal(t, "new", f.String())
|
||||
|
||||
// Bind returns an h.H element
|
||||
assert.NotNil(t, f.Bind())
|
||||
}
|
||||
|
||||
func TestFieldValidateSingleRule(t *testing.T) {
|
||||
f := newTestField("", Required())
|
||||
assert.False(t, f.Validate())
|
||||
assert.True(t, f.HasError())
|
||||
assert.Equal(t, "This field is required", f.FirstError())
|
||||
|
||||
f.SetValue("ok")
|
||||
assert.True(t, f.Validate())
|
||||
assert.False(t, f.HasError())
|
||||
assert.Equal(t, "", f.FirstError())
|
||||
}
|
||||
|
||||
func TestFieldValidateMultipleRules(t *testing.T) {
|
||||
f := newTestField("ab", Required(), MinLen(3))
|
||||
assert.False(t, f.Validate())
|
||||
errs := f.Errors()
|
||||
assert.Len(t, errs, 1)
|
||||
assert.Equal(t, "Must be at least 3 characters", errs[0])
|
||||
|
||||
f.SetValue("")
|
||||
assert.False(t, f.Validate())
|
||||
errs = f.Errors()
|
||||
assert.Len(t, errs, 2)
|
||||
}
|
||||
|
||||
func TestFieldErrors(t *testing.T) {
|
||||
f := newTestField("")
|
||||
assert.Nil(t, f.Errors())
|
||||
assert.False(t, f.HasError())
|
||||
assert.Equal(t, "", f.FirstError())
|
||||
}
|
||||
|
||||
func TestFieldAddError(t *testing.T) {
|
||||
f := newTestField("ok")
|
||||
f.AddError("username taken")
|
||||
assert.True(t, f.HasError())
|
||||
assert.Equal(t, "username taken", f.FirstError())
|
||||
assert.Len(t, f.Errors(), 1)
|
||||
}
|
||||
|
||||
func TestFieldClearErrors(t *testing.T) {
|
||||
f := newTestField("", Required())
|
||||
f.Validate()
|
||||
assert.True(t, f.HasError())
|
||||
f.ClearErrors()
|
||||
assert.False(t, f.HasError())
|
||||
}
|
||||
|
||||
func TestFieldReset(t *testing.T) {
|
||||
f := newTestField("initial", Required(), MinLen(3))
|
||||
f.SetValue("changed")
|
||||
f.AddError("some error")
|
||||
|
||||
f.Reset()
|
||||
assert.Equal(t, "initial", f.String())
|
||||
assert.False(t, f.HasError())
|
||||
}
|
||||
|
||||
func TestValidateAll(t *testing.T) {
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
c.Field("", Required(), MinLen(3))
|
||||
c.Field("", Required(), Email())
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
// both empty → both fail
|
||||
assert.False(t, c.ValidateAll())
|
||||
})
|
||||
|
||||
v2 := New()
|
||||
v2.Page("/", func(c *Context) {
|
||||
c.Field("joe", Required(), MinLen(3))
|
||||
c.Field("joe@x.com", Required(), Email())
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
assert.True(t, c.ValidateAll())
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateAllPartialFailure(t *testing.T) {
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
good := c.Field("valid", Required())
|
||||
bad := c.Field("", Required())
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
ok := c.ValidateAll()
|
||||
assert.False(t, ok)
|
||||
assert.False(t, good.HasError())
|
||||
assert.True(t, bad.HasError())
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateAllSelectiveArgs(t *testing.T) {
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
a := c.Field("", Required())
|
||||
b := c.Field("ok", Required())
|
||||
cField := c.Field("", Required())
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
// only validate a and b — cField should be untouched
|
||||
ok := c.ValidateAll(a, b)
|
||||
assert.False(t, ok)
|
||||
assert.True(t, a.HasError())
|
||||
assert.False(t, b.HasError())
|
||||
assert.False(t, cField.HasError(), "unselected field should not be validated")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResetFields(t *testing.T) {
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
a := c.Field("a", Required())
|
||||
b := c.Field("b", Required())
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
a.SetValue("changed-a")
|
||||
b.SetValue("changed-b")
|
||||
a.AddError("err")
|
||||
|
||||
c.ResetFields()
|
||||
assert.Equal(t, "a", a.String())
|
||||
assert.Equal(t, "b", b.String())
|
||||
assert.False(t, a.HasError())
|
||||
})
|
||||
}
|
||||
|
||||
func TestResetFieldsSelectiveArgs(t *testing.T) {
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
a := c.Field("a")
|
||||
b := c.Field("b")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
a.SetValue("changed-a")
|
||||
b.SetValue("changed-b")
|
||||
|
||||
// only reset a
|
||||
c.ResetFields(a)
|
||||
assert.Equal(t, "a", a.String())
|
||||
assert.Equal(t, "changed-b", b.String(), "unselected field should not be reset")
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldValidateClearsPreviousErrors(t *testing.T) {
|
||||
f := newTestField("", Required())
|
||||
f.Validate()
|
||||
assert.True(t, f.HasError())
|
||||
|
||||
f.SetValue("ok")
|
||||
f.Validate()
|
||||
assert.False(t, f.HasError())
|
||||
}
|
||||
|
||||
func TestFieldCustomValidator(t *testing.T) {
|
||||
f := newTestField("bad", Custom(func(val string) error {
|
||||
if val == "bad" {
|
||||
return fmt.Errorf("no bad words")
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
assert.False(t, f.Validate())
|
||||
assert.Equal(t, "no bad words", f.FirstError())
|
||||
|
||||
f.SetValue("good")
|
||||
assert.True(t, f.Validate())
|
||||
}
|
||||
2
go.mod
2
go.mod
@@ -14,6 +14,7 @@ require (
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/starfederation/datastar-go v1.0.3
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/time v0.14.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -37,6 +38,5 @@ require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -11,3 +11,11 @@ func DataEffect(expression string) H {
|
||||
func DataIgnoreMorph() H {
|
||||
return Attr("data-ignore-morph")
|
||||
}
|
||||
|
||||
// DataViewTransition sets the view-transition-name CSS property on an element
|
||||
// via an inline style. Elements with matching names animate between pages
|
||||
// during SPA navigation. If the element also needs other inline styles,
|
||||
// include view-transition-name directly in the Style() call instead.
|
||||
func DataViewTransition(name string) H {
|
||||
return Attr("style", "view-transition-name: "+name)
|
||||
}
|
||||
|
||||
151
internal/examples/middleware/main.go
Normal file
151
internal/examples/middleware/main.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via"
|
||||
"github.com/ryanhamamura/via/h"
|
||||
)
|
||||
|
||||
func main() {
|
||||
v := via.New()
|
||||
v.Config(via.Options{
|
||||
ServerAddress: ":8080",
|
||||
DocumentTitle: "Middleware Example",
|
||||
})
|
||||
|
||||
// --- Middleware definitions ---
|
||||
|
||||
// requestLogger logs every page request to stdout.
|
||||
requestLogger := func(c *via.Context, next func()) {
|
||||
fmt.Printf("[%s] request\n", time.Now().Format("15:04:05"))
|
||||
next()
|
||||
}
|
||||
|
||||
// authRequired redirects unauthenticated users to /login.
|
||||
authRequired := func(c *via.Context, next func()) {
|
||||
if c.Session().GetString("role") == "" {
|
||||
c.RedirectView("/login")
|
||||
return
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
||||
// auditLog prints the authenticated username to stdout.
|
||||
auditLog := func(c *via.Context, next func()) {
|
||||
fmt.Printf("[audit] user=%s\n", c.Session().GetString("username"))
|
||||
next()
|
||||
}
|
||||
|
||||
// superAdminOnly rejects non-superadmin users with a forbidden view.
|
||||
superAdminOnly := func(c *via.Context, next func()) {
|
||||
if c.Session().GetString("role") != "superadmin" {
|
||||
c.View(func() h.H {
|
||||
return h.Div(
|
||||
h.H1(h.Text("Forbidden")),
|
||||
h.P(h.Text("Super-admin access required.")),
|
||||
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
|
||||
)
|
||||
})
|
||||
return
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
||||
// --- Route registration ---
|
||||
|
||||
v.Use(requestLogger) // global middleware
|
||||
|
||||
admin := v.Group("/admin", authRequired) // prefixed group
|
||||
admin.Use(auditLog) // Group.Use()
|
||||
superAdmin := admin.Group("/super", superAdminOnly) // nested group
|
||||
|
||||
// Public: redirect root to login
|
||||
v.Page("/", func(c *via.Context) {
|
||||
c.View(func() h.H {
|
||||
c.Redirect("/login")
|
||||
return h.Div()
|
||||
})
|
||||
})
|
||||
|
||||
// Public: login page with role-selection buttons
|
||||
v.Page("/login", func(c *via.Context) {
|
||||
loginAdmin := c.Action(func() {
|
||||
c.Session().Set("role", "admin")
|
||||
c.Session().Set("username", "alice")
|
||||
c.Session().RenewToken()
|
||||
c.Redirect("/admin/dashboard")
|
||||
})
|
||||
|
||||
loginSuper := c.Action(func() {
|
||||
c.Session().Set("role", "superadmin")
|
||||
c.Session().Set("username", "bob")
|
||||
c.Session().RenewToken()
|
||||
c.Redirect("/admin/dashboard")
|
||||
})
|
||||
|
||||
c.View(func() h.H {
|
||||
return h.Div(
|
||||
h.H1(h.Text("Login")),
|
||||
h.P(h.Text("Choose a role:")),
|
||||
h.Button(h.Text("Login as Admin"), loginAdmin.OnClick()),
|
||||
h.Raw(" "),
|
||||
h.Button(h.Text("Login as Super Admin"), loginSuper.OnClick()),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
// Per-action middleware: only superadmins can invoke this action.
|
||||
requireSuperAdmin := func(c *via.Context, next func()) {
|
||||
if c.Session().GetString("role") != "superadmin" {
|
||||
return
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
||||
// Admin: dashboard (requires authRequired + auditLog)
|
||||
admin.Page("/dashboard", func(c *via.Context) {
|
||||
logout := c.Action(func() {
|
||||
c.Session().Delete("role")
|
||||
c.Session().Delete("username")
|
||||
c.Redirect("/login")
|
||||
})
|
||||
|
||||
dangerAction := c.Action(func() {
|
||||
fmt.Printf("[danger] executed by %s\n", c.Session().GetString("username"))
|
||||
c.Sync()
|
||||
}, via.WithMiddleware(requireSuperAdmin))
|
||||
|
||||
c.View(func() h.H {
|
||||
username := c.Session().GetString("username")
|
||||
role := c.Session().GetString("role")
|
||||
return h.Div(
|
||||
h.H1(h.Textf("Dashboard — %s (%s)", username, role)),
|
||||
h.Ul(
|
||||
h.Li(h.A(h.Href("/admin/super/settings"), h.Text("Super Admin Settings"))),
|
||||
),
|
||||
h.H2(h.Text("Danger Zone")),
|
||||
h.P(h.Text("This action is protected by per-action middleware (superadmin only):")),
|
||||
h.Button(h.Text("Delete Everything"), dangerAction.OnClick()),
|
||||
h.Br(),
|
||||
h.Br(),
|
||||
h.Button(h.Text("Logout"), logout.OnClick()),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
// Super-admin: settings (requires authRequired + auditLog + superAdminOnly)
|
||||
superAdmin.Page("/settings", func(c *via.Context) {
|
||||
c.View(func() h.H {
|
||||
username := c.Session().GetString("username")
|
||||
return h.Div(
|
||||
h.H1(h.Textf("Super Admin Settings — %s", username)),
|
||||
h.P(h.Text("Only super-admins can see this page.")),
|
||||
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
v.Start()
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"math/rand"
|
||||
"sync"
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
|
||||
"github.com/ryanhamamura/via"
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"github.com/ryanhamamura/via/vianats"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -36,15 +34,15 @@ func (u *UserInfo) Avatar() h.H {
|
||||
var roomNames = []string{"Go", "Rust", "Python", "JavaScript", "Clojure"}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
v := via.New()
|
||||
v.Config(via.Options{
|
||||
DevMode: true,
|
||||
DocumentTitle: "NATS Chat",
|
||||
LogLevel: via.LogLevelInfo,
|
||||
ServerAddress: ":7331",
|
||||
})
|
||||
|
||||
ps, err := vianats.New(ctx, "./data/nats")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to start embedded NATS: %v", err)
|
||||
}
|
||||
defer ps.Close()
|
||||
|
||||
err = vianats.EnsureStream(ps, vianats.StreamConfig{
|
||||
err := via.EnsureStream(v, via.StreamConfig{
|
||||
Name: "CHAT",
|
||||
Subjects: []string{"chat.>"},
|
||||
MaxMsgs: 1000,
|
||||
@@ -54,15 +52,6 @@ func main() {
|
||||
log.Fatalf("Failed to ensure stream: %v", err)
|
||||
}
|
||||
|
||||
v := via.New()
|
||||
v.Config(via.Options{
|
||||
DevMode: true,
|
||||
DocumentTitle: "NATS Chat",
|
||||
LogLevel: via.LogLevelInfo,
|
||||
ServerAddress: ":7331",
|
||||
PubSub: ps,
|
||||
})
|
||||
|
||||
v.AppendToHead(
|
||||
h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.min.css")),
|
||||
h.StyleEl(h.Raw(`
|
||||
@@ -148,7 +137,7 @@ func main() {
|
||||
subject := "chat.room." + room
|
||||
|
||||
// Replay history from JetStream
|
||||
if hist, err := vianats.ReplayHistory[ChatMessage](ps, subject, 50); err == nil {
|
||||
if hist, err := via.ReplayHistory[ChatMessage](v, subject, 50); err == nil {
|
||||
messages = hist
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"html"
|
||||
@@ -11,7 +10,6 @@ import (
|
||||
|
||||
"github.com/ryanhamamura/via"
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"github.com/ryanhamamura/via/vianats"
|
||||
)
|
||||
|
||||
var WithSignal = via.WithSignal
|
||||
@@ -49,15 +47,15 @@ func findBookmark(id string) (Bookmark, int) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
v := via.New()
|
||||
v.Config(via.Options{
|
||||
DevMode: true,
|
||||
DocumentTitle: "Bookmarks",
|
||||
LogLevel: via.LogLevelInfo,
|
||||
ServerAddress: ":7331",
|
||||
})
|
||||
|
||||
ps, err := vianats.New(ctx, "./data/nats")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to start embedded NATS: %v", err)
|
||||
}
|
||||
defer ps.Close()
|
||||
|
||||
err = vianats.EnsureStream(ps, vianats.StreamConfig{
|
||||
err := via.EnsureStream(v, via.StreamConfig{
|
||||
Name: "BOOKMARKS",
|
||||
Subjects: []string{"bookmarks.>"},
|
||||
MaxMsgs: 1000,
|
||||
@@ -67,15 +65,6 @@ func main() {
|
||||
log.Fatalf("Failed to ensure stream: %v", err)
|
||||
}
|
||||
|
||||
v := via.New()
|
||||
v.Config(via.Options{
|
||||
DevMode: true,
|
||||
DocumentTitle: "Bookmarks",
|
||||
LogLevel: via.LogLevelInfo,
|
||||
ServerAddress: ":7331",
|
||||
PubSub: ps,
|
||||
})
|
||||
|
||||
v.AppendToHead(
|
||||
h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/daisyui@4/dist/full.min.css")),
|
||||
h.Script(h.Src("https://cdn.tailwindcss.com")),
|
||||
|
||||
@@ -37,29 +37,33 @@ func main() {
|
||||
return 1000 / time.Duration(refreshRate.Int()) * time.Millisecond
|
||||
}
|
||||
|
||||
updateData := c.OnInterval(computedTickDuration(), func() {
|
||||
ts := time.Now().UnixMilli()
|
||||
val := rand.ExpFloat64() * 10
|
||||
var stopUpdate func()
|
||||
startInterval := func() {
|
||||
stopUpdate = c.OnInterval(computedTickDuration(), func() {
|
||||
ts := time.Now().UnixMilli()
|
||||
val := rand.ExpFloat64() * 10
|
||||
|
||||
c.ExecScript(fmt.Sprintf(`
|
||||
if (myChart) {
|
||||
myChart.appendData({seriesIndex: 0, data: [[%d, %f]]});
|
||||
myChart.setOption({},{notMerge:false,lazyUpdate:true});
|
||||
};
|
||||
`, ts, val))
|
||||
})
|
||||
updateData.Start()
|
||||
c.ExecScript(fmt.Sprintf(`
|
||||
if (myChart) {
|
||||
myChart.appendData({seriesIndex: 0, data: [[%d, %f]]});
|
||||
myChart.setOption({},{notMerge:false,lazyUpdate:true});
|
||||
};
|
||||
`, ts, val))
|
||||
})
|
||||
}
|
||||
startInterval()
|
||||
|
||||
updateRefreshRate := c.Action(func() {
|
||||
updateData.UpdateInterval(computedTickDuration())
|
||||
stopUpdate()
|
||||
startInterval()
|
||||
})
|
||||
|
||||
toggleIsLive := c.Action(func() {
|
||||
isLive = isLiveSig.Bool()
|
||||
if isLive {
|
||||
updateData.Start()
|
||||
startInterval()
|
||||
} else {
|
||||
updateData.Stop()
|
||||
stopUpdate()
|
||||
}
|
||||
})
|
||||
c.View(func() h.H {
|
||||
|
||||
@@ -29,7 +29,17 @@ func main() {
|
||||
SessionManager: sm,
|
||||
})
|
||||
|
||||
// Login page
|
||||
// Auth middleware — redirects unauthenticated users to /login
|
||||
authRequired := func(c *via.Context, next func()) {
|
||||
if c.Session().GetString("username") == "" {
|
||||
c.Session().Set("flash", "Please log in first")
|
||||
c.RedirectView("/login")
|
||||
return
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
||||
// Login page (public)
|
||||
v.Page("/login", func(c *via.Context) {
|
||||
flash := c.Session().PopString("flash")
|
||||
usernameInput := c.Signal("")
|
||||
@@ -64,8 +74,10 @@ func main() {
|
||||
})
|
||||
})
|
||||
|
||||
// Dashboard page (protected)
|
||||
v.Page("/dashboard", func(c *via.Context) {
|
||||
// Protected pages
|
||||
protected := v.Group("", authRequired)
|
||||
|
||||
protected.Page("/dashboard", func(c *via.Context) {
|
||||
logout := c.Action(func() {
|
||||
c.Session().Set("flash", "Goodbye!")
|
||||
c.Session().Delete("username")
|
||||
@@ -74,14 +86,6 @@ func main() {
|
||||
|
||||
c.View(func() h.H {
|
||||
username := c.Session().GetString("username")
|
||||
|
||||
// Not logged in? Redirect to login
|
||||
if username == "" {
|
||||
c.Session().Set("flash", "Please log in first")
|
||||
c.Redirect("/login")
|
||||
return h.Div()
|
||||
}
|
||||
|
||||
flash := c.Session().PopString("flash")
|
||||
var flashMsg h.H
|
||||
if flash != "" {
|
||||
|
||||
87
internal/examples/signup/main.go
Normal file
87
internal/examples/signup/main.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/ryanhamamura/via"
|
||||
"github.com/ryanhamamura/via/h"
|
||||
)
|
||||
|
||||
func main() {
|
||||
v := via.New()
|
||||
v.Config(via.Options{
|
||||
DocumentTitle: "Signup",
|
||||
ServerAddress: ":8080",
|
||||
})
|
||||
|
||||
v.AppendToHead(h.StyleEl(h.Raw(`
|
||||
body { font-family: system-ui, sans-serif; max-width: 420px; margin: 2rem auto; padding: 0 1rem; }
|
||||
label { display: block; font-weight: 600; margin-top: 1rem; }
|
||||
input { display: block; width: 100%; padding: 0.4rem; margin-top: 0.25rem; box-sizing: border-box; }
|
||||
.error { color: #c00; font-size: 0.85rem; margin-top: 0.2rem; }
|
||||
.success { color: #080; margin-top: 1rem; }
|
||||
.actions { margin-top: 1.5rem; display: flex; gap: 0.5rem; }
|
||||
`)))
|
||||
|
||||
v.Page("/", func(c *via.Context) {
|
||||
username := c.Field("", via.Required(), via.MinLen(3), via.MaxLen(20))
|
||||
email := c.Field("", via.Required(), via.Email())
|
||||
age := c.Field("", via.Required(), via.Min(13), via.Max(120))
|
||||
// Optional field — only validated when non-empty
|
||||
website := c.Field("", via.Pattern(`^$|^https?://\S+$`, "Must be a valid URL"))
|
||||
|
||||
var success string
|
||||
|
||||
signup := c.Action(func() {
|
||||
success = ""
|
||||
if !c.ValidateAll() {
|
||||
c.Sync()
|
||||
return
|
||||
}
|
||||
// Server-side check
|
||||
if username.String() == "admin" {
|
||||
username.AddError("Username is already taken")
|
||||
c.Sync()
|
||||
return
|
||||
}
|
||||
success = "Account created for " + username.String() + "!"
|
||||
c.ResetFields()
|
||||
c.Sync()
|
||||
})
|
||||
|
||||
reset := c.Action(func() {
|
||||
success = ""
|
||||
c.ResetFields()
|
||||
c.Sync()
|
||||
})
|
||||
|
||||
c.View(func() h.H {
|
||||
return h.Div(
|
||||
h.H1(h.Text("Sign Up")),
|
||||
|
||||
h.Label(h.Text("Username")),
|
||||
h.Input(h.Type("text"), h.Placeholder("pick a username"), username.Bind()),
|
||||
h.If(username.HasError(), h.Div(h.Class("error"), h.Text(username.FirstError()))),
|
||||
|
||||
h.Label(h.Text("Email")),
|
||||
h.Input(h.Type("email"), h.Placeholder("you@example.com"), email.Bind()),
|
||||
h.If(email.HasError(), h.Div(h.Class("error"), h.Text(email.FirstError()))),
|
||||
|
||||
h.Label(h.Text("Age")),
|
||||
h.Input(h.Type("number"), h.Placeholder("your age"), age.Bind()),
|
||||
h.If(age.HasError(), h.Div(h.Class("error"), h.Text(age.FirstError()))),
|
||||
|
||||
h.Label(h.Text("Website (optional)")),
|
||||
h.Input(h.Type("url"), h.Placeholder("https://example.com"), website.Bind()),
|
||||
h.If(website.HasError(), h.Div(h.Class("error"), h.Text(website.FirstError()))),
|
||||
|
||||
h.Div(h.Class("actions"),
|
||||
h.Button(h.Text("Sign Up"), signup.OnClick()),
|
||||
h.Button(h.Text("Reset"), reset.OnClick()),
|
||||
),
|
||||
|
||||
h.If(success != "", h.P(h.Class("success"), h.Text(success))),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
v.Start()
|
||||
}
|
||||
91
internal/examples/spa/main.go
Normal file
91
internal/examples/spa/main.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via"
|
||||
. "github.com/ryanhamamura/via/h"
|
||||
)
|
||||
|
||||
func main() {
|
||||
v := via.New()
|
||||
v.Config(via.Options{
|
||||
DocumentTitle: "SPA Navigation",
|
||||
ServerAddress: ":7331",
|
||||
})
|
||||
|
||||
v.AppendToHead(
|
||||
Raw(`<link rel="preconnect" href="https://fonts.googleapis.com">`),
|
||||
Raw(`<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>`),
|
||||
Raw(`<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap" rel="stylesheet">`),
|
||||
Raw(`<style>body{font-family:'Inter',sans-serif;margin:0;background:#111;color:#eee}</style>`),
|
||||
)
|
||||
|
||||
v.Layout(func(content func() H) H {
|
||||
return Div(
|
||||
Nav(
|
||||
Style("display:flex;gap:1rem;padding:1rem;background:#222;"),
|
||||
A(Href("/"), Text("Home"), Style("color:#fff")),
|
||||
A(Href("/counter"), Text("Counter"), Style("color:#fff")),
|
||||
A(Href("/clock"), Text("Clock"), Style("color:#fff")),
|
||||
A(Href("https://github.com"), Text("GitHub (external)"), Style("color:#888")),
|
||||
A(Href("/"), Text("Full Reload"), Attr("data-via-no-boost"), Style("color:#f88")),
|
||||
),
|
||||
Main(Style("padding:1rem"), content()),
|
||||
)
|
||||
})
|
||||
|
||||
// Home page
|
||||
v.Page("/", func(c *via.Context) {
|
||||
goCounter := c.Action(func() { c.Navigate("/counter", false) })
|
||||
|
||||
c.View(func() H {
|
||||
return Div(
|
||||
H1(Text("Home"), DataViewTransition("page-title")),
|
||||
P(Text("Click the nav links above — no page reload, no white flash.")),
|
||||
P(Text("Or navigate programmatically:")),
|
||||
Button(Text("Go to Counter"), goCounter.OnClick()),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
// Counter page — demonstrates signals and actions survive within a page,
|
||||
// but reset on navigate away and back.
|
||||
v.Page("/counter", func(c *via.Context) {
|
||||
count := 0
|
||||
increment := c.Action(func() { count++; c.Sync() })
|
||||
goHome := c.Action(func() { c.Navigate("/", false) })
|
||||
|
||||
c.View(func() H {
|
||||
return Div(
|
||||
H1(Text("Counter"), DataViewTransition("page-title")),
|
||||
P(Textf("Count: %d", count)),
|
||||
Button(Text("+1"), increment.OnClick()),
|
||||
Button(Text("Go Home"), goHome.OnClick(), Style("margin-left:0.5rem")),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
// Clock page — demonstrates OnInterval cleanup on navigate.
|
||||
v.Page("/clock", func(c *via.Context) {
|
||||
now := time.Now().Format("15:04:05")
|
||||
c.OnInterval(time.Second, func() {
|
||||
now = time.Now().Format("15:04:05")
|
||||
c.Sync()
|
||||
})
|
||||
|
||||
c.View(func() H {
|
||||
return Div(
|
||||
H1(Text("Clock"), DataViewTransition("page-title")),
|
||||
P(Text("This page has an OnInterval that ticks every second.")),
|
||||
P(Textf("Current time: %s", now)),
|
||||
P(Text("Navigate away and back — the old interval stops, a new one starts.")),
|
||||
P(Textf("Proof this is a fresh page init: random = %d", time.Now().UnixNano()%1000)),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
fmt.Println("SPA example running at http://localhost:7331")
|
||||
v.Start()
|
||||
}
|
||||
82
middleware.go
Normal file
82
middleware.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package via
|
||||
|
||||
// Middleware wraps a page init function. Call next to continue the chain;
|
||||
// return without calling next to abort (set a view first, e.g. RedirectView).
|
||||
type Middleware func(c *Context, next func())
|
||||
|
||||
// Group is a route group with a shared prefix and middleware stack.
|
||||
type Group struct {
|
||||
v *V
|
||||
prefix string
|
||||
middleware []Middleware
|
||||
}
|
||||
|
||||
// Use appends middleware to the global stack.
|
||||
// Global middleware runs before every page handler.
|
||||
func (v *V) Use(mw ...Middleware) {
|
||||
v.middleware = append(v.middleware, mw...)
|
||||
}
|
||||
|
||||
// Group creates a route group with the given path prefix and middleware.
|
||||
// Routes registered on the group are prefixed and run the group's middleware
|
||||
// after any global middleware.
|
||||
func (v *V) Group(prefix string, mw ...Middleware) *Group {
|
||||
return &Group{
|
||||
v: v,
|
||||
prefix: prefix,
|
||||
middleware: mw,
|
||||
}
|
||||
}
|
||||
|
||||
// Page registers a route on this group. The full route is the group prefix
|
||||
// concatenated with route.
|
||||
func (g *Group) Page(route string, initContextFn func(c *Context)) {
|
||||
fullRoute := g.prefix + route
|
||||
allMw := make([]Middleware, 0, len(g.v.middleware)+len(g.middleware))
|
||||
allMw = append(allMw, g.v.middleware...)
|
||||
allMw = append(allMw, g.middleware...)
|
||||
wrapped := chainMiddleware(allMw, initContextFn)
|
||||
g.v.page(fullRoute, initContextFn, wrapped)
|
||||
}
|
||||
|
||||
// Group creates a nested sub-group that inherits this group's prefix and
|
||||
// middleware, then adds its own.
|
||||
func (g *Group) Group(prefix string, mw ...Middleware) *Group {
|
||||
combined := make([]Middleware, len(g.middleware), len(g.middleware)+len(mw))
|
||||
copy(combined, g.middleware)
|
||||
combined = append(combined, mw...)
|
||||
return &Group{
|
||||
v: g.v,
|
||||
prefix: g.prefix + prefix,
|
||||
middleware: combined,
|
||||
}
|
||||
}
|
||||
|
||||
// Use appends middleware to this group's stack.
|
||||
func (g *Group) Use(mw ...Middleware) {
|
||||
g.middleware = append(g.middleware, mw...)
|
||||
}
|
||||
|
||||
// WithMiddleware returns an ActionOption that attaches middleware to an action.
|
||||
// Action middleware runs after CSRF/rate-limit checks and signal injection.
|
||||
func WithMiddleware(mw ...Middleware) ActionOption {
|
||||
return func(e *actionEntry) {
|
||||
e.middleware = append(e.middleware, mw...)
|
||||
}
|
||||
}
|
||||
|
||||
// chainMiddleware wraps handler with the given middleware, outer-first.
|
||||
func chainMiddleware(mws []Middleware, handler func(*Context)) func(*Context) {
|
||||
if len(mws) == 0 {
|
||||
return handler
|
||||
}
|
||||
chained := handler
|
||||
for i := len(mws) - 1; i >= 0; i-- {
|
||||
mw := mws[i]
|
||||
next := chained
|
||||
chained = func(c *Context) {
|
||||
mw(c, func() { next(c) })
|
||||
}
|
||||
}
|
||||
return chained
|
||||
}
|
||||
340
middleware_test.go
Normal file
340
middleware_test.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMiddlewareRunsBeforeHandler(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
v.Use(func(c *Context, next func()) {
|
||||
order = append(order, "mw")
|
||||
next()
|
||||
})
|
||||
v.Page("/", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
// Reset after registration (panic-check runs the raw handler)
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, []string{"mw", "handler"}, order)
|
||||
}
|
||||
|
||||
func TestMiddlewareAbortSkipsHandler(t *testing.T) {
|
||||
handlerCalled := false
|
||||
|
||||
v := New()
|
||||
v.Use(func(c *Context, next func()) {
|
||||
c.RedirectView("/other")
|
||||
})
|
||||
v.Page("/", func(c *Context) {
|
||||
handlerCalled = true
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
handlerCalled = false
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.False(t, handlerCalled)
|
||||
}
|
||||
|
||||
func TestMiddlewareChainOrder(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
for _, label := range []string{"A", "B", "C"} {
|
||||
l := label
|
||||
v.Use(func(c *Context, next func()) {
|
||||
order = append(order, l)
|
||||
next()
|
||||
})
|
||||
}
|
||||
v.Page("/", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
assert.Equal(t, []string{"A", "B", "C", "handler"}, order)
|
||||
}
|
||||
|
||||
func TestGroupPrefixRouting(t *testing.T) {
|
||||
v := New()
|
||||
g := v.Group("/admin")
|
||||
g.Page("/dashboard", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("admin dashboard")) })
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dashboard", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "admin dashboard")
|
||||
}
|
||||
|
||||
func TestGroupMiddlewareAppliesToGroupOnly(t *testing.T) {
|
||||
var groupMwCalled bool
|
||||
|
||||
v := New()
|
||||
g := v.Group("/admin", func(c *Context, next func()) {
|
||||
groupMwCalled = true
|
||||
next()
|
||||
})
|
||||
g.Page("/panel", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("panel")) })
|
||||
})
|
||||
v.Page("/public", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("public")) })
|
||||
})
|
||||
|
||||
// Hit public page — group middleware should NOT run
|
||||
groupMwCalled = false
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/public", nil))
|
||||
assert.False(t, groupMwCalled)
|
||||
assert.Contains(t, w.Body.String(), "public")
|
||||
|
||||
// Hit group page — group middleware should run
|
||||
groupMwCalled = false
|
||||
w = httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/panel", nil))
|
||||
assert.True(t, groupMwCalled)
|
||||
assert.Contains(t, w.Body.String(), "panel")
|
||||
}
|
||||
|
||||
func TestGlobalMiddlewareAppliesToGroupPages(t *testing.T) {
|
||||
var globalCalled bool
|
||||
|
||||
v := New()
|
||||
v.Use(func(c *Context, next func()) {
|
||||
globalCalled = true
|
||||
next()
|
||||
})
|
||||
g := v.Group("/admin")
|
||||
g.Page("/dash", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("dash")) })
|
||||
})
|
||||
|
||||
globalCalled = false
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dash", nil))
|
||||
|
||||
assert.True(t, globalCalled)
|
||||
assert.Contains(t, w.Body.String(), "dash")
|
||||
}
|
||||
|
||||
func TestNestedGroupInheritsPrefixAndMiddleware(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
admin := v.Group("/admin", func(c *Context, next func()) {
|
||||
order = append(order, "admin")
|
||||
next()
|
||||
})
|
||||
superAdmin := admin.Group("/super", func(c *Context, next func()) {
|
||||
order = append(order, "super")
|
||||
next()
|
||||
})
|
||||
superAdmin.Page("/secret", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div(h.Text("secret")) })
|
||||
})
|
||||
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/super/secret", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, []string{"admin", "super", "handler"}, order)
|
||||
assert.Contains(t, w.Body.String(), "secret")
|
||||
}
|
||||
|
||||
func TestGroupUse(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
g := v.Group("/api")
|
||||
g.Use(func(c *Context, next func()) {
|
||||
order = append(order, "added-later")
|
||||
next()
|
||||
})
|
||||
g.Page("/items", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/api/items", nil))
|
||||
|
||||
assert.Equal(t, []string{"added-later", "handler"}, order)
|
||||
}
|
||||
|
||||
func TestRedirectViewSetsValidView(t *testing.T) {
|
||||
v := New()
|
||||
v.Page("/test", func(c *Context) {
|
||||
c.RedirectView("/somewhere")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/test", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "<!doctype html>")
|
||||
}
|
||||
|
||||
func TestGlobalAndGroupMiddlewareOrder(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
v.Use(func(c *Context, next func()) {
|
||||
order = append(order, "global")
|
||||
next()
|
||||
})
|
||||
g := v.Group("/g", func(c *Context, next func()) {
|
||||
order = append(order, "group")
|
||||
next()
|
||||
})
|
||||
g.Page("/page", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/g/page", nil))
|
||||
|
||||
assert.Equal(t, []string{"global", "group", "handler"}, order)
|
||||
}
|
||||
|
||||
// --- Action middleware tests ---
|
||||
|
||||
func TestActionMiddlewareRunsBeforeAction(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
c := newContext("test", "/", v)
|
||||
|
||||
mw := func(_ *Context, next func()) {
|
||||
order = append(order, "mw")
|
||||
next()
|
||||
}
|
||||
|
||||
trigger := c.Action(func() {
|
||||
order = append(order, "action")
|
||||
}, WithMiddleware(mw))
|
||||
|
||||
entry, err := c.getAction(trigger.id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||
|
||||
assert.Equal(t, []string{"mw", "action"}, order)
|
||||
}
|
||||
|
||||
func TestActionMiddlewareAbortSkipsAction(t *testing.T) {
|
||||
actionCalled := false
|
||||
|
||||
v := New()
|
||||
c := newContext("test", "/", v)
|
||||
|
||||
mw := func(_ *Context, next func()) {
|
||||
// don't call next — action should not run
|
||||
}
|
||||
|
||||
trigger := c.Action(func() {
|
||||
actionCalled = true
|
||||
}, WithMiddleware(mw))
|
||||
|
||||
entry, err := c.getAction(trigger.id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||
|
||||
assert.False(t, actionCalled)
|
||||
}
|
||||
|
||||
func TestActionMiddlewareChainOrder(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
c := newContext("test", "/", v)
|
||||
|
||||
var mws []Middleware
|
||||
for _, label := range []string{"A", "B", "C"} {
|
||||
l := label
|
||||
mws = append(mws, func(_ *Context, next func()) {
|
||||
order = append(order, l)
|
||||
next()
|
||||
})
|
||||
}
|
||||
|
||||
trigger := c.Action(func() {
|
||||
order = append(order, "action")
|
||||
}, WithMiddleware(mws...))
|
||||
|
||||
entry, err := c.getAction(trigger.id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||
|
||||
assert.Equal(t, []string{"A", "B", "C", "action"}, order)
|
||||
}
|
||||
|
||||
func TestActionMiddlewareCombinedWithRateLimit(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test", "/", v)
|
||||
|
||||
mw := func(_ *Context, next func()) { next() }
|
||||
|
||||
trigger := c.Action(func() {}, WithRateLimit(5, 10), WithMiddleware(mw))
|
||||
|
||||
entry, err := c.getAction(trigger.id)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, entry.limiter)
|
||||
assert.Len(t, entry.middleware, 1)
|
||||
}
|
||||
|
||||
func TestGroupWithEmptyPrefix(t *testing.T) {
|
||||
var mwCalled bool
|
||||
|
||||
v := New()
|
||||
g := v.Group("", func(c *Context, next func()) {
|
||||
mwCalled = true
|
||||
next()
|
||||
})
|
||||
g.Page("/dashboard", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("dash")) })
|
||||
})
|
||||
|
||||
mwCalled = false
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/dashboard", nil))
|
||||
|
||||
assert.True(t, mwCalled)
|
||||
assert.Contains(t, w.Body.String(), "dash")
|
||||
}
|
||||
190
nats.go
Normal file
190
nats.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/delaneyj/toolbelt/embeddednats"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
// defaultNATS is the process-scoped embedded NATS server.
|
||||
type defaultNATS struct {
|
||||
server *embeddednats.Server
|
||||
nc *nats.Conn
|
||||
js nats.JetStreamContext
|
||||
cancel context.CancelFunc
|
||||
dataDir string
|
||||
}
|
||||
|
||||
var (
|
||||
sharedNATS *defaultNATS
|
||||
sharedNATSOnce sync.Once
|
||||
sharedNATSErr error
|
||||
)
|
||||
|
||||
// getSharedNATS returns a process-level singleton embedded NATS server.
|
||||
// The server starts once and is reused across all V instances.
|
||||
func getSharedNATS() (*defaultNATS, error) {
|
||||
sharedNATSOnce.Do(func() {
|
||||
sharedNATS, sharedNATSErr = startDefaultNATS()
|
||||
})
|
||||
return sharedNATS, sharedNATSErr
|
||||
}
|
||||
|
||||
func startDefaultNATS() (dn *defaultNATS, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("nats server panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
dataDir, err := os.MkdirTemp("", "via-nats-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
ns, err := embeddednats.New(ctx, embeddednats.WithDirectory(dataDir))
|
||||
if err != nil {
|
||||
cancel()
|
||||
os.RemoveAll(dataDir)
|
||||
return nil, fmt.Errorf("start embedded nats: %w", err)
|
||||
}
|
||||
ns.WaitForServer()
|
||||
|
||||
nc, err := ns.Client()
|
||||
if err != nil {
|
||||
ns.Close()
|
||||
cancel()
|
||||
os.RemoveAll(dataDir)
|
||||
return nil, fmt.Errorf("connect nats client: %w", err)
|
||||
}
|
||||
|
||||
js, err := nc.JetStream()
|
||||
if err != nil {
|
||||
nc.Close()
|
||||
ns.Close()
|
||||
cancel()
|
||||
os.RemoveAll(dataDir)
|
||||
return nil, fmt.Errorf("init jetstream: %w", err)
|
||||
}
|
||||
|
||||
return &defaultNATS{
|
||||
server: ns,
|
||||
nc: nc,
|
||||
js: js,
|
||||
cancel: cancel,
|
||||
dataDir: dataDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *defaultNATS) Publish(subject string, data []byte) error {
|
||||
return n.nc.Publish(subject, data)
|
||||
}
|
||||
|
||||
func (n *defaultNATS) Subscribe(subject string, handler func(data []byte)) (Subscription, error) {
|
||||
sub, err := n.nc.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// natsRef wraps a shared defaultNATS as a PubSub. Close is a no-op because
|
||||
// the underlying server is process-scoped and outlives individual V instances.
|
||||
type natsRef struct {
|
||||
dn *defaultNATS
|
||||
}
|
||||
|
||||
func (r *natsRef) Publish(subject string, data []byte) error {
|
||||
return r.dn.Publish(subject, data)
|
||||
}
|
||||
|
||||
func (r *natsRef) Subscribe(subject string, handler func(data []byte)) (Subscription, error) {
|
||||
return r.dn.Subscribe(subject, handler)
|
||||
}
|
||||
|
||||
func (r *natsRef) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NATSConn returns the underlying NATS connection from the built-in embedded
|
||||
// server, or nil if a custom PubSub backend is in use.
|
||||
func (v *V) NATSConn() *nats.Conn {
|
||||
if v.defaultNATS != nil {
|
||||
return v.defaultNATS.nc
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// JetStream returns the JetStream context from the built-in embedded server,
|
||||
// or nil if a custom PubSub backend is in use.
|
||||
func (v *V) JetStream() nats.JetStreamContext {
|
||||
if v.defaultNATS != nil {
|
||||
return v.defaultNATS.js
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamConfig holds the parameters for creating or updating a JetStream stream.
|
||||
type StreamConfig struct {
|
||||
Name string
|
||||
Subjects []string
|
||||
MaxMsgs int64
|
||||
MaxAge time.Duration
|
||||
}
|
||||
|
||||
// EnsureStream creates or updates a JetStream stream matching cfg.
|
||||
func EnsureStream(v *V, cfg StreamConfig) error {
|
||||
js := v.JetStream()
|
||||
if js == nil {
|
||||
return fmt.Errorf("jetstream not available")
|
||||
}
|
||||
_, err := js.AddStream(&nats.StreamConfig{
|
||||
Name: cfg.Name,
|
||||
Subjects: cfg.Subjects,
|
||||
Retention: nats.LimitsPolicy,
|
||||
MaxMsgs: cfg.MaxMsgs,
|
||||
MaxAge: cfg.MaxAge,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ReplayHistory fetches the last limit messages from subject,
|
||||
// deserializing each as T. Returns an empty slice if nothing is available.
|
||||
func ReplayHistory[T any](v *V, subject string, limit int) ([]T, error) {
|
||||
js := v.JetStream()
|
||||
if js == nil {
|
||||
return nil, fmt.Errorf("jetstream not available")
|
||||
}
|
||||
sub, err := js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
var msgs []T
|
||||
for {
|
||||
raw, err := sub.NextMsg(200 * time.Millisecond)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
var msg T
|
||||
if json.Unmarshal(raw.Data, &msg) == nil {
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 && len(msgs) > limit {
|
||||
msgs = msgs[len(msgs)-limit:]
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
108
nats_test.go
108
nats_test.go
@@ -2,7 +2,6 @@ package via
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -11,88 +10,36 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockHandler struct {
|
||||
id int64
|
||||
fn func([]byte)
|
||||
active atomic.Bool
|
||||
}
|
||||
|
||||
// mockPubSub implements PubSub for testing without NATS.
|
||||
type mockPubSub struct {
|
||||
mu sync.Mutex
|
||||
subs map[string][]*mockHandler
|
||||
nextID atomic.Int64
|
||||
}
|
||||
|
||||
func newMockPubSub() *mockPubSub {
|
||||
return &mockPubSub{subs: make(map[string][]*mockHandler)}
|
||||
}
|
||||
|
||||
func (m *mockPubSub) Publish(subject string, data []byte) error {
|
||||
m.mu.Lock()
|
||||
handlers := make([]*mockHandler, len(m.subs[subject]))
|
||||
copy(handlers, m.subs[subject])
|
||||
m.mu.Unlock()
|
||||
for _, h := range handlers {
|
||||
if h.active.Load() {
|
||||
h.fn(data)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPubSub) Subscribe(subject string, handler func(data []byte)) (Subscription, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
mh := &mockHandler{
|
||||
id: m.nextID.Add(1),
|
||||
fn: handler,
|
||||
}
|
||||
mh.active.Store(true)
|
||||
m.subs[subject] = append(m.subs[subject], mh)
|
||||
return &mockSub{handler: mh}, nil
|
||||
}
|
||||
|
||||
func (m *mockPubSub) Close() error { return nil }
|
||||
|
||||
type mockSub struct {
|
||||
handler *mockHandler
|
||||
}
|
||||
|
||||
func (s *mockSub) Unsubscribe() error {
|
||||
s.handler.active.Store(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPubSub_RoundTrip(t *testing.T) {
|
||||
ps := newMockPubSub()
|
||||
v := New()
|
||||
v.Config(Options{PubSub: ps})
|
||||
defer v.Shutdown()
|
||||
|
||||
var received []byte
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
done := make(chan struct{})
|
||||
|
||||
c := newContext("test-ctx", "/", v)
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
_, err := c.Subscribe("test.topic", func(data []byte) {
|
||||
received = data
|
||||
wg.Done()
|
||||
close(done)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.Publish("test.topic", []byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for message")
|
||||
}
|
||||
assert.Equal(t, []byte("hello"), received)
|
||||
}
|
||||
|
||||
func TestPubSub_MultipleSubscribers(t *testing.T) {
|
||||
ps := newMockPubSub()
|
||||
v := New()
|
||||
v.Config(Options{PubSub: ps})
|
||||
defer v.Shutdown()
|
||||
|
||||
var mu sync.Mutex
|
||||
var results []string
|
||||
@@ -119,7 +66,17 @@ func TestPubSub_MultipleSubscribers(t *testing.T) {
|
||||
})
|
||||
|
||||
c1.Publish("broadcast", []byte("msg"))
|
||||
wg.Wait()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for messages")
|
||||
}
|
||||
|
||||
assert.Len(t, results, 2)
|
||||
assert.Contains(t, results, "c1:msg")
|
||||
@@ -127,9 +84,8 @@ func TestPubSub_MultipleSubscribers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPubSub_SubscriptionCleanupOnDispose(t *testing.T) {
|
||||
ps := newMockPubSub()
|
||||
v := New()
|
||||
v.Config(Options{PubSub: ps})
|
||||
defer v.Shutdown()
|
||||
|
||||
c := newContext("cleanup-ctx", "/", v)
|
||||
c.View(func() h.H { return h.Div() })
|
||||
@@ -144,9 +100,8 @@ func TestPubSub_SubscriptionCleanupOnDispose(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPubSub_ManualUnsubscribe(t *testing.T) {
|
||||
ps := newMockPubSub()
|
||||
v := New()
|
||||
v.Config(Options{PubSub: ps})
|
||||
defer v.Shutdown()
|
||||
|
||||
c := newContext("unsub-ctx", "/", v)
|
||||
c.View(func() h.H { return h.Div() })
|
||||
@@ -160,28 +115,13 @@ func TestPubSub_ManualUnsubscribe(t *testing.T) {
|
||||
sub.Unsubscribe()
|
||||
|
||||
c.Publish("topic", []byte("ignored"))
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestPubSub_NoOpWhenNotConfigured(t *testing.T) {
|
||||
v := New()
|
||||
|
||||
c := newContext("noop-ctx", "/", v)
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
err := c.Publish("topic", []byte("data"))
|
||||
assert.Error(t, err)
|
||||
|
||||
sub, err := c.Subscribe("topic", func(data []byte) {})
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, sub)
|
||||
}
|
||||
|
||||
func TestPubSub_NoOpDuringPanicCheck(t *testing.T) {
|
||||
ps := newMockPubSub()
|
||||
v := New()
|
||||
v.Config(Options{PubSub: ps})
|
||||
defer v.Shutdown()
|
||||
|
||||
// Panic-check context has id=""
|
||||
c := newContext("", "/", v)
|
||||
|
||||
51
navigate.js
Normal file
51
navigate.js
Normal file
@@ -0,0 +1,51 @@
|
||||
(function() {
|
||||
const meta = document.querySelector('meta[data-signals]');
|
||||
if (!meta) return;
|
||||
const raw = meta.getAttribute('data-signals');
|
||||
const parsed = JSON.parse(raw.replace(/'/g, '"'));
|
||||
const ctxID = parsed['via-ctx'];
|
||||
const csrf = parsed['via-csrf'];
|
||||
if (!ctxID || !csrf) return;
|
||||
|
||||
function navigate(url, popstate) {
|
||||
const params = new URLSearchParams({
|
||||
'via-ctx': ctxID,
|
||||
'via-csrf': csrf,
|
||||
'url': url,
|
||||
});
|
||||
if (popstate) params.set('popstate', '1');
|
||||
fetch('/_navigate', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/x-www-form-urlencoded'},
|
||||
body: params.toString()
|
||||
}).then(function(res) {
|
||||
if (!res.ok) window.location.href = url;
|
||||
}).catch(function() {
|
||||
window.location.href = url;
|
||||
});
|
||||
}
|
||||
|
||||
document.addEventListener('click', function(e) {
|
||||
var el = e.target;
|
||||
while (el && el.tagName !== 'A') el = el.parentElement;
|
||||
if (!el) return;
|
||||
if (e.ctrlKey || e.metaKey || e.shiftKey || e.altKey) return;
|
||||
if (el.hasAttribute('target')) return;
|
||||
if (el.hasAttribute('data-via-no-boost')) return;
|
||||
var href = el.getAttribute('href');
|
||||
if (!href || href.startsWith('#')) return;
|
||||
try {
|
||||
var url = new URL(href, window.location.origin);
|
||||
if (url.origin !== window.location.origin) return;
|
||||
e.preventDefault();
|
||||
navigate(url.pathname + url.search + url.hash);
|
||||
} catch(_) {}
|
||||
});
|
||||
|
||||
var ready = false;
|
||||
window.addEventListener('popstate', function() {
|
||||
if (!ready) return;
|
||||
navigate(window.location.pathname + window.location.search + window.location.hash, true);
|
||||
});
|
||||
setTimeout(function() { ready = true; }, 0);
|
||||
})();
|
||||
@@ -1,7 +1,8 @@
|
||||
package via
|
||||
|
||||
// PubSub is an interface for publish/subscribe messaging backends.
|
||||
// The vianats sub-package provides an embedded NATS implementation.
|
||||
// By default, New() starts an embedded NATS server. Supply a custom
|
||||
// implementation via Config(Options{PubSub: yourBackend}) to override.
|
||||
type PubSub interface {
|
||||
Publish(subject string, data []byte) error
|
||||
Subscribe(subject string, handler func(data []byte)) (Subscription, error)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -10,9 +10,8 @@ import (
|
||||
)
|
||||
|
||||
func TestPublishSubscribe_RoundTrip(t *testing.T) {
|
||||
ps := newMockPubSub()
|
||||
v := New()
|
||||
v.Config(Options{PubSub: ps})
|
||||
defer v.Shutdown()
|
||||
|
||||
type event struct {
|
||||
Name string `json:"name"`
|
||||
@@ -20,30 +19,32 @@ func TestPublishSubscribe_RoundTrip(t *testing.T) {
|
||||
}
|
||||
|
||||
var got event
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
done := make(chan struct{})
|
||||
|
||||
c := newContext("typed-ctx", "/", v)
|
||||
c.View(func() h.H { return h.Div() })
|
||||
|
||||
_, err := Subscribe(c, "events", func(e event) {
|
||||
got = e
|
||||
wg.Done()
|
||||
close(done)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = Publish(c, "events", event{Name: "click", Count: 42})
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for message")
|
||||
}
|
||||
assert.Equal(t, "click", got.Name)
|
||||
assert.Equal(t, 42, got.Count)
|
||||
}
|
||||
|
||||
func TestSubscribe_SkipsBadJSON(t *testing.T) {
|
||||
ps := newMockPubSub()
|
||||
v := New()
|
||||
v.Config(Options{PubSub: ps})
|
||||
defer v.Shutdown()
|
||||
|
||||
type msg struct {
|
||||
Text string `json:"text"`
|
||||
@@ -62,5 +63,6 @@ func TestSubscribe_SkipsBadJSON(t *testing.T) {
|
||||
err = c.Publish("topic", []byte("not json"))
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
49
ratelimit.go
Normal file
49
ratelimit.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package via
|
||||
|
||||
import "golang.org/x/time/rate"
|
||||
|
||||
const (
|
||||
defaultActionRate float64 = 10.0
|
||||
defaultActionBurst int = 20
|
||||
)
|
||||
|
||||
// RateLimitConfig configures token-bucket rate limiting for actions.
|
||||
// Zero values fall back to defaults. Rate of -1 disables limiting entirely.
|
||||
type RateLimitConfig struct {
|
||||
Rate float64
|
||||
Burst int
|
||||
}
|
||||
|
||||
// ActionOption configures per-action behaviour when passed to Context.Action.
|
||||
type ActionOption func(*actionEntry)
|
||||
|
||||
type actionEntry struct {
|
||||
fn func()
|
||||
limiter *rate.Limiter // nil = use context default
|
||||
middleware []Middleware
|
||||
}
|
||||
|
||||
// WithRateLimit returns an ActionOption that gives this action its own
|
||||
// token-bucket limiter, overriding the context-level default.
|
||||
func WithRateLimit(r float64, burst int) ActionOption {
|
||||
return func(e *actionEntry) {
|
||||
e.limiter = newLimiter(RateLimitConfig{Rate: r, Burst: burst}, defaultActionRate, defaultActionBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// newLimiter creates a *rate.Limiter from cfg, substituting defaults for zero
|
||||
// values. A Rate of -1 disables limiting (returns nil).
|
||||
func newLimiter(cfg RateLimitConfig, defaultRate float64, defaultBurst int) *rate.Limiter {
|
||||
r := cfg.Rate
|
||||
b := cfg.Burst
|
||||
if r == -1 {
|
||||
return nil
|
||||
}
|
||||
if r == 0 {
|
||||
r = defaultRate
|
||||
}
|
||||
if b == 0 {
|
||||
b = defaultBurst
|
||||
}
|
||||
return rate.NewLimiter(rate.Limit(r), b)
|
||||
}
|
||||
101
ratelimit_test.go
Normal file
101
ratelimit_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewLimiter_Defaults(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{}, defaultActionRate, defaultActionBurst)
|
||||
require.NotNil(t, l)
|
||||
assert.InDelta(t, defaultActionRate, float64(l.Limit()), 0.001)
|
||||
assert.Equal(t, defaultActionBurst, l.Burst())
|
||||
}
|
||||
|
||||
func TestNewLimiter_CustomValues(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: 5, Burst: 10}, defaultActionRate, defaultActionBurst)
|
||||
require.NotNil(t, l)
|
||||
assert.InDelta(t, 5.0, float64(l.Limit()), 0.001)
|
||||
assert.Equal(t, 10, l.Burst())
|
||||
}
|
||||
|
||||
func TestNewLimiter_DisabledWithNegativeRate(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: -1}, defaultActionRate, defaultActionBurst)
|
||||
assert.Nil(t, l)
|
||||
}
|
||||
|
||||
func TestTokenBucket_AllowsBurstThenRejects(t *testing.T) {
|
||||
l := newLimiter(RateLimitConfig{Rate: 1, Burst: 3}, 1, 3)
|
||||
require.NotNil(t, l)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
assert.True(t, l.Allow(), "request %d should be allowed within burst", i)
|
||||
}
|
||||
assert.False(t, l.Allow(), "request beyond burst should be rejected")
|
||||
}
|
||||
|
||||
func TestWithRateLimit_CreatesLimiter(t *testing.T) {
|
||||
entry := actionEntry{fn: func() {}}
|
||||
opt := WithRateLimit(2, 4)
|
||||
opt(&entry)
|
||||
|
||||
require.NotNil(t, entry.limiter)
|
||||
assert.InDelta(t, 2.0, float64(entry.limiter.Limit()), 0.001)
|
||||
assert.Equal(t, 4, entry.limiter.Burst())
|
||||
}
|
||||
|
||||
func TestContextAction_WithRateLimit(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-rl", "/", v)
|
||||
|
||||
called := false
|
||||
c.Action(func() { called = true }, WithRateLimit(1, 2))
|
||||
|
||||
// Verify the entry has its own limiter
|
||||
for _, entry := range c.actionRegistry {
|
||||
require.NotNil(t, entry.limiter)
|
||||
assert.InDelta(t, 1.0, float64(entry.limiter.Limit()), 0.001)
|
||||
assert.Equal(t, 2, entry.limiter.Burst())
|
||||
}
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestContextAction_DefaultNoPerActionLimiter(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-no-rl", "/", v)
|
||||
|
||||
c.Action(func() {})
|
||||
|
||||
for _, entry := range c.actionRegistry {
|
||||
assert.Nil(t, entry.limiter, "entry without WithRateLimit should have nil limiter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextLimiter_DefaultsApplied(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test-ctx-limiter", "/", v)
|
||||
|
||||
require.NotNil(t, c.actionLimiter)
|
||||
assert.InDelta(t, defaultActionRate, float64(c.actionLimiter.Limit()), 0.001)
|
||||
assert.Equal(t, defaultActionBurst, c.actionLimiter.Burst())
|
||||
}
|
||||
|
||||
func TestContextLimiter_DisabledViaConfig(t *testing.T) {
|
||||
v := New()
|
||||
v.actionRateLimit = RateLimitConfig{Rate: -1}
|
||||
c := newContext("test-disabled", "/", v)
|
||||
|
||||
assert.Nil(t, c.actionLimiter)
|
||||
}
|
||||
|
||||
func TestContextLimiter_CustomConfig(t *testing.T) {
|
||||
v := New()
|
||||
v.Config(Options{ActionRateLimit: RateLimitConfig{Rate: 50, Burst: 100}})
|
||||
c := newContext("test-custom", "/", v)
|
||||
|
||||
require.NotNil(t, c.actionLimiter)
|
||||
assert.InDelta(t, 50.0, float64(c.actionLimiter.Limit()), 0.001)
|
||||
assert.Equal(t, 100, c.actionLimiter.Burst())
|
||||
}
|
||||
74
routine.go
74
routine.go
@@ -1,76 +1,34 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OnIntervalRoutine allows for defining concurrent goroutines safely. Goroutines started by *OnIntervalRoutine
|
||||
// are tied to the *Context lifecycle.
|
||||
type OnIntervalRoutine struct {
|
||||
mu sync.RWMutex
|
||||
ctxDisposed chan struct{}
|
||||
localInterrupt chan struct{}
|
||||
isRunning atomic.Bool
|
||||
routineFn func()
|
||||
tckDuration time.Duration
|
||||
updateTkrChan chan time.Duration
|
||||
}
|
||||
func newOnInterval(ctxDisposedChan, pageStopChan chan struct{}, duration time.Duration, handler func()) func() {
|
||||
localInterrupt := make(chan struct{})
|
||||
var stopped atomic.Bool
|
||||
|
||||
// UpdateInterval sets a new interval duration for the internal *time.Ticker. If the provided
|
||||
// duration is equal of less than 0, UpdateInterval does nothing.
|
||||
func (r *OnIntervalRoutine) UpdateInterval(d time.Duration) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tckDuration = d
|
||||
r.updateTkrChan <- d
|
||||
|
||||
}
|
||||
|
||||
// Start executes the predifined goroutine. If no predifined goroutine exists, or it already
|
||||
// started, Start does nothing.
|
||||
func (r *OnIntervalRoutine) Start() {
|
||||
if !r.isRunning.CompareAndSwap(false, true) || r.routineFn == nil {
|
||||
return
|
||||
}
|
||||
go r.routineFn()
|
||||
}
|
||||
|
||||
// Stop interrupts the predifined goroutine. If no predifined goroutine exists, or it already
|
||||
// ustopped, Stop does nothing.
|
||||
func (r *OnIntervalRoutine) Stop() {
|
||||
if !r.isRunning.CompareAndSwap(true, false) || r.routineFn == nil {
|
||||
return
|
||||
}
|
||||
r.localInterrupt <- struct{}{}
|
||||
}
|
||||
|
||||
func newOnIntervalRoutine(ctxDisposedChan chan struct{},
|
||||
duration time.Duration, handler func()) *OnIntervalRoutine {
|
||||
r := &OnIntervalRoutine{
|
||||
ctxDisposed: ctxDisposedChan,
|
||||
localInterrupt: make(chan struct{}),
|
||||
updateTkrChan: make(chan time.Duration),
|
||||
}
|
||||
r.tckDuration = duration
|
||||
r.routineFn = func() {
|
||||
r.mu.RLock()
|
||||
tkr := time.NewTicker(r.tckDuration)
|
||||
r.mu.RUnlock()
|
||||
defer tkr.Stop() // clean up the ticker when routine stops
|
||||
go func() {
|
||||
tkr := time.NewTicker(duration)
|
||||
defer tkr.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-r.ctxDisposed: // dispose of the routine when ctx is disposed
|
||||
case <-ctxDisposedChan:
|
||||
return
|
||||
case <-r.localInterrupt: // dispose of the routine on interrupt signal
|
||||
case <-pageStopChan:
|
||||
return
|
||||
case <-localInterrupt:
|
||||
return
|
||||
case d := <-r.updateTkrChan:
|
||||
tkr.Reset(d)
|
||||
case <-tkr.C:
|
||||
handler()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return func() {
|
||||
if stopped.CompareAndSwap(false, true) {
|
||||
close(localInterrupt)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
130
rule.go
Normal file
130
rule.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Rule defines a single validation check for a Field.
|
||||
type Rule struct {
|
||||
validate func(val string) error
|
||||
}
|
||||
|
||||
// Required rejects empty or whitespace-only values.
|
||||
func Required(msg ...string) Rule {
|
||||
m := "This field is required"
|
||||
if len(msg) > 0 {
|
||||
m = msg[0]
|
||||
}
|
||||
return Rule{func(val string) error {
|
||||
if strings.TrimSpace(val) == "" {
|
||||
return errors.New(m)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
||||
|
||||
// MinLen rejects values shorter than n characters.
|
||||
func MinLen(n int, msg ...string) Rule {
|
||||
m := fmt.Sprintf("Must be at least %d characters", n)
|
||||
if len(msg) > 0 {
|
||||
m = msg[0]
|
||||
}
|
||||
return Rule{func(val string) error {
|
||||
if utf8.RuneCountInString(val) < n {
|
||||
return errors.New(m)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
||||
|
||||
// MaxLen rejects values longer than n characters.
|
||||
func MaxLen(n int, msg ...string) Rule {
|
||||
m := fmt.Sprintf("Must be at most %d characters", n)
|
||||
if len(msg) > 0 {
|
||||
m = msg[0]
|
||||
}
|
||||
return Rule{func(val string) error {
|
||||
if utf8.RuneCountInString(val) > n {
|
||||
return errors.New(m)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
||||
|
||||
// Min parses the value as an integer and rejects values less than n.
|
||||
func Min(n int, msg ...string) Rule {
|
||||
m := fmt.Sprintf("Must be at least %d", n)
|
||||
if len(msg) > 0 {
|
||||
m = msg[0]
|
||||
}
|
||||
return Rule{func(val string) error {
|
||||
v, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return errors.New("Must be a valid number")
|
||||
}
|
||||
if v < n {
|
||||
return errors.New(m)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
||||
|
||||
// Max parses the value as an integer and rejects values greater than n.
|
||||
func Max(n int, msg ...string) Rule {
|
||||
m := fmt.Sprintf("Must be at most %d", n)
|
||||
if len(msg) > 0 {
|
||||
m = msg[0]
|
||||
}
|
||||
return Rule{func(val string) error {
|
||||
v, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return errors.New("Must be a valid number")
|
||||
}
|
||||
if v > n {
|
||||
return errors.New(m)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
||||
|
||||
// Pattern rejects values that don't match the regular expression re.
|
||||
func Pattern(re string, msg ...string) Rule {
|
||||
m := "Invalid format"
|
||||
if len(msg) > 0 {
|
||||
m = msg[0]
|
||||
}
|
||||
compiled := regexp.MustCompile(re)
|
||||
return Rule{func(val string) error {
|
||||
if !compiled.MatchString(val) {
|
||||
return errors.New(m)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
||||
|
||||
var emailRegexp = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
||||
|
||||
// Email rejects values that don't look like an email address.
|
||||
func Email(msg ...string) Rule {
|
||||
m := "Invalid email address"
|
||||
if len(msg) > 0 {
|
||||
m = msg[0]
|
||||
}
|
||||
return Rule{func(val string) error {
|
||||
if !emailRegexp.MatchString(val) {
|
||||
return errors.New(m)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
||||
|
||||
// Custom creates a rule from a user-provided validation function.
|
||||
// The function should return nil for valid input and an error for invalid input.
|
||||
func Custom(fn func(string) error) Rule {
|
||||
return Rule{validate: fn}
|
||||
}
|
||||
116
rule_test.go
Normal file
116
rule_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRequired(t *testing.T) {
|
||||
r := Required()
|
||||
assert.NoError(t, r.validate("hello"))
|
||||
assert.Error(t, r.validate(""))
|
||||
assert.Error(t, r.validate(" "))
|
||||
}
|
||||
|
||||
func TestRequiredCustomMessage(t *testing.T) {
|
||||
r := Required("name needed")
|
||||
err := r.validate("")
|
||||
assert.EqualError(t, err, "name needed")
|
||||
}
|
||||
|
||||
func TestMinLen(t *testing.T) {
|
||||
r := MinLen(3)
|
||||
assert.NoError(t, r.validate("abc"))
|
||||
assert.NoError(t, r.validate("abcd"))
|
||||
assert.Error(t, r.validate("ab"))
|
||||
assert.Error(t, r.validate(""))
|
||||
}
|
||||
|
||||
func TestMinLenCustomMessage(t *testing.T) {
|
||||
r := MinLen(5, "too short")
|
||||
err := r.validate("ab")
|
||||
assert.EqualError(t, err, "too short")
|
||||
}
|
||||
|
||||
func TestMaxLen(t *testing.T) {
|
||||
r := MaxLen(5)
|
||||
assert.NoError(t, r.validate("abc"))
|
||||
assert.NoError(t, r.validate("abcde"))
|
||||
assert.Error(t, r.validate("abcdef"))
|
||||
}
|
||||
|
||||
func TestMaxLenCustomMessage(t *testing.T) {
|
||||
r := MaxLen(2, "too long")
|
||||
err := r.validate("abc")
|
||||
assert.EqualError(t, err, "too long")
|
||||
}
|
||||
|
||||
func TestMin(t *testing.T) {
|
||||
r := Min(5)
|
||||
assert.NoError(t, r.validate("5"))
|
||||
assert.NoError(t, r.validate("10"))
|
||||
assert.Error(t, r.validate("4"))
|
||||
assert.Error(t, r.validate("abc"))
|
||||
}
|
||||
|
||||
func TestMinCustomMessage(t *testing.T) {
|
||||
r := Min(10, "need 10+")
|
||||
err := r.validate("3")
|
||||
assert.EqualError(t, err, "need 10+")
|
||||
}
|
||||
|
||||
func TestMax(t *testing.T) {
|
||||
r := Max(10)
|
||||
assert.NoError(t, r.validate("10"))
|
||||
assert.NoError(t, r.validate("5"))
|
||||
assert.Error(t, r.validate("11"))
|
||||
assert.Error(t, r.validate("abc"))
|
||||
}
|
||||
|
||||
func TestMaxCustomMessage(t *testing.T) {
|
||||
r := Max(5, "too big")
|
||||
err := r.validate("6")
|
||||
assert.EqualError(t, err, "too big")
|
||||
}
|
||||
|
||||
func TestPattern(t *testing.T) {
|
||||
r := Pattern(`^\d{3}$`)
|
||||
assert.NoError(t, r.validate("123"))
|
||||
assert.Error(t, r.validate("12"))
|
||||
assert.Error(t, r.validate("abcd"))
|
||||
}
|
||||
|
||||
func TestPatternCustomMessage(t *testing.T) {
|
||||
r := Pattern(`^\d+$`, "digits only")
|
||||
err := r.validate("abc")
|
||||
assert.EqualError(t, err, "digits only")
|
||||
}
|
||||
|
||||
func TestEmail(t *testing.T) {
|
||||
r := Email()
|
||||
assert.NoError(t, r.validate("user@example.com"))
|
||||
assert.NoError(t, r.validate("a.b+c@foo.co"))
|
||||
assert.Error(t, r.validate("notanemail"))
|
||||
assert.Error(t, r.validate("@example.com"))
|
||||
assert.Error(t, r.validate("user@"))
|
||||
assert.Error(t, r.validate(""))
|
||||
}
|
||||
|
||||
func TestEmailCustomMessage(t *testing.T) {
|
||||
r := Email("bad email")
|
||||
err := r.validate("nope")
|
||||
assert.EqualError(t, err, "bad email")
|
||||
}
|
||||
|
||||
func TestCustom(t *testing.T) {
|
||||
r := Custom(func(val string) error {
|
||||
if val != "magic" {
|
||||
return fmt.Errorf("must be magic")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, r.validate("magic"))
|
||||
assert.EqualError(t, r.validate("other"), "must be magic")
|
||||
}
|
||||
23
signal.go
23
signal.go
@@ -81,26 +81,3 @@ func (s *signal) Int() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Int64 tries to read the signal value as an int64.
|
||||
// Returns the value or 0 on failure.
|
||||
func (s *signal) Int64() int64 {
|
||||
if n, err := strconv.ParseInt(s.String(), 10, 64); err == nil {
|
||||
return n
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Float64 tries to read the signal value as a float64.
|
||||
// Returns the value or 0.0 on failure.
|
||||
func (s *signal) Float() float64 {
|
||||
if n, err := strconv.ParseFloat(s.String(), 64); err == nil {
|
||||
return n
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// Bytes tries to read the signal value as a []byte
|
||||
// Returns the value or an empty []byte on failure.
|
||||
func (s *signal) Bytes() []byte {
|
||||
return []byte(s.String())
|
||||
}
|
||||
|
||||
143
static_test.go
Normal file
143
static_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStatic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.MkdirAll(filepath.Join(dir, "sub"), 0755)
|
||||
os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello world"), 0644)
|
||||
os.WriteFile(filepath.Join(dir, "sub", "nested.txt"), []byte("nested"), 0644)
|
||||
|
||||
v := New()
|
||||
v.Static("/assets/", dir)
|
||||
|
||||
t.Run("serves file", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/assets/hello.txt", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "hello world", w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("serves nested file", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/assets/sub/nested.txt", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "nested", w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("directory listing returns 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/assets/", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
})
|
||||
|
||||
t.Run("subdirectory listing returns 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/assets/sub/", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
})
|
||||
|
||||
t.Run("missing file returns 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/assets/nope.txt", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStaticAutoSlash(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "ok.txt"), []byte("ok"), 0644)
|
||||
|
||||
v := New()
|
||||
v.Static("/files", dir) // no trailing slash
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/files/ok.txt", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "ok", w.Body.String())
|
||||
}
|
||||
|
||||
func TestStaticFS(t *testing.T) {
|
||||
fsys := fstest.MapFS{
|
||||
"style.css": {Data: []byte("body{}")},
|
||||
"js/app.js": {Data: []byte("console.log('hi')")},
|
||||
}
|
||||
|
||||
v := New()
|
||||
v.StaticFS("/static/", fsys)
|
||||
|
||||
t.Run("serves file", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/static/style.css", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "body{}", w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("serves nested file", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/static/js/app.js", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "console.log('hi')", w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("directory listing returns 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/static/", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
})
|
||||
|
||||
t.Run("missing file returns 404", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/static/nope.css", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStaticFSAutoSlash(t *testing.T) {
|
||||
fsys := fstest.MapFS{
|
||||
"ok.txt": {Data: []byte("ok")},
|
||||
}
|
||||
|
||||
v := New()
|
||||
v.StaticFS("/embed", fsys) // no trailing slash
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/embed/ok.txt", nil)
|
||||
v.mux.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "ok", w.Body.String())
|
||||
}
|
||||
|
||||
// Verify StaticFS accepts the fs.FS interface (compile-time check).
|
||||
var _ fs.FS = fstest.MapFS{}
|
||||
308
via.go
308
via.go
@@ -9,11 +9,13 @@ package via
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
_ "embed"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -33,23 +35,32 @@ import (
|
||||
//go:embed datastar.js
|
||||
var datastarJS []byte
|
||||
|
||||
//go:embed navigate.js
|
||||
var navigateJS []byte
|
||||
|
||||
// V is the root application.
|
||||
// It manages page routing, user sessions, and SSE connections for live updates.
|
||||
type V struct {
|
||||
cfg Options
|
||||
mux *http.ServeMux
|
||||
server *http.Server
|
||||
logger zerolog.Logger
|
||||
contextRegistry map[string]*Context
|
||||
contextRegistryMutex sync.RWMutex
|
||||
documentHeadIncludes []h.H
|
||||
documentFootIncludes []h.H
|
||||
devModePageInitFnMap map[string]func(*Context)
|
||||
sessionManager *scs.SessionManager
|
||||
pubsub PubSub
|
||||
datastarPath string
|
||||
datastarContent []byte
|
||||
datastarOnce sync.Once
|
||||
cfg Options
|
||||
mux *http.ServeMux
|
||||
server *http.Server
|
||||
logger zerolog.Logger
|
||||
contextRegistry map[string]*Context
|
||||
contextRegistryMutex sync.RWMutex
|
||||
documentHeadIncludes []h.H
|
||||
documentFootIncludes []h.H
|
||||
devModePageInitFnMap map[string]func(*Context)
|
||||
pageRegistry map[string]func(*Context)
|
||||
sessionManager *scs.SessionManager
|
||||
pubsub PubSub
|
||||
defaultNATS *defaultNATS
|
||||
actionRateLimit RateLimitConfig
|
||||
datastarPath string
|
||||
datastarContent []byte
|
||||
datastarOnce sync.Once
|
||||
reaperStop chan struct{}
|
||||
middleware []Middleware
|
||||
layout func(func() h.H) h.H
|
||||
}
|
||||
|
||||
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
|
||||
@@ -125,8 +136,15 @@ func (v *V) Config(cfg Options) {
|
||||
v.datastarPath = cfg.DatastarPath
|
||||
}
|
||||
if cfg.PubSub != nil {
|
||||
v.defaultNATS = nil
|
||||
v.pubsub = cfg.PubSub
|
||||
}
|
||||
if cfg.ContextTTL != 0 {
|
||||
v.cfg.ContextTTL = cfg.ContextTTL
|
||||
}
|
||||
if cfg.ActionRateLimit.Rate != 0 || cfg.ActionRateLimit.Burst != 0 {
|
||||
v.actionRateLimit = cfg.ActionRateLimit
|
||||
}
|
||||
}
|
||||
|
||||
// AppendToHead appends the given h.H nodes to the head of the base HTML document.
|
||||
@@ -160,8 +178,16 @@ func (v *V) AppendToFoot(elements ...h.H) {
|
||||
// })
|
||||
// })
|
||||
func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
wrapped := chainMiddleware(v.middleware, initContextFn)
|
||||
v.page(route, initContextFn, wrapped)
|
||||
}
|
||||
|
||||
// page registers a route with separate raw and wrapped init functions.
|
||||
// raw is used for the panic-check at registration time; wrapped includes
|
||||
// any middleware and is used as the live handler.
|
||||
func (v *V) page(route string, raw, wrapped func(*Context)) {
|
||||
v.ensureDatastarHandler()
|
||||
// check for panics
|
||||
// check for panics using the raw handler (no middleware)
|
||||
func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -170,14 +196,14 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
}
|
||||
}()
|
||||
c := newContext("", "", v)
|
||||
initContextFn(c)
|
||||
raw(c)
|
||||
c.view()
|
||||
c.stopAllRoutines()
|
||||
}()
|
||||
|
||||
// save page init function allows devmode to restore persisted ctx later
|
||||
v.pageRegistry[route] = wrapped
|
||||
if v.cfg.DevMode {
|
||||
v.devModePageInitFnMap[route] = initContextFn
|
||||
v.devModePageInitFnMap[route] = wrapped
|
||||
}
|
||||
v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
v.logDebug(nil, "GET %s", r.URL.String())
|
||||
@@ -191,7 +217,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
c.reqCtx = r.Context()
|
||||
routeParams := extractParams(route, r.URL.Path)
|
||||
c.injectRouteParams(routeParams)
|
||||
initContextFn(c)
|
||||
wrapped(c)
|
||||
v.registerCtx(c)
|
||||
if v.cfg.DevMode {
|
||||
v.devModePersist(c)
|
||||
@@ -199,10 +225,12 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
headElements := []h.H{h.Script(h.Type("module"), h.Src(v.datastarPath))}
|
||||
headElements = append(headElements, v.documentHeadIncludes...)
|
||||
headElements = append(headElements,
|
||||
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s'}", id))),
|
||||
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s','via-csrf':'%s'}", id, c.csrfToken))),
|
||||
h.Meta(h.Data("init", "@get('/_sse')")),
|
||||
h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => {
|
||||
navigator.sendBeacon('/_session/close', '%s');});`, c.id))),
|
||||
h.Meta(h.Attr("name", "view-transition"), h.Attr("content", "same-origin")),
|
||||
h.Script(h.Raw(string(navigateJS))),
|
||||
)
|
||||
|
||||
bodyElements := []h.H{c.view()}
|
||||
@@ -216,8 +244,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
Title: v.cfg.DocumentTitle,
|
||||
Head: headElements,
|
||||
Body: bodyElements,
|
||||
HTMLAttrs: []h.H{},
|
||||
})
|
||||
})
|
||||
_ = view.Render(w)
|
||||
}))
|
||||
}
|
||||
@@ -225,17 +252,17 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
func (v *V) registerCtx(c *Context) {
|
||||
v.contextRegistryMutex.Lock()
|
||||
defer v.contextRegistryMutex.Unlock()
|
||||
if c == nil {
|
||||
v.logErr(c, "failed to add nil context to registry")
|
||||
return
|
||||
}
|
||||
v.contextRegistry[c.id] = c
|
||||
v.logDebug(c, "new context added to registry")
|
||||
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
|
||||
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
|
||||
}
|
||||
|
||||
func (v *V) currSessionNum() int {
|
||||
return len(v.contextRegistry)
|
||||
func (v *V) cleanupCtx(c *Context) {
|
||||
c.dispose()
|
||||
if v.cfg.DevMode {
|
||||
v.devModeRemovePersisted(c)
|
||||
}
|
||||
v.unregisterCtx(c)
|
||||
}
|
||||
|
||||
func (v *V) unregisterCtx(c *Context) {
|
||||
@@ -247,7 +274,7 @@ func (v *V) unregisterCtx(c *Context) {
|
||||
defer v.contextRegistryMutex.Unlock()
|
||||
v.logDebug(c, "ctx removed from registry")
|
||||
delete(v.contextRegistry, c.id)
|
||||
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
|
||||
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
|
||||
}
|
||||
|
||||
func (v *V) getCtx(id string) (*Context, error) {
|
||||
@@ -259,6 +286,50 @@ func (v *V) getCtx(id string) (*Context, error) {
|
||||
return nil, fmt.Errorf("ctx '%s' not found", id)
|
||||
}
|
||||
|
||||
func (v *V) startReaper() {
|
||||
ttl := v.cfg.ContextTTL
|
||||
if ttl < 0 {
|
||||
return
|
||||
}
|
||||
if ttl == 0 {
|
||||
ttl = 30 * time.Second
|
||||
}
|
||||
interval := ttl / 3
|
||||
if interval < 5*time.Second {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
v.reaperStop = make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-v.reaperStop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
v.reapOrphanedContexts(ttl)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (v *V) reapOrphanedContexts(ttl time.Duration) {
|
||||
now := time.Now()
|
||||
v.contextRegistryMutex.RLock()
|
||||
var orphans []*Context
|
||||
for _, c := range v.contextRegistry {
|
||||
if !c.sseConnected.Load() && now.Sub(c.createdAt) > ttl {
|
||||
orphans = append(orphans, c)
|
||||
}
|
||||
}
|
||||
v.contextRegistryMutex.RUnlock()
|
||||
|
||||
for _, c := range orphans {
|
||||
v.logInfo(c, "reaping orphaned context (no SSE connection after %s)", ttl)
|
||||
v.cleanupCtx(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the Via HTTP server and blocks until a SIGINT or SIGTERM
|
||||
// signal is received, then performs a graceful shutdown.
|
||||
func (v *V) Start() {
|
||||
@@ -271,6 +342,8 @@ func (v *V) Start() {
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
v.startReaper()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- v.server.ListenAndServe()
|
||||
@@ -291,16 +364,15 @@ func (v *V) Start() {
|
||||
return
|
||||
}
|
||||
|
||||
v.shutdown()
|
||||
v.Shutdown()
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server and all contexts.
|
||||
// Safe for programmatic or test use.
|
||||
func (v *V) Shutdown() {
|
||||
v.shutdown()
|
||||
}
|
||||
|
||||
func (v *V) shutdown() {
|
||||
if v.reaperStop != nil {
|
||||
close(v.reaperStop)
|
||||
}
|
||||
v.logInfo(nil, "draining all contexts")
|
||||
v.drainAllContexts()
|
||||
|
||||
@@ -317,6 +389,7 @@ func (v *V) shutdown() {
|
||||
v.logErr(nil, "pubsub close error: %v", err)
|
||||
}
|
||||
}
|
||||
v.defaultNATS = nil
|
||||
|
||||
v.logInfo(nil, "shutdown complete")
|
||||
}
|
||||
@@ -346,6 +419,51 @@ func (v *V) HTTPServeMux() *http.ServeMux {
|
||||
return v.mux
|
||||
}
|
||||
|
||||
// PubSub returns the configured PubSub backend, or nil if none is set.
|
||||
func (v *V) PubSub() PubSub {
|
||||
return v.pubsub
|
||||
}
|
||||
|
||||
// Static serves files from a filesystem directory at the given URL prefix.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// v.Static("/assets/", "./public")
|
||||
func (v *V) Static(urlPrefix, dir string) {
|
||||
if !strings.HasSuffix(urlPrefix, "/") {
|
||||
urlPrefix += "/"
|
||||
}
|
||||
fileServer := http.StripPrefix(urlPrefix, http.FileServer(http.Dir(dir)))
|
||||
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
|
||||
}
|
||||
|
||||
// StaticFS serves files from an [fs.FS] at the given URL prefix.
|
||||
// This is useful with //go:embed filesystems.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// //go:embed static
|
||||
// var staticFiles embed.FS
|
||||
// v.StaticFS("/assets/", staticFiles)
|
||||
func (v *V) StaticFS(urlPrefix string, fsys fs.FS) {
|
||||
if !strings.HasSuffix(urlPrefix, "/") {
|
||||
urlPrefix += "/"
|
||||
}
|
||||
fileServer := http.StripPrefix(urlPrefix, http.FileServerFS(fsys))
|
||||
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
|
||||
}
|
||||
|
||||
// noDirListing wraps a file server handler to return 404 for directory requests.
|
||||
func noDirListing(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/") {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (v *V) ensureDatastarHandler() {
|
||||
v.datastarOnce.Do(func() {
|
||||
v.mux.HandleFunc("GET "+v.datastarPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -400,10 +518,7 @@ func (v *V) devModeRemovePersisted(c *Context) {
|
||||
}
|
||||
file.Close()
|
||||
|
||||
// remove ctx to persisted list
|
||||
if _, ok := ctxRegMap[c.id]; !ok {
|
||||
delete(ctxRegMap, c.id)
|
||||
}
|
||||
delete(ctxRegMap, c.id)
|
||||
|
||||
// write persisted list to file
|
||||
file, err = os.Create(p)
|
||||
@@ -455,6 +570,7 @@ type patchType int
|
||||
|
||||
const (
|
||||
patchTypeElements = iota
|
||||
patchTypeElementsWithVT
|
||||
patchTypeSignals
|
||||
patchTypeScript
|
||||
patchTypeRedirect
|
||||
@@ -475,6 +591,7 @@ func New() *V {
|
||||
logger: newConsoleLogger(zerolog.InfoLevel),
|
||||
contextRegistry: make(map[string]*Context),
|
||||
devModePageInitFnMap: make(map[string]func(*Context)),
|
||||
pageRegistry: make(map[string]func(*Context)),
|
||||
sessionManager: scs.New(),
|
||||
datastarPath: "/_datastar.js",
|
||||
datastarContent: datastarJS,
|
||||
@@ -507,16 +624,16 @@ func New() *V {
|
||||
// use last-event-id to tell if request is a sse reconnect
|
||||
sse.Send(datastar.EventTypePatchElements, []string{}, datastar.WithSSEEventId("via"))
|
||||
|
||||
c.sseConnected.Store(true)
|
||||
v.logDebug(c, "SSE connection established")
|
||||
|
||||
go func() {
|
||||
c.Sync()
|
||||
}()
|
||||
go c.Sync()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sse.Context().Done():
|
||||
v.logDebug(c, "SSE connection ended")
|
||||
v.cleanupCtx(c)
|
||||
return
|
||||
case <-c.ctxDisposedChan:
|
||||
v.logDebug(c, "context disposed, closing SSE")
|
||||
@@ -525,11 +642,16 @@ func New() *V {
|
||||
switch patch.typ {
|
||||
case patchTypeElements:
|
||||
if err := sse.PatchElements(patch.content); err != nil {
|
||||
// Only log if connection wasn't closed (avoids noise during shutdown/tests)
|
||||
if sse.Context().Err() == nil {
|
||||
v.logErr(c, "PatchElements failed: %v", err)
|
||||
}
|
||||
}
|
||||
case patchTypeElementsWithVT:
|
||||
if err := sse.PatchElements(patch.content, datastar.WithViewTransitions()); err != nil {
|
||||
if sse.Context().Err() == nil {
|
||||
v.logErr(c, "PatchElements (view transition) failed: %v", err)
|
||||
}
|
||||
}
|
||||
case patchTypeSignals:
|
||||
if err := sse.PatchSignals([]byte(patch.content)); err != nil {
|
||||
if sse.Context().Err() == nil {
|
||||
@@ -572,13 +694,29 @@ func New() *V {
|
||||
v.logErr(nil, "action '%s' failed: %v", actionID, err)
|
||||
return
|
||||
}
|
||||
csrfToken, _ := sigs["via-csrf"].(string)
|
||||
if subtle.ConstantTimeCompare([]byte(csrfToken), []byte(c.csrfToken)) != 1 {
|
||||
v.logWarn(c, "action '%s' rejected: invalid CSRF token", actionID)
|
||||
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if c.actionLimiter != nil && !c.actionLimiter.Allow() {
|
||||
v.logWarn(c, "action '%s' rate limited", actionID)
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
c.reqCtx = r.Context()
|
||||
actionFn, err := c.getActionFn(actionID)
|
||||
entry, err := c.getAction(actionID)
|
||||
if err != nil {
|
||||
v.logDebug(c, "action '%s' failed: %v", actionID, err)
|
||||
return
|
||||
}
|
||||
// log err if actionFn panics
|
||||
if entry.limiter != nil && !entry.limiter.Allow() {
|
||||
v.logWarn(c, "action '%s' rate limited (per-action)", actionID)
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
// log err if action panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
v.logErr(c, "action '%s' failed: %v", actionID, r)
|
||||
@@ -586,7 +724,44 @@ func New() *V {
|
||||
}()
|
||||
|
||||
c.injectSignals(sigs)
|
||||
actionFn()
|
||||
if len(entry.middleware) > 0 {
|
||||
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||
} else {
|
||||
entry.fn()
|
||||
}
|
||||
})
|
||||
|
||||
v.mux.HandleFunc("POST /_navigate", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
cID := r.FormValue("via-ctx")
|
||||
csrfToken := r.FormValue("via-csrf")
|
||||
navURL := r.FormValue("url")
|
||||
popstate := r.FormValue("popstate") == "1"
|
||||
|
||||
if cID == "" || navURL == "" || !strings.HasPrefix(navURL, "/") {
|
||||
http.Error(w, "missing or invalid parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
c, err := v.getCtx(cID)
|
||||
if err != nil {
|
||||
v.logErr(nil, "navigate failed: %v", err)
|
||||
http.Error(w, "context not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(csrfToken), []byte(c.csrfToken)) != 1 {
|
||||
v.logWarn(c, "navigate rejected: invalid CSRF token")
|
||||
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if c.actionLimiter != nil && !c.actionLimiter.Allow() {
|
||||
v.logWarn(c, "navigate rate limited")
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
c.reqCtx = r.Context()
|
||||
v.logDebug(c, "SPA navigate to %s", navURL)
|
||||
c.Navigate(navURL, popstate)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -603,20 +778,31 @@ func New() *V {
|
||||
v.logErr(c, "failed to handle session close: %v", err)
|
||||
return
|
||||
}
|
||||
c.dispose()
|
||||
v.logDebug(c, "session close event triggered")
|
||||
if v.cfg.DevMode {
|
||||
v.devModeRemovePersisted(c)
|
||||
}
|
||||
v.unregisterCtx(c)
|
||||
v.cleanupCtx(c)
|
||||
})
|
||||
|
||||
dn, err := getSharedNATS()
|
||||
if err != nil {
|
||||
v.logWarn(nil, "embedded NATS unavailable: %v", err)
|
||||
} else {
|
||||
v.defaultNATS = dn
|
||||
v.pubsub = &natsRef{dn: dn}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func genRandID() string {
|
||||
b := make([]byte, 4)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func genCSRFToken() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)[:8]
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func extractParams(pattern, path string) map[string]string {
|
||||
@@ -631,8 +817,24 @@ func extractParams(pattern, path string) map[string]string {
|
||||
key := p[i][1 : len(p[i])-1] // remove {}
|
||||
params[key] = u[i]
|
||||
} else if p[i] != u[i] {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// matchRoute finds the registered page init function and extracted params for the given path.
|
||||
func (v *V) matchRoute(path string) (route string, initFn func(*Context), params map[string]string) {
|
||||
for pattern, fn := range v.pageRegistry {
|
||||
if p := extractParams(pattern, path); p != nil {
|
||||
return pattern, fn, p
|
||||
}
|
||||
}
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Layout sets a layout function that wraps every page's view.
|
||||
// The layout receives the page content as a function and returns the full view.
|
||||
func (v *V) Layout(f func(func() h.H) h.H) {
|
||||
v.layout = f
|
||||
}
|
||||
|
||||
148
via_test.go
148
via_test.go
@@ -1,9 +1,13 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -128,6 +132,60 @@ func TestAction(t *testing.T) {
|
||||
assert.Contains(t, body, "/_action/")
|
||||
}
|
||||
|
||||
func TestEventTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attr string
|
||||
buildEl func(trigger *actionTrigger) h.H
|
||||
}{
|
||||
{"OnSubmit", "data-on:submit", func(tr *actionTrigger) h.H { return h.Form(tr.OnSubmit()) }},
|
||||
{"OnInput", "data-on:input", func(tr *actionTrigger) h.H { return h.Input(tr.OnInput()) }},
|
||||
{"OnFocus", "data-on:focus", func(tr *actionTrigger) h.H { return h.Input(tr.OnFocus()) }},
|
||||
{"OnBlur", "data-on:blur", func(tr *actionTrigger) h.H { return h.Input(tr.OnBlur()) }},
|
||||
{"OnMouseEnter", "data-on:mouseenter", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseEnter()) }},
|
||||
{"OnMouseLeave", "data-on:mouseleave", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseLeave()) }},
|
||||
{"OnScroll", "data-on:scroll", func(tr *actionTrigger) h.H { return h.Div(tr.OnScroll()) }},
|
||||
{"OnDblClick", "data-on:dblclick", func(tr *actionTrigger) h.H { return h.Div(tr.OnDblClick()) }},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var trigger *actionTrigger
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
trigger = c.Action(func() {})
|
||||
c.View(func() h.H { return tt.buildEl(trigger) })
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, req)
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, tt.attr)
|
||||
assert.Contains(t, body, "/_action/"+trigger.id)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("WithSignal", func(t *testing.T) {
|
||||
var trigger *actionTrigger
|
||||
var sig *signal
|
||||
v := New()
|
||||
v.Page("/", func(c *Context) {
|
||||
trigger = c.Action(func() {})
|
||||
sig = c.Signal("val")
|
||||
c.View(func() h.H {
|
||||
return h.Div(trigger.OnDblClick(WithSignal(sig, "x")))
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, req)
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, "data-on:dblclick")
|
||||
assert.Contains(t, body, "$"+sig.ID()+"='x'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnKeyDownWithWindow(t *testing.T) {
|
||||
var trigger *actionTrigger
|
||||
v := New()
|
||||
@@ -235,3 +293,93 @@ func TestPage_PanicsOnNoView(t *testing.T) {
|
||||
v.Page("/", func(c *Context) {})
|
||||
})
|
||||
}
|
||||
|
||||
func TestReaperCleansOrphanedContexts(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("orphan-1", "/", v)
|
||||
c.createdAt = time.Now().Add(-time.Minute) // created 1 min ago
|
||||
v.registerCtx(c)
|
||||
|
||||
_, err := v.getCtx("orphan-1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
v.reapOrphanedContexts(10 * time.Second)
|
||||
|
||||
_, err = v.getCtx("orphan-1")
|
||||
assert.Error(t, err, "orphaned context should have been reaped")
|
||||
}
|
||||
|
||||
func TestReaperIgnoresConnectedContexts(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("connected-1", "/", v)
|
||||
c.createdAt = time.Now().Add(-time.Minute)
|
||||
c.sseConnected.Store(true)
|
||||
v.registerCtx(c)
|
||||
|
||||
v.reapOrphanedContexts(10 * time.Second)
|
||||
|
||||
_, err := v.getCtx("connected-1")
|
||||
assert.NoError(t, err, "connected context should survive reaping")
|
||||
}
|
||||
|
||||
func TestReaperDisabledWithNegativeTTL(t *testing.T) {
|
||||
v := New()
|
||||
v.cfg.ContextTTL = -1
|
||||
v.startReaper()
|
||||
assert.Nil(t, v.reaperStop, "reaper should not start with negative TTL")
|
||||
}
|
||||
|
||||
func TestCleanupCtxIdempotent(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("idempotent-1", "/", v)
|
||||
v.registerCtx(c)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
v.cleanupCtx(c)
|
||||
v.cleanupCtx(c)
|
||||
})
|
||||
|
||||
_, err := v.getCtx("idempotent-1")
|
||||
assert.Error(t, err, "context should be removed after cleanup")
|
||||
}
|
||||
|
||||
func TestDevModeRemovePersistedFix(t *testing.T) {
|
||||
v := New()
|
||||
v.cfg.DevMode = true
|
||||
|
||||
dir := filepath.Join(t.TempDir(), ".via", "devmode")
|
||||
p := filepath.Join(dir, "ctx.json")
|
||||
assert.NoError(t, os.MkdirAll(dir, 0755))
|
||||
|
||||
// Write a persisted context
|
||||
ctxRegMap := map[string]string{"test-ctx-1": "/"}
|
||||
f, err := os.Create(p)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, json.NewEncoder(f).Encode(ctxRegMap))
|
||||
f.Close()
|
||||
|
||||
// Patch devModeRemovePersisted to use our temp path by calling it
|
||||
// directly — we need to override the path. Instead, test via the
|
||||
// actual function by temporarily changing the working dir.
|
||||
origDir, _ := os.Getwd()
|
||||
assert.NoError(t, os.Chdir(t.TempDir()))
|
||||
defer os.Chdir(origDir)
|
||||
|
||||
// Re-create the structure in the temp dir
|
||||
assert.NoError(t, os.MkdirAll(filepath.Join(".via", "devmode"), 0755))
|
||||
p2 := filepath.Join(".via", "devmode", "ctx.json")
|
||||
f2, _ := os.Create(p2)
|
||||
json.NewEncoder(f2).Encode(map[string]string{"test-ctx-1": "/"})
|
||||
f2.Close()
|
||||
|
||||
c := newContext("test-ctx-1", "/", v)
|
||||
v.devModeRemovePersisted(c)
|
||||
|
||||
// Read back and verify
|
||||
f3, err := os.Open(p2)
|
||||
assert.NoError(t, err)
|
||||
defer f3.Close()
|
||||
var result map[string]string
|
||||
assert.NoError(t, json.NewDecoder(f3).Decode(&result))
|
||||
assert.Empty(t, result, "persisted context should be removed")
|
||||
}
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
// Package vianats provides an embedded NATS server with JetStream as a
|
||||
// pub/sub backend for Via applications.
|
||||
package vianats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/delaneyj/toolbelt/embeddednats"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/ryanhamamura/via"
|
||||
)
|
||||
|
||||
// NATS implements via.PubSub using an embedded NATS server with JetStream.
|
||||
type NATS struct {
|
||||
server *embeddednats.Server
|
||||
nc *nats.Conn
|
||||
js nats.JetStreamContext
|
||||
}
|
||||
|
||||
// New starts an embedded NATS server with JetStream enabled and returns a
|
||||
// ready-to-use NATS instance. The server stores data in dataDir and shuts
|
||||
// down when ctx is cancelled.
|
||||
func New(ctx context.Context, dataDir string) (*NATS, error) {
|
||||
ns, err := embeddednats.New(ctx, embeddednats.WithDirectory(dataDir))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vianats: start server: %w", err)
|
||||
}
|
||||
ns.WaitForServer()
|
||||
|
||||
nc, err := ns.Client()
|
||||
if err != nil {
|
||||
ns.Close()
|
||||
return nil, fmt.Errorf("vianats: connect client: %w", err)
|
||||
}
|
||||
|
||||
js, err := nc.JetStream()
|
||||
if err != nil {
|
||||
nc.Close()
|
||||
ns.Close()
|
||||
return nil, fmt.Errorf("vianats: init jetstream: %w", err)
|
||||
}
|
||||
|
||||
return &NATS{server: ns, nc: nc, js: js}, nil
|
||||
}
|
||||
|
||||
// Publish sends data to the given subject using core NATS publish.
|
||||
// JetStream captures messages automatically if a matching stream exists.
|
||||
func (n *NATS) Publish(subject string, data []byte) error {
|
||||
return n.nc.Publish(subject, data)
|
||||
}
|
||||
|
||||
// Subscribe creates a core NATS subscription for real-time fan-out delivery.
|
||||
func (n *NATS) Subscribe(subject string, handler func(data []byte)) (via.Subscription, error) {
|
||||
sub, err := n.nc.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// Close shuts down the client connection and embedded server.
|
||||
func (n *NATS) Close() error {
|
||||
n.nc.Close()
|
||||
return n.server.Close()
|
||||
}
|
||||
|
||||
// Conn returns the underlying NATS connection for advanced usage.
|
||||
func (n *NATS) Conn() *nats.Conn {
|
||||
return n.nc
|
||||
}
|
||||
|
||||
// JetStream returns the JetStream context for stream configuration and replay.
|
||||
func (n *NATS) JetStream() nats.JetStreamContext {
|
||||
return n.js
|
||||
}
|
||||
|
||||
// StreamConfig holds the parameters for creating or updating a JetStream stream.
|
||||
type StreamConfig struct {
|
||||
Name string
|
||||
Subjects []string
|
||||
MaxMsgs int64
|
||||
MaxAge time.Duration
|
||||
}
|
||||
|
||||
// EnsureStream creates or updates a JetStream stream matching cfg.
|
||||
func EnsureStream(n *NATS, cfg StreamConfig) error {
|
||||
_, err := n.js.AddStream(&nats.StreamConfig{
|
||||
Name: cfg.Name,
|
||||
Subjects: cfg.Subjects,
|
||||
Retention: nats.LimitsPolicy,
|
||||
MaxMsgs: cfg.MaxMsgs,
|
||||
MaxAge: cfg.MaxAge,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ReplayHistory fetches the last limit messages from subject,
|
||||
// deserializing each as T. Returns an empty slice if nothing is available.
|
||||
func ReplayHistory[T any](n *NATS, subject string, limit int) ([]T, error) {
|
||||
sub, err := n.js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
var msgs []T
|
||||
for {
|
||||
raw, err := sub.NextMsg(200 * time.Millisecond)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
var msg T
|
||||
if json.Unmarshal(raw.Data, &msg) == nil {
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 && len(msgs) > limit {
|
||||
msgs = msgs[len(msgs)-limit:]
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
Reference in New Issue
Block a user