9 Commits

Author SHA1 Message Date
Ryan Hamamura
785f11e52d fix: harden SPA navigation with race protection and correctness fixes
- Add navMu to serialize concurrent navigations on the same context
- Replace url.PathEscape with targeted JS string escaper (PathEscape
  mangles full paths and doesn't escape single quotes)
- Collapse syncWithViewTransition into syncView(bool) to remove duplication
- Simplify popstate ready guard in navigate.js
- Preserve URL hash during SPA navigation
2026-02-12 14:41:50 -10:00
Ryan Hamamura
2f19874c17 feat: add PubSub() accessor to V struct 2026-02-12 14:32:05 -10:00
Ryan Hamamura
27b8540b71 feat: add SPA navigation with view transitions
Swap page content over the existing SSE connection without full page
loads. A persistent Context resets its page-specific state (signals,
actions, intervals, subscriptions) on navigate while preserving the
SSE stream, CSRF token, and session.

- c.Navigate(path) for programmatic SPA navigation from actions
- Injected JS intercepts same-origin <a> clicks (opt out with
  data-via-no-boost) and handles popstate for back/forward
- v.Layout() wraps pages in a shared shell for DRY nav/chrome
- View Transition API integration via WithViewTransitions() on
  PatchElements and h.DataViewTransition() helper
- POST /_navigate endpoint with CSRF validation and rate limiting
- pageStopChan cancels page-level OnInterval goroutines on navigate
- Includes SPA example with layout, counter, and live clock pages
2026-02-12 13:52:47 -10:00
Ryan Hamamura
532651552a refactor: simplify OnInterval API to auto-start and return stop func
Replace the exported OnIntervalRoutine struct (Start/Stop/UpdateInterval)
with a single function that auto-starts the goroutine and returns an
idempotent stop closure. Uses close(channel) instead of send-on-channel,
fixing a potential deadlock when the goroutine exits via context disposal.

Closes #5 item 4.
2026-02-12 12:27:50 -10:00
Ryan Hamamura
2310e45d35 feat: auto-start embedded NATS server in New()
Pub/sub now works out of the box — New() starts a process-scoped
embedded NATS server with JetStream. The PubSub interface remains
for custom backends via Config(Options{PubSub: ...}).

- Move vianats functionality into nats.go (eliminates circular import)
- Add NATSConn(), JetStream(), EnsureStream(), ReplayHistory[T]() to via
- Delete vianats/ package
- Simplify nats-chatroom and pubsub-crud examples
- Rewrite pubsub tests to use real embedded NATS
2026-02-12 08:54:44 -10:00
Ryan Hamamura
10b4838f8d feat: auto-track fields on context for zero-arg ValidateAll/ResetFields
Fields created via Context.Field are now tracked on the page context,
so ValidateAll() and ResetFields() with no arguments operate on all
fields by default. Explicit field args still work for selective use.

Also switches MinLen/MaxLen to utf8.RuneCountInString for correct
unicode handling and replaces fmt.Errorf with errors.New where
format strings are unnecessary.
2026-02-11 19:57:13 -10:00
Ryan Hamamura
5362614c3e feat: add field validation API with signup form example
Introduces Field, Rule, ValidateAll, ResetFields, and AddError for
declarative input validation. Includes built-in rules (Required,
MinLen, MaxLen, Min, Max, Email, Pattern, Custom) and a signup
example exercising the full API surface.
2026-02-11 14:42:44 -10:00
ryanhamamura
e636970f7b feat: add middleware, route groups, and codebase cleanup
* feat: add middleware example demonstrating route groups

Self-contained example covering v.Use(), v.Group(), nested groups,
Group.Use(), and middleware chaining with role-based access control.

* feat: add per-action middleware via WithMiddleware ActionOption

Reuses the existing Middleware type so the same auth/logging functions
work at both page and action level. Middleware runs after CSRF and
rate-limit checks, with full access to session and signals.

* feat: add RedirectView helper and refactor session example to use middleware

RedirectView lets middleware abort and redirect in one step. The session
example now uses an authRequired middleware on a route group instead of
an inline check inside the view.

* fix: remove dead code, fix double Load and extractParams mismatch

- Remove componentRegistry (written, never read)
- Remove unused signal methods: Bytes, Int64, Float
- Remove unreachable nil check in registerCtx
- Simplify injectRouteParams (extractParams already returns fresh map)
- Fix double sync.Map.Load in injectSignals
- Merge Shutdown/shutdown into single method
- Inline currSessionNum
- Fix extractParams: mismatched literal segment now returns nil
- Minor: new(bytes.Buffer), go c.Sync(), genRandID reads 4 bytes
2026-02-11 13:50:02 -10:00
Ryan Hamamura
f5158b866c feat: add Static and StaticFS helpers for serving static files
One-liner static file serving: v.Static("/assets/", "./public") for
filesystem directories and v.StaticFS("/assets/", fsys) for embed.FS.
Both auto-normalize the URL prefix and disable directory listings.
2026-02-06 13:22:00 -10:00
28 changed files with 2041 additions and 428 deletions

View File

@@ -69,7 +69,7 @@ func main() {
- **CSRF protection** — automatic token generation and validation on every action - **CSRF protection** — automatic token generation and validation on every action
- **Rate limiting** — token-bucket algorithm, configurable globally and per-action - **Rate limiting** — token-bucket algorithm, configurable globally and per-action
- **Event handling** — `OnClick`, `OnChange`, `OnSubmit`, `OnInput`, `OnFocus`, `OnBlur`, `OnMouseEnter`, `OnMouseLeave`, `OnScroll`, `OnDblClick`, `OnKeyDown`, and `OnKeyDownMap` for multi-key bindings - **Event handling** — `OnClick`, `OnChange`, `OnSubmit`, `OnInput`, `OnFocus`, `OnBlur`, `OnMouseEnter`, `OnMouseLeave`, `OnScroll`, `OnDblClick`, `OnKeyDown`, and `OnKeyDownMap` for multi-key bindings
- **Timed routines** — `OnInterval` with start/stop/update controls, tied to context lifecycle - **Timed routines** — `OnInterval` auto-starts a ticker goroutine, returns a stop function, tied to context lifecycle
- **Redirects** — `Redirect`, `ReplaceURL`, and format-string variants - **Redirects** — `Redirect`, `ReplaceURL`, and format-string variants
- **Plugin system** — `func(v *V)` hooks for integrating CSS/JS libraries - **Plugin system** — `func(v *V)` hooks for integrating CSS/JS libraries
- **Structured logging** — zerolog with configurable levels; console output in dev, JSON in production - **Structured logging** — zerolog with configurable levels; console output in dev, JSON in production

View File

@@ -5,8 +5,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"maps"
"reflect" "reflect"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -25,15 +25,17 @@ type Context struct {
app *V app *V
view func() h.H view func() h.H
routeParams map[string]string routeParams map[string]string
componentRegistry map[string]*Context parentPageCtx *Context
parentPageCtx *Context
patchChan chan patch patchChan chan patch
actionLimiter *rate.Limiter actionLimiter *rate.Limiter
actionRegistry map[string]actionEntry actionRegistry map[string]actionEntry
signals *sync.Map signals *sync.Map
mu sync.RWMutex mu sync.RWMutex
navMu sync.Mutex
ctxDisposedChan chan struct{} ctxDisposedChan chan struct{}
pageStopChan chan struct{}
reqCtx context.Context reqCtx context.Context
fields []*Field
subscriptions []Subscription subscriptions []Subscription
subsMu sync.Mutex subsMu sync.Mutex
disposeOnce sync.Once disposeOnce sync.Once
@@ -49,7 +51,11 @@ func (c *Context) View(f func() h.H) {
if f == nil { if f == nil {
panic("nil viewfn") panic("nil viewfn")
} }
c.view = func() h.H { return h.Div(h.ID(c.id), f()) } if c.app.layout != nil {
c.view = func() h.H { return h.Div(h.ID(c.id), c.app.layout(f)) }
} else {
c.view = func() h.H { return h.Div(h.ID(c.id), f()) }
}
} }
// Component registers a subcontext that has self contained data, actions and signals. // Component registers a subcontext that has self contained data, actions and signals.
@@ -81,7 +87,6 @@ func (c *Context) Component(initCtx func(c *Context)) func() h.H {
compCtx.parentPageCtx = c compCtx.parentPageCtx = c
} }
initCtx(compCtx) initCtx(compCtx)
c.componentRegistry[id] = compCtx
return compCtx.view return compCtx.view
} }
@@ -133,17 +138,19 @@ func (c *Context) getAction(id string) (actionEntry, error) {
return actionEntry{}, fmt.Errorf("action '%s' not found", id) return actionEntry{}, fmt.Errorf("action '%s' not found", id)
} }
// OnInterval starts a go routine that sets a time.Ticker with the given duration and executes // OnInterval starts a goroutine that executes handler on every tick of the given duration.
// the given handler func() on every tick. Use *Routine.UpdateInterval to update the interval. // The goroutine is tied to the context lifecycle and will stop when the context is disposed.
func (c *Context) OnInterval(duration time.Duration, handler func()) *OnIntervalRoutine { // Returns a func() that stops the interval when called.
var cn chan struct{} func (c *Context) OnInterval(duration time.Duration, handler func()) func() {
if c.isComponent() { // components use the chan on the parent page ctx var disposeCh, pageCh chan struct{}
cn = c.parentPageCtx.ctxDisposedChan if c.isComponent() {
disposeCh = c.parentPageCtx.ctxDisposedChan
pageCh = c.parentPageCtx.pageStopChan
} else { } else {
cn = c.ctxDisposedChan disposeCh = c.ctxDisposedChan
pageCh = c.pageStopChan
} }
r := newOnIntervalRoutine(cn, duration, handler) return newOnInterval(disposeCh, pageCh, duration, handler)
return r
} }
// Signal creates a reactive signal and initializes it with the given value. // Signal creates a reactive signal and initializes it with the given value.
@@ -208,14 +215,14 @@ func (c *Context) injectSignals(sigs map[string]any) {
defer c.mu.Unlock() defer c.mu.Unlock()
for sigID, val := range sigs { for sigID, val := range sigs {
if _, ok := c.signals.Load(sigID); !ok { item, ok := c.signals.Load(sigID)
if !ok {
c.signals.Store(sigID, &signal{ c.signals.Store(sigID, &signal{
id: sigID, id: sigID,
val: val, val: val,
}) })
continue continue
} }
item, _ := c.signals.Load(sigID)
if sig, ok := item.(*signal); ok { if sig, ok := item.(*signal); ok {
sig.val = val sig.val = val
sig.changed = false sig.changed = false
@@ -266,15 +273,22 @@ func (c *Context) sendPatch(p patch) {
// Sync pushes the current view state and signal changes to the browser immediately // Sync pushes the current view state and signal changes to the browser immediately
// over the live SSE event stream. // over the live SSE event stream.
func (c *Context) Sync() { func (c *Context) Sync() {
elemsPatch := bytes.NewBuffer(make([]byte, 0)) c.syncView(false)
}
func (c *Context) syncView(viewTransition bool) {
elemsPatch := new(bytes.Buffer)
if err := c.view().Render(elemsPatch); err != nil { if err := c.view().Render(elemsPatch); err != nil {
c.app.logErr(c, "sync view failed: %v", err) c.app.logErr(c, "sync view failed: %v", err)
return return
} }
c.sendPatch(patch{patchTypeElements, elemsPatch.String()}) typ := patchType(patchTypeElements)
if viewTransition {
typ = patchTypeElementsWithVT
}
c.sendPatch(patch{typ, elemsPatch.String()})
updatedSigs := c.prepareSignalsForPatch() updatedSigs := c.prepareSignalsForPatch()
if len(updatedSigs) != 0 { if len(updatedSigs) != 0 {
outgoingSigs, _ := json.Marshal(updatedSigs) outgoingSigs, _ := json.Marshal(updatedSigs)
c.sendPatch(patch{patchTypeSignals, string(outgoingSigs)}) c.sendPatch(patch{patchTypeSignals, string(outgoingSigs)})
@@ -331,6 +345,15 @@ func (c *Context) ExecScript(s string) {
c.sendPatch(patch{patchTypeScript, s}) c.sendPatch(patch{patchTypeScript, s})
} }
// RedirectView sets a view that redirects the browser to the given URL.
// Use this in middleware to abort the chain and redirect in one step.
func (c *Context) RedirectView(url string) {
c.View(func() h.H {
c.Redirect(url)
return h.Div()
})
}
// Redirect navigates the browser to the given URL. // Redirect navigates the browser to the given URL.
// This triggers a full page navigation - the current context will be disposed // This triggers a full page navigation - the current context will be disposed
// and a new context created at the destination URL. // and a new context created at the destination URL.
@@ -362,6 +385,46 @@ func (c *Context) ReplaceURLf(format string, a ...any) {
c.ReplaceURL(fmt.Sprintf(format, a...)) c.ReplaceURL(fmt.Sprintf(format, a...))
} }
// resetPageState tears down page-specific state (intervals, subscriptions,
// actions, signals, fields) without disposing the context itself. The SSE
// connection and context lifetime are unaffected.
func (c *Context) resetPageState() {
close(c.pageStopChan)
c.unsubscribeAll()
c.mu.Lock()
c.actionRegistry = make(map[string]actionEntry)
c.signals = new(sync.Map)
c.fields = nil
c.pageStopChan = make(chan struct{})
c.mu.Unlock()
}
// Navigate performs an SPA navigation to the given path. It resets page state,
// runs the target page's init function (with middleware), and pushes the new
// view over the existing SSE connection with a view transition animation.
// If popstate is true, replaceState is used instead of pushState.
func (c *Context) Navigate(path string, popstate bool) {
c.navMu.Lock()
defer c.navMu.Unlock()
route, initFn, params := c.app.matchRoute(path)
if initFn == nil {
c.Redirect(path)
return
}
c.resetPageState()
c.route = route
c.injectRouteParams(params)
initFn(c)
c.syncView(true)
safe := strings.NewReplacer(`\`, `\\`, `'`, `\'`).Replace(path)
if popstate {
c.ExecScript(fmt.Sprintf("history.replaceState({},'','%s')", safe))
} else {
c.ExecScript(fmt.Sprintf("history.pushState({},'','%s')", safe))
}
}
// dispose idempotently tears down this context: unsubscribes all pubsub // dispose idempotently tears down this context: unsubscribes all pubsub
// subscriptions and closes ctxDisposedChan to stop routines and exit the SSE loop. // subscriptions and closes ctxDisposedChan to stop routines and exit the SSE loop.
func (c *Context) dispose() { func (c *Context) dispose() {
@@ -372,7 +435,7 @@ func (c *Context) dispose() {
} }
// stopAllRoutines closes ctxDisposedChan, broadcasting to all listening // stopAllRoutines closes ctxDisposedChan, broadcasting to all listening
// goroutines (OnIntervalRoutine, SSE loop) that this context is done. // goroutines (OnInterval, SSE loop) that this context is done.
func (c *Context) stopAllRoutines() { func (c *Context) stopAllRoutines() {
select { select {
case <-c.ctxDisposedChan: case <-c.ctxDisposedChan:
@@ -386,12 +449,9 @@ func (c *Context) injectRouteParams(params map[string]string) {
if params == nil { if params == nil {
return return
} }
m := make(map[string]string)
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
maps.Copy(m, params) c.routeParams = params
c.routeParams = m
} }
// GetPathParam retrieves the value from the page request URL for the given parameter name // GetPathParam retrieves the value from the page request URL for the given parameter name
@@ -477,6 +537,50 @@ func (c *Context) unsubscribeAll() {
} }
} }
// Field creates a signal with validation rules attached.
// The initial value seeds both the signal and the reset target.
// The field is tracked on the context so ValidateAll/ResetFields
// can operate on all fields by default.
func (c *Context) Field(initial any, rules ...Rule) *Field {
f := &Field{
signal: c.Signal(initial),
rules: rules,
initialVal: initial,
}
target := c
if c.isComponent() {
target = c.parentPageCtx
}
target.fields = append(target.fields, f)
return f
}
// ValidateAll runs Validate on each field, returning true only if all pass.
// With no arguments it validates every field tracked on this context.
func (c *Context) ValidateAll(fields ...*Field) bool {
if len(fields) == 0 {
fields = c.fields
}
ok := true
for _, f := range fields {
if !f.Validate() {
ok = false
}
}
return ok
}
// ResetFields resets each field to its initial value and clears errors.
// With no arguments it resets every field tracked on this context.
func (c *Context) ResetFields(fields ...*Field) {
if len(fields) == 0 {
fields = c.fields
}
for _, f := range fields {
f.Reset()
}
}
func newContext(id string, route string, v *V) *Context { func newContext(id string, route string, v *V) *Context {
if v == nil { if v == nil {
panic("create context failed: app pointer is nil") panic("create context failed: app pointer is nil")
@@ -488,12 +592,12 @@ func newContext(id string, route string, v *V) *Context {
csrfToken: genCSRFToken(), csrfToken: genCSRFToken(),
routeParams: make(map[string]string), routeParams: make(map[string]string),
app: v, app: v,
componentRegistry: make(map[string]*Context), actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
actionRegistry: make(map[string]actionEntry), actionRegistry: make(map[string]actionEntry),
signals: new(sync.Map), signals: new(sync.Map),
patchChan: make(chan patch, 1), patchChan: make(chan patch, 8),
ctxDisposedChan: make(chan struct{}, 1), ctxDisposedChan: make(chan struct{}, 1),
pageStopChan: make(chan struct{}),
createdAt: time.Now(), createdAt: time.Now(),
} }
} }

58
field.go Normal file
View File

@@ -0,0 +1,58 @@
package via
// Field is a signal with built-in validation rules and error state.
// It embeds *signal, so all signal methods (Bind, String, Int, Bool, SetValue, Text, ID)
// work transparently.
type Field struct {
*signal
rules []Rule
errors []string
initialVal any
}
// Validate runs all rules against the current value.
// Clears previous errors, populates new ones, returns true if all rules pass.
func (f *Field) Validate() bool {
f.errors = nil
val := f.String()
for _, r := range f.rules {
if err := r.validate(val); err != nil {
f.errors = append(f.errors, err.Error())
}
}
return len(f.errors) == 0
}
// HasError returns true if this field has any validation errors.
func (f *Field) HasError() bool {
return len(f.errors) > 0
}
// FirstError returns the first validation error message, or "" if valid.
func (f *Field) FirstError() string {
if len(f.errors) > 0 {
return f.errors[0]
}
return ""
}
// Errors returns all current validation error messages.
func (f *Field) Errors() []string {
return f.errors
}
// AddError manually adds an error message (useful for server-side or cross-field validation).
func (f *Field) AddError(msg string) {
f.errors = append(f.errors, msg)
}
// ClearErrors removes all validation errors from this field.
func (f *Field) ClearErrors() {
f.errors = nil
}
// Reset restores the field value to its initial value and clears all errors.
func (f *Field) Reset() {
f.SetValue(f.initialVal)
f.errors = nil
}

206
field_test.go Normal file
View File

@@ -0,0 +1,206 @@
package via
import (
"fmt"
"testing"
"github.com/ryanhamamura/via/h"
"github.com/stretchr/testify/assert"
)
func newTestField(initial any, rules ...Rule) *Field {
v := New()
var f *Field
v.Page("/", func(c *Context) {
f = c.Field(initial, rules...)
c.View(func() h.H { return h.Div() })
})
return f
}
func TestFieldCreation(t *testing.T) {
f := newTestField("hello", Required())
assert.Equal(t, "hello", f.String())
assert.NotEmpty(t, f.ID())
}
func TestFieldSignalDelegation(t *testing.T) {
f := newTestField(42)
assert.Equal(t, "42", f.String())
assert.Equal(t, 42, f.Int())
f.SetValue("new")
assert.Equal(t, "new", f.String())
// Bind returns an h.H element
assert.NotNil(t, f.Bind())
}
func TestFieldValidateSingleRule(t *testing.T) {
f := newTestField("", Required())
assert.False(t, f.Validate())
assert.True(t, f.HasError())
assert.Equal(t, "This field is required", f.FirstError())
f.SetValue("ok")
assert.True(t, f.Validate())
assert.False(t, f.HasError())
assert.Equal(t, "", f.FirstError())
}
func TestFieldValidateMultipleRules(t *testing.T) {
f := newTestField("ab", Required(), MinLen(3))
assert.False(t, f.Validate())
errs := f.Errors()
assert.Len(t, errs, 1)
assert.Equal(t, "Must be at least 3 characters", errs[0])
f.SetValue("")
assert.False(t, f.Validate())
errs = f.Errors()
assert.Len(t, errs, 2)
}
func TestFieldErrors(t *testing.T) {
f := newTestField("")
assert.Nil(t, f.Errors())
assert.False(t, f.HasError())
assert.Equal(t, "", f.FirstError())
}
func TestFieldAddError(t *testing.T) {
f := newTestField("ok")
f.AddError("username taken")
assert.True(t, f.HasError())
assert.Equal(t, "username taken", f.FirstError())
assert.Len(t, f.Errors(), 1)
}
func TestFieldClearErrors(t *testing.T) {
f := newTestField("", Required())
f.Validate()
assert.True(t, f.HasError())
f.ClearErrors()
assert.False(t, f.HasError())
}
func TestFieldReset(t *testing.T) {
f := newTestField("initial", Required(), MinLen(3))
f.SetValue("changed")
f.AddError("some error")
f.Reset()
assert.Equal(t, "initial", f.String())
assert.False(t, f.HasError())
}
func TestValidateAll(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
c.Field("", Required(), MinLen(3))
c.Field("", Required(), Email())
c.View(func() h.H { return h.Div() })
// both empty → both fail
assert.False(t, c.ValidateAll())
})
v2 := New()
v2.Page("/", func(c *Context) {
c.Field("joe", Required(), MinLen(3))
c.Field("joe@x.com", Required(), Email())
c.View(func() h.H { return h.Div() })
assert.True(t, c.ValidateAll())
})
}
func TestValidateAllPartialFailure(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
good := c.Field("valid", Required())
bad := c.Field("", Required())
c.View(func() h.H { return h.Div() })
ok := c.ValidateAll()
assert.False(t, ok)
assert.False(t, good.HasError())
assert.True(t, bad.HasError())
})
}
func TestValidateAllSelectiveArgs(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
a := c.Field("", Required())
b := c.Field("ok", Required())
cField := c.Field("", Required())
c.View(func() h.H { return h.Div() })
// only validate a and b — cField should be untouched
ok := c.ValidateAll(a, b)
assert.False(t, ok)
assert.True(t, a.HasError())
assert.False(t, b.HasError())
assert.False(t, cField.HasError(), "unselected field should not be validated")
})
}
func TestResetFields(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
a := c.Field("a", Required())
b := c.Field("b", Required())
c.View(func() h.H { return h.Div() })
a.SetValue("changed-a")
b.SetValue("changed-b")
a.AddError("err")
c.ResetFields()
assert.Equal(t, "a", a.String())
assert.Equal(t, "b", b.String())
assert.False(t, a.HasError())
})
}
func TestResetFieldsSelectiveArgs(t *testing.T) {
v := New()
v.Page("/", func(c *Context) {
a := c.Field("a")
b := c.Field("b")
c.View(func() h.H { return h.Div() })
a.SetValue("changed-a")
b.SetValue("changed-b")
// only reset a
c.ResetFields(a)
assert.Equal(t, "a", a.String())
assert.Equal(t, "changed-b", b.String(), "unselected field should not be reset")
})
}
func TestFieldValidateClearsPreviousErrors(t *testing.T) {
f := newTestField("", Required())
f.Validate()
assert.True(t, f.HasError())
f.SetValue("ok")
f.Validate()
assert.False(t, f.HasError())
}
func TestFieldCustomValidator(t *testing.T) {
f := newTestField("bad", Custom(func(val string) error {
if val == "bad" {
return fmt.Errorf("no bad words")
}
return nil
}))
assert.False(t, f.Validate())
assert.Equal(t, "no bad words", f.FirstError())
f.SetValue("good")
assert.True(t, f.Validate())
}

1
go.mod
View File

@@ -38,6 +38,5 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/crypto v0.45.0 // indirect golang.org/x/crypto v0.45.0 // indirect
golang.org/x/sys v0.38.0 // indirect golang.org/x/sys v0.38.0 // indirect
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@@ -11,3 +11,11 @@ func DataEffect(expression string) H {
func DataIgnoreMorph() H { func DataIgnoreMorph() H {
return Attr("data-ignore-morph") return Attr("data-ignore-morph")
} }
// DataViewTransition sets the view-transition-name CSS property on an element
// via an inline style. Elements with matching names animate between pages
// during SPA navigation. If the element also needs other inline styles,
// include view-transition-name directly in the Style() call instead.
func DataViewTransition(name string) H {
return Attr("style", "view-transition-name: "+name)
}

View File

@@ -0,0 +1,151 @@
package main
import (
"fmt"
"time"
"github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h"
)
func main() {
v := via.New()
v.Config(via.Options{
ServerAddress: ":8080",
DocumentTitle: "Middleware Example",
})
// --- Middleware definitions ---
// requestLogger logs every page request to stdout.
requestLogger := func(c *via.Context, next func()) {
fmt.Printf("[%s] request\n", time.Now().Format("15:04:05"))
next()
}
// authRequired redirects unauthenticated users to /login.
authRequired := func(c *via.Context, next func()) {
if c.Session().GetString("role") == "" {
c.RedirectView("/login")
return
}
next()
}
// auditLog prints the authenticated username to stdout.
auditLog := func(c *via.Context, next func()) {
fmt.Printf("[audit] user=%s\n", c.Session().GetString("username"))
next()
}
// superAdminOnly rejects non-superadmin users with a forbidden view.
superAdminOnly := func(c *via.Context, next func()) {
if c.Session().GetString("role") != "superadmin" {
c.View(func() h.H {
return h.Div(
h.H1(h.Text("Forbidden")),
h.P(h.Text("Super-admin access required.")),
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
)
})
return
}
next()
}
// --- Route registration ---
v.Use(requestLogger) // global middleware
admin := v.Group("/admin", authRequired) // prefixed group
admin.Use(auditLog) // Group.Use()
superAdmin := admin.Group("/super", superAdminOnly) // nested group
// Public: redirect root to login
v.Page("/", func(c *via.Context) {
c.View(func() h.H {
c.Redirect("/login")
return h.Div()
})
})
// Public: login page with role-selection buttons
v.Page("/login", func(c *via.Context) {
loginAdmin := c.Action(func() {
c.Session().Set("role", "admin")
c.Session().Set("username", "alice")
c.Session().RenewToken()
c.Redirect("/admin/dashboard")
})
loginSuper := c.Action(func() {
c.Session().Set("role", "superadmin")
c.Session().Set("username", "bob")
c.Session().RenewToken()
c.Redirect("/admin/dashboard")
})
c.View(func() h.H {
return h.Div(
h.H1(h.Text("Login")),
h.P(h.Text("Choose a role:")),
h.Button(h.Text("Login as Admin"), loginAdmin.OnClick()),
h.Raw(" "),
h.Button(h.Text("Login as Super Admin"), loginSuper.OnClick()),
)
})
})
// Per-action middleware: only superadmins can invoke this action.
requireSuperAdmin := func(c *via.Context, next func()) {
if c.Session().GetString("role") != "superadmin" {
return
}
next()
}
// Admin: dashboard (requires authRequired + auditLog)
admin.Page("/dashboard", func(c *via.Context) {
logout := c.Action(func() {
c.Session().Delete("role")
c.Session().Delete("username")
c.Redirect("/login")
})
dangerAction := c.Action(func() {
fmt.Printf("[danger] executed by %s\n", c.Session().GetString("username"))
c.Sync()
}, via.WithMiddleware(requireSuperAdmin))
c.View(func() h.H {
username := c.Session().GetString("username")
role := c.Session().GetString("role")
return h.Div(
h.H1(h.Textf("Dashboard — %s (%s)", username, role)),
h.Ul(
h.Li(h.A(h.Href("/admin/super/settings"), h.Text("Super Admin Settings"))),
),
h.H2(h.Text("Danger Zone")),
h.P(h.Text("This action is protected by per-action middleware (superadmin only):")),
h.Button(h.Text("Delete Everything"), dangerAction.OnClick()),
h.Br(),
h.Br(),
h.Button(h.Text("Logout"), logout.OnClick()),
)
})
})
// Super-admin: settings (requires authRequired + auditLog + superAdminOnly)
superAdmin.Page("/settings", func(c *via.Context) {
c.View(func() h.H {
username := c.Session().GetString("username")
return h.Div(
h.H1(h.Textf("Super Admin Settings — %s", username)),
h.P(h.Text("Only super-admins can see this page.")),
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
)
})
})
v.Start()
}

View File

@@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"log" "log"
"math/rand" "math/rand"
"sync" "sync"
@@ -9,7 +8,6 @@ import (
"github.com/ryanhamamura/via" "github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h" "github.com/ryanhamamura/via/h"
"github.com/ryanhamamura/via/vianats"
) )
var ( var (
@@ -36,15 +34,15 @@ func (u *UserInfo) Avatar() h.H {
var roomNames = []string{"Go", "Rust", "Python", "JavaScript", "Clojure"} var roomNames = []string{"Go", "Rust", "Python", "JavaScript", "Clojure"}
func main() { func main() {
ctx := context.Background() v := via.New()
v.Config(via.Options{
DevMode: true,
DocumentTitle: "NATS Chat",
LogLevel: via.LogLevelInfo,
ServerAddress: ":7331",
})
ps, err := vianats.New(ctx, "./data/nats") err := via.EnsureStream(v, via.StreamConfig{
if err != nil {
log.Fatalf("Failed to start embedded NATS: %v", err)
}
defer ps.Close()
err = vianats.EnsureStream(ps, vianats.StreamConfig{
Name: "CHAT", Name: "CHAT",
Subjects: []string{"chat.>"}, Subjects: []string{"chat.>"},
MaxMsgs: 1000, MaxMsgs: 1000,
@@ -54,15 +52,6 @@ func main() {
log.Fatalf("Failed to ensure stream: %v", err) log.Fatalf("Failed to ensure stream: %v", err)
} }
v := via.New()
v.Config(via.Options{
DevMode: true,
DocumentTitle: "NATS Chat",
LogLevel: via.LogLevelInfo,
ServerAddress: ":7331",
PubSub: ps,
})
v.AppendToHead( v.AppendToHead(
h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.min.css")), h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.min.css")),
h.StyleEl(h.Raw(` h.StyleEl(h.Raw(`
@@ -148,7 +137,7 @@ func main() {
subject := "chat.room." + room subject := "chat.room." + room
// Replay history from JetStream // Replay history from JetStream
if hist, err := vianats.ReplayHistory[ChatMessage](ps, subject, 50); err == nil { if hist, err := via.ReplayHistory[ChatMessage](v, subject, 50); err == nil {
messages = hist messages = hist
} }

View File

@@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"html" "html"
@@ -11,7 +10,6 @@ import (
"github.com/ryanhamamura/via" "github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h" "github.com/ryanhamamura/via/h"
"github.com/ryanhamamura/via/vianats"
) )
var WithSignal = via.WithSignal var WithSignal = via.WithSignal
@@ -49,15 +47,15 @@ func findBookmark(id string) (Bookmark, int) {
} }
func main() { func main() {
ctx := context.Background() v := via.New()
v.Config(via.Options{
DevMode: true,
DocumentTitle: "Bookmarks",
LogLevel: via.LogLevelInfo,
ServerAddress: ":7331",
})
ps, err := vianats.New(ctx, "./data/nats") err := via.EnsureStream(v, via.StreamConfig{
if err != nil {
log.Fatalf("Failed to start embedded NATS: %v", err)
}
defer ps.Close()
err = vianats.EnsureStream(ps, vianats.StreamConfig{
Name: "BOOKMARKS", Name: "BOOKMARKS",
Subjects: []string{"bookmarks.>"}, Subjects: []string{"bookmarks.>"},
MaxMsgs: 1000, MaxMsgs: 1000,
@@ -67,15 +65,6 @@ func main() {
log.Fatalf("Failed to ensure stream: %v", err) log.Fatalf("Failed to ensure stream: %v", err)
} }
v := via.New()
v.Config(via.Options{
DevMode: true,
DocumentTitle: "Bookmarks",
LogLevel: via.LogLevelInfo,
ServerAddress: ":7331",
PubSub: ps,
})
v.AppendToHead( v.AppendToHead(
h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/daisyui@4/dist/full.min.css")), h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/daisyui@4/dist/full.min.css")),
h.Script(h.Src("https://cdn.tailwindcss.com")), h.Script(h.Src("https://cdn.tailwindcss.com")),

View File

@@ -37,29 +37,33 @@ func main() {
return 1000 / time.Duration(refreshRate.Int()) * time.Millisecond return 1000 / time.Duration(refreshRate.Int()) * time.Millisecond
} }
updateData := c.OnInterval(computedTickDuration(), func() { var stopUpdate func()
ts := time.Now().UnixMilli() startInterval := func() {
val := rand.ExpFloat64() * 10 stopUpdate = c.OnInterval(computedTickDuration(), func() {
ts := time.Now().UnixMilli()
val := rand.ExpFloat64() * 10
c.ExecScript(fmt.Sprintf(` c.ExecScript(fmt.Sprintf(`
if (myChart) { if (myChart) {
myChart.appendData({seriesIndex: 0, data: [[%d, %f]]}); myChart.appendData({seriesIndex: 0, data: [[%d, %f]]});
myChart.setOption({},{notMerge:false,lazyUpdate:true}); myChart.setOption({},{notMerge:false,lazyUpdate:true});
}; };
`, ts, val)) `, ts, val))
}) })
updateData.Start() }
startInterval()
updateRefreshRate := c.Action(func() { updateRefreshRate := c.Action(func() {
updateData.UpdateInterval(computedTickDuration()) stopUpdate()
startInterval()
}) })
toggleIsLive := c.Action(func() { toggleIsLive := c.Action(func() {
isLive = isLiveSig.Bool() isLive = isLiveSig.Bool()
if isLive { if isLive {
updateData.Start() startInterval()
} else { } else {
updateData.Stop() stopUpdate()
} }
}) })
c.View(func() h.H { c.View(func() h.H {

View File

@@ -29,7 +29,17 @@ func main() {
SessionManager: sm, SessionManager: sm,
}) })
// Login page // Auth middleware — redirects unauthenticated users to /login
authRequired := func(c *via.Context, next func()) {
if c.Session().GetString("username") == "" {
c.Session().Set("flash", "Please log in first")
c.RedirectView("/login")
return
}
next()
}
// Login page (public)
v.Page("/login", func(c *via.Context) { v.Page("/login", func(c *via.Context) {
flash := c.Session().PopString("flash") flash := c.Session().PopString("flash")
usernameInput := c.Signal("") usernameInput := c.Signal("")
@@ -64,8 +74,10 @@ func main() {
}) })
}) })
// Dashboard page (protected) // Protected pages
v.Page("/dashboard", func(c *via.Context) { protected := v.Group("", authRequired)
protected.Page("/dashboard", func(c *via.Context) {
logout := c.Action(func() { logout := c.Action(func() {
c.Session().Set("flash", "Goodbye!") c.Session().Set("flash", "Goodbye!")
c.Session().Delete("username") c.Session().Delete("username")
@@ -74,14 +86,6 @@ func main() {
c.View(func() h.H { c.View(func() h.H {
username := c.Session().GetString("username") username := c.Session().GetString("username")
// Not logged in? Redirect to login
if username == "" {
c.Session().Set("flash", "Please log in first")
c.Redirect("/login")
return h.Div()
}
flash := c.Session().PopString("flash") flash := c.Session().PopString("flash")
var flashMsg h.H var flashMsg h.H
if flash != "" { if flash != "" {

View File

@@ -0,0 +1,87 @@
package main
import (
"github.com/ryanhamamura/via"
"github.com/ryanhamamura/via/h"
)
func main() {
v := via.New()
v.Config(via.Options{
DocumentTitle: "Signup",
ServerAddress: ":8080",
})
v.AppendToHead(h.StyleEl(h.Raw(`
body { font-family: system-ui, sans-serif; max-width: 420px; margin: 2rem auto; padding: 0 1rem; }
label { display: block; font-weight: 600; margin-top: 1rem; }
input { display: block; width: 100%; padding: 0.4rem; margin-top: 0.25rem; box-sizing: border-box; }
.error { color: #c00; font-size: 0.85rem; margin-top: 0.2rem; }
.success { color: #080; margin-top: 1rem; }
.actions { margin-top: 1.5rem; display: flex; gap: 0.5rem; }
`)))
v.Page("/", func(c *via.Context) {
username := c.Field("", via.Required(), via.MinLen(3), via.MaxLen(20))
email := c.Field("", via.Required(), via.Email())
age := c.Field("", via.Required(), via.Min(13), via.Max(120))
// Optional field — only validated when non-empty
website := c.Field("", via.Pattern(`^$|^https?://\S+$`, "Must be a valid URL"))
var success string
signup := c.Action(func() {
success = ""
if !c.ValidateAll() {
c.Sync()
return
}
// Server-side check
if username.String() == "admin" {
username.AddError("Username is already taken")
c.Sync()
return
}
success = "Account created for " + username.String() + "!"
c.ResetFields()
c.Sync()
})
reset := c.Action(func() {
success = ""
c.ResetFields()
c.Sync()
})
c.View(func() h.H {
return h.Div(
h.H1(h.Text("Sign Up")),
h.Label(h.Text("Username")),
h.Input(h.Type("text"), h.Placeholder("pick a username"), username.Bind()),
h.If(username.HasError(), h.Div(h.Class("error"), h.Text(username.FirstError()))),
h.Label(h.Text("Email")),
h.Input(h.Type("email"), h.Placeholder("you@example.com"), email.Bind()),
h.If(email.HasError(), h.Div(h.Class("error"), h.Text(email.FirstError()))),
h.Label(h.Text("Age")),
h.Input(h.Type("number"), h.Placeholder("your age"), age.Bind()),
h.If(age.HasError(), h.Div(h.Class("error"), h.Text(age.FirstError()))),
h.Label(h.Text("Website (optional)")),
h.Input(h.Type("url"), h.Placeholder("https://example.com"), website.Bind()),
h.If(website.HasError(), h.Div(h.Class("error"), h.Text(website.FirstError()))),
h.Div(h.Class("actions"),
h.Button(h.Text("Sign Up"), signup.OnClick()),
h.Button(h.Text("Reset"), reset.OnClick()),
),
h.If(success != "", h.P(h.Class("success"), h.Text(success))),
)
})
})
v.Start()
}

View File

@@ -0,0 +1,91 @@
package main
import (
"fmt"
"time"
"github.com/ryanhamamura/via"
. "github.com/ryanhamamura/via/h"
)
func main() {
v := via.New()
v.Config(via.Options{
DocumentTitle: "SPA Navigation",
ServerAddress: ":7331",
})
v.AppendToHead(
Raw(`<link rel="preconnect" href="https://fonts.googleapis.com">`),
Raw(`<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>`),
Raw(`<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap" rel="stylesheet">`),
Raw(`<style>body{font-family:'Inter',sans-serif;margin:0;background:#111;color:#eee}</style>`),
)
v.Layout(func(content func() H) H {
return Div(
Nav(
Style("display:flex;gap:1rem;padding:1rem;background:#222;"),
A(Href("/"), Text("Home"), Style("color:#fff")),
A(Href("/counter"), Text("Counter"), Style("color:#fff")),
A(Href("/clock"), Text("Clock"), Style("color:#fff")),
A(Href("https://github.com"), Text("GitHub (external)"), Style("color:#888")),
A(Href("/"), Text("Full Reload"), Attr("data-via-no-boost"), Style("color:#f88")),
),
Main(Style("padding:1rem"), content()),
)
})
// Home page
v.Page("/", func(c *via.Context) {
goCounter := c.Action(func() { c.Navigate("/counter", false) })
c.View(func() H {
return Div(
H1(Text("Home"), DataViewTransition("page-title")),
P(Text("Click the nav links above — no page reload, no white flash.")),
P(Text("Or navigate programmatically:")),
Button(Text("Go to Counter"), goCounter.OnClick()),
)
})
})
// Counter page — demonstrates signals and actions survive within a page,
// but reset on navigate away and back.
v.Page("/counter", func(c *via.Context) {
count := 0
increment := c.Action(func() { count++; c.Sync() })
goHome := c.Action(func() { c.Navigate("/", false) })
c.View(func() H {
return Div(
H1(Text("Counter"), DataViewTransition("page-title")),
P(Textf("Count: %d", count)),
Button(Text("+1"), increment.OnClick()),
Button(Text("Go Home"), goHome.OnClick(), Style("margin-left:0.5rem")),
)
})
})
// Clock page — demonstrates OnInterval cleanup on navigate.
v.Page("/clock", func(c *via.Context) {
now := time.Now().Format("15:04:05")
c.OnInterval(time.Second, func() {
now = time.Now().Format("15:04:05")
c.Sync()
})
c.View(func() H {
return Div(
H1(Text("Clock"), DataViewTransition("page-title")),
P(Text("This page has an OnInterval that ticks every second.")),
P(Textf("Current time: %s", now)),
P(Text("Navigate away and back — the old interval stops, a new one starts.")),
P(Textf("Proof this is a fresh page init: random = %d", time.Now().UnixNano()%1000)),
)
})
})
fmt.Println("SPA example running at http://localhost:7331")
v.Start()
}

82
middleware.go Normal file
View File

@@ -0,0 +1,82 @@
package via
// Middleware wraps a page init function. Call next to continue the chain;
// return without calling next to abort (set a view first, e.g. RedirectView).
type Middleware func(c *Context, next func())
// Group is a route group with a shared prefix and middleware stack.
type Group struct {
v *V
prefix string
middleware []Middleware
}
// Use appends middleware to the global stack.
// Global middleware runs before every page handler.
func (v *V) Use(mw ...Middleware) {
v.middleware = append(v.middleware, mw...)
}
// Group creates a route group with the given path prefix and middleware.
// Routes registered on the group are prefixed and run the group's middleware
// after any global middleware.
func (v *V) Group(prefix string, mw ...Middleware) *Group {
return &Group{
v: v,
prefix: prefix,
middleware: mw,
}
}
// Page registers a route on this group. The full route is the group prefix
// concatenated with route.
func (g *Group) Page(route string, initContextFn func(c *Context)) {
fullRoute := g.prefix + route
allMw := make([]Middleware, 0, len(g.v.middleware)+len(g.middleware))
allMw = append(allMw, g.v.middleware...)
allMw = append(allMw, g.middleware...)
wrapped := chainMiddleware(allMw, initContextFn)
g.v.page(fullRoute, initContextFn, wrapped)
}
// Group creates a nested sub-group that inherits this group's prefix and
// middleware, then adds its own.
func (g *Group) Group(prefix string, mw ...Middleware) *Group {
combined := make([]Middleware, len(g.middleware), len(g.middleware)+len(mw))
copy(combined, g.middleware)
combined = append(combined, mw...)
return &Group{
v: g.v,
prefix: g.prefix + prefix,
middleware: combined,
}
}
// Use appends middleware to this group's stack.
func (g *Group) Use(mw ...Middleware) {
g.middleware = append(g.middleware, mw...)
}
// WithMiddleware returns an ActionOption that attaches middleware to an action.
// Action middleware runs after CSRF/rate-limit checks and signal injection.
func WithMiddleware(mw ...Middleware) ActionOption {
return func(e *actionEntry) {
e.middleware = append(e.middleware, mw...)
}
}
// chainMiddleware wraps handler with the given middleware, outer-first.
func chainMiddleware(mws []Middleware, handler func(*Context)) func(*Context) {
if len(mws) == 0 {
return handler
}
chained := handler
for i := len(mws) - 1; i >= 0; i-- {
mw := mws[i]
next := chained
chained = func(c *Context) {
mw(c, func() { next(c) })
}
}
return chained
}

340
middleware_test.go Normal file
View File

@@ -0,0 +1,340 @@
package via
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/ryanhamamura/via/h"
"github.com/stretchr/testify/assert"
)
func TestMiddlewareRunsBeforeHandler(t *testing.T) {
var order []string
v := New()
v.Use(func(c *Context, next func()) {
order = append(order, "mw")
next()
})
v.Page("/", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div() })
})
// Reset after registration (panic-check runs the raw handler)
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, []string{"mw", "handler"}, order)
}
func TestMiddlewareAbortSkipsHandler(t *testing.T) {
handlerCalled := false
v := New()
v.Use(func(c *Context, next func()) {
c.RedirectView("/other")
})
v.Page("/", func(c *Context) {
handlerCalled = true
c.View(func() h.H { return h.Div() })
})
handlerCalled = false
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.False(t, handlerCalled)
}
func TestMiddlewareChainOrder(t *testing.T) {
var order []string
v := New()
for _, label := range []string{"A", "B", "C"} {
l := label
v.Use(func(c *Context, next func()) {
order = append(order, l)
next()
})
}
v.Page("/", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div() })
})
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
assert.Equal(t, []string{"A", "B", "C", "handler"}, order)
}
func TestGroupPrefixRouting(t *testing.T) {
v := New()
g := v.Group("/admin")
g.Page("/dashboard", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("admin dashboard")) })
})
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dashboard", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "admin dashboard")
}
func TestGroupMiddlewareAppliesToGroupOnly(t *testing.T) {
var groupMwCalled bool
v := New()
g := v.Group("/admin", func(c *Context, next func()) {
groupMwCalled = true
next()
})
g.Page("/panel", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("panel")) })
})
v.Page("/public", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("public")) })
})
// Hit public page — group middleware should NOT run
groupMwCalled = false
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/public", nil))
assert.False(t, groupMwCalled)
assert.Contains(t, w.Body.String(), "public")
// Hit group page — group middleware should run
groupMwCalled = false
w = httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/panel", nil))
assert.True(t, groupMwCalled)
assert.Contains(t, w.Body.String(), "panel")
}
func TestGlobalMiddlewareAppliesToGroupPages(t *testing.T) {
var globalCalled bool
v := New()
v.Use(func(c *Context, next func()) {
globalCalled = true
next()
})
g := v.Group("/admin")
g.Page("/dash", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("dash")) })
})
globalCalled = false
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dash", nil))
assert.True(t, globalCalled)
assert.Contains(t, w.Body.String(), "dash")
}
func TestNestedGroupInheritsPrefixAndMiddleware(t *testing.T) {
var order []string
v := New()
admin := v.Group("/admin", func(c *Context, next func()) {
order = append(order, "admin")
next()
})
superAdmin := admin.Group("/super", func(c *Context, next func()) {
order = append(order, "super")
next()
})
superAdmin.Page("/secret", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div(h.Text("secret")) })
})
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/super/secret", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, []string{"admin", "super", "handler"}, order)
assert.Contains(t, w.Body.String(), "secret")
}
func TestGroupUse(t *testing.T) {
var order []string
v := New()
g := v.Group("/api")
g.Use(func(c *Context, next func()) {
order = append(order, "added-later")
next()
})
g.Page("/items", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div() })
})
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/api/items", nil))
assert.Equal(t, []string{"added-later", "handler"}, order)
}
func TestRedirectViewSetsValidView(t *testing.T) {
v := New()
v.Page("/test", func(c *Context) {
c.RedirectView("/somewhere")
})
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/test", nil))
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "<!doctype html>")
}
func TestGlobalAndGroupMiddlewareOrder(t *testing.T) {
var order []string
v := New()
v.Use(func(c *Context, next func()) {
order = append(order, "global")
next()
})
g := v.Group("/g", func(c *Context, next func()) {
order = append(order, "group")
next()
})
g.Page("/page", func(c *Context) {
order = append(order, "handler")
c.View(func() h.H { return h.Div() })
})
order = nil
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/g/page", nil))
assert.Equal(t, []string{"global", "group", "handler"}, order)
}
// --- Action middleware tests ---
func TestActionMiddlewareRunsBeforeAction(t *testing.T) {
var order []string
v := New()
c := newContext("test", "/", v)
mw := func(_ *Context, next func()) {
order = append(order, "mw")
next()
}
trigger := c.Action(func() {
order = append(order, "action")
}, WithMiddleware(mw))
entry, err := c.getAction(trigger.id)
assert.NoError(t, err)
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
assert.Equal(t, []string{"mw", "action"}, order)
}
func TestActionMiddlewareAbortSkipsAction(t *testing.T) {
actionCalled := false
v := New()
c := newContext("test", "/", v)
mw := func(_ *Context, next func()) {
// don't call next — action should not run
}
trigger := c.Action(func() {
actionCalled = true
}, WithMiddleware(mw))
entry, err := c.getAction(trigger.id)
assert.NoError(t, err)
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
assert.False(t, actionCalled)
}
func TestActionMiddlewareChainOrder(t *testing.T) {
var order []string
v := New()
c := newContext("test", "/", v)
var mws []Middleware
for _, label := range []string{"A", "B", "C"} {
l := label
mws = append(mws, func(_ *Context, next func()) {
order = append(order, l)
next()
})
}
trigger := c.Action(func() {
order = append(order, "action")
}, WithMiddleware(mws...))
entry, err := c.getAction(trigger.id)
assert.NoError(t, err)
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
assert.Equal(t, []string{"A", "B", "C", "action"}, order)
}
func TestActionMiddlewareCombinedWithRateLimit(t *testing.T) {
v := New()
c := newContext("test", "/", v)
mw := func(_ *Context, next func()) { next() }
trigger := c.Action(func() {}, WithRateLimit(5, 10), WithMiddleware(mw))
entry, err := c.getAction(trigger.id)
assert.NoError(t, err)
assert.NotNil(t, entry.limiter)
assert.Len(t, entry.middleware, 1)
}
func TestGroupWithEmptyPrefix(t *testing.T) {
var mwCalled bool
v := New()
g := v.Group("", func(c *Context, next func()) {
mwCalled = true
next()
})
g.Page("/dashboard", func(c *Context) {
c.View(func() h.H { return h.Div(h.Text("dash")) })
})
mwCalled = false
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/dashboard", nil))
assert.True(t, mwCalled)
assert.Contains(t, w.Body.String(), "dash")
}

