Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e636970f7b | ||
|
|
f5158b866c | ||
|
|
2f6c5916ce | ||
|
|
0762ddbbc2 | ||
|
|
b7acfa6302 | ||
|
|
8aa91c577c | ||
|
|
6dcd54c88b | ||
|
|
2c44671d0e |
75
README.md
75
README.md
@@ -1,30 +1,33 @@
|
|||||||
# ⚡Via
|
# Via
|
||||||
|
|
||||||
Real-time engine for building reactive web applications in pure Go.
|
Real-time engine for building reactive web applications in pure Go.
|
||||||
|
|
||||||
|
|
||||||
## Why Via?
|
## Why Via?
|
||||||
Somewhere along the way, the web became tangled in layers of JavaScript, build chains, and frameworks stacked on frameworks.
|
|
||||||
|
|
||||||
Via takes a radical stance:
|
The web became tangled in layers of JavaScript, build chains, and frameworks stacked on frameworks. Via takes a different path.
|
||||||
|
|
||||||
- No templates.
|
**Philosophy**
|
||||||
- No JavaScript.
|
- No templates. No JavaScript. No transpilation. No hydration.
|
||||||
- No transpilation.
|
- Views are pure Go functions. HTML is composed with a type-safe DSL.
|
||||||
- No hydration.
|
- A single SSE stream carries all reactivity — no WebSocket juggling, no polling.
|
||||||
- No front-end fatigue.
|
|
||||||
- Single SSE stream.
|
|
||||||
- Full reactivity.
|
|
||||||
- Built-in Brotli compression.
|
|
||||||
- Pure Go.
|
|
||||||
|
|
||||||
|
**Batteries included**
|
||||||
|
- Automatic CSRF protection on every action call
|
||||||
|
- Token-bucket rate limiting (global defaults + per-action overrides)
|
||||||
|
- Cookie-based sessions backed by SQLite
|
||||||
|
- Pub/sub messaging with an embedded NATS backend
|
||||||
|
- Structured logging via zerolog
|
||||||
|
- Graceful shutdown with context draining
|
||||||
|
- Brotli compression out of the box
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/go-via/via"
|
"github.com/ryanhamamura/via"
|
||||||
"github.com/go-via/via/h"
|
"github.com/ryanhamamura/via/h"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Counter struct{ Count int }
|
type Counter struct{ Count int }
|
||||||
@@ -57,25 +60,43 @@ func main() {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## What's built in
|
||||||
|
|
||||||
## 🚧 Experimental
|
- **Reactive views + signals** — bind state to the DOM; changes push over SSE automatically
|
||||||
<s>Via is still a newborn.</s> Via is taking its first steps!
|
- **Components** — self-contained subcontexts with their own data, actions, and signals
|
||||||
- Version `0.1.0` released.
|
- **Sessions** — cookie-based, backed by SQLite via `scs`
|
||||||
- Expect a little less chaos.
|
- **Pub/sub** — embedded NATS server with JetStream; generic `Publish[T]` / `Subscribe[T]` helpers
|
||||||
|
- **CSRF protection** — automatic token generation and validation on every action
|
||||||
|
- **Rate limiting** — token-bucket algorithm, configurable globally and per-action
|
||||||
|
- **Event handling** — `OnClick`, `OnChange`, `OnSubmit`, `OnInput`, `OnFocus`, `OnBlur`, `OnMouseEnter`, `OnMouseLeave`, `OnScroll`, `OnDblClick`, `OnKeyDown`, and `OnKeyDownMap` for multi-key bindings
|
||||||
|
- **Timed routines** — `OnInterval` with start/stop/update controls, tied to context lifecycle
|
||||||
|
- **Redirects** — `Redirect`, `ReplaceURL`, and format-string variants
|
||||||
|
- **Plugin system** — `func(v *V)` hooks for integrating CSS/JS libraries
|
||||||
|
- **Structured logging** — zerolog with configurable levels; console output in dev, JSON in production
|
||||||
|
- **Graceful shutdown** — listens for SIGINT/SIGTERM, drains contexts, closes pub/sub
|
||||||
|
- **Context lifecycle** — background reaper cleans up disconnected contexts; configurable TTL
|
||||||
|
- **HTML DSL** — the `h` package provides type-safe Go-native HTML composition
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
The `internal/examples/` directory contains 14 runnable examples:
|
||||||
|
|
||||||
|
`chatroom` · `counter` · `countercomp` · `greeter` · `keyboard` · `livereload` · `nats-chatroom` · `pathparams` · `picocss` · `plugins` · `pubsub-crud` · `realtimechart` · `session` · `shakespeare`
|
||||||
|
|
||||||
|
## Experimental
|
||||||
|
|
||||||
|
Via is maturing — sessions, CSRF, rate limiting, pub/sub, and graceful shutdown are in place — but the API is still evolving. Expect breaking changes before `v1`.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
- Via is intentionally minimal and opinionated — and so is contributing.
|
- Via is intentionally minimal and opinionated — and so is contributing.
|
||||||
- If you love Go, simplicity, and meaningful abstractions — Come along for the ride!
|
- Fork, branch, build, tinker, submit a pull request.
|
||||||
- Fork, branch, build, tinker with things, submit a pull request.
|
|
||||||
- Keep every line purposeful.
|
- Keep every line purposeful.
|
||||||
- Share feedback: open an issue or start a discussion.
|
- Share feedback: open an issue or start a discussion.
|
||||||
|
|
||||||
|
|
||||||
## Credits
|
## Credits
|
||||||
|
|
||||||
Via builds upon the work of these amazing projects:
|
Via builds upon the work of these projects:
|
||||||
|
|
||||||
- 🚀 [Datastar](https://data-star.dev) - The hypermedia powerhouse at the core of Via. It powers browser reactivity through Signals and enables real-time HTML/Signal patches over an always-on SSE event stream.
|
- [Datastar](https://data-star.dev) — the hypermedia framework powering browser reactivity through signals and real-time HTML patches over SSE.
|
||||||
- 🧩 [Gomponents](https://maragu.dev/gomponents) - The awesome project that gifts Via with Go-native HTML composition superpowers through the `via/h` package.
|
- [Gomponents](https://maragu.dev/gomponents) — Go-native HTML composition that powers the `via/h` package.
|
||||||
|
|
||||||
> Thank you for building something that doesn’t just function — it inspires. 🫶
|
|
||||||
|
|||||||
@@ -107,6 +107,54 @@ func (a *actionTrigger) OnChange(options ...ActionTriggerOption) h.H {
|
|||||||
return h.Data("on:change__debounce.200ms", buildOnExpr(actionURL(a.id), &opts))
|
return h.Data("on:change__debounce.200ms", buildOnExpr(actionURL(a.id), &opts))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnSubmit returns a via.h DOM attribute that triggers on form submit.
|
||||||
|
func (a *actionTrigger) OnSubmit(options ...ActionTriggerOption) h.H {
|
||||||
|
opts := applyOptions(options...)
|
||||||
|
return h.Data("on:submit", buildOnExpr(actionURL(a.id), &opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnInput returns a via.h DOM attribute that triggers on input (without debounce).
|
||||||
|
func (a *actionTrigger) OnInput(options ...ActionTriggerOption) h.H {
|
||||||
|
opts := applyOptions(options...)
|
||||||
|
return h.Data("on:input", buildOnExpr(actionURL(a.id), &opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnFocus returns a via.h DOM attribute that triggers when the element gains focus.
|
||||||
|
func (a *actionTrigger) OnFocus(options ...ActionTriggerOption) h.H {
|
||||||
|
opts := applyOptions(options...)
|
||||||
|
return h.Data("on:focus", buildOnExpr(actionURL(a.id), &opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnBlur returns a via.h DOM attribute that triggers when the element loses focus.
|
||||||
|
func (a *actionTrigger) OnBlur(options ...ActionTriggerOption) h.H {
|
||||||
|
opts := applyOptions(options...)
|
||||||
|
return h.Data("on:blur", buildOnExpr(actionURL(a.id), &opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMouseEnter returns a via.h DOM attribute that triggers when the mouse enters the element.
|
||||||
|
func (a *actionTrigger) OnMouseEnter(options ...ActionTriggerOption) h.H {
|
||||||
|
opts := applyOptions(options...)
|
||||||
|
return h.Data("on:mouseenter", buildOnExpr(actionURL(a.id), &opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMouseLeave returns a via.h DOM attribute that triggers when the mouse leaves the element.
|
||||||
|
func (a *actionTrigger) OnMouseLeave(options ...ActionTriggerOption) h.H {
|
||||||
|
opts := applyOptions(options...)
|
||||||
|
return h.Data("on:mouseleave", buildOnExpr(actionURL(a.id), &opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnScroll returns a via.h DOM attribute that triggers on scroll.
|
||||||
|
func (a *actionTrigger) OnScroll(options ...ActionTriggerOption) h.H {
|
||||||
|
opts := applyOptions(options...)
|
||||||
|
return h.Data("on:scroll", buildOnExpr(actionURL(a.id), &opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDblClick returns a via.h DOM attribute that triggers on double click.
|
||||||
|
func (a *actionTrigger) OnDblClick(options ...ActionTriggerOption) h.H {
|
||||||
|
opts := applyOptions(options...)
|
||||||
|
return h.Data("on:dblclick", buildOnExpr(actionURL(a.id), &opts))
|
||||||
|
}
|
||||||
|
|
||||||
// OnKeyDown returns a via.h DOM attribute that triggers when a key is pressed.
|
// OnKeyDown returns a via.h DOM attribute that triggers when a key is pressed.
|
||||||
// key: optional, see https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key
|
// key: optional, see https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key
|
||||||
// Example: OnKeyDown("Enter")
|
// Example: OnKeyDown("Enter")
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package via
|
package via
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/alexedwards/scs/v2"
|
"github.com/alexedwards/scs/v2"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
@@ -54,4 +56,14 @@ type Options struct {
|
|||||||
// PubSub enables publish/subscribe messaging. Use vianats.New() for an
|
// PubSub enables publish/subscribe messaging. Use vianats.New() for an
|
||||||
// embedded NATS backend, or supply any PubSub implementation.
|
// embedded NATS backend, or supply any PubSub implementation.
|
||||||
PubSub PubSub
|
PubSub PubSub
|
||||||
|
|
||||||
|
// ContextTTL is the maximum time a context may exist without an SSE
|
||||||
|
// connection before the background reaper disposes it.
|
||||||
|
// Default: 30s. Negative value disables the reaper.
|
||||||
|
ContextTTL time.Duration
|
||||||
|
|
||||||
|
// ActionRateLimit configures the default token-bucket rate limiter for
|
||||||
|
// action endpoints. Zero values use built-in defaults (10 req/s, burst 20).
|
||||||
|
// Set Rate to -1 to disable rate limiting entirely.
|
||||||
|
ActionRateLimit RateLimitConfig
|
||||||
}
|
}
|
||||||
|
|||||||
58
context.go
58
context.go
@@ -5,12 +5,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ryanhamamura/via/h"
|
"github.com/ryanhamamura/via/h"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Context is the living bridge between Go and the browser.
|
// Context is the living bridge between Go and the browser.
|
||||||
@@ -19,13 +20,14 @@ import (
|
|||||||
type Context struct {
|
type Context struct {
|
||||||
id string
|
id string
|
||||||
route string
|
route string
|
||||||
|
csrfToken string
|
||||||
app *V
|
app *V
|
||||||
view func() h.H
|
view func() h.H
|
||||||
routeParams map[string]string
|
routeParams map[string]string
|
||||||
componentRegistry map[string]*Context
|
parentPageCtx *Context
|
||||||
parentPageCtx *Context
|
|
||||||
patchChan chan patch
|
patchChan chan patch
|
||||||
actionRegistry map[string]func()
|
actionLimiter *rate.Limiter
|
||||||
|
actionRegistry map[string]actionEntry
|
||||||
signals *sync.Map
|
signals *sync.Map
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ctxDisposedChan chan struct{}
|
ctxDisposedChan chan struct{}
|
||||||
@@ -33,6 +35,8 @@ type Context struct {
|
|||||||
subscriptions []Subscription
|
subscriptions []Subscription
|
||||||
subsMu sync.Mutex
|
subsMu sync.Mutex
|
||||||
disposeOnce sync.Once
|
disposeOnce sync.Once
|
||||||
|
createdAt time.Time
|
||||||
|
sseConnected atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// View defines the UI rendered by this context.
|
// View defines the UI rendered by this context.
|
||||||
@@ -75,7 +79,6 @@ func (c *Context) Component(initCtx func(c *Context)) func() h.H {
|
|||||||
compCtx.parentPageCtx = c
|
compCtx.parentPageCtx = c
|
||||||
}
|
}
|
||||||
initCtx(compCtx)
|
initCtx(compCtx)
|
||||||
c.componentRegistry[id] = compCtx
|
|
||||||
return compCtx.view
|
return compCtx.view
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,26 +103,31 @@ func (c *Context) isComponent() bool {
|
|||||||
// h.Button(h.Text("Increment n"), increment.OnClick()),
|
// h.Button(h.Text("Increment n"), increment.OnClick()),
|
||||||
// )
|
// )
|
||||||
// })
|
// })
|
||||||
func (c *Context) Action(f func()) *actionTrigger {
|
func (c *Context) Action(f func(), opts ...ActionOption) *actionTrigger {
|
||||||
id := genRandID()
|
id := genRandID()
|
||||||
if f == nil {
|
if f == nil {
|
||||||
c.app.logErr(c, "failed to bind action '%s' to context: nil func", id)
|
c.app.logErr(c, "failed to bind action '%s' to context: nil func", id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
entry := actionEntry{fn: f}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&entry)
|
||||||
|
}
|
||||||
|
|
||||||
if c.isComponent() {
|
if c.isComponent() {
|
||||||
c.parentPageCtx.actionRegistry[id] = f
|
c.parentPageCtx.actionRegistry[id] = entry
|
||||||
} else {
|
} else {
|
||||||
c.actionRegistry[id] = f
|
c.actionRegistry[id] = entry
|
||||||
}
|
}
|
||||||
return &actionTrigger{id}
|
return &actionTrigger{id}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) getActionFn(id string) (func(), error) {
|
func (c *Context) getAction(id string) (actionEntry, error) {
|
||||||
if f, ok := c.actionRegistry[id]; ok {
|
if e, ok := c.actionRegistry[id]; ok {
|
||||||
return f, nil
|
return e, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("action '%s' not found", id)
|
return actionEntry{}, fmt.Errorf("action '%s' not found", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnInterval starts a go routine that sets a time.Ticker with the given duration and executes
|
// OnInterval starts a go routine that sets a time.Ticker with the given duration and executes
|
||||||
@@ -197,14 +205,14 @@ func (c *Context) injectSignals(sigs map[string]any) {
|
|||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
for sigID, val := range sigs {
|
for sigID, val := range sigs {
|
||||||
if _, ok := c.signals.Load(sigID); !ok {
|
item, ok := c.signals.Load(sigID)
|
||||||
|
if !ok {
|
||||||
c.signals.Store(sigID, &signal{
|
c.signals.Store(sigID, &signal{
|
||||||
id: sigID,
|
id: sigID,
|
||||||
val: val,
|
val: val,
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
item, _ := c.signals.Load(sigID)
|
|
||||||
if sig, ok := item.(*signal); ok {
|
if sig, ok := item.(*signal); ok {
|
||||||
sig.val = val
|
sig.val = val
|
||||||
sig.changed = false
|
sig.changed = false
|
||||||
@@ -255,7 +263,7 @@ func (c *Context) sendPatch(p patch) {
|
|||||||
// Sync pushes the current view state and signal changes to the browser immediately
|
// Sync pushes the current view state and signal changes to the browser immediately
|
||||||
// over the live SSE event stream.
|
// over the live SSE event stream.
|
||||||
func (c *Context) Sync() {
|
func (c *Context) Sync() {
|
||||||
elemsPatch := bytes.NewBuffer(make([]byte, 0))
|
elemsPatch := new(bytes.Buffer)
|
||||||
if err := c.view().Render(elemsPatch); err != nil {
|
if err := c.view().Render(elemsPatch); err != nil {
|
||||||
c.app.logErr(c, "sync view failed: %v", err)
|
c.app.logErr(c, "sync view failed: %v", err)
|
||||||
return
|
return
|
||||||
@@ -320,6 +328,15 @@ func (c *Context) ExecScript(s string) {
|
|||||||
c.sendPatch(patch{patchTypeScript, s})
|
c.sendPatch(patch{patchTypeScript, s})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RedirectView sets a view that redirects the browser to the given URL.
|
||||||
|
// Use this in middleware to abort the chain and redirect in one step.
|
||||||
|
func (c *Context) RedirectView(url string) {
|
||||||
|
c.View(func() h.H {
|
||||||
|
c.Redirect(url)
|
||||||
|
return h.Div()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Redirect navigates the browser to the given URL.
|
// Redirect navigates the browser to the given URL.
|
||||||
// This triggers a full page navigation - the current context will be disposed
|
// This triggers a full page navigation - the current context will be disposed
|
||||||
// and a new context created at the destination URL.
|
// and a new context created at the destination URL.
|
||||||
@@ -375,12 +392,9 @@ func (c *Context) injectRouteParams(params map[string]string) {
|
|||||||
if params == nil {
|
if params == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m := make(map[string]string)
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
maps.Copy(m, params)
|
c.routeParams = params
|
||||||
c.routeParams = m
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPathParam retrieves the value from the page request URL for the given parameter name
|
// GetPathParam retrieves the value from the page request URL for the given parameter name
|
||||||
@@ -474,12 +488,14 @@ func newContext(id string, route string, v *V) *Context {
|
|||||||
return &Context{
|
return &Context{
|
||||||
id: id,
|
id: id,
|
||||||
route: route,
|
route: route,
|
||||||
|
csrfToken: genCSRFToken(),
|
||||||
routeParams: make(map[string]string),
|
routeParams: make(map[string]string),
|
||||||
app: v,
|
app: v,
|
||||||
componentRegistry: make(map[string]*Context),
|
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
|
||||||
actionRegistry: make(map[string]func()),
|
actionRegistry: make(map[string]actionEntry),
|
||||||
signals: new(sync.Map),
|
signals: new(sync.Map),
|
||||||
patchChan: make(chan patch, 1),
|
patchChan: make(chan patch, 1),
|
||||||
ctxDisposedChan: make(chan struct{}, 1),
|
ctxDisposedChan: make(chan struct{}, 1),
|
||||||
|
createdAt: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -14,6 +14,7 @@ require (
|
|||||||
github.com/rs/zerolog v1.34.0
|
github.com/rs/zerolog v1.34.0
|
||||||
github.com/starfederation/datastar-go v1.0.3
|
github.com/starfederation/datastar-go v1.0.3
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
|
golang.org/x/time v0.14.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -37,6 +38,6 @@ require (
|
|||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
golang.org/x/crypto v0.45.0 // indirect
|
golang.org/x/crypto v0.45.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/time v0.14.0 // indirect
|
golang.org/x/time v0.14.0
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
151
internal/examples/middleware/main.go
Normal file
151
internal/examples/middleware/main.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ryanhamamura/via"
|
||||||
|
"github.com/ryanhamamura/via/h"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
v := via.New()
|
||||||
|
v.Config(via.Options{
|
||||||
|
ServerAddress: ":8080",
|
||||||
|
DocumentTitle: "Middleware Example",
|
||||||
|
})
|
||||||
|
|
||||||
|
// --- Middleware definitions ---
|
||||||
|
|
||||||
|
// requestLogger logs every page request to stdout.
|
||||||
|
requestLogger := func(c *via.Context, next func()) {
|
||||||
|
fmt.Printf("[%s] request\n", time.Now().Format("15:04:05"))
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// authRequired redirects unauthenticated users to /login.
|
||||||
|
authRequired := func(c *via.Context, next func()) {
|
||||||
|
if c.Session().GetString("role") == "" {
|
||||||
|
c.RedirectView("/login")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// auditLog prints the authenticated username to stdout.
|
||||||
|
auditLog := func(c *via.Context, next func()) {
|
||||||
|
fmt.Printf("[audit] user=%s\n", c.Session().GetString("username"))
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// superAdminOnly rejects non-superadmin users with a forbidden view.
|
||||||
|
superAdminOnly := func(c *via.Context, next func()) {
|
||||||
|
if c.Session().GetString("role") != "superadmin" {
|
||||||
|
c.View(func() h.H {
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Text("Forbidden")),
|
||||||
|
h.P(h.Text("Super-admin access required.")),
|
||||||
|
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Route registration ---
|
||||||
|
|
||||||
|
v.Use(requestLogger) // global middleware
|
||||||
|
|
||||||
|
admin := v.Group("/admin", authRequired) // prefixed group
|
||||||
|
admin.Use(auditLog) // Group.Use()
|
||||||
|
superAdmin := admin.Group("/super", superAdminOnly) // nested group
|
||||||
|
|
||||||
|
// Public: redirect root to login
|
||||||
|
v.Page("/", func(c *via.Context) {
|
||||||
|
c.View(func() h.H {
|
||||||
|
c.Redirect("/login")
|
||||||
|
return h.Div()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Public: login page with role-selection buttons
|
||||||
|
v.Page("/login", func(c *via.Context) {
|
||||||
|
loginAdmin := c.Action(func() {
|
||||||
|
c.Session().Set("role", "admin")
|
||||||
|
c.Session().Set("username", "alice")
|
||||||
|
c.Session().RenewToken()
|
||||||
|
c.Redirect("/admin/dashboard")
|
||||||
|
})
|
||||||
|
|
||||||
|
loginSuper := c.Action(func() {
|
||||||
|
c.Session().Set("role", "superadmin")
|
||||||
|
c.Session().Set("username", "bob")
|
||||||
|
c.Session().RenewToken()
|
||||||
|
c.Redirect("/admin/dashboard")
|
||||||
|
})
|
||||||
|
|
||||||
|
c.View(func() h.H {
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Text("Login")),
|
||||||
|
h.P(h.Text("Choose a role:")),
|
||||||
|
h.Button(h.Text("Login as Admin"), loginAdmin.OnClick()),
|
||||||
|
h.Raw(" "),
|
||||||
|
h.Button(h.Text("Login as Super Admin"), loginSuper.OnClick()),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Per-action middleware: only superadmins can invoke this action.
|
||||||
|
requireSuperAdmin := func(c *via.Context, next func()) {
|
||||||
|
if c.Session().GetString("role") != "superadmin" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin: dashboard (requires authRequired + auditLog)
|
||||||
|
admin.Page("/dashboard", func(c *via.Context) {
|
||||||
|
logout := c.Action(func() {
|
||||||
|
c.Session().Delete("role")
|
||||||
|
c.Session().Delete("username")
|
||||||
|
c.Redirect("/login")
|
||||||
|
})
|
||||||
|
|
||||||
|
dangerAction := c.Action(func() {
|
||||||
|
fmt.Printf("[danger] executed by %s\n", c.Session().GetString("username"))
|
||||||
|
c.Sync()
|
||||||
|
}, via.WithMiddleware(requireSuperAdmin))
|
||||||
|
|
||||||
|
c.View(func() h.H {
|
||||||
|
username := c.Session().GetString("username")
|
||||||
|
role := c.Session().GetString("role")
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Textf("Dashboard — %s (%s)", username, role)),
|
||||||
|
h.Ul(
|
||||||
|
h.Li(h.A(h.Href("/admin/super/settings"), h.Text("Super Admin Settings"))),
|
||||||
|
),
|
||||||
|
h.H2(h.Text("Danger Zone")),
|
||||||
|
h.P(h.Text("This action is protected by per-action middleware (superadmin only):")),
|
||||||
|
h.Button(h.Text("Delete Everything"), dangerAction.OnClick()),
|
||||||
|
h.Br(),
|
||||||
|
h.Br(),
|
||||||
|
h.Button(h.Text("Logout"), logout.OnClick()),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Super-admin: settings (requires authRequired + auditLog + superAdminOnly)
|
||||||
|
superAdmin.Page("/settings", func(c *via.Context) {
|
||||||
|
c.View(func() h.H {
|
||||||
|
username := c.Session().GetString("username")
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Textf("Super Admin Settings — %s", username)),
|
||||||
|
h.P(h.Text("Only super-admins can see this page.")),
|
||||||
|
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
v.Start()
|
||||||
|
}
|
||||||
@@ -2,13 +2,11 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/nats-io/nats.go"
|
|
||||||
"github.com/ryanhamamura/via"
|
"github.com/ryanhamamura/via"
|
||||||
"github.com/ryanhamamura/via/h"
|
"github.com/ryanhamamura/via/h"
|
||||||
"github.com/ryanhamamura/via/vianats"
|
"github.com/ryanhamamura/via/vianats"
|
||||||
@@ -46,15 +44,15 @@ func main() {
|
|||||||
}
|
}
|
||||||
defer ps.Close()
|
defer ps.Close()
|
||||||
|
|
||||||
// Create JetStream stream for message durability
|
err = vianats.EnsureStream(ps, vianats.StreamConfig{
|
||||||
js := ps.JetStream()
|
Name: "CHAT",
|
||||||
js.AddStream(&nats.StreamConfig{
|
Subjects: []string{"chat.>"},
|
||||||
Name: "CHAT",
|
MaxMsgs: 1000,
|
||||||
Subjects: []string{"chat.>"},
|
MaxAge: 24 * time.Hour,
|
||||||
Retention: nats.LimitsPolicy,
|
|
||||||
MaxMsgs: 1000,
|
|
||||||
MaxAge: 24 * time.Hour,
|
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to ensure stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
v := via.New()
|
v := via.New()
|
||||||
v.Config(via.Options{
|
v.Config(via.Options{
|
||||||
@@ -147,30 +145,14 @@ func main() {
|
|||||||
currentSub.Unsubscribe()
|
currentSub.Unsubscribe()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replay history from JetStream before subscribing for real-time
|
|
||||||
subject := "chat.room." + room
|
subject := "chat.room." + room
|
||||||
if hist, err := js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer()); err == nil {
|
|
||||||
for {
|
// Replay history from JetStream
|
||||||
msg, err := hist.NextMsg(200 * time.Millisecond)
|
if hist, err := vianats.ReplayHistory[ChatMessage](ps, subject, 50); err == nil {
|
||||||
if err != nil {
|
messages = hist
|
||||||
break
|
|
||||||
}
|
|
||||||
var chatMsg ChatMessage
|
|
||||||
if json.Unmarshal(msg.Data, &chatMsg) == nil {
|
|
||||||
messages = append(messages, chatMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hist.Unsubscribe()
|
|
||||||
if len(messages) > 50 {
|
|
||||||
messages = messages[len(messages)-50:]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sub, _ := c.Subscribe(subject, func(data []byte) {
|
sub, _ := via.Subscribe(c, subject, func(msg ChatMessage) {
|
||||||
var msg ChatMessage
|
|
||||||
if err := json.Unmarshal(data, &msg); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
messagesMu.Lock()
|
messagesMu.Lock()
|
||||||
messages = append(messages, msg)
|
messages = append(messages, msg)
|
||||||
if len(messages) > 50 {
|
if len(messages) > 50 {
|
||||||
@@ -203,12 +185,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
statement.SetValue("")
|
statement.SetValue("")
|
||||||
|
|
||||||
data, _ := json.Marshal(ChatMessage{
|
via.Publish(c, "chat.room."+currentRoom, ChatMessage{
|
||||||
User: currentUser,
|
User: currentUser,
|
||||||
Message: msg,
|
Message: msg,
|
||||||
Time: time.Now().UnixMilli(),
|
Time: time.Now().UnixMilli(),
|
||||||
})
|
})
|
||||||
c.Publish("chat.room."+currentRoom, data)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
c.View(func() h.H {
|
c.View(func() h.H {
|
||||||
|
|||||||
284
internal/examples/pubsub-crud/main.go
Normal file
284
internal/examples/pubsub-crud/main.go
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ryanhamamura/via"
|
||||||
|
"github.com/ryanhamamura/via/h"
|
||||||
|
"github.com/ryanhamamura/via/vianats"
|
||||||
|
)
|
||||||
|
|
||||||
|
var WithSignal = via.WithSignal
|
||||||
|
|
||||||
|
type Bookmark struct {
|
||||||
|
ID string
|
||||||
|
Title string
|
||||||
|
URL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CRUDEvent struct {
|
||||||
|
Action string `json:"action"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
bookmarks []Bookmark
|
||||||
|
bookmarksMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
func randomHex(n int) string {
|
||||||
|
b := make([]byte, n)
|
||||||
|
rand.Read(b)
|
||||||
|
return fmt.Sprintf("%x", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func findBookmark(id string) (Bookmark, int) {
|
||||||
|
for i, bm := range bookmarks {
|
||||||
|
if bm.ID == id {
|
||||||
|
return bm, i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Bookmark{}, -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ps, err := vianats.New(ctx, "./data/nats")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to start embedded NATS: %v", err)
|
||||||
|
}
|
||||||
|
defer ps.Close()
|
||||||
|
|
||||||
|
err = vianats.EnsureStream(ps, vianats.StreamConfig{
|
||||||
|
Name: "BOOKMARKS",
|
||||||
|
Subjects: []string{"bookmarks.>"},
|
||||||
|
MaxMsgs: 1000,
|
||||||
|
MaxAge: 24 * time.Hour,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to ensure stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
v := via.New()
|
||||||
|
v.Config(via.Options{
|
||||||
|
DevMode: true,
|
||||||
|
DocumentTitle: "Bookmarks",
|
||||||
|
LogLevel: via.LogLevelInfo,
|
||||||
|
ServerAddress: ":7331",
|
||||||
|
PubSub: ps,
|
||||||
|
})
|
||||||
|
|
||||||
|
v.AppendToHead(
|
||||||
|
h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/daisyui@4/dist/full.min.css")),
|
||||||
|
h.Script(h.Src("https://cdn.tailwindcss.com")),
|
||||||
|
)
|
||||||
|
|
||||||
|
v.Page("/", func(c *via.Context) {
|
||||||
|
userID := randomHex(8)
|
||||||
|
|
||||||
|
titleSignal := c.Signal("")
|
||||||
|
urlSignal := c.Signal("")
|
||||||
|
targetIDSignal := c.Signal("")
|
||||||
|
|
||||||
|
via.Subscribe(c, "bookmarks.events", func(evt CRUDEvent) {
|
||||||
|
if evt.UserID == userID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
safeTitle := html.EscapeString(evt.Title)
|
||||||
|
var alertClass string
|
||||||
|
switch evt.Action {
|
||||||
|
case "created":
|
||||||
|
alertClass = "alert-success"
|
||||||
|
case "updated":
|
||||||
|
alertClass = "alert-info"
|
||||||
|
case "deleted":
|
||||||
|
alertClass = "alert-error"
|
||||||
|
}
|
||||||
|
c.ExecScript(fmt.Sprintf(`(function(){
|
||||||
|
var tc = document.getElementById('toast-container');
|
||||||
|
if (!tc) return;
|
||||||
|
var d = document.createElement('div');
|
||||||
|
d.className = 'alert %s';
|
||||||
|
d.innerHTML = '<span>Bookmark "%s" %s</span>';
|
||||||
|
tc.appendChild(d);
|
||||||
|
setTimeout(function(){ d.remove(); }, 3000);
|
||||||
|
})()`, alertClass, safeTitle, evt.Action))
|
||||||
|
c.Sync()
|
||||||
|
})
|
||||||
|
|
||||||
|
save := c.Action(func() {
|
||||||
|
title := titleSignal.String()
|
||||||
|
url := urlSignal.String()
|
||||||
|
if title == "" || url == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetID := targetIDSignal.String()
|
||||||
|
action := "created"
|
||||||
|
|
||||||
|
bookmarksMu.Lock()
|
||||||
|
if targetID != "" {
|
||||||
|
if _, idx := findBookmark(targetID); idx >= 0 {
|
||||||
|
bookmarks[idx].Title = title
|
||||||
|
bookmarks[idx].URL = url
|
||||||
|
action = "updated"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bookmarks = append(bookmarks, Bookmark{
|
||||||
|
ID: randomHex(8),
|
||||||
|
Title: title,
|
||||||
|
URL: url,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
bookmarksMu.Unlock()
|
||||||
|
|
||||||
|
titleSignal.SetValue("")
|
||||||
|
urlSignal.SetValue("")
|
||||||
|
targetIDSignal.SetValue("")
|
||||||
|
|
||||||
|
via.Publish(c, "bookmarks.events", CRUDEvent{
|
||||||
|
Action: action,
|
||||||
|
Title: title,
|
||||||
|
UserID: userID,
|
||||||
|
})
|
||||||
|
c.Sync()
|
||||||
|
})
|
||||||
|
|
||||||
|
edit := c.Action(func() {
|
||||||
|
id := targetIDSignal.String()
|
||||||
|
bookmarksMu.RLock()
|
||||||
|
bm, idx := findBookmark(id)
|
||||||
|
bookmarksMu.RUnlock()
|
||||||
|
if idx < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
titleSignal.SetValue(bm.Title)
|
||||||
|
urlSignal.SetValue(bm.URL)
|
||||||
|
})
|
||||||
|
|
||||||
|
del := c.Action(func() {
|
||||||
|
id := targetIDSignal.String()
|
||||||
|
bookmarksMu.Lock()
|
||||||
|
bm, idx := findBookmark(id)
|
||||||
|
if idx >= 0 {
|
||||||
|
bookmarks = append(bookmarks[:idx], bookmarks[idx+1:]...)
|
||||||
|
}
|
||||||
|
bookmarksMu.Unlock()
|
||||||
|
if idx < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetIDSignal.SetValue("")
|
||||||
|
|
||||||
|
via.Publish(c, "bookmarks.events", CRUDEvent{
|
||||||
|
Action: "deleted",
|
||||||
|
Title: bm.Title,
|
||||||
|
UserID: userID,
|
||||||
|
})
|
||||||
|
c.Sync()
|
||||||
|
})
|
||||||
|
|
||||||
|
cancelEdit := c.Action(func() {
|
||||||
|
titleSignal.SetValue("")
|
||||||
|
urlSignal.SetValue("")
|
||||||
|
targetIDSignal.SetValue("")
|
||||||
|
})
|
||||||
|
|
||||||
|
c.View(func() h.H {
|
||||||
|
isEditing := targetIDSignal.String() != ""
|
||||||
|
|
||||||
|
// Build table rows
|
||||||
|
bookmarksMu.RLock()
|
||||||
|
var rows []h.H
|
||||||
|
for _, bm := range bookmarks {
|
||||||
|
rows = append(rows, h.Tr(
|
||||||
|
h.Td(h.Text(bm.Title)),
|
||||||
|
h.Td(h.A(h.Href(bm.URL), h.Attr("target", "_blank"), h.Class("link link-primary"), h.Text(bm.URL))),
|
||||||
|
h.Td(
|
||||||
|
h.Div(h.Class("flex gap-1"),
|
||||||
|
h.Button(h.Class("btn btn-xs btn-ghost"), h.Text("Edit"),
|
||||||
|
edit.OnClick(WithSignal(targetIDSignal, bm.ID)),
|
||||||
|
),
|
||||||
|
h.Button(h.Class("btn btn-xs btn-ghost text-error"), h.Text("Delete"),
|
||||||
|
del.OnClick(WithSignal(targetIDSignal, bm.ID)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
bookmarksMu.RUnlock()
|
||||||
|
|
||||||
|
saveLabel := "Add Bookmark"
|
||||||
|
if isEditing {
|
||||||
|
saveLabel = "Update Bookmark"
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.Div(h.Class("min-h-screen bg-base-200"),
|
||||||
|
// Navbar
|
||||||
|
h.Div(h.Class("navbar bg-base-100 shadow-sm"),
|
||||||
|
h.Div(h.Class("flex-1"),
|
||||||
|
h.A(h.Class("btn btn-ghost text-xl"), h.Text("Bookmarks")),
|
||||||
|
),
|
||||||
|
h.Div(h.Class("flex-none"),
|
||||||
|
h.Div(h.Class("badge badge-outline"), h.Text(userID[:8])),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
|
||||||
|
h.Div(h.Class("container mx-auto p-4 max-w-3xl flex flex-col gap-4"),
|
||||||
|
// Form card
|
||||||
|
h.Div(h.Class("card bg-base-100 shadow"),
|
||||||
|
h.Div(h.Class("card-body"),
|
||||||
|
h.H2(h.Class("card-title"), h.Text(saveLabel)),
|
||||||
|
h.Div(h.Class("flex flex-col gap-2"),
|
||||||
|
h.Input(h.Class("input input-bordered w-full"), h.Type("text"), h.Placeholder("Title"), titleSignal.Bind()),
|
||||||
|
h.Input(h.Class("input input-bordered w-full"), h.Type("text"), h.Placeholder("https://example.com"), urlSignal.Bind()),
|
||||||
|
h.Div(h.Class("card-actions justify-end"),
|
||||||
|
h.If(isEditing,
|
||||||
|
h.Button(h.Class("btn btn-ghost"), h.Text("Cancel"), cancelEdit.OnClick()),
|
||||||
|
),
|
||||||
|
h.Button(h.Class("btn btn-primary"), h.Text(saveLabel), save.OnClick()),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
|
||||||
|
// Table card
|
||||||
|
h.Div(h.Class("card bg-base-100 shadow"),
|
||||||
|
h.Div(h.Class("card-body"),
|
||||||
|
h.H2(h.Class("card-title"), h.Text("All Bookmarks")),
|
||||||
|
h.If(len(rows) == 0,
|
||||||
|
h.P(h.Class("text-base-content/60"), h.Text("No bookmarks yet. Add one above!")),
|
||||||
|
),
|
||||||
|
h.If(len(rows) > 0,
|
||||||
|
h.Div(h.Class("overflow-x-auto"),
|
||||||
|
h.Table(h.Class("table"),
|
||||||
|
h.THead(h.Tr(
|
||||||
|
h.Th(h.Text("Title")),
|
||||||
|
h.Th(h.Text("URL")),
|
||||||
|
h.Th(h.Text("Actions")),
|
||||||
|
)),
|
||||||
|
h.TBody(rows...),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
|
||||||
|
// Toast container — ignored by morph so Sync() doesn't wipe active toasts
|
||||||
|
h.Div(h.ID("toast-container"), h.Class("toast toast-end toast-top"), h.DataIgnoreMorph()),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Println("Starting pubsub-crud example on :7331")
|
||||||
|
v.Start()
|
||||||
|
}
|
||||||
@@ -29,7 +29,17 @@ func main() {
|
|||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Login page
|
// Auth middleware — redirects unauthenticated users to /login
|
||||||
|
authRequired := func(c *via.Context, next func()) {
|
||||||
|
if c.Session().GetString("username") == "" {
|
||||||
|
c.Session().Set("flash", "Please log in first")
|
||||||
|
c.RedirectView("/login")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login page (public)
|
||||||
v.Page("/login", func(c *via.Context) {
|
v.Page("/login", func(c *via.Context) {
|
||||||
flash := c.Session().PopString("flash")
|
flash := c.Session().PopString("flash")
|
||||||
usernameInput := c.Signal("")
|
usernameInput := c.Signal("")
|
||||||
@@ -64,8 +74,10 @@ func main() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Dashboard page (protected)
|
// Protected pages
|
||||||
v.Page("/dashboard", func(c *via.Context) {
|
protected := v.Group("", authRequired)
|
||||||
|
|
||||||
|
protected.Page("/dashboard", func(c *via.Context) {
|
||||||
logout := c.Action(func() {
|
logout := c.Action(func() {
|
||||||
c.Session().Set("flash", "Goodbye!")
|
c.Session().Set("flash", "Goodbye!")
|
||||||
c.Session().Delete("username")
|
c.Session().Delete("username")
|
||||||
@@ -74,14 +86,6 @@ func main() {
|
|||||||
|
|
||||||
c.View(func() h.H {
|
c.View(func() h.H {
|
||||||
username := c.Session().GetString("username")
|
username := c.Session().GetString("username")
|
||||||
|
|
||||||
// Not logged in? Redirect to login
|
|
||||||
if username == "" {
|
|
||||||
c.Session().Set("flash", "Please log in first")
|
|
||||||
c.Redirect("/login")
|
|
||||||
return h.Div()
|
|
||||||
}
|
|
||||||
|
|
||||||
flash := c.Session().PopString("flash")
|
flash := c.Session().PopString("flash")
|
||||||
var flashMsg h.H
|
var flashMsg h.H
|
||||||
if flash != "" {
|
if flash != "" {
|
||||||
|
|||||||
82
middleware.go
Normal file
82
middleware.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
// Middleware wraps a page init function. Call next to continue the chain;
|
||||||
|
// return without calling next to abort (set a view first, e.g. RedirectView).
|
||||||
|
type Middleware func(c *Context, next func())
|
||||||
|
|
||||||
|
// Group is a route group with a shared prefix and middleware stack.
|
||||||
|
type Group struct {
|
||||||
|
v *V
|
||||||
|
prefix string
|
||||||
|
middleware []Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use appends middleware to the global stack.
|
||||||
|
// Global middleware runs before every page handler.
|
||||||
|
func (v *V) Use(mw ...Middleware) {
|
||||||
|
v.middleware = append(v.middleware, mw...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group creates a route group with the given path prefix and middleware.
|
||||||
|
// Routes registered on the group are prefixed and run the group's middleware
|
||||||
|
// after any global middleware.
|
||||||
|
func (v *V) Group(prefix string, mw ...Middleware) *Group {
|
||||||
|
return &Group{
|
||||||
|
v: v,
|
||||||
|
prefix: prefix,
|
||||||
|
middleware: mw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Page registers a route on this group. The full route is the group prefix
|
||||||
|
// concatenated with route.
|
||||||
|
func (g *Group) Page(route string, initContextFn func(c *Context)) {
|
||||||
|
fullRoute := g.prefix + route
|
||||||
|
allMw := make([]Middleware, 0, len(g.v.middleware)+len(g.middleware))
|
||||||
|
allMw = append(allMw, g.v.middleware...)
|
||||||
|
allMw = append(allMw, g.middleware...)
|
||||||
|
wrapped := chainMiddleware(allMw, initContextFn)
|
||||||
|
g.v.page(fullRoute, initContextFn, wrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group creates a nested sub-group that inherits this group's prefix and
|
||||||
|
// middleware, then adds its own.
|
||||||
|
func (g *Group) Group(prefix string, mw ...Middleware) *Group {
|
||||||
|
combined := make([]Middleware, len(g.middleware), len(g.middleware)+len(mw))
|
||||||
|
copy(combined, g.middleware)
|
||||||
|
combined = append(combined, mw...)
|
||||||
|
return &Group{
|
||||||
|
v: g.v,
|
||||||
|
prefix: g.prefix + prefix,
|
||||||
|
middleware: combined,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use appends middleware to this group's stack.
|
||||||
|
func (g *Group) Use(mw ...Middleware) {
|
||||||
|
g.middleware = append(g.middleware, mw...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMiddleware returns an ActionOption that attaches middleware to an action.
|
||||||
|
// Action middleware runs after CSRF/rate-limit checks and signal injection.
|
||||||
|
func WithMiddleware(mw ...Middleware) ActionOption {
|
||||||
|
return func(e *actionEntry) {
|
||||||
|
e.middleware = append(e.middleware, mw...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chainMiddleware wraps handler with the given middleware, outer-first.
|
||||||
|
func chainMiddleware(mws []Middleware, handler func(*Context)) func(*Context) {
|
||||||
|
if len(mws) == 0 {
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
chained := handler
|
||||||
|
for i := len(mws) - 1; i >= 0; i-- {
|
||||||
|
mw := mws[i]
|
||||||
|
next := chained
|
||||||
|
chained = func(c *Context) {
|
||||||
|
mw(c, func() { next(c) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return chained
|
||||||
|
}
|
||||||
340
middleware_test.go
Normal file
340
middleware_test.go
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ryanhamamura/via/h"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMiddlewareRunsBeforeHandler(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
order = append(order, "mw")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
// Reset after registration (panic-check runs the raw handler)
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, []string{"mw", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareAbortSkipsHandler(t *testing.T) {
|
||||||
|
handlerCalled := false
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
c.RedirectView("/other")
|
||||||
|
})
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
handlerCalled = true
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
handlerCalled = false
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.False(t, handlerCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareChainOrder(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
for _, label := range []string{"A", "B", "C"} {
|
||||||
|
l := label
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
order = append(order, l)
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"A", "B", "C", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupPrefixRouting(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
g := v.Group("/admin")
|
||||||
|
g.Page("/dashboard", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("admin dashboard")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dashboard", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "admin dashboard")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupMiddlewareAppliesToGroupOnly(t *testing.T) {
|
||||||
|
var groupMwCalled bool
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
g := v.Group("/admin", func(c *Context, next func()) {
|
||||||
|
groupMwCalled = true
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g.Page("/panel", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("panel")) })
|
||||||
|
})
|
||||||
|
v.Page("/public", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("public")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hit public page — group middleware should NOT run
|
||||||
|
groupMwCalled = false
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/public", nil))
|
||||||
|
assert.False(t, groupMwCalled)
|
||||||
|
assert.Contains(t, w.Body.String(), "public")
|
||||||
|
|
||||||
|
// Hit group page — group middleware should run
|
||||||
|
groupMwCalled = false
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/panel", nil))
|
||||||
|
assert.True(t, groupMwCalled)
|
||||||
|
assert.Contains(t, w.Body.String(), "panel")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalMiddlewareAppliesToGroupPages(t *testing.T) {
|
||||||
|
var globalCalled bool
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
globalCalled = true
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g := v.Group("/admin")
|
||||||
|
g.Page("/dash", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("dash")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
globalCalled = false
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dash", nil))
|
||||||
|
|
||||||
|
assert.True(t, globalCalled)
|
||||||
|
assert.Contains(t, w.Body.String(), "dash")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNestedGroupInheritsPrefixAndMiddleware(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
admin := v.Group("/admin", func(c *Context, next func()) {
|
||||||
|
order = append(order, "admin")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
superAdmin := admin.Group("/super", func(c *Context, next func()) {
|
||||||
|
order = append(order, "super")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
superAdmin.Page("/secret", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("secret")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/super/secret", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, []string{"admin", "super", "handler"}, order)
|
||||||
|
assert.Contains(t, w.Body.String(), "secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupUse(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
g := v.Group("/api")
|
||||||
|
g.Use(func(c *Context, next func()) {
|
||||||
|
order = append(order, "added-later")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g.Page("/items", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/api/items", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"added-later", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedirectViewSetsValidView(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.Page("/test", func(c *Context) {
|
||||||
|
c.RedirectView("/somewhere")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/test", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "<!doctype html>")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalAndGroupMiddlewareOrder(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
order = append(order, "global")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g := v.Group("/g", func(c *Context, next func()) {
|
||||||
|
order = append(order, "group")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g.Page("/page", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/g/page", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"global", "group", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Action middleware tests ---
|
||||||
|
|
||||||
|
func TestActionMiddlewareRunsBeforeAction(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
c := newContext("test", "/", v)
|
||||||
|
|
||||||
|
mw := func(_ *Context, next func()) {
|
||||||
|
order = append(order, "mw")
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger := c.Action(func() {
|
||||||
|
order = append(order, "action")
|
||||||
|
}, WithMiddleware(mw))
|
||||||
|
|
||||||
|
entry, err := c.getAction(trigger.id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"mw", "action"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestActionMiddlewareAbortSkipsAction(t *testing.T) {
|
||||||
|
actionCalled := false
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
c := newContext("test", "/", v)
|
||||||
|
|
||||||
|
mw := func(_ *Context, next func()) {
|
||||||
|
// don't call next — action should not run
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger := c.Action(func() {
|
||||||
|
actionCalled = true
|
||||||
|
}, WithMiddleware(mw))
|
||||||
|
|
||||||
|
entry, err := c.getAction(trigger.id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||||
|
|
||||||
|
assert.False(t, actionCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestActionMiddlewareChainOrder(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
c := newContext("test", "/", v)
|
||||||
|
|
||||||
|
var mws []Middleware
|
||||||
|
for _, label := range []string{"A", "B", "C"} {
|
||||||
|
l := label
|
||||||
|
mws = append(mws, func(_ *Context, next func()) {
|
||||||
|
order = append(order, l)
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger := c.Action(func() {
|
||||||
|
order = append(order, "action")
|
||||||
|
}, WithMiddleware(mws...))
|
||||||
|
|
||||||
|
entry, err := c.getAction(trigger.id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"A", "B", "C", "action"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestActionMiddlewareCombinedWithRateLimit(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
c := newContext("test", "/", v)
|
||||||
|
|
||||||
|
mw := func(_ *Context, next func()) { next() }
|
||||||
|
|
||||||
|
trigger := c.Action(func() {}, WithRateLimit(5, 10), WithMiddleware(mw))
|
||||||
|
|
||||||
|
entry, err := c.getAction(trigger.id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, entry.limiter)
|
||||||
|
assert.Len(t, entry.middleware, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupWithEmptyPrefix(t *testing.T) {
|
||||||
|
var mwCalled bool
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
g := v.Group("", func(c *Context, next func()) {
|
||||||
|
mwCalled = true
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g.Page("/dashboard", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("dash")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
mwCalled = false
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/dashboard", nil))
|
||||||
|
|
||||||
|
assert.True(t, mwCalled)
|
||||||
|
assert.Contains(t, w.Body.String(), "dash")
|
||||||
|
}
|
||||||
23
pubsub_helpers.go
Normal file
23
pubsub_helpers.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
// Publish JSON-marshals msg and publishes to subject.
|
||||||
|
func Publish[T any](c *Context, subject string, msg T) error {
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Publish(subject, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe JSON-unmarshals each message as T and calls handler.
|
||||||
|
func Subscribe[T any](c *Context, subject string, handler func(T)) (Subscription, error) {
|
||||||
|
return c.Subscribe(subject, func(data []byte) {
|
||||||
|
var msg T
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler(msg)
|
||||||
|
})
|
||||||
|
}
|
||||||
66
pubsub_helpers_test.go
Normal file
66
pubsub_helpers_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ryanhamamura/via/h"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPublishSubscribe_RoundTrip(t *testing.T) {
|
||||||
|
ps := newMockPubSub()
|
||||||
|
v := New()
|
||||||
|
v.Config(Options{PubSub: ps})
|
||||||
|
|
||||||
|
type event struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var got event
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
c := newContext("typed-ctx", "/", v)
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
|
||||||
|
_, err := Subscribe(c, "events", func(e event) {
|
||||||
|
got = e
|
||||||
|
wg.Done()
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = Publish(c, "events", event{Name: "click", Count: 42})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, "click", got.Name)
|
||||||
|
assert.Equal(t, 42, got.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscribe_SkipsBadJSON(t *testing.T) {
|
||||||
|
ps := newMockPubSub()
|
||||||
|
v := New()
|
||||||
|
v.Config(Options{PubSub: ps})
|
||||||
|
|
||||||
|
type msg struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
c := newContext("bad-json-ctx", "/", v)
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
|
||||||
|
_, err := Subscribe(c, "topic", func(m msg) {
|
||||||
|
called = true
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Publish raw invalid JSON — handler should silently skip
|
||||||
|
err = c.Publish("topic", []byte("not json"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, called)
|
||||||
|
}
|
||||||
49
ratelimit.go
Normal file
49
ratelimit.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import "golang.org/x/time/rate"
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultActionRate float64 = 10.0
|
||||||
|
defaultActionBurst int = 20
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimitConfig configures token-bucket rate limiting for actions.
|
||||||
|
// Zero values fall back to defaults. Rate of -1 disables limiting entirely.
|
||||||
|
type RateLimitConfig struct {
|
||||||
|
Rate float64
|
||||||
|
Burst int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActionOption configures per-action behaviour when passed to Context.Action.
|
||||||
|
type ActionOption func(*actionEntry)
|
||||||
|
|
||||||
|
type actionEntry struct {
|
||||||
|
fn func()
|
||||||
|
limiter *rate.Limiter // nil = use context default
|
||||||
|
middleware []Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRateLimit returns an ActionOption that gives this action its own
|
||||||
|
// token-bucket limiter, overriding the context-level default.
|
||||||
|
func WithRateLimit(r float64, burst int) ActionOption {
|
||||||
|
return func(e *actionEntry) {
|
||||||
|
e.limiter = newLimiter(RateLimitConfig{Rate: r, Burst: burst}, defaultActionRate, defaultActionBurst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newLimiter creates a *rate.Limiter from cfg, substituting defaults for zero
|
||||||
|
// values. A Rate of -1 disables limiting (returns nil).
|
||||||
|
func newLimiter(cfg RateLimitConfig, defaultRate float64, defaultBurst int) *rate.Limiter {
|
||||||
|
r := cfg.Rate
|
||||||
|
b := cfg.Burst
|
||||||
|
if r == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if r == 0 {
|
||||||
|
r = defaultRate
|
||||||
|
}
|
||||||
|
if b == 0 {
|
||||||
|
b = defaultBurst
|
||||||
|
}
|
||||||
|
return rate.NewLimiter(rate.Limit(r), b)
|
||||||
|
}
|
||||||
101
ratelimit_test.go
Normal file
101
ratelimit_test.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewLimiter_Defaults(t *testing.T) {
|
||||||
|
l := newLimiter(RateLimitConfig{}, defaultActionRate, defaultActionBurst)
|
||||||
|
require.NotNil(t, l)
|
||||||
|
assert.InDelta(t, defaultActionRate, float64(l.Limit()), 0.001)
|
||||||
|
assert.Equal(t, defaultActionBurst, l.Burst())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewLimiter_CustomValues(t *testing.T) {
|
||||||
|
l := newLimiter(RateLimitConfig{Rate: 5, Burst: 10}, defaultActionRate, defaultActionBurst)
|
||||||
|
require.NotNil(t, l)
|
||||||
|
assert.InDelta(t, 5.0, float64(l.Limit()), 0.001)
|
||||||
|
assert.Equal(t, 10, l.Burst())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewLimiter_DisabledWithNegativeRate(t *testing.T) {
|
||||||
|
l := newLimiter(RateLimitConfig{Rate: -1}, defaultActionRate, defaultActionBurst)
|
||||||
|
assert.Nil(t, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenBucket_AllowsBurstThenRejects(t *testing.T) {
|
||||||
|
l := newLimiter(RateLimitConfig{Rate: 1, Burst: 3}, 1, 3)
|
||||||
|
require.NotNil(t, l)
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
assert.True(t, l.Allow(), "request %d should be allowed within burst", i)
|
||||||
|
}
|
||||||
|
assert.False(t, l.Allow(), "request beyond burst should be rejected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRateLimit_CreatesLimiter(t *testing.T) {
|
||||||
|
entry := actionEntry{fn: func() {}}
|
||||||
|
opt := WithRateLimit(2, 4)
|
||||||
|
opt(&entry)
|
||||||
|
|
||||||
|
require.NotNil(t, entry.limiter)
|
||||||
|
assert.InDelta(t, 2.0, float64(entry.limiter.Limit()), 0.001)
|
||||||
|
assert.Equal(t, 4, entry.limiter.Burst())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextAction_WithRateLimit(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
c := newContext("test-rl", "/", v)
|
||||||
|
|
||||||
|
called := false
|
||||||
|
c.Action(func() { called = true }, WithRateLimit(1, 2))
|
||||||
|
|
||||||
|
// Verify the entry has its own limiter
|
||||||
|
for _, entry := range c.actionRegistry {
|
||||||
|
require.NotNil(t, entry.limiter)
|
||||||
|
assert.InDelta(t, 1.0, float64(entry.limiter.Limit()), 0.001)
|
||||||
|
assert.Equal(t, 2, entry.limiter.Burst())
|
||||||
|
}
|
||||||
|
assert.False(t, called)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextAction_DefaultNoPerActionLimiter(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
c := newContext("test-no-rl", "/", v)
|
||||||
|
|
||||||
|
c.Action(func() {})
|
||||||
|
|
||||||
|
for _, entry := range c.actionRegistry {
|
||||||
|
assert.Nil(t, entry.limiter, "entry without WithRateLimit should have nil limiter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextLimiter_DefaultsApplied(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
c := newContext("test-ctx-limiter", "/", v)
|
||||||
|
|
||||||
|
require.NotNil(t, c.actionLimiter)
|
||||||
|
assert.InDelta(t, defaultActionRate, float64(c.actionLimiter.Limit()), 0.001)
|
||||||
|
assert.Equal(t, defaultActionBurst, c.actionLimiter.Burst())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextLimiter_DisabledViaConfig(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.actionRateLimit = RateLimitConfig{Rate: -1}
|
||||||
|
c := newContext("test-disabled", "/", v)
|
||||||
|
|
||||||
|
assert.Nil(t, c.actionLimiter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextLimiter_CustomConfig(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.Config(Options{ActionRateLimit: RateLimitConfig{Rate: 50, Burst: 100}})
|
||||||
|
c := newContext("test-custom", "/", v)
|
||||||
|
|
||||||
|
require.NotNil(t, c.actionLimiter)
|
||||||
|
assert.InDelta(t, 50.0, float64(c.actionLimiter.Limit()), 0.001)
|
||||||
|
assert.Equal(t, 100, c.actionLimiter.Burst())
|
||||||
|
}
|
||||||
23
signal.go
23
signal.go
@@ -81,26 +81,3 @@ func (s *signal) Int() int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Int64 tries to read the signal value as an int64.
|
|
||||||
// Returns the value or 0 on failure.
|
|
||||||
func (s *signal) Int64() int64 {
|
|
||||||
if n, err := strconv.ParseInt(s.String(), 10, 64); err == nil {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Float64 tries to read the signal value as a float64.
|
|
||||||
// Returns the value or 0.0 on failure.
|
|
||||||
func (s *signal) Float() float64 {
|
|
||||||
if n, err := strconv.ParseFloat(s.String(), 64); err == nil {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
return 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytes tries to read the signal value as a []byte
|
|
||||||
// Returns the value or an empty []byte on failure.
|
|
||||||
func (s *signal) Bytes() []byte {
|
|
||||||
return []byte(s.String())
|
|
||||||
}
|
|
||||||
|
|||||||
143
static_test.go
Normal file
143
static_test.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStatic(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
os.MkdirAll(filepath.Join(dir, "sub"), 0755)
|
||||||
|
os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello world"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(dir, "sub", "nested.txt"), []byte("nested"), 0644)
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Static("/assets/", dir)
|
||||||
|
|
||||||
|
t.Run("serves file", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/hello.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "hello world", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("serves nested file", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/sub/nested.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "nested", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("directory listing returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subdirectory listing returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/sub/", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing file returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/nope.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticAutoSlash(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
os.WriteFile(filepath.Join(dir, "ok.txt"), []byte("ok"), 0644)
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Static("/files", dir) // no trailing slash
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/files/ok.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "ok", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticFS(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"style.css": {Data: []byte("body{}")},
|
||||||
|
"js/app.js": {Data: []byte("console.log('hi')")},
|
||||||
|
}
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.StaticFS("/static/", fsys)
|
||||||
|
|
||||||
|
t.Run("serves file", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/static/style.css", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "body{}", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("serves nested file", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/static/js/app.js", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "console.log('hi')", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("directory listing returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/static/", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing file returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/static/nope.css", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticFSAutoSlash(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"ok.txt": {Data: []byte("ok")},
|
||||||
|
}
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.StaticFS("/embed", fsys) // no trailing slash
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/embed/ok.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "ok", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify StaticFS accepts the fs.FS interface (compile-time check).
|
||||||
|
var _ fs.FS = fstest.MapFS{}
|
||||||
225
via.go
225
via.go
@@ -9,11 +9,13 @@ package via
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -36,20 +38,23 @@ var datastarJS []byte
|
|||||||
// V is the root application.
|
// V is the root application.
|
||||||
// It manages page routing, user sessions, and SSE connections for live updates.
|
// It manages page routing, user sessions, and SSE connections for live updates.
|
||||||
type V struct {
|
type V struct {
|
||||||
cfg Options
|
cfg Options
|
||||||
mux *http.ServeMux
|
mux *http.ServeMux
|
||||||
server *http.Server
|
server *http.Server
|
||||||
logger zerolog.Logger
|
logger zerolog.Logger
|
||||||
contextRegistry map[string]*Context
|
contextRegistry map[string]*Context
|
||||||
contextRegistryMutex sync.RWMutex
|
contextRegistryMutex sync.RWMutex
|
||||||
documentHeadIncludes []h.H
|
documentHeadIncludes []h.H
|
||||||
documentFootIncludes []h.H
|
documentFootIncludes []h.H
|
||||||
devModePageInitFnMap map[string]func(*Context)
|
devModePageInitFnMap map[string]func(*Context)
|
||||||
sessionManager *scs.SessionManager
|
sessionManager *scs.SessionManager
|
||||||
pubsub PubSub
|
pubsub PubSub
|
||||||
datastarPath string
|
actionRateLimit RateLimitConfig
|
||||||
datastarContent []byte
|
datastarPath string
|
||||||
datastarOnce sync.Once
|
datastarContent []byte
|
||||||
|
datastarOnce sync.Once
|
||||||
|
reaperStop chan struct{}
|
||||||
|
middleware []Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
|
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
|
||||||
@@ -127,6 +132,12 @@ func (v *V) Config(cfg Options) {
|
|||||||
if cfg.PubSub != nil {
|
if cfg.PubSub != nil {
|
||||||
v.pubsub = cfg.PubSub
|
v.pubsub = cfg.PubSub
|
||||||
}
|
}
|
||||||
|
if cfg.ContextTTL != 0 {
|
||||||
|
v.cfg.ContextTTL = cfg.ContextTTL
|
||||||
|
}
|
||||||
|
if cfg.ActionRateLimit.Rate != 0 || cfg.ActionRateLimit.Burst != 0 {
|
||||||
|
v.actionRateLimit = cfg.ActionRateLimit
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppendToHead appends the given h.H nodes to the head of the base HTML document.
|
// AppendToHead appends the given h.H nodes to the head of the base HTML document.
|
||||||
@@ -160,8 +171,16 @@ func (v *V) AppendToFoot(elements ...h.H) {
|
|||||||
// })
|
// })
|
||||||
// })
|
// })
|
||||||
func (v *V) Page(route string, initContextFn func(c *Context)) {
|
func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||||
|
wrapped := chainMiddleware(v.middleware, initContextFn)
|
||||||
|
v.page(route, initContextFn, wrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// page registers a route with separate raw and wrapped init functions.
|
||||||
|
// raw is used for the panic-check at registration time; wrapped includes
|
||||||
|
// any middleware and is used as the live handler.
|
||||||
|
func (v *V) page(route string, raw, wrapped func(*Context)) {
|
||||||
v.ensureDatastarHandler()
|
v.ensureDatastarHandler()
|
||||||
// check for panics
|
// check for panics using the raw handler (no middleware)
|
||||||
func() {
|
func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
@@ -170,14 +189,13 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
c := newContext("", "", v)
|
c := newContext("", "", v)
|
||||||
initContextFn(c)
|
raw(c)
|
||||||
c.view()
|
c.view()
|
||||||
c.stopAllRoutines()
|
c.stopAllRoutines()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// save page init function allows devmode to restore persisted ctx later
|
|
||||||
if v.cfg.DevMode {
|
if v.cfg.DevMode {
|
||||||
v.devModePageInitFnMap[route] = initContextFn
|
v.devModePageInitFnMap[route] = wrapped
|
||||||
}
|
}
|
||||||
v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
v.logDebug(nil, "GET %s", r.URL.String())
|
v.logDebug(nil, "GET %s", r.URL.String())
|
||||||
@@ -191,7 +209,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
c.reqCtx = r.Context()
|
c.reqCtx = r.Context()
|
||||||
routeParams := extractParams(route, r.URL.Path)
|
routeParams := extractParams(route, r.URL.Path)
|
||||||
c.injectRouteParams(routeParams)
|
c.injectRouteParams(routeParams)
|
||||||
initContextFn(c)
|
wrapped(c)
|
||||||
v.registerCtx(c)
|
v.registerCtx(c)
|
||||||
if v.cfg.DevMode {
|
if v.cfg.DevMode {
|
||||||
v.devModePersist(c)
|
v.devModePersist(c)
|
||||||
@@ -199,7 +217,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
headElements := []h.H{h.Script(h.Type("module"), h.Src(v.datastarPath))}
|
headElements := []h.H{h.Script(h.Type("module"), h.Src(v.datastarPath))}
|
||||||
headElements = append(headElements, v.documentHeadIncludes...)
|
headElements = append(headElements, v.documentHeadIncludes...)
|
||||||
headElements = append(headElements,
|
headElements = append(headElements,
|
||||||
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s'}", id))),
|
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s','via-csrf':'%s'}", id, c.csrfToken))),
|
||||||
h.Meta(h.Data("init", "@get('/_sse')")),
|
h.Meta(h.Data("init", "@get('/_sse')")),
|
||||||
h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => {
|
h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => {
|
||||||
navigator.sendBeacon('/_session/close', '%s');});`, c.id))),
|
navigator.sendBeacon('/_session/close', '%s');});`, c.id))),
|
||||||
@@ -216,8 +234,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
Title: v.cfg.DocumentTitle,
|
Title: v.cfg.DocumentTitle,
|
||||||
Head: headElements,
|
Head: headElements,
|
||||||
Body: bodyElements,
|
Body: bodyElements,
|
||||||
HTMLAttrs: []h.H{},
|
})
|
||||||
})
|
|
||||||
_ = view.Render(w)
|
_ = view.Render(w)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@@ -225,17 +242,17 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
func (v *V) registerCtx(c *Context) {
|
func (v *V) registerCtx(c *Context) {
|
||||||
v.contextRegistryMutex.Lock()
|
v.contextRegistryMutex.Lock()
|
||||||
defer v.contextRegistryMutex.Unlock()
|
defer v.contextRegistryMutex.Unlock()
|
||||||
if c == nil {
|
|
||||||
v.logErr(c, "failed to add nil context to registry")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
v.contextRegistry[c.id] = c
|
v.contextRegistry[c.id] = c
|
||||||
v.logDebug(c, "new context added to registry")
|
v.logDebug(c, "new context added to registry")
|
||||||
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
|
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *V) currSessionNum() int {
|
func (v *V) cleanupCtx(c *Context) {
|
||||||
return len(v.contextRegistry)
|
c.dispose()
|
||||||
|
if v.cfg.DevMode {
|
||||||
|
v.devModeRemovePersisted(c)
|
||||||
|
}
|
||||||
|
v.unregisterCtx(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *V) unregisterCtx(c *Context) {
|
func (v *V) unregisterCtx(c *Context) {
|
||||||
@@ -247,7 +264,7 @@ func (v *V) unregisterCtx(c *Context) {
|
|||||||
defer v.contextRegistryMutex.Unlock()
|
defer v.contextRegistryMutex.Unlock()
|
||||||
v.logDebug(c, "ctx removed from registry")
|
v.logDebug(c, "ctx removed from registry")
|
||||||
delete(v.contextRegistry, c.id)
|
delete(v.contextRegistry, c.id)
|
||||||
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
|
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *V) getCtx(id string) (*Context, error) {
|
func (v *V) getCtx(id string) (*Context, error) {
|
||||||
@@ -259,6 +276,50 @@ func (v *V) getCtx(id string) (*Context, error) {
|
|||||||
return nil, fmt.Errorf("ctx '%s' not found", id)
|
return nil, fmt.Errorf("ctx '%s' not found", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *V) startReaper() {
|
||||||
|
ttl := v.cfg.ContextTTL
|
||||||
|
if ttl < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = 30 * time.Second
|
||||||
|
}
|
||||||
|
interval := ttl / 3
|
||||||
|
if interval < 5*time.Second {
|
||||||
|
interval = 5 * time.Second
|
||||||
|
}
|
||||||
|
v.reaperStop = make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-v.reaperStop:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
v.reapOrphanedContexts(ttl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *V) reapOrphanedContexts(ttl time.Duration) {
|
||||||
|
now := time.Now()
|
||||||
|
v.contextRegistryMutex.RLock()
|
||||||
|
var orphans []*Context
|
||||||
|
for _, c := range v.contextRegistry {
|
||||||
|
if !c.sseConnected.Load() && now.Sub(c.createdAt) > ttl {
|
||||||
|
orphans = append(orphans, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v.contextRegistryMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, c := range orphans {
|
||||||
|
v.logInfo(c, "reaping orphaned context (no SSE connection after %s)", ttl)
|
||||||
|
v.cleanupCtx(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the Via HTTP server and blocks until a SIGINT or SIGTERM
|
// Start starts the Via HTTP server and blocks until a SIGINT or SIGTERM
|
||||||
// signal is received, then performs a graceful shutdown.
|
// signal is received, then performs a graceful shutdown.
|
||||||
func (v *V) Start() {
|
func (v *V) Start() {
|
||||||
@@ -271,6 +332,8 @@ func (v *V) Start() {
|
|||||||
Handler: handler,
|
Handler: handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
v.startReaper()
|
||||||
|
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errCh <- v.server.ListenAndServe()
|
errCh <- v.server.ListenAndServe()
|
||||||
@@ -291,16 +354,15 @@ func (v *V) Start() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
v.shutdown()
|
v.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown gracefully shuts down the server and all contexts.
|
// Shutdown gracefully shuts down the server and all contexts.
|
||||||
// Safe for programmatic or test use.
|
// Safe for programmatic or test use.
|
||||||
func (v *V) Shutdown() {
|
func (v *V) Shutdown() {
|
||||||
v.shutdown()
|
if v.reaperStop != nil {
|
||||||
}
|
close(v.reaperStop)
|
||||||
|
}
|
||||||
func (v *V) shutdown() {
|
|
||||||
v.logInfo(nil, "draining all contexts")
|
v.logInfo(nil, "draining all contexts")
|
||||||
v.drainAllContexts()
|
v.drainAllContexts()
|
||||||
|
|
||||||
@@ -346,6 +408,46 @@ func (v *V) HTTPServeMux() *http.ServeMux {
|
|||||||
return v.mux
|
return v.mux
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Static serves files from a filesystem directory at the given URL prefix.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// v.Static("/assets/", "./public")
|
||||||
|
func (v *V) Static(urlPrefix, dir string) {
|
||||||
|
if !strings.HasSuffix(urlPrefix, "/") {
|
||||||
|
urlPrefix += "/"
|
||||||
|
}
|
||||||
|
fileServer := http.StripPrefix(urlPrefix, http.FileServer(http.Dir(dir)))
|
||||||
|
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaticFS serves files from an [fs.FS] at the given URL prefix.
|
||||||
|
// This is useful with //go:embed filesystems.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// //go:embed static
|
||||||
|
// var staticFiles embed.FS
|
||||||
|
// v.StaticFS("/assets/", staticFiles)
|
||||||
|
func (v *V) StaticFS(urlPrefix string, fsys fs.FS) {
|
||||||
|
if !strings.HasSuffix(urlPrefix, "/") {
|
||||||
|
urlPrefix += "/"
|
||||||
|
}
|
||||||
|
fileServer := http.StripPrefix(urlPrefix, http.FileServerFS(fsys))
|
||||||
|
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// noDirListing wraps a file server handler to return 404 for directory requests.
|
||||||
|
func noDirListing(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.HasSuffix(r.URL.Path, "/") {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (v *V) ensureDatastarHandler() {
|
func (v *V) ensureDatastarHandler() {
|
||||||
v.datastarOnce.Do(func() {
|
v.datastarOnce.Do(func() {
|
||||||
v.mux.HandleFunc("GET "+v.datastarPath, func(w http.ResponseWriter, r *http.Request) {
|
v.mux.HandleFunc("GET "+v.datastarPath, func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -400,10 +502,7 @@ func (v *V) devModeRemovePersisted(c *Context) {
|
|||||||
}
|
}
|
||||||
file.Close()
|
file.Close()
|
||||||
|
|
||||||
// remove ctx to persisted list
|
delete(ctxRegMap, c.id)
|
||||||
if _, ok := ctxRegMap[c.id]; !ok {
|
|
||||||
delete(ctxRegMap, c.id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// write persisted list to file
|
// write persisted list to file
|
||||||
file, err = os.Create(p)
|
file, err = os.Create(p)
|
||||||
@@ -507,16 +606,16 @@ func New() *V {
|
|||||||
// use last-event-id to tell if request is a sse reconnect
|
// use last-event-id to tell if request is a sse reconnect
|
||||||
sse.Send(datastar.EventTypePatchElements, []string{}, datastar.WithSSEEventId("via"))
|
sse.Send(datastar.EventTypePatchElements, []string{}, datastar.WithSSEEventId("via"))
|
||||||
|
|
||||||
|
c.sseConnected.Store(true)
|
||||||
v.logDebug(c, "SSE connection established")
|
v.logDebug(c, "SSE connection established")
|
||||||
|
|
||||||
go func() {
|
go c.Sync()
|
||||||
c.Sync()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-sse.Context().Done():
|
case <-sse.Context().Done():
|
||||||
v.logDebug(c, "SSE connection ended")
|
v.logDebug(c, "SSE connection ended")
|
||||||
|
v.cleanupCtx(c)
|
||||||
return
|
return
|
||||||
case <-c.ctxDisposedChan:
|
case <-c.ctxDisposedChan:
|
||||||
v.logDebug(c, "context disposed, closing SSE")
|
v.logDebug(c, "context disposed, closing SSE")
|
||||||
@@ -572,13 +671,29 @@ func New() *V {
|
|||||||
v.logErr(nil, "action '%s' failed: %v", actionID, err)
|
v.logErr(nil, "action '%s' failed: %v", actionID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
csrfToken, _ := sigs["via-csrf"].(string)
|
||||||
|
if subtle.ConstantTimeCompare([]byte(csrfToken), []byte(c.csrfToken)) != 1 {
|
||||||
|
v.logWarn(c, "action '%s' rejected: invalid CSRF token", actionID)
|
||||||
|
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.actionLimiter != nil && !c.actionLimiter.Allow() {
|
||||||
|
v.logWarn(c, "action '%s' rate limited", actionID)
|
||||||
|
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.reqCtx = r.Context()
|
c.reqCtx = r.Context()
|
||||||
actionFn, err := c.getActionFn(actionID)
|
entry, err := c.getAction(actionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
v.logDebug(c, "action '%s' failed: %v", actionID, err)
|
v.logDebug(c, "action '%s' failed: %v", actionID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// log err if actionFn panics
|
if entry.limiter != nil && !entry.limiter.Allow() {
|
||||||
|
v.logWarn(c, "action '%s' rate limited (per-action)", actionID)
|
||||||
|
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// log err if action panics
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
v.logErr(c, "action '%s' failed: %v", actionID, r)
|
v.logErr(c, "action '%s' failed: %v", actionID, r)
|
||||||
@@ -586,7 +701,11 @@ func New() *V {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
c.injectSignals(sigs)
|
c.injectSignals(sigs)
|
||||||
actionFn()
|
if len(entry.middleware) > 0 {
|
||||||
|
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||||
|
} else {
|
||||||
|
entry.fn()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
|
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -603,20 +722,22 @@ func New() *V {
|
|||||||
v.logErr(c, "failed to handle session close: %v", err)
|
v.logErr(c, "failed to handle session close: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.dispose()
|
|
||||||
v.logDebug(c, "session close event triggered")
|
v.logDebug(c, "session close event triggered")
|
||||||
if v.cfg.DevMode {
|
v.cleanupCtx(c)
|
||||||
v.devModeRemovePersisted(c)
|
|
||||||
}
|
|
||||||
v.unregisterCtx(c)
|
|
||||||
})
|
})
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRandID() string {
|
func genRandID() string {
|
||||||
|
b := make([]byte, 4)
|
||||||
|
rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genCSRFToken() string {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
rand.Read(b)
|
rand.Read(b)
|
||||||
return hex.EncodeToString(b)[:8]
|
return hex.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractParams(pattern, path string) map[string]string {
|
func extractParams(pattern, path string) map[string]string {
|
||||||
@@ -631,7 +752,7 @@ func extractParams(pattern, path string) map[string]string {
|
|||||||
key := p[i][1 : len(p[i])-1] // remove {}
|
key := p[i][1 : len(p[i])-1] // remove {}
|
||||||
params[key] = u[i]
|
params[key] = u[i]
|
||||||
} else if p[i] != u[i] {
|
} else if p[i] != u[i] {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return params
|
return params
|
||||||
|
|||||||
148
via_test.go
148
via_test.go
@@ -1,9 +1,13 @@
|
|||||||
package via
|
package via
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ryanhamamura/via/h"
|
"github.com/ryanhamamura/via/h"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -128,6 +132,60 @@ func TestAction(t *testing.T) {
|
|||||||
assert.Contains(t, body, "/_action/")
|
assert.Contains(t, body, "/_action/")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEventTypes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
attr string
|
||||||
|
buildEl func(trigger *actionTrigger) h.H
|
||||||
|
}{
|
||||||
|
{"OnSubmit", "data-on:submit", func(tr *actionTrigger) h.H { return h.Form(tr.OnSubmit()) }},
|
||||||
|
{"OnInput", "data-on:input", func(tr *actionTrigger) h.H { return h.Input(tr.OnInput()) }},
|
||||||
|
{"OnFocus", "data-on:focus", func(tr *actionTrigger) h.H { return h.Input(tr.OnFocus()) }},
|
||||||
|
{"OnBlur", "data-on:blur", func(tr *actionTrigger) h.H { return h.Input(tr.OnBlur()) }},
|
||||||
|
{"OnMouseEnter", "data-on:mouseenter", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseEnter()) }},
|
||||||
|
{"OnMouseLeave", "data-on:mouseleave", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseLeave()) }},
|
||||||
|
{"OnScroll", "data-on:scroll", func(tr *actionTrigger) h.H { return h.Div(tr.OnScroll()) }},
|
||||||
|
{"OnDblClick", "data-on:dblclick", func(tr *actionTrigger) h.H { return h.Div(tr.OnDblClick()) }},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var trigger *actionTrigger
|
||||||
|
v := New()
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
trigger = c.Action(func() {})
|
||||||
|
c.View(func() h.H { return tt.buildEl(trigger) })
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, req)
|
||||||
|
body := w.Body.String()
|
||||||
|
assert.Contains(t, body, tt.attr)
|
||||||
|
assert.Contains(t, body, "/_action/"+trigger.id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("WithSignal", func(t *testing.T) {
|
||||||
|
var trigger *actionTrigger
|
||||||
|
var sig *signal
|
||||||
|
v := New()
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
trigger = c.Action(func() {})
|
||||||
|
sig = c.Signal("val")
|
||||||
|
c.View(func() h.H {
|
||||||
|
return h.Div(trigger.OnDblClick(WithSignal(sig, "x")))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, req)
|
||||||
|
body := w.Body.String()
|
||||||
|
assert.Contains(t, body, "data-on:dblclick")
|
||||||
|
assert.Contains(t, body, "$"+sig.ID()+"='x'")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestOnKeyDownWithWindow(t *testing.T) {
|
func TestOnKeyDownWithWindow(t *testing.T) {
|
||||||
var trigger *actionTrigger
|
var trigger *actionTrigger
|
||||||
v := New()
|
v := New()
|
||||||
@@ -235,3 +293,93 @@ func TestPage_PanicsOnNoView(t *testing.T) {
|
|||||||
v.Page("/", func(c *Context) {})
|
v.Page("/", func(c *Context) {})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReaperCleansOrphanedContexts(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
c := newContext("orphan-1", "/", v)
|
||||||
|
c.createdAt = time.Now().Add(-time.Minute) // created 1 min ago
|
||||||
|
v.registerCtx(c)
|
||||||
|
|
||||||
|
_, err := v.getCtx("orphan-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
v.reapOrphanedContexts(10 * time.Second)
|
||||||
|
|
||||||
|
_, err = v.getCtx("orphan-1")
|
||||||
|
assert.Error(t, err, "orphaned context should have been reaped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReaperIgnoresConnectedContexts(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
c := newContext("connected-1", "/", v)
|
||||||
|
c.createdAt = time.Now().Add(-time.Minute)
|
||||||
|
c.sseConnected.Store(true)
|
||||||
|
v.registerCtx(c)
|
||||||
|
|
||||||
|
v.reapOrphanedContexts(10 * time.Second)
|
||||||
|
|
||||||
|
_, err := v.getCtx("connected-1")
|
||||||
|
assert.NoError(t, err, "connected context should survive reaping")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReaperDisabledWithNegativeTTL(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.cfg.ContextTTL = -1
|
||||||
|
v.startReaper()
|
||||||
|
assert.Nil(t, v.reaperStop, "reaper should not start with negative TTL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupCtxIdempotent(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
c := newContext("idempotent-1", "/", v)
|
||||||
|
v.registerCtx(c)
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
v.cleanupCtx(c)
|
||||||
|
v.cleanupCtx(c)
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := v.getCtx("idempotent-1")
|
||||||
|
assert.Error(t, err, "context should be removed after cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDevModeRemovePersistedFix(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.cfg.DevMode = true
|
||||||
|
|
||||||
|
dir := filepath.Join(t.TempDir(), ".via", "devmode")
|
||||||
|
p := filepath.Join(dir, "ctx.json")
|
||||||
|
assert.NoError(t, os.MkdirAll(dir, 0755))
|
||||||
|
|
||||||
|
// Write a persisted context
|
||||||
|
ctxRegMap := map[string]string{"test-ctx-1": "/"}
|
||||||
|
f, err := os.Create(p)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, json.NewEncoder(f).Encode(ctxRegMap))
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
// Patch devModeRemovePersisted to use our temp path by calling it
|
||||||
|
// directly — we need to override the path. Instead, test via the
|
||||||
|
// actual function by temporarily changing the working dir.
|
||||||
|
origDir, _ := os.Getwd()
|
||||||
|
assert.NoError(t, os.Chdir(t.TempDir()))
|
||||||
|
defer os.Chdir(origDir)
|
||||||
|
|
||||||
|
// Re-create the structure in the temp dir
|
||||||
|
assert.NoError(t, os.MkdirAll(filepath.Join(".via", "devmode"), 0755))
|
||||||
|
p2 := filepath.Join(".via", "devmode", "ctx.json")
|
||||||
|
f2, _ := os.Create(p2)
|
||||||
|
json.NewEncoder(f2).Encode(map[string]string{"test-ctx-1": "/"})
|
||||||
|
f2.Close()
|
||||||
|
|
||||||
|
c := newContext("test-ctx-1", "/", v)
|
||||||
|
v.devModeRemovePersisted(c)
|
||||||
|
|
||||||
|
// Read back and verify
|
||||||
|
f3, err := os.Open(p2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer f3.Close()
|
||||||
|
var result map[string]string
|
||||||
|
assert.NoError(t, json.NewDecoder(f3).Decode(&result))
|
||||||
|
assert.Empty(t, result, "persisted context should be removed")
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ package vianats
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/delaneyj/toolbelt/embeddednats"
|
"github.com/delaneyj/toolbelt/embeddednats"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
@@ -76,3 +78,50 @@ func (n *NATS) Conn() *nats.Conn {
|
|||||||
func (n *NATS) JetStream() nats.JetStreamContext {
|
func (n *NATS) JetStream() nats.JetStreamContext {
|
||||||
return n.js
|
return n.js
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamConfig holds the parameters for creating or updating a JetStream stream.
|
||||||
|
type StreamConfig struct {
|
||||||
|
Name string
|
||||||
|
Subjects []string
|
||||||
|
MaxMsgs int64
|
||||||
|
MaxAge time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureStream creates or updates a JetStream stream matching cfg.
|
||||||
|
func EnsureStream(n *NATS, cfg StreamConfig) error {
|
||||||
|
_, err := n.js.AddStream(&nats.StreamConfig{
|
||||||
|
Name: cfg.Name,
|
||||||
|
Subjects: cfg.Subjects,
|
||||||
|
Retention: nats.LimitsPolicy,
|
||||||
|
MaxMsgs: cfg.MaxMsgs,
|
||||||
|
MaxAge: cfg.MaxAge,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplayHistory fetches the last limit messages from subject,
|
||||||
|
// deserializing each as T. Returns an empty slice if nothing is available.
|
||||||
|
func ReplayHistory[T any](n *NATS, subject string, limit int) ([]T, error) {
|
||||||
|
sub, err := n.js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer sub.Unsubscribe()
|
||||||
|
|
||||||
|
var msgs []T
|
||||||
|
for {
|
||||||
|
raw, err := sub.NextMsg(200 * time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
var msg T
|
||||||
|
if json.Unmarshal(raw.Data, &msg) == nil {
|
||||||
|
msgs = append(msgs, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit > 0 && len(msgs) > limit {
|
||||||
|
msgs = msgs[len(msgs)-limit:]
|
||||||
|
}
|
||||||
|
return msgs, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user