feat: add middleware, route groups, and codebase cleanup
* feat: add middleware example demonstrating route groups Self-contained example covering v.Use(), v.Group(), nested groups, Group.Use(), and middleware chaining with role-based access control. * feat: add per-action middleware via WithMiddleware ActionOption Reuses the existing Middleware type so the same auth/logging functions work at both page and action level. Middleware runs after CSRF and rate-limit checks, with full access to session and signals. * feat: add RedirectView helper and refactor session example to use middleware RedirectView lets middleware abort and redirect in one step. The session example now uses an authRequired middleware on a route group instead of an inline check inside the view. * fix: remove dead code, fix double Load and extractParams mismatch - Remove componentRegistry (written, never read) - Remove unused signal methods: Bytes, Int64, Float - Remove unreachable nil check in registerCtx - Simplify injectRouteParams (extractParams already returns fresh map) - Fix double sync.Map.Load in injectSignals - Merge Shutdown/shutdown into single method - Inline currSessionNum - Fix extractParams: mismatched literal segment now returns nil - Minor: new(bytes.Buffer), go c.Sync(), genRandID reads 4 bytes
This commit is contained in:
24
context.go
24
context.go
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -25,7 +24,6 @@ type Context struct {
|
||||
app *V
|
||||
view func() h.H
|
||||
routeParams map[string]string
|
||||
componentRegistry map[string]*Context
|
||||
parentPageCtx *Context
|
||||
patchChan chan patch
|
||||
actionLimiter *rate.Limiter
|
||||
@@ -81,7 +79,6 @@ func (c *Context) Component(initCtx func(c *Context)) func() h.H {
|
||||
compCtx.parentPageCtx = c
|
||||
}
|
||||
initCtx(compCtx)
|
||||
c.componentRegistry[id] = compCtx
|
||||
return compCtx.view
|
||||
}
|
||||
|
||||
@@ -208,14 +205,14 @@ func (c *Context) injectSignals(sigs map[string]any) {
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for sigID, val := range sigs {
|
||||
if _, ok := c.signals.Load(sigID); !ok {
|
||||
item, ok := c.signals.Load(sigID)
|
||||
if !ok {
|
||||
c.signals.Store(sigID, &signal{
|
||||
id: sigID,
|
||||
val: val,
|
||||
})
|
||||
continue
|
||||
}
|
||||
item, _ := c.signals.Load(sigID)
|
||||
if sig, ok := item.(*signal); ok {
|
||||
sig.val = val
|
||||
sig.changed = false
|
||||
@@ -266,7 +263,7 @@ func (c *Context) sendPatch(p patch) {
|
||||
// Sync pushes the current view state and signal changes to the browser immediately
|
||||
// over the live SSE event stream.
|
||||
func (c *Context) Sync() {
|
||||
elemsPatch := bytes.NewBuffer(make([]byte, 0))
|
||||
elemsPatch := new(bytes.Buffer)
|
||||
if err := c.view().Render(elemsPatch); err != nil {
|
||||
c.app.logErr(c, "sync view failed: %v", err)
|
||||
return
|
||||
@@ -331,6 +328,15 @@ func (c *Context) ExecScript(s string) {
|
||||
c.sendPatch(patch{patchTypeScript, s})
|
||||
}
|
||||
|
||||
// RedirectView sets a view that redirects the browser to the given URL.
|
||||
// Use this in middleware to abort the chain and redirect in one step.
|
||||
func (c *Context) RedirectView(url string) {
|
||||
c.View(func() h.H {
|
||||
c.Redirect(url)
|
||||
return h.Div()
|
||||
})
|
||||
}
|
||||
|
||||
// Redirect navigates the browser to the given URL.
|
||||
// This triggers a full page navigation - the current context will be disposed
|
||||
// and a new context created at the destination URL.
|
||||
@@ -386,12 +392,9 @@ func (c *Context) injectRouteParams(params map[string]string) {
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
m := make(map[string]string)
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
maps.Copy(m, params)
|
||||
c.routeParams = m
|
||||
|
||||
c.routeParams = params
|
||||
}
|
||||
|
||||
// GetPathParam retrieves the value from the page request URL for the given parameter name
|
||||
@@ -488,7 +491,6 @@ func newContext(id string, route string, v *V) *Context {
|
||||
csrfToken: genCSRFToken(),
|
||||
routeParams: make(map[string]string),
|
||||
app: v,
|
||||
componentRegistry: make(map[string]*Context),
|
||||
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
|
||||
actionRegistry: make(map[string]actionEntry),
|
||||
signals: new(sync.Map),
|
||||
|
||||
151
internal/examples/middleware/main.go
Normal file
151
internal/examples/middleware/main.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ryanhamamura/via"
|
||||
"github.com/ryanhamamura/via/h"
|
||||
)
|
||||
|
||||
func main() {
|
||||
v := via.New()
|
||||
v.Config(via.Options{
|
||||
ServerAddress: ":8080",
|
||||
DocumentTitle: "Middleware Example",
|
||||
})
|
||||
|
||||
// --- Middleware definitions ---
|
||||
|
||||
// requestLogger logs every page request to stdout.
|
||||
requestLogger := func(c *via.Context, next func()) {
|
||||
fmt.Printf("[%s] request\n", time.Now().Format("15:04:05"))
|
||||
next()
|
||||
}
|
||||
|
||||
// authRequired redirects unauthenticated users to /login.
|
||||
authRequired := func(c *via.Context, next func()) {
|
||||
if c.Session().GetString("role") == "" {
|
||||
c.RedirectView("/login")
|
||||
return
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
||||
// auditLog prints the authenticated username to stdout.
|
||||
auditLog := func(c *via.Context, next func()) {
|
||||
fmt.Printf("[audit] user=%s\n", c.Session().GetString("username"))
|
||||
next()
|
||||
}
|
||||
|
||||
// superAdminOnly rejects non-superadmin users with a forbidden view.
|
||||
superAdminOnly := func(c *via.Context, next func()) {
|
||||
if c.Session().GetString("role") != "superadmin" {
|
||||
c.View(func() h.H {
|
||||
return h.Div(
|
||||
h.H1(h.Text("Forbidden")),
|
||||
h.P(h.Text("Super-admin access required.")),
|
||||
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
|
||||
)
|
||||
})
|
||||
return
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
||||
// --- Route registration ---
|
||||
|
||||
v.Use(requestLogger) // global middleware
|
||||
|
||||
admin := v.Group("/admin", authRequired) // prefixed group
|
||||
admin.Use(auditLog) // Group.Use()
|
||||
superAdmin := admin.Group("/super", superAdminOnly) // nested group
|
||||
|
||||
// Public: redirect root to login
|
||||
v.Page("/", func(c *via.Context) {
|
||||
c.View(func() h.H {
|
||||
c.Redirect("/login")
|
||||
return h.Div()
|
||||
})
|
||||
})
|
||||
|
||||
// Public: login page with role-selection buttons
|
||||
v.Page("/login", func(c *via.Context) {
|
||||
loginAdmin := c.Action(func() {
|
||||
c.Session().Set("role", "admin")
|
||||
c.Session().Set("username", "alice")
|
||||
c.Session().RenewToken()
|
||||
c.Redirect("/admin/dashboard")
|
||||
})
|
||||
|
||||
loginSuper := c.Action(func() {
|
||||
c.Session().Set("role", "superadmin")
|
||||
c.Session().Set("username", "bob")
|
||||
c.Session().RenewToken()
|
||||
c.Redirect("/admin/dashboard")
|
||||
})
|
||||
|
||||
c.View(func() h.H {
|
||||
return h.Div(
|
||||
h.H1(h.Text("Login")),
|
||||
h.P(h.Text("Choose a role:")),
|
||||
h.Button(h.Text("Login as Admin"), loginAdmin.OnClick()),
|
||||
h.Raw(" "),
|
||||
h.Button(h.Text("Login as Super Admin"), loginSuper.OnClick()),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
// Per-action middleware: only superadmins can invoke this action.
|
||||
requireSuperAdmin := func(c *via.Context, next func()) {
|
||||
if c.Session().GetString("role") != "superadmin" {
|
||||
return
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
||||
// Admin: dashboard (requires authRequired + auditLog)
|
||||
admin.Page("/dashboard", func(c *via.Context) {
|
||||
logout := c.Action(func() {
|
||||
c.Session().Delete("role")
|
||||
c.Session().Delete("username")
|
||||
c.Redirect("/login")
|
||||
})
|
||||
|
||||
dangerAction := c.Action(func() {
|
||||
fmt.Printf("[danger] executed by %s\n", c.Session().GetString("username"))
|
||||
c.Sync()
|
||||
}, via.WithMiddleware(requireSuperAdmin))
|
||||
|
||||
c.View(func() h.H {
|
||||
username := c.Session().GetString("username")
|
||||
role := c.Session().GetString("role")
|
||||
return h.Div(
|
||||
h.H1(h.Textf("Dashboard — %s (%s)", username, role)),
|
||||
h.Ul(
|
||||
h.Li(h.A(h.Href("/admin/super/settings"), h.Text("Super Admin Settings"))),
|
||||
),
|
||||
h.H2(h.Text("Danger Zone")),
|
||||
h.P(h.Text("This action is protected by per-action middleware (superadmin only):")),
|
||||
h.Button(h.Text("Delete Everything"), dangerAction.OnClick()),
|
||||
h.Br(),
|
||||
h.Br(),
|
||||
h.Button(h.Text("Logout"), logout.OnClick()),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
// Super-admin: settings (requires authRequired + auditLog + superAdminOnly)
|
||||
superAdmin.Page("/settings", func(c *via.Context) {
|
||||
c.View(func() h.H {
|
||||
username := c.Session().GetString("username")
|
||||
return h.Div(
|
||||
h.H1(h.Textf("Super Admin Settings — %s", username)),
|
||||
h.P(h.Text("Only super-admins can see this page.")),
|
||||
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
v.Start()
|
||||
}
|
||||
@@ -29,7 +29,17 @@ func main() {
|
||||
SessionManager: sm,
|
||||
})
|
||||
|
||||
// Login page
|
||||
// Auth middleware — redirects unauthenticated users to /login
|
||||
authRequired := func(c *via.Context, next func()) {
|
||||
if c.Session().GetString("username") == "" {
|
||||
c.Session().Set("flash", "Please log in first")
|
||||
c.RedirectView("/login")
|
||||
return
|
||||
}
|
||||
next()
|
||||
}
|
||||
|
||||
// Login page (public)
|
||||
v.Page("/login", func(c *via.Context) {
|
||||
flash := c.Session().PopString("flash")
|
||||
usernameInput := c.Signal("")
|
||||
@@ -64,8 +74,10 @@ func main() {
|
||||
})
|
||||
})
|
||||
|
||||
// Dashboard page (protected)
|
||||
v.Page("/dashboard", func(c *via.Context) {
|
||||
// Protected pages
|
||||
protected := v.Group("", authRequired)
|
||||
|
||||
protected.Page("/dashboard", func(c *via.Context) {
|
||||
logout := c.Action(func() {
|
||||
c.Session().Set("flash", "Goodbye!")
|
||||
c.Session().Delete("username")
|
||||
@@ -74,14 +86,6 @@ func main() {
|
||||
|
||||
c.View(func() h.H {
|
||||
username := c.Session().GetString("username")
|
||||
|
||||
// Not logged in? Redirect to login
|
||||
if username == "" {
|
||||
c.Session().Set("flash", "Please log in first")
|
||||
c.Redirect("/login")
|
||||
return h.Div()
|
||||
}
|
||||
|
||||
flash := c.Session().PopString("flash")
|
||||
var flashMsg h.H
|
||||
if flash != "" {
|
||||
|
||||
82
middleware.go
Normal file
82
middleware.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package via
|
||||
|
||||
// Middleware wraps a page init function. Call next to continue the chain;
|
||||
// return without calling next to abort (set a view first, e.g. RedirectView).
|
||||
type Middleware func(c *Context, next func())
|
||||
|
||||
// Group is a route group with a shared prefix and middleware stack.
|
||||
type Group struct {
|
||||
v *V
|
||||
prefix string
|
||||
middleware []Middleware
|
||||
}
|
||||
|
||||
// Use appends middleware to the global stack.
|
||||
// Global middleware runs before every page handler.
|
||||
func (v *V) Use(mw ...Middleware) {
|
||||
v.middleware = append(v.middleware, mw...)
|
||||
}
|
||||
|
||||
// Group creates a route group with the given path prefix and middleware.
|
||||
// Routes registered on the group are prefixed and run the group's middleware
|
||||
// after any global middleware.
|
||||
func (v *V) Group(prefix string, mw ...Middleware) *Group {
|
||||
return &Group{
|
||||
v: v,
|
||||
prefix: prefix,
|
||||
middleware: mw,
|
||||
}
|
||||
}
|
||||
|
||||
// Page registers a route on this group. The full route is the group prefix
|
||||
// concatenated with route.
|
||||
func (g *Group) Page(route string, initContextFn func(c *Context)) {
|
||||
fullRoute := g.prefix + route
|
||||
allMw := make([]Middleware, 0, len(g.v.middleware)+len(g.middleware))
|
||||
allMw = append(allMw, g.v.middleware...)
|
||||
allMw = append(allMw, g.middleware...)
|
||||
wrapped := chainMiddleware(allMw, initContextFn)
|
||||
g.v.page(fullRoute, initContextFn, wrapped)
|
||||
}
|
||||
|
||||
// Group creates a nested sub-group that inherits this group's prefix and
|
||||
// middleware, then adds its own.
|
||||
func (g *Group) Group(prefix string, mw ...Middleware) *Group {
|
||||
combined := make([]Middleware, len(g.middleware), len(g.middleware)+len(mw))
|
||||
copy(combined, g.middleware)
|
||||
combined = append(combined, mw...)
|
||||
return &Group{
|
||||
v: g.v,
|
||||
prefix: g.prefix + prefix,
|
||||
middleware: combined,
|
||||
}
|
||||
}
|
||||
|
||||
// Use appends middleware to this group's stack.
|
||||
func (g *Group) Use(mw ...Middleware) {
|
||||
g.middleware = append(g.middleware, mw...)
|
||||
}
|
||||
|
||||
// WithMiddleware returns an ActionOption that attaches middleware to an action.
|
||||
// Action middleware runs after CSRF/rate-limit checks and signal injection.
|
||||
func WithMiddleware(mw ...Middleware) ActionOption {
|
||||
return func(e *actionEntry) {
|
||||
e.middleware = append(e.middleware, mw...)
|
||||
}
|
||||
}
|
||||
|
||||
// chainMiddleware wraps handler with the given middleware, outer-first.
|
||||
func chainMiddleware(mws []Middleware, handler func(*Context)) func(*Context) {
|
||||
if len(mws) == 0 {
|
||||
return handler
|
||||
}
|
||||
chained := handler
|
||||
for i := len(mws) - 1; i >= 0; i-- {
|
||||
mw := mws[i]
|
||||
next := chained
|
||||
chained = func(c *Context) {
|
||||
mw(c, func() { next(c) })
|
||||
}
|
||||
}
|
||||
return chained
|
||||
}
|
||||
340
middleware_test.go
Normal file
340
middleware_test.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package via
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/ryanhamamura/via/h"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMiddlewareRunsBeforeHandler(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
v.Use(func(c *Context, next func()) {
|
||||
order = append(order, "mw")
|
||||
next()
|
||||
})
|
||||
v.Page("/", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
// Reset after registration (panic-check runs the raw handler)
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, []string{"mw", "handler"}, order)
|
||||
}
|
||||
|
||||
func TestMiddlewareAbortSkipsHandler(t *testing.T) {
|
||||
handlerCalled := false
|
||||
|
||||
v := New()
|
||||
v.Use(func(c *Context, next func()) {
|
||||
c.RedirectView("/other")
|
||||
})
|
||||
v.Page("/", func(c *Context) {
|
||||
handlerCalled = true
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
handlerCalled = false
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.False(t, handlerCalled)
|
||||
}
|
||||
|
||||
func TestMiddlewareChainOrder(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
for _, label := range []string{"A", "B", "C"} {
|
||||
l := label
|
||||
v.Use(func(c *Context, next func()) {
|
||||
order = append(order, l)
|
||||
next()
|
||||
})
|
||||
}
|
||||
v.Page("/", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
assert.Equal(t, []string{"A", "B", "C", "handler"}, order)
|
||||
}
|
||||
|
||||
func TestGroupPrefixRouting(t *testing.T) {
|
||||
v := New()
|
||||
g := v.Group("/admin")
|
||||
g.Page("/dashboard", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("admin dashboard")) })
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dashboard", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "admin dashboard")
|
||||
}
|
||||
|
||||
func TestGroupMiddlewareAppliesToGroupOnly(t *testing.T) {
|
||||
var groupMwCalled bool
|
||||
|
||||
v := New()
|
||||
g := v.Group("/admin", func(c *Context, next func()) {
|
||||
groupMwCalled = true
|
||||
next()
|
||||
})
|
||||
g.Page("/panel", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("panel")) })
|
||||
})
|
||||
v.Page("/public", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("public")) })
|
||||
})
|
||||
|
||||
// Hit public page — group middleware should NOT run
|
||||
groupMwCalled = false
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/public", nil))
|
||||
assert.False(t, groupMwCalled)
|
||||
assert.Contains(t, w.Body.String(), "public")
|
||||
|
||||
// Hit group page — group middleware should run
|
||||
groupMwCalled = false
|
||||
w = httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/panel", nil))
|
||||
assert.True(t, groupMwCalled)
|
||||
assert.Contains(t, w.Body.String(), "panel")
|
||||
}
|
||||
|
||||
func TestGlobalMiddlewareAppliesToGroupPages(t *testing.T) {
|
||||
var globalCalled bool
|
||||
|
||||
v := New()
|
||||
v.Use(func(c *Context, next func()) {
|
||||
globalCalled = true
|
||||
next()
|
||||
})
|
||||
g := v.Group("/admin")
|
||||
g.Page("/dash", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("dash")) })
|
||||
})
|
||||
|
||||
globalCalled = false
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dash", nil))
|
||||
|
||||
assert.True(t, globalCalled)
|
||||
assert.Contains(t, w.Body.String(), "dash")
|
||||
}
|
||||
|
||||
func TestNestedGroupInheritsPrefixAndMiddleware(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
admin := v.Group("/admin", func(c *Context, next func()) {
|
||||
order = append(order, "admin")
|
||||
next()
|
||||
})
|
||||
superAdmin := admin.Group("/super", func(c *Context, next func()) {
|
||||
order = append(order, "super")
|
||||
next()
|
||||
})
|
||||
superAdmin.Page("/secret", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div(h.Text("secret")) })
|
||||
})
|
||||
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/super/secret", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, []string{"admin", "super", "handler"}, order)
|
||||
assert.Contains(t, w.Body.String(), "secret")
|
||||
}
|
||||
|
||||
func TestGroupUse(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
g := v.Group("/api")
|
||||
g.Use(func(c *Context, next func()) {
|
||||
order = append(order, "added-later")
|
||||
next()
|
||||
})
|
||||
g.Page("/items", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/api/items", nil))
|
||||
|
||||
assert.Equal(t, []string{"added-later", "handler"}, order)
|
||||
}
|
||||
|
||||
func TestRedirectViewSetsValidView(t *testing.T) {
|
||||
v := New()
|
||||
v.Page("/test", func(c *Context) {
|
||||
c.RedirectView("/somewhere")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/test", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "<!doctype html>")
|
||||
}
|
||||
|
||||
func TestGlobalAndGroupMiddlewareOrder(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
v.Use(func(c *Context, next func()) {
|
||||
order = append(order, "global")
|
||||
next()
|
||||
})
|
||||
g := v.Group("/g", func(c *Context, next func()) {
|
||||
order = append(order, "group")
|
||||
next()
|
||||
})
|
||||
g.Page("/page", func(c *Context) {
|
||||
order = append(order, "handler")
|
||||
c.View(func() h.H { return h.Div() })
|
||||
})
|
||||
|
||||
order = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/g/page", nil))
|
||||
|
||||
assert.Equal(t, []string{"global", "group", "handler"}, order)
|
||||
}
|
||||
|
||||
// --- Action middleware tests ---
|
||||
|
||||
func TestActionMiddlewareRunsBeforeAction(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
c := newContext("test", "/", v)
|
||||
|
||||
mw := func(_ *Context, next func()) {
|
||||
order = append(order, "mw")
|
||||
next()
|
||||
}
|
||||
|
||||
trigger := c.Action(func() {
|
||||
order = append(order, "action")
|
||||
}, WithMiddleware(mw))
|
||||
|
||||
entry, err := c.getAction(trigger.id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||
|
||||
assert.Equal(t, []string{"mw", "action"}, order)
|
||||
}
|
||||
|
||||
func TestActionMiddlewareAbortSkipsAction(t *testing.T) {
|
||||
actionCalled := false
|
||||
|
||||
v := New()
|
||||
c := newContext("test", "/", v)
|
||||
|
||||
mw := func(_ *Context, next func()) {
|
||||
// don't call next — action should not run
|
||||
}
|
||||
|
||||
trigger := c.Action(func() {
|
||||
actionCalled = true
|
||||
}, WithMiddleware(mw))
|
||||
|
||||
entry, err := c.getAction(trigger.id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||
|
||||
assert.False(t, actionCalled)
|
||||
}
|
||||
|
||||
func TestActionMiddlewareChainOrder(t *testing.T) {
|
||||
var order []string
|
||||
|
||||
v := New()
|
||||
c := newContext("test", "/", v)
|
||||
|
||||
var mws []Middleware
|
||||
for _, label := range []string{"A", "B", "C"} {
|
||||
l := label
|
||||
mws = append(mws, func(_ *Context, next func()) {
|
||||
order = append(order, l)
|
||||
next()
|
||||
})
|
||||
}
|
||||
|
||||
trigger := c.Action(func() {
|
||||
order = append(order, "action")
|
||||
}, WithMiddleware(mws...))
|
||||
|
||||
entry, err := c.getAction(trigger.id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||
|
||||
assert.Equal(t, []string{"A", "B", "C", "action"}, order)
|
||||
}
|
||||
|
||||
func TestActionMiddlewareCombinedWithRateLimit(t *testing.T) {
|
||||
v := New()
|
||||
c := newContext("test", "/", v)
|
||||
|
||||
mw := func(_ *Context, next func()) { next() }
|
||||
|
||||
trigger := c.Action(func() {}, WithRateLimit(5, 10), WithMiddleware(mw))
|
||||
|
||||
entry, err := c.getAction(trigger.id)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, entry.limiter)
|
||||
assert.Len(t, entry.middleware, 1)
|
||||
}
|
||||
|
||||
func TestGroupWithEmptyPrefix(t *testing.T) {
|
||||
var mwCalled bool
|
||||
|
||||
v := New()
|
||||
g := v.Group("", func(c *Context, next func()) {
|
||||
mwCalled = true
|
||||
next()
|
||||
})
|
||||
g.Page("/dashboard", func(c *Context) {
|
||||
c.View(func() h.H { return h.Div(h.Text("dash")) })
|
||||
})
|
||||
|
||||
mwCalled = false
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/dashboard", nil))
|
||||
|
||||
assert.True(t, mwCalled)
|
||||
assert.Contains(t, w.Body.String(), "dash")
|
||||
}
|
||||
@@ -20,6 +20,7 @@ type ActionOption func(*actionEntry)
|
||||
type actionEntry struct {
|
||||
fn func()
|
||||
limiter *rate.Limiter // nil = use context default
|
||||
middleware []Middleware
|
||||
}
|
||||
|
||||
// WithRateLimit returns an ActionOption that gives this action its own
|
||||
|
||||
23
signal.go
23
signal.go
@@ -81,26 +81,3 @@ func (s *signal) Int() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Int64 tries to read the signal value as an int64.
|
||||
// Returns the value or 0 on failure.
|
||||
func (s *signal) Int64() int64 {
|
||||
if n, err := strconv.ParseInt(s.String(), 10, 64); err == nil {
|
||||
return n
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Float64 tries to read the signal value as a float64.
|
||||
// Returns the value or 0.0 on failure.
|
||||
func (s *signal) Float() float64 {
|
||||
if n, err := strconv.ParseFloat(s.String(), 64); err == nil {
|
||||
return n
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// Bytes tries to read the signal value as a []byte
|
||||
// Returns the value or an empty []byte on failure.
|
||||
func (s *signal) Bytes() []byte {
|
||||
return []byte(s.String())
|
||||
}
|
||||
|
||||
51
via.go
51
via.go
@@ -54,6 +54,7 @@ type V struct {
|
||||
datastarContent []byte
|
||||
datastarOnce sync.Once
|
||||
reaperStop chan struct{}
|
||||
middleware []Middleware
|
||||
}
|
||||
|
||||
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
|
||||
@@ -170,8 +171,16 @@ func (v *V) AppendToFoot(elements ...h.H) {
|
||||
// })
|
||||
// })
|
||||
func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
wrapped := chainMiddleware(v.middleware, initContextFn)
|
||||
v.page(route, initContextFn, wrapped)
|
||||
}
|
||||
|
||||
// page registers a route with separate raw and wrapped init functions.
|
||||
// raw is used for the panic-check at registration time; wrapped includes
|
||||
// any middleware and is used as the live handler.
|
||||
func (v *V) page(route string, raw, wrapped func(*Context)) {
|
||||
v.ensureDatastarHandler()
|
||||
// check for panics
|
||||
// check for panics using the raw handler (no middleware)
|
||||
func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -180,14 +189,13 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
}
|
||||
}()
|
||||
c := newContext("", "", v)
|
||||
initContextFn(c)
|
||||
raw(c)
|
||||
c.view()
|
||||
c.stopAllRoutines()
|
||||
}()
|
||||
|
||||
// save page init function allows devmode to restore persisted ctx later
|
||||
if v.cfg.DevMode {
|
||||
v.devModePageInitFnMap[route] = initContextFn
|
||||
v.devModePageInitFnMap[route] = wrapped
|
||||
}
|
||||
v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
v.logDebug(nil, "GET %s", r.URL.String())
|
||||
@@ -201,7 +209,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
c.reqCtx = r.Context()
|
||||
routeParams := extractParams(route, r.URL.Path)
|
||||
c.injectRouteParams(routeParams)
|
||||
initContextFn(c)
|
||||
wrapped(c)
|
||||
v.registerCtx(c)
|
||||
if v.cfg.DevMode {
|
||||
v.devModePersist(c)
|
||||
@@ -226,7 +234,6 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
Title: v.cfg.DocumentTitle,
|
||||
Head: headElements,
|
||||
Body: bodyElements,
|
||||
HTMLAttrs: []h.H{},
|
||||
})
|
||||
_ = view.Render(w)
|
||||
}))
|
||||
@@ -235,17 +242,9 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||
func (v *V) registerCtx(c *Context) {
|
||||
v.contextRegistryMutex.Lock()
|
||||
defer v.contextRegistryMutex.Unlock()
|
||||
if c == nil {
|
||||
v.logErr(c, "failed to add nil context to registry")
|
||||
return
|
||||
}
|
||||
v.contextRegistry[c.id] = c
|
||||
v.logDebug(c, "new context added to registry")
|
||||
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
|
||||
}
|
||||
|
||||
func (v *V) currSessionNum() int {
|
||||
return len(v.contextRegistry)
|
||||
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
|
||||
}
|
||||
|
||||
func (v *V) cleanupCtx(c *Context) {
|
||||
@@ -265,7 +264,7 @@ func (v *V) unregisterCtx(c *Context) {
|
||||
defer v.contextRegistryMutex.Unlock()
|
||||
v.logDebug(c, "ctx removed from registry")
|
||||
delete(v.contextRegistry, c.id)
|
||||
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
|
||||
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
|
||||
}
|
||||
|
||||
func (v *V) getCtx(id string) (*Context, error) {
|
||||
@@ -355,16 +354,12 @@ func (v *V) Start() {
|
||||
return
|
||||
}
|
||||
|
||||
v.shutdown()
|
||||
v.Shutdown()
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server and all contexts.
|
||||
// Safe for programmatic or test use.
|
||||
func (v *V) Shutdown() {
|
||||
v.shutdown()
|
||||
}
|
||||
|
||||
func (v *V) shutdown() {
|
||||
if v.reaperStop != nil {
|
||||
close(v.reaperStop)
|
||||
}
|
||||
@@ -614,9 +609,7 @@ func New() *V {
|
||||
c.sseConnected.Store(true)
|
||||
v.logDebug(c, "SSE connection established")
|
||||
|
||||
go func() {
|
||||
c.Sync()
|
||||
}()
|
||||
go c.Sync()
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -708,7 +701,11 @@ func New() *V {
|
||||
}()
|
||||
|
||||
c.injectSignals(sigs)
|
||||
if len(entry.middleware) > 0 {
|
||||
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||
} else {
|
||||
entry.fn()
|
||||
}
|
||||
})
|
||||
|
||||
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -732,9 +729,9 @@ func New() *V {
|
||||
}
|
||||
|
||||
func genRandID() string {
|
||||
b := make([]byte, 16)
|
||||
b := make([]byte, 4)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)[:8]
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func genCSRFToken() string {
|
||||
@@ -755,7 +752,7 @@ func extractParams(pattern, path string) map[string]string {
|
||||
key := p[i][1 : len(p[i])-1] // remove {}
|
||||
params[key] = u[i]
|
||||
} else if p[i] != u[i] {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return params
|
||||
|
||||
Reference in New Issue
Block a user