190
nats.go Normal file
View File

@@ -0,0 +1,190 @@
package via
import (
"context"
"encoding/json"
"fmt"
"os"
"sync"
"time"
"github.com/delaneyj/toolbelt/embeddednats"
"github.com/nats-io/nats.go"
)
// defaultNATS is the process-scoped embedded NATS server.
type defaultNATS struct {
server *embeddednats.Server
nc *nats.Conn
js nats.JetStreamContext
cancel context.CancelFunc
dataDir string
}
var (
sharedNATS *defaultNATS
sharedNATSOnce sync.Once
sharedNATSErr error
)
// getSharedNATS returns a process-level singleton embedded NATS server.
// The server starts once and is reused across all V instances.
func getSharedNATS() (*defaultNATS, error) {
sharedNATSOnce.Do(func() {
sharedNATS, sharedNATSErr = startDefaultNATS()
})
return sharedNATS, sharedNATSErr
}
func startDefaultNATS() (dn *defaultNATS, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("nats server panic: %v", r)
}
}()
dataDir, err := os.MkdirTemp("", "via-nats-*")
if err != nil {
return nil, fmt.Errorf("create temp dir: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
ns, err := embeddednats.New(ctx, embeddednats.WithDirectory(dataDir))
if err != nil {
cancel()
os.RemoveAll(dataDir)
return nil, fmt.Errorf("start embedded nats: %w", err)
}
ns.WaitForServer()
nc, err := ns.Client()
if err != nil {
ns.Close()
cancel()
os.RemoveAll(dataDir)
return nil, fmt.Errorf("connect nats client: %w", err)
}
js, err := nc.JetStream()
if err != nil {
nc.Close()
ns.Close()
cancel()
os.RemoveAll(dataDir)
return nil, fmt.Errorf("init jetstream: %w", err)
}
return &defaultNATS{
server: ns,
nc: nc,
js: js,
cancel: cancel,
dataDir: dataDir,
}, nil
}
func (n *defaultNATS) Publish(subject string, data []byte) error {
return n.nc.Publish(subject, data)
}
func (n *defaultNATS) Subscribe(subject string, handler func(data []byte)) (Subscription, error) {
sub, err := n.nc.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data)
})
if err != nil {
return nil, err
}
return sub, nil
}
// natsRef wraps a shared defaultNATS as a PubSub. Close is a no-op because
// the underlying server is process-scoped and outlives individual V instances.
type natsRef struct {
dn *defaultNATS
}
func (r *natsRef) Publish(subject string, data []byte) error {
return r.dn.Publish(subject, data)
}
func (r *natsRef) Subscribe(subject string, handler func(data []byte)) (Subscription, error) {
return r.dn.Subscribe(subject, handler)
}
func (r *natsRef) Close() error {
return nil
}
// NATSConn returns the underlying NATS connection from the built-in embedded
// server, or nil if a custom PubSub backend is in use.
func (v *V) NATSConn() *nats.Conn {
if v.defaultNATS != nil {
return v.defaultNATS.nc
}
return nil
}
// JetStream returns the JetStream context from the built-in embedded server,
// or nil if a custom PubSub backend is in use.
func (v *V) JetStream() nats.JetStreamContext {
if v.defaultNATS != nil {
return v.defaultNATS.js
}
return nil
}
// StreamConfig holds the parameters for creating or updating a JetStream stream.
type StreamConfig struct {
Name string
Subjects []string
MaxMsgs int64
MaxAge time.Duration
}
// EnsureStream creates or updates a JetStream stream matching cfg.
func EnsureStream(v *V, cfg StreamConfig) error {
js := v.JetStream()
if js == nil {
return fmt.Errorf("jetstream not available")
}
_, err := js.AddStream(&nats.StreamConfig{
Name: cfg.Name,
Subjects: cfg.Subjects,
Retention: nats.LimitsPolicy,
MaxMsgs: cfg.MaxMsgs,
MaxAge: cfg.MaxAge,
})
return err
}
// ReplayHistory fetches the last limit messages from subject,
// deserializing each as T. Returns an empty slice if nothing is available.
func ReplayHistory[T any](v *V, subject string, limit int) ([]T, error) {
js := v.JetStream()
if js == nil {
return nil, fmt.Errorf("jetstream not available")
}
sub, err := js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer())
if err != nil {
return nil, err
}
defer sub.Unsubscribe()
var msgs []T
for {
raw, err := sub.NextMsg(200 * time.Millisecond)
if err != nil {
break
}
var msg T
if json.Unmarshal(raw.Data, &msg) == nil {
msgs = append(msgs, msg)
}
}
if limit > 0 && len(msgs) > limit {
msgs = msgs[len(msgs)-limit:]
}
return msgs, nil
}

View File

@@ -2,7 +2,6 @@ package via
import ( import (
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -11,88 +10,36 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type mockHandler struct {
id int64
fn func([]byte)
active atomic.Bool
}
// mockPubSub implements PubSub for testing without NATS.
type mockPubSub struct {
mu sync.Mutex
subs map[string][]*mockHandler
nextID atomic.Int64
}
func newMockPubSub() *mockPubSub {
return &mockPubSub{subs: make(map[string][]*mockHandler)}
}
func (m *mockPubSub) Publish(subject string, data []byte) error {
m.mu.Lock()
handlers := make([]*mockHandler, len(m.subs[subject]))
copy(handlers, m.subs[subject])
m.mu.Unlock()
for _, h := range handlers {
if h.active.Load() {
h.fn(data)
}
}
return nil
}
func (m *mockPubSub) Subscribe(subject string, handler func(data []byte)) (Subscription, error) {
m.mu.Lock()
defer m.mu.Unlock()
mh := &mockHandler{
id: m.nextID.Add(1),
fn: handler,
}
mh.active.Store(true)
m.subs[subject] = append(m.subs[subject], mh)
return &mockSub{handler: mh}, nil
}
func (m *mockPubSub) Close() error { return nil }
type mockSub struct {
handler *mockHandler
}
func (s *mockSub) Unsubscribe() error {
s.handler.active.Store(false)
return nil
}
func TestPubSub_RoundTrip(t *testing.T) { func TestPubSub_RoundTrip(t *testing.T) {
ps := newMockPubSub()
v := New() v := New()
v.Config(Options{PubSub: ps}) defer v.Shutdown()
var received []byte var received []byte
var wg sync.WaitGroup done := make(chan struct{})
wg.Add(1)
c := newContext("test-ctx", "/", v) c := newContext("test-ctx", "/", v)
c.View(func() h.H { return h.Div() }) c.View(func() h.H { return h.Div() })
_, err := c.Subscribe("test.topic", func(data []byte) { _, err := c.Subscribe("test.topic", func(data []byte) {
received = data received = data
wg.Done() close(done)
}) })
require.NoError(t, err) require.NoError(t, err)
err = c.Publish("test.topic", []byte("hello")) err = c.Publish("test.topic", []byte("hello"))
require.NoError(t, err) require.NoError(t, err)
wg.Wait() select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for message")
}
assert.Equal(t, []byte("hello"), received) assert.Equal(t, []byte("hello"), received)
} }
func TestPubSub_MultipleSubscribers(t *testing.T) { func TestPubSub_MultipleSubscribers(t *testing.T) {
ps := newMockPubSub()
v := New() v := New()
v.Config(Options{PubSub: ps}) defer v.Shutdown()
var mu sync.Mutex var mu sync.Mutex
var results []string var results []string
@@ -119,7 +66,17 @@ func TestPubSub_MultipleSubscribers(t *testing.T) {
}) })
c1.Publish("broadcast", []byte("msg")) c1.Publish("broadcast", []byte("msg"))
wg.Wait()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for messages")
}
assert.Len(t, results, 2) assert.Len(t, results, 2)
assert.Contains(t, results, "c1:msg") assert.Contains(t, results, "c1:msg")
@@ -127,9 +84,8 @@ func TestPubSub_MultipleSubscribers(t *testing.T) {
} }
func TestPubSub_SubscriptionCleanupOnDispose(t *testing.T) { func TestPubSub_SubscriptionCleanupOnDispose(t *testing.T) {
ps := newMockPubSub()
v := New() v := New()
v.Config(Options{PubSub: ps}) defer v.Shutdown()
c := newContext("cleanup-ctx", "/", v) c := newContext("cleanup-ctx", "/", v)
c.View(func() h.H { return h.Div() }) c.View(func() h.H { return h.Div() })
@@ -144,9 +100,8 @@ func TestPubSub_SubscriptionCleanupOnDispose(t *testing.T) {
} }
func TestPubSub_ManualUnsubscribe(t *testing.T) { func TestPubSub_ManualUnsubscribe(t *testing.T) {
ps := newMockPubSub()
v := New() v := New()
v.Config(Options{PubSub: ps}) defer v.Shutdown()
c := newContext("unsub-ctx", "/", v) c := newContext("unsub-ctx", "/", v)
c.View(func() h.H { return h.Div() }) c.View(func() h.H { return h.Div() })
@@ -160,28 +115,13 @@ func TestPubSub_ManualUnsubscribe(t *testing.T) {
sub.Unsubscribe() sub.Unsubscribe()
c.Publish("topic", []byte("ignored")) c.Publish("topic", []byte("ignored"))
time.Sleep(10 * time.Millisecond) time.Sleep(50 * time.Millisecond)
assert.False(t, called) assert.False(t, called)
} }
func TestPubSub_NoOpWhenNotConfigured(t *testing.T) {
v := New()
c := newContext("noop-ctx", "/", v)
c.View(func() h.H { return h.Div() })
err := c.Publish("topic", []byte("data"))
assert.Error(t, err)
sub, err := c.Subscribe("topic", func(data []byte) {})
assert.Error(t, err)
assert.Nil(t, sub)
}
func TestPubSub_NoOpDuringPanicCheck(t *testing.T) { func TestPubSub_NoOpDuringPanicCheck(t *testing.T) {
ps := newMockPubSub()
v := New() v := New()
v.Config(Options{PubSub: ps}) defer v.Shutdown()
// Panic-check context has id="" // Panic-check context has id=""
c := newContext("", "/", v) c := newContext("", "/", v)

51
navigate.js Normal file
View File

@@ -0,0 +1,51 @@
(function() {
const meta = document.querySelector('meta[data-signals]');
if (!meta) return;
const raw = meta.getAttribute('data-signals');
const parsed = JSON.parse(raw.replace(/'/g, '"'));
const ctxID = parsed['via-ctx'];
const csrf = parsed['via-csrf'];
if (!ctxID || !csrf) return;
function navigate(url, popstate) {
const params = new URLSearchParams({
'via-ctx': ctxID,
'via-csrf': csrf,
'url': url,
});
if (popstate) params.set('popstate', '1');
fetch('/_navigate', {
method: 'POST',
headers: {'Content-Type': 'application/x-www-form-urlencoded'},
body: params.toString()
}).then(function(res) {
if (!res.ok) window.location.href = url;
}).catch(function() {
window.location.href = url;
});
}
document.addEventListener('click', function(e) {
var el = e.target;
while (el && el.tagName !== 'A') el = el.parentElement;
if (!el) return;
if (e.ctrlKey || e.metaKey || e.shiftKey || e.altKey) return;
if (el.hasAttribute('target')) return;
if (el.hasAttribute('data-via-no-boost')) return;
var href = el.getAttribute('href');
if (!href || href.startsWith('#')) return;
try {
var url = new URL(href, window.location.origin);
if (url.origin !== window.location.origin) return;
e.preventDefault();
navigate(url.pathname + url.search + url.hash);
} catch(_) {}
});
var ready = false;
window.addEventListener('popstate', function() {
if (!ready) return;
navigate(window.location.pathname + window.location.search + window.location.hash, true);
});
setTimeout(function() { ready = true; }, 0);
})();

View File

@@ -1,7 +1,8 @@
package via package via
// PubSub is an interface for publish/subscribe messaging backends. // PubSub is an interface for publish/subscribe messaging backends.
// The vianats sub-package provides an embedded NATS implementation. // By default, New() starts an embedded NATS server. Supply a custom
// implementation via Config(Options{PubSub: yourBackend}) to override.
type PubSub interface { type PubSub interface {
Publish(subject string, data []byte) error Publish(subject string, data []byte) error
Subscribe(subject string, handler func(data []byte)) (Subscription, error) Subscribe(subject string, handler func(data []byte)) (Subscription, error)

View File

@@ -1,8 +1,8 @@
package via package via
import ( import (
"sync"
"testing" "testing"
"time"
"github.com/ryanhamamura/via/h" "github.com/ryanhamamura/via/h"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -10,9 +10,8 @@ import (
) )
func TestPublishSubscribe_RoundTrip(t *testing.T) { func TestPublishSubscribe_RoundTrip(t *testing.T) {
ps := newMockPubSub()
v := New() v := New()
v.Config(Options{PubSub: ps}) defer v.Shutdown()
type event struct { type event struct {
Name string `json:"name"` Name string `json:"name"`
@@ -20,30 +19,32 @@ func TestPublishSubscribe_RoundTrip(t *testing.T) {
} }
var got event var got event
var wg sync.WaitGroup done := make(chan struct{})
wg.Add(1)
c := newContext("typed-ctx", "/", v) c := newContext("typed-ctx", "/", v)
c.View(func() h.H { return h.Div() }) c.View(func() h.H { return h.Div() })
_, err := Subscribe(c, "events", func(e event) { _, err := Subscribe(c, "events", func(e event) {
got = e got = e
wg.Done() close(done)
}) })
require.NoError(t, err) require.NoError(t, err)
err = Publish(c, "events", event{Name: "click", Count: 42}) err = Publish(c, "events", event{Name: "click", Count: 42})
require.NoError(t, err) require.NoError(t, err)
wg.Wait() select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for message")
}
assert.Equal(t, "click", got.Name) assert.Equal(t, "click", got.Name)
assert.Equal(t, 42, got.Count) assert.Equal(t, 42, got.Count)
} }
func TestSubscribe_SkipsBadJSON(t *testing.T) { func TestSubscribe_SkipsBadJSON(t *testing.T) {
ps := newMockPubSub()
v := New() v := New()
v.Config(Options{PubSub: ps}) defer v.Shutdown()
type msg struct { type msg struct {
Text string `json:"text"` Text string `json:"text"`
@@ -62,5 +63,6 @@ func TestSubscribe_SkipsBadJSON(t *testing.T) {
err = c.Publish("topic", []byte("not json")) err = c.Publish("topic", []byte("not json"))
require.NoError(t, err) require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.False(t, called) assert.False(t, called)
} }

View File

@@ -18,8 +18,9 @@ type RateLimitConfig struct {
type ActionOption func(*actionEntry) type ActionOption func(*actionEntry)
type actionEntry struct { type actionEntry struct {
fn func() fn func()
limiter *rate.Limiter // nil = use context default limiter *rate.Limiter // nil = use context default
middleware []Middleware
} }
// WithRateLimit returns an ActionOption that gives this action its own // WithRateLimit returns an ActionOption that gives this action its own

View File

@@ -1,76 +1,34 @@
package via package via
import ( import (
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
// OnIntervalRoutine allows for defining concurrent goroutines safely. Goroutines started by *OnIntervalRoutine func newOnInterval(ctxDisposedChan, pageStopChan chan struct{}, duration time.Duration, handler func()) func() {
// are tied to the *Context lifecycle. localInterrupt := make(chan struct{})
type OnIntervalRoutine struct { var stopped atomic.Bool
mu sync.RWMutex
ctxDisposed chan struct{}
localInterrupt chan struct{}
isRunning atomic.Bool
routineFn func()
tckDuration time.Duration
updateTkrChan chan time.Duration
}
// UpdateInterval sets a new interval duration for the internal *time.Ticker. If the provided go func() {
// duration is equal of less than 0, UpdateInterval does nothing. tkr := time.NewTicker(duration)
func (r *OnIntervalRoutine) UpdateInterval(d time.Duration) { defer tkr.Stop()
r.mu.Lock()
defer r.mu.Unlock()
r.tckDuration = d
r.updateTkrChan <- d
}
// Start executes the predifined goroutine. If no predifined goroutine exists, or it already
// started, Start does nothing.
func (r *OnIntervalRoutine) Start() {
if !r.isRunning.CompareAndSwap(false, true) || r.routineFn == nil {
return
}
go r.routineFn()
}
// Stop interrupts the predifined goroutine. If no predifined goroutine exists, or it already
// ustopped, Stop does nothing.
func (r *OnIntervalRoutine) Stop() {
if !r.isRunning.CompareAndSwap(true, false) || r.routineFn == nil {
return
}
r.localInterrupt <- struct{}{}
}
func newOnIntervalRoutine(ctxDisposedChan chan struct{},
duration time.Duration, handler func()) *OnIntervalRoutine {
r := &OnIntervalRoutine{
ctxDisposed: ctxDisposedChan,
localInterrupt: make(chan struct{}),
updateTkrChan: make(chan time.Duration),
}
r.tckDuration = duration
r.routineFn = func() {
r.mu.RLock()
tkr := time.NewTicker(r.tckDuration)
r.mu.RUnlock()
defer tkr.Stop() // clean up the ticker when routine stops
for { for {
select { select {
case <-r.ctxDisposed: // dispose of the routine when ctx is disposed case <-ctxDisposedChan:
return return
case <-r.localInterrupt: // dispose of the routine on interrupt signal case <-pageStopChan:
return
case <-localInterrupt:
return return
case d := <-r.updateTkrChan:
tkr.Reset(d)
case <-tkr.C: case <-tkr.C:
handler() handler()
} }
} }
}()
return func() {
if stopped.CompareAndSwap(false, true) {
close(localInterrupt)
}
} }
return r
} }

130
rule.go Normal file
View File

@@ -0,0 +1,130 @@
package via
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"unicode/utf8"
)
// Rule defines a single validation check for a Field.
type Rule struct {
validate func(val string) error
}
// Required rejects empty or whitespace-only values.
func Required(msg ...string) Rule {
m := "This field is required"
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
if strings.TrimSpace(val) == "" {
return errors.New(m)
}
return nil
}}
}
// MinLen rejects values shorter than n characters.
func MinLen(n int, msg ...string) Rule {
m := fmt.Sprintf("Must be at least %d characters", n)
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
if utf8.RuneCountInString(val) < n {
return errors.New(m)
}
return nil
}}
}
// MaxLen rejects values longer than n characters.
func MaxLen(n int, msg ...string) Rule {
m := fmt.Sprintf("Must be at most %d characters", n)
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
if utf8.RuneCountInString(val) > n {
return errors.New(m)
}
return nil
}}
}
// Min parses the value as an integer and rejects values less than n.
func Min(n int, msg ...string) Rule {
m := fmt.Sprintf("Must be at least %d", n)
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
v, err := strconv.Atoi(val)
if err != nil {
return errors.New("Must be a valid number")
}
if v < n {
return errors.New(m)
}
return nil
}}
}
// Max parses the value as an integer and rejects values greater than n.
func Max(n int, msg ...string) Rule {
m := fmt.Sprintf("Must be at most %d", n)
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
v, err := strconv.Atoi(val)
if err != nil {
return errors.New("Must be a valid number")
}
if v > n {
return errors.New(m)
}
return nil
}}
}
// Pattern rejects values that don't match the regular expression re.
func Pattern(re string, msg ...string) Rule {
m := "Invalid format"
if len(msg) > 0 {
m = msg[0]
}
compiled := regexp.MustCompile(re)
return Rule{func(val string) error {
if !compiled.MatchString(val) {
return errors.New(m)
}
return nil
}}
}
var emailRegexp = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
// Email rejects values that don't look like an email address.
func Email(msg ...string) Rule {
m := "Invalid email address"
if len(msg) > 0 {
m = msg[0]
}
return Rule{func(val string) error {
if !emailRegexp.MatchString(val) {
return errors.New(m)
}
return nil
}}
}
// Custom creates a rule from a user-provided validation function.
// The function should return nil for valid input and an error for invalid input.
func Custom(fn func(string) error) Rule {
return Rule{validate: fn}
}

116
rule_test.go Normal file
View File

@@ -0,0 +1,116 @@
package via
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRequired(t *testing.T) {
r := Required()
assert.NoError(t, r.validate("hello"))
assert.Error(t, r.validate(""))
assert.Error(t, r.validate(" "))
}
func TestRequiredCustomMessage(t *testing.T) {
r := Required("name needed")
err := r.validate("")
assert.EqualError(t, err, "name needed")
}
func TestMinLen(t *testing.T) {
r := MinLen(3)
assert.NoError(t, r.validate("abc"))
assert.NoError(t, r.validate("abcd"))
assert.Error(t, r.validate("ab"))
assert.Error(t, r.validate(""))
}
func TestMinLenCustomMessage(t *testing.T) {
r := MinLen(5, "too short")
err := r.validate("ab")
assert.EqualError(t, err, "too short")
}
func TestMaxLen(t *testing.T) {
r := MaxLen(5)
assert.NoError(t, r.validate("abc"))
assert.NoError(t, r.validate("abcde"))
assert.Error(t, r.validate("abcdef"))
}
func TestMaxLenCustomMessage(t *testing.T) {
r := MaxLen(2, "too long")
err := r.validate("abc")
assert.EqualError(t, err, "too long")
}
func TestMin(t *testing.T) {
r := Min(5)
assert.NoError(t, r.validate("5"))
assert.NoError(t, r.validate("10"))
assert.Error(t, r.validate("4"))
assert.Error(t, r.validate("abc"))
}
func TestMinCustomMessage(t *testing.T) {
r := Min(10, "need 10+")
err := r.validate("3")
assert.EqualError(t, err, "need 10+")
}
func TestMax(t *testing.T) {
r := Max(10)
assert.NoError(t, r.validate("10"))
assert.NoError(t, r.validate("5"))
assert.Error(t, r.validate("11"))
assert.Error(t, r.validate("abc"))
}
func TestMaxCustomMessage(t *testing.T) {
r := Max(5, "too big")
err := r.validate("6")
assert.EqualError(t, err, "too big")
}
func TestPattern(t *testing.T) {
r := Pattern(`^\d{3}$`)
assert.NoError(t, r.validate("123"))
assert.Error(t, r.validate("12"))
assert.Error(t, r.validate("abcd"))
}
func TestPatternCustomMessage(t *testing.T) {
r := Pattern(`^\d+$`, "digits only")
err := r.validate("abc")
assert.EqualError(t, err, "digits only")
}
func TestEmail(t *testing.T) {
r := Email()
assert.NoError(t, r.validate("user@example.com"))
assert.NoError(t, r.validate("a.b+c@foo.co"))
assert.Error(t, r.validate("notanemail"))
assert.Error(t, r.validate("@example.com"))
assert.Error(t, r.validate("user@"))
assert.Error(t, r.validate(""))
}
func TestEmailCustomMessage(t *testing.T) {
r := Email("bad email")
err := r.validate("nope")
assert.EqualError(t, err, "bad email")
}
func TestCustom(t *testing.T) {
r := Custom(func(val string) error {
if val != "magic" {
return fmt.Errorf("must be magic")
}
return nil
})
assert.NoError(t, r.validate("magic"))
assert.EqualError(t, r.validate("other"), "must be magic")
}

View File

@@ -81,26 +81,3 @@ func (s *signal) Int() int {
return 0 return 0
} }
// Int64 tries to read the signal value as an int64.
// Returns the value or 0 on failure.
func (s *signal) Int64() int64 {
if n, err := strconv.ParseInt(s.String(), 10, 64); err == nil {
return n
}
return 0
}
// Float64 tries to read the signal value as a float64.
// Returns the value or 0.0 on failure.
func (s *signal) Float() float64 {
if n, err := strconv.ParseFloat(s.String(), 64); err == nil {
return n
}
return 0.0
}
// Bytes tries to read the signal value as a []byte
// Returns the value or an empty []byte on failure.
func (s *signal) Bytes() []byte {
return []byte(s.String())
}

143
static_test.go Normal file
View File

@@ -0,0 +1,143 @@
package via
import (
"io/fs"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"testing/fstest"
"github.com/stretchr/testify/assert"
)
func TestStatic(t *testing.T) {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "sub"), 0755)
os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello world"), 0644)
os.WriteFile(filepath.Join(dir, "sub", "nested.txt"), []byte("nested"), 0644)
v := New()
v.Static("/assets/", dir)
t.Run("serves file", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/hello.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "hello world", w.Body.String())
})
t.Run("serves nested file", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/sub/nested.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "nested", w.Body.String())
})
t.Run("directory listing returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("subdirectory listing returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/sub/", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("missing file returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/assets/nope.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
}
func TestStaticAutoSlash(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "ok.txt"), []byte("ok"), 0644)
v := New()
v.Static("/files", dir) // no trailing slash
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/files/ok.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "ok", w.Body.String())
}
func TestStaticFS(t *testing.T) {
fsys := fstest.MapFS{
"style.css": {Data: []byte("body{}")},
"js/app.js": {Data: []byte("console.log('hi')")},
}
v := New()
v.StaticFS("/static/", fsys)
t.Run("serves file", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/static/style.css", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "body{}", w.Body.String())
})
t.Run("serves nested file", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/static/js/app.js", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "console.log('hi')", w.Body.String())
})
t.Run("directory listing returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/static/", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("missing file returns 404", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/static/nope.css", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusNotFound, w.Code)
})
}
func TestStaticFSAutoSlash(t *testing.T) {
fsys := fstest.MapFS{
"ok.txt": {Data: []byte("ok")},
}
v := New()
v.StaticFS("/embed", fsys) // no trailing slash
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/embed/ok.txt", nil)
v.mux.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "ok", w.Body.String())
}
// Verify StaticFS accepts the fs.FS interface (compile-time check).
var _ fs.FS = fstest.MapFS{}

181
via.go
View File

@@ -9,12 +9,13 @@ package via
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
_ "embed"
"crypto/subtle" "crypto/subtle"
_ "embed"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/fs"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -34,6 +35,9 @@ import (
//go:embed datastar.js //go:embed datastar.js
var datastarJS []byte var datastarJS []byte
//go:embed navigate.js
var navigateJS []byte
// V is the root application. // V is the root application.
// It manages page routing, user sessions, and SSE connections for live updates. // It manages page routing, user sessions, and SSE connections for live updates.
type V struct { type V struct {
@@ -46,13 +50,17 @@ type V struct {
documentHeadIncludes []h.H documentHeadIncludes []h.H
documentFootIncludes []h.H documentFootIncludes []h.H
devModePageInitFnMap map[string]func(*Context) devModePageInitFnMap map[string]func(*Context)
pageRegistry map[string]func(*Context)
sessionManager *scs.SessionManager sessionManager *scs.SessionManager
pubsub PubSub pubsub PubSub
defaultNATS *defaultNATS
actionRateLimit RateLimitConfig actionRateLimit RateLimitConfig
datastarPath string datastarPath string
datastarContent []byte datastarContent []byte
datastarOnce sync.Once datastarOnce sync.Once
reaperStop chan struct{} reaperStop chan struct{}
middleware []Middleware
layout func(func() h.H) h.H
} }
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event { func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
@@ -128,6 +136,7 @@ func (v *V) Config(cfg Options) {
v.datastarPath = cfg.DatastarPath v.datastarPath = cfg.DatastarPath
} }
if cfg.PubSub != nil { if cfg.PubSub != nil {
v.defaultNATS = nil
v.pubsub = cfg.PubSub v.pubsub = cfg.PubSub
} }
if cfg.ContextTTL != 0 { if cfg.ContextTTL != 0 {
@@ -169,8 +178,16 @@ func (v *V) AppendToFoot(elements ...h.H) {
// }) // })
// }) // })
func (v *V) Page(route string, initContextFn func(c *Context)) { func (v *V) Page(route string, initContextFn func(c *Context)) {
wrapped := chainMiddleware(v.middleware, initContextFn)
v.page(route, initContextFn, wrapped)
}
// page registers a route with separate raw and wrapped init functions.
// raw is used for the panic-check at registration time; wrapped includes
// any middleware and is used as the live handler.
func (v *V) page(route string, raw, wrapped func(*Context)) {
v.ensureDatastarHandler() v.ensureDatastarHandler()
// check for panics // check for panics using the raw handler (no middleware)
func() { func() {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
@@ -179,14 +196,14 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
} }
}() }()
c := newContext("", "", v) c := newContext("", "", v)
initContextFn(c) raw(c)
c.view() c.view()
c.stopAllRoutines() c.stopAllRoutines()
}() }()
// save page init function allows devmode to restore persisted ctx later v.pageRegistry[route] = wrapped
if v.cfg.DevMode { if v.cfg.DevMode {
v.devModePageInitFnMap[route] = initContextFn v.devModePageInitFnMap[route] = wrapped
} }
v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v.logDebug(nil, "GET %s", r.URL.String()) v.logDebug(nil, "GET %s", r.URL.String())
@@ -200,7 +217,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
c.reqCtx = r.Context() c.reqCtx = r.Context()
routeParams := extractParams(route, r.URL.Path) routeParams := extractParams(route, r.URL.Path)
c.injectRouteParams(routeParams) c.injectRouteParams(routeParams)
initContextFn(c) wrapped(c)
v.registerCtx(c) v.registerCtx(c)
if v.cfg.DevMode { if v.cfg.DevMode {
v.devModePersist(c) v.devModePersist(c)
@@ -212,6 +229,8 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
h.Meta(h.Data("init", "@get('/_sse')")), h.Meta(h.Data("init", "@get('/_sse')")),
h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => { h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => {
navigator.sendBeacon('/_session/close', '%s');});`, c.id))), navigator.sendBeacon('/_session/close', '%s');});`, c.id))),
h.Meta(h.Attr("name", "view-transition"), h.Attr("content", "same-origin")),
h.Script(h.Raw(string(navigateJS))),
) )
bodyElements := []h.H{c.view()} bodyElements := []h.H{c.view()}
@@ -225,8 +244,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
Title: v.cfg.DocumentTitle, Title: v.cfg.DocumentTitle,
Head: headElements, Head: headElements,
Body: bodyElements, Body: bodyElements,
HTMLAttrs: []h.H{}, })
})
_ = view.Render(w) _ = view.Render(w)
})) }))
} }
@@ -234,17 +252,9 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
func (v *V) registerCtx(c *Context) { func (v *V) registerCtx(c *Context) {
v.contextRegistryMutex.Lock() v.contextRegistryMutex.Lock()
defer v.contextRegistryMutex.Unlock() defer v.contextRegistryMutex.Unlock()
if c == nil {
v.logErr(c, "failed to add nil context to registry")
return
}
v.contextRegistry[c.id] = c v.contextRegistry[c.id] = c
v.logDebug(c, "new context added to registry") v.logDebug(c, "new context added to registry")
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum()) v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
}
func (v *V) currSessionNum() int {
return len(v.contextRegistry)
} }
func (v *V) cleanupCtx(c *Context) { func (v *V) cleanupCtx(c *Context) {
@@ -264,7 +274,7 @@ func (v *V) unregisterCtx(c *Context) {
defer v.contextRegistryMutex.Unlock() defer v.contextRegistryMutex.Unlock()
v.logDebug(c, "ctx removed from registry") v.logDebug(c, "ctx removed from registry")
delete(v.contextRegistry, c.id) delete(v.contextRegistry, c.id)
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum()) v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
} }
func (v *V) getCtx(id string) (*Context, error) { func (v *V) getCtx(id string) (*Context, error) {
@@ -354,16 +364,12 @@ func (v *V) Start() {
return return
} }
v.shutdown() v.Shutdown()
} }
// Shutdown gracefully shuts down the server and all contexts. // Shutdown gracefully shuts down the server and all contexts.
// Safe for programmatic or test use. // Safe for programmatic or test use.
func (v *V) Shutdown() { func (v *V) Shutdown() {
v.shutdown()
}
func (v *V) shutdown() {
if v.reaperStop != nil { if v.reaperStop != nil {
close(v.reaperStop) close(v.reaperStop)
} }
@@ -383,6 +389,7 @@ func (v *V) shutdown() {
v.logErr(nil, "pubsub close error: %v", err) v.logErr(nil, "pubsub close error: %v", err)
} }
} }
v.defaultNATS = nil
v.logInfo(nil, "shutdown complete") v.logInfo(nil, "shutdown complete")
} }
@@ -412,6 +419,51 @@ func (v *V) HTTPServeMux() *http.ServeMux {
return v.mux return v.mux
} }
// PubSub returns the configured PubSub backend, or nil if none is set.
func (v *V) PubSub() PubSub {
return v.pubsub
}
// Static serves files from a filesystem directory at the given URL prefix.
//
// Example:
//
// v.Static("/assets/", "./public")
func (v *V) Static(urlPrefix, dir string) {
if !strings.HasSuffix(urlPrefix, "/") {
urlPrefix += "/"
}
fileServer := http.StripPrefix(urlPrefix, http.FileServer(http.Dir(dir)))
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
}
// StaticFS serves files from an [fs.FS] at the given URL prefix.
// This is useful with //go:embed filesystems.
//
// Example:
//
// //go:embed static
// var staticFiles embed.FS
// v.StaticFS("/assets/", staticFiles)
func (v *V) StaticFS(urlPrefix string, fsys fs.FS) {
if !strings.HasSuffix(urlPrefix, "/") {
urlPrefix += "/"
}
fileServer := http.StripPrefix(urlPrefix, http.FileServerFS(fsys))
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
}
// noDirListing wraps a file server handler to return 404 for directory requests.
func noDirListing(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/") {
http.NotFound(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (v *V) ensureDatastarHandler() { func (v *V) ensureDatastarHandler() {
v.datastarOnce.Do(func() { v.datastarOnce.Do(func() {
v.mux.HandleFunc("GET "+v.datastarPath, func(w http.ResponseWriter, r *http.Request) { v.mux.HandleFunc("GET "+v.datastarPath, func(w http.ResponseWriter, r *http.Request) {
@@ -518,6 +570,7 @@ type patchType int
const ( const (
patchTypeElements = iota patchTypeElements = iota
patchTypeElementsWithVT
patchTypeSignals patchTypeSignals
patchTypeScript patchTypeScript
patchTypeRedirect patchTypeRedirect
@@ -538,6 +591,7 @@ func New() *V {
logger: newConsoleLogger(zerolog.InfoLevel), logger: newConsoleLogger(zerolog.InfoLevel),
contextRegistry: make(map[string]*Context), contextRegistry: make(map[string]*Context),
devModePageInitFnMap: make(map[string]func(*Context)), devModePageInitFnMap: make(map[string]func(*Context)),
pageRegistry: make(map[string]func(*Context)),
sessionManager: scs.New(), sessionManager: scs.New(),
datastarPath: "/_datastar.js", datastarPath: "/_datastar.js",
datastarContent: datastarJS, datastarContent: datastarJS,
@@ -573,9 +627,7 @@ func New() *V {
c.sseConnected.Store(true) c.sseConnected.Store(true)
v.logDebug(c, "SSE connection established") v.logDebug(c, "SSE connection established")
go func() { go c.Sync()
c.Sync()
}()
for { for {
select { select {
@@ -590,11 +642,16 @@ func New() *V {
switch patch.typ { switch patch.typ {
case patchTypeElements: case patchTypeElements:
if err := sse.PatchElements(patch.content); err != nil { if err := sse.PatchElements(patch.content); err != nil {
// Only log if connection wasn't closed (avoids noise during shutdown/tests)
if sse.Context().Err() == nil { if sse.Context().Err() == nil {
v.logErr(c, "PatchElements failed: %v", err) v.logErr(c, "PatchElements failed: %v", err)
} }
} }
case patchTypeElementsWithVT:
if err := sse.PatchElements(patch.content, datastar.WithViewTransitions()); err != nil {
if sse.Context().Err() == nil {
v.logErr(c, "PatchElements (view transition) failed: %v", err)
}
}
case patchTypeSignals: case patchTypeSignals:
if err := sse.PatchSignals([]byte(patch.content)); err != nil { if err := sse.PatchSignals([]byte(patch.content)); err != nil {
if sse.Context().Err() == nil { if sse.Context().Err() == nil {
@@ -667,7 +724,44 @@ func New() *V {
}() }()
c.injectSignals(sigs) c.injectSignals(sigs)
entry.fn() if len(entry.middleware) > 0 {
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
} else {
entry.fn()
}
})
v.mux.HandleFunc("POST /_navigate", func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
cID := r.FormValue("via-ctx")
csrfToken := r.FormValue("via-csrf")
navURL := r.FormValue("url")
popstate := r.FormValue("popstate") == "1"
if cID == "" || navURL == "" || !strings.HasPrefix(navURL, "/") {
http.Error(w, "missing or invalid parameters", http.StatusBadRequest)
return
}
c, err := v.getCtx(cID)
if err != nil {
v.logErr(nil, "navigate failed: %v", err)
http.Error(w, "context not found", http.StatusNotFound)
return
}
if subtle.ConstantTimeCompare([]byte(csrfToken), []byte(c.csrfToken)) != 1 {
v.logWarn(c, "navigate rejected: invalid CSRF token")
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
if c.actionLimiter != nil && !c.actionLimiter.Allow() {
v.logWarn(c, "navigate rate limited")
http.Error(w, "rate limited", http.StatusTooManyRequests)
return
}
c.reqCtx = r.Context()
v.logDebug(c, "SPA navigate to %s", navURL)
c.Navigate(navURL, popstate)
w.WriteHeader(http.StatusOK)
}) })
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) { v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
@@ -687,13 +781,22 @@ func New() *V {
v.logDebug(c, "session close event triggered") v.logDebug(c, "session close event triggered")
v.cleanupCtx(c) v.cleanupCtx(c)
}) })
dn, err := getSharedNATS()
if err != nil {
v.logWarn(nil, "embedded NATS unavailable: %v", err)
} else {
v.defaultNATS = dn
v.pubsub = &natsRef{dn: dn}
}
return v return v
} }
func genRandID() string { func genRandID() string {
b := make([]byte, 16) b := make([]byte, 4)
rand.Read(b) rand.Read(b)
return hex.EncodeToString(b)[:8] return hex.EncodeToString(b)
} }
func genCSRFToken() string { func genCSRFToken() string {
@@ -714,8 +817,24 @@ func extractParams(pattern, path string) map[string]string {
key := p[i][1 : len(p[i])-1] // remove {} key := p[i][1 : len(p[i])-1] // remove {}
params[key] = u[i] params[key] = u[i]
} else if p[i] != u[i] { } else if p[i] != u[i] {
continue return nil
} }
} }
return params return params
} }
// matchRoute finds the registered page init function and extracted params for the given path.
func (v *V) matchRoute(path string) (route string, initFn func(*Context), params map[string]string) {
for pattern, fn := range v.pageRegistry {
if p := extractParams(pattern, path); p != nil {
return pattern, fn, p
}
}
return "", nil, nil
}
// Layout sets a layout function that wraps every page's view.
// The layout receives the page content as a function and returns the full view.
func (v *V) Layout(f func(func() h.H) h.H) {
v.layout = f
}

View File

@@ -1,127 +0,0 @@
// Package vianats provides an embedded NATS server with JetStream as a
// pub/sub backend for Via applications.
package vianats
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/delaneyj/toolbelt/embeddednats"
"github.com/nats-io/nats.go"
"github.com/ryanhamamura/via"
)
// NATS implements via.PubSub using an embedded NATS server with JetStream.
type NATS struct {
server *embeddednats.Server
nc *nats.Conn
js nats.JetStreamContext
}
// New starts an embedded NATS server with JetStream enabled and returns a
// ready-to-use NATS instance. The server stores data in dataDir and shuts
// down when ctx is cancelled.
func New(ctx context.Context, dataDir string) (*NATS, error) {
ns, err := embeddednats.New(ctx, embeddednats.WithDirectory(dataDir))
if err != nil {
return nil, fmt.Errorf("vianats: start server: %w", err)
}
ns.WaitForServer()
nc, err := ns.Client()
if err != nil {
ns.Close()
return nil, fmt.Errorf("vianats: connect client: %w", err)
}
js, err := nc.JetStream()
if err != nil {
nc.Close()
ns.Close()
return nil, fmt.Errorf("vianats: init jetstream: %w", err)
}
return &NATS{server: ns, nc: nc, js: js}, nil
}
// Publish sends data to the given subject using core NATS publish.
// JetStream captures messages automatically if a matching stream exists.
func (n *NATS) Publish(subject string, data []byte) error {
return n.nc.Publish(subject, data)
}
// Subscribe creates a core NATS subscription for real-time fan-out delivery.
func (n *NATS) Subscribe(subject string, handler func(data []byte)) (via.Subscription, error) {
sub, err := n.nc.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data)
})
if err != nil {
return nil, err
}
return sub, nil
}
// Close shuts down the client connection and embedded server.
func (n *NATS) Close() error {
n.nc.Close()
return n.server.Close()
}
// Conn returns the underlying NATS connection for advanced usage.
func (n *NATS) Conn() *nats.Conn {
return n.nc
}
// JetStream returns the JetStream context for stream configuration and replay.
func (n *NATS) JetStream() nats.JetStreamContext {
return n.js
}
// StreamConfig holds the parameters for creating or updating a JetStream stream.
type StreamConfig struct {
Name string
Subjects []string
MaxMsgs int64
MaxAge time.Duration
}
// EnsureStream creates or updates a JetStream stream matching cfg.
func EnsureStream(n *NATS, cfg StreamConfig) error {
_, err := n.js.AddStream(&nats.StreamConfig{
Name: cfg.Name,
Subjects: cfg.Subjects,
Retention: nats.LimitsPolicy,
MaxMsgs: cfg.MaxMsgs,
MaxAge: cfg.MaxAge,
})
return err
}
// ReplayHistory fetches the last limit messages from subject,
// deserializing each as T. Returns an empty slice if nothing is available.
func ReplayHistory[T any](n *NATS, subject string, limit int) ([]T, error) {
sub, err := n.js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer())
if err != nil {
return nil, err
}
defer sub.Unsubscribe()
var msgs []T
for {
raw, err := sub.NextMsg(200 * time.Millisecond)
if err != nil {
break
}
var msg T
if json.Unmarshal(raw.Data, &msg) == nil {
msgs = append(msgs, msg)
}
}
if limit > 0 && len(msgs) > limit {
msgs = msgs[len(msgs)-limit:]
}
return msgs, nil
}