From e636970f7b602698335923caaa61b23cd1a27811 Mon Sep 17 00:00:00 2001 From: ryanhamamura <58859899+ryanhamamura@users.noreply.github.com> Date: Wed, 11 Feb 2026 13:50:02 -1000 Subject: [PATCH] feat: add middleware, route groups, and codebase cleanup * feat: add middleware example demonstrating route groups Self-contained example covering v.Use(), v.Group(), nested groups, Group.Use(), and middleware chaining with role-based access control. * feat: add per-action middleware via WithMiddleware ActionOption Reuses the existing Middleware type so the same auth/logging functions work at both page and action level. Middleware runs after CSRF and rate-limit checks, with full access to session and signals. * feat: add RedirectView helper and refactor session example to use middleware RedirectView lets middleware abort and redirect in one step. The session example now uses an authRequired middleware on a route group instead of an inline check inside the view. * fix: remove dead code, fix double Load and extractParams mismatch - Remove componentRegistry (written, never read) - Remove unused signal methods: Bytes, Int64, Float - Remove unreachable nil check in registerCtx - Simplify injectRouteParams (extractParams already returns fresh map) - Fix double sync.Map.Load in injectSignals - Merge Shutdown/shutdown into single method - Inline currSessionNum - Fix extractParams: mismatched literal segment now returns nil - Minor: new(bytes.Buffer), go c.Sync(), genRandID reads 4 bytes --- context.go | 28 ++- internal/examples/middleware/main.go | 151 ++++++++++++ internal/examples/session/main.go | 26 +- middleware.go | 82 +++++++ middleware_test.go | 340 +++++++++++++++++++++++++++ ratelimit.go | 5 +- signal.go | 23 -- via.go | 55 ++--- 8 files changed, 632 insertions(+), 78 deletions(-) create mode 100644 internal/examples/middleware/main.go create mode 100644 middleware.go create mode 100644 middleware_test.go diff --git a/context.go b/context.go index c396d07..e2f8df6 100644 --- a/context.go +++ b/context.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "maps" "reflect" "sync" "sync/atomic" @@ -25,8 +24,7 @@ type Context struct { app *V view func() h.H routeParams map[string]string - componentRegistry map[string]*Context - parentPageCtx *Context + parentPageCtx *Context patchChan chan patch actionLimiter *rate.Limiter actionRegistry map[string]actionEntry @@ -81,7 +79,6 @@ func (c *Context) Component(initCtx func(c *Context)) func() h.H { compCtx.parentPageCtx = c } initCtx(compCtx) - c.componentRegistry[id] = compCtx return compCtx.view } @@ -208,14 +205,14 @@ func (c *Context) injectSignals(sigs map[string]any) { defer c.mu.Unlock() for sigID, val := range sigs { - if _, ok := c.signals.Load(sigID); !ok { + item, ok := c.signals.Load(sigID) + if !ok { c.signals.Store(sigID, &signal{ id: sigID, val: val, }) continue } - item, _ := c.signals.Load(sigID) if sig, ok := item.(*signal); ok { sig.val = val sig.changed = false @@ -266,7 +263,7 @@ func (c *Context) sendPatch(p patch) { // Sync pushes the current view state and signal changes to the browser immediately // over the live SSE event stream. func (c *Context) Sync() { - elemsPatch := bytes.NewBuffer(make([]byte, 0)) + elemsPatch := new(bytes.Buffer) if err := c.view().Render(elemsPatch); err != nil { c.app.logErr(c, "sync view failed: %v", err) return @@ -331,6 +328,15 @@ func (c *Context) ExecScript(s string) { c.sendPatch(patch{patchTypeScript, s}) } +// RedirectView sets a view that redirects the browser to the given URL. +// Use this in middleware to abort the chain and redirect in one step. +func (c *Context) RedirectView(url string) { + c.View(func() h.H { + c.Redirect(url) + return h.Div() + }) +} + // Redirect navigates the browser to the given URL. // This triggers a full page navigation - the current context will be disposed // and a new context created at the destination URL. @@ -386,12 +392,9 @@ func (c *Context) injectRouteParams(params map[string]string) { if params == nil { return } - m := make(map[string]string) c.mu.Lock() defer c.mu.Unlock() - maps.Copy(m, params) - c.routeParams = m - + c.routeParams = params } // GetPathParam retrieves the value from the page request URL for the given parameter name @@ -488,8 +491,7 @@ func newContext(id string, route string, v *V) *Context { csrfToken: genCSRFToken(), routeParams: make(map[string]string), app: v, - componentRegistry: make(map[string]*Context), - actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst), + actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst), actionRegistry: make(map[string]actionEntry), signals: new(sync.Map), patchChan: make(chan patch, 1), diff --git a/internal/examples/middleware/main.go b/internal/examples/middleware/main.go new file mode 100644 index 0000000..e7672ca --- /dev/null +++ b/internal/examples/middleware/main.go @@ -0,0 +1,151 @@ +package main + +import ( + "fmt" + "time" + + "github.com/ryanhamamura/via" + "github.com/ryanhamamura/via/h" +) + +func main() { + v := via.New() + v.Config(via.Options{ + ServerAddress: ":8080", + DocumentTitle: "Middleware Example", + }) + + // --- Middleware definitions --- + + // requestLogger logs every page request to stdout. + requestLogger := func(c *via.Context, next func()) { + fmt.Printf("[%s] request\n", time.Now().Format("15:04:05")) + next() + } + + // authRequired redirects unauthenticated users to /login. + authRequired := func(c *via.Context, next func()) { + if c.Session().GetString("role") == "" { + c.RedirectView("/login") + return + } + next() + } + + // auditLog prints the authenticated username to stdout. + auditLog := func(c *via.Context, next func()) { + fmt.Printf("[audit] user=%s\n", c.Session().GetString("username")) + next() + } + + // superAdminOnly rejects non-superadmin users with a forbidden view. + superAdminOnly := func(c *via.Context, next func()) { + if c.Session().GetString("role") != "superadmin" { + c.View(func() h.H { + return h.Div( + h.H1(h.Text("Forbidden")), + h.P(h.Text("Super-admin access required.")), + h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")), + ) + }) + return + } + next() + } + + // --- Route registration --- + + v.Use(requestLogger) // global middleware + + admin := v.Group("/admin", authRequired) // prefixed group + admin.Use(auditLog) // Group.Use() + superAdmin := admin.Group("/super", superAdminOnly) // nested group + + // Public: redirect root to login + v.Page("/", func(c *via.Context) { + c.View(func() h.H { + c.Redirect("/login") + return h.Div() + }) + }) + + // Public: login page with role-selection buttons + v.Page("/login", func(c *via.Context) { + loginAdmin := c.Action(func() { + c.Session().Set("role", "admin") + c.Session().Set("username", "alice") + c.Session().RenewToken() + c.Redirect("/admin/dashboard") + }) + + loginSuper := c.Action(func() { + c.Session().Set("role", "superadmin") + c.Session().Set("username", "bob") + c.Session().RenewToken() + c.Redirect("/admin/dashboard") + }) + + c.View(func() h.H { + return h.Div( + h.H1(h.Text("Login")), + h.P(h.Text("Choose a role:")), + h.Button(h.Text("Login as Admin"), loginAdmin.OnClick()), + h.Raw(" "), + h.Button(h.Text("Login as Super Admin"), loginSuper.OnClick()), + ) + }) + }) + + // Per-action middleware: only superadmins can invoke this action. + requireSuperAdmin := func(c *via.Context, next func()) { + if c.Session().GetString("role") != "superadmin" { + return + } + next() + } + + // Admin: dashboard (requires authRequired + auditLog) + admin.Page("/dashboard", func(c *via.Context) { + logout := c.Action(func() { + c.Session().Delete("role") + c.Session().Delete("username") + c.Redirect("/login") + }) + + dangerAction := c.Action(func() { + fmt.Printf("[danger] executed by %s\n", c.Session().GetString("username")) + c.Sync() + }, via.WithMiddleware(requireSuperAdmin)) + + c.View(func() h.H { + username := c.Session().GetString("username") + role := c.Session().GetString("role") + return h.Div( + h.H1(h.Textf("Dashboard — %s (%s)", username, role)), + h.Ul( + h.Li(h.A(h.Href("/admin/super/settings"), h.Text("Super Admin Settings"))), + ), + h.H2(h.Text("Danger Zone")), + h.P(h.Text("This action is protected by per-action middleware (superadmin only):")), + h.Button(h.Text("Delete Everything"), dangerAction.OnClick()), + h.Br(), + h.Br(), + h.Button(h.Text("Logout"), logout.OnClick()), + ) + }) + }) + + // Super-admin: settings (requires authRequired + auditLog + superAdminOnly) + superAdmin.Page("/settings", func(c *via.Context) { + c.View(func() h.H { + username := c.Session().GetString("username") + return h.Div( + h.H1(h.Textf("Super Admin Settings — %s", username)), + h.P(h.Text("Only super-admins can see this page.")), + h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")), + ) + }) + }) + + v.Start() +} diff --git a/internal/examples/session/main.go b/internal/examples/session/main.go index 0674076..bb82e00 100644 --- a/internal/examples/session/main.go +++ b/internal/examples/session/main.go @@ -29,7 +29,17 @@ func main() { SessionManager: sm, }) - // Login page + // Auth middleware — redirects unauthenticated users to /login + authRequired := func(c *via.Context, next func()) { + if c.Session().GetString("username") == "" { + c.Session().Set("flash", "Please log in first") + c.RedirectView("/login") + return + } + next() + } + + // Login page (public) v.Page("/login", func(c *via.Context) { flash := c.Session().PopString("flash") usernameInput := c.Signal("") @@ -64,8 +74,10 @@ func main() { }) }) - // Dashboard page (protected) - v.Page("/dashboard", func(c *via.Context) { + // Protected pages + protected := v.Group("", authRequired) + + protected.Page("/dashboard", func(c *via.Context) { logout := c.Action(func() { c.Session().Set("flash", "Goodbye!") c.Session().Delete("username") @@ -74,14 +86,6 @@ func main() { c.View(func() h.H { username := c.Session().GetString("username") - - // Not logged in? Redirect to login - if username == "" { - c.Session().Set("flash", "Please log in first") - c.Redirect("/login") - return h.Div() - } - flash := c.Session().PopString("flash") var flashMsg h.H if flash != "" { diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..073bbae --- /dev/null +++ b/middleware.go @@ -0,0 +1,82 @@ +package via + +// Middleware wraps a page init function. Call next to continue the chain; +// return without calling next to abort (set a view first, e.g. RedirectView). +type Middleware func(c *Context, next func()) + +// Group is a route group with a shared prefix and middleware stack. +type Group struct { + v *V + prefix string + middleware []Middleware +} + +// Use appends middleware to the global stack. +// Global middleware runs before every page handler. +func (v *V) Use(mw ...Middleware) { + v.middleware = append(v.middleware, mw...) +} + +// Group creates a route group with the given path prefix and middleware. +// Routes registered on the group are prefixed and run the group's middleware +// after any global middleware. +func (v *V) Group(prefix string, mw ...Middleware) *Group { + return &Group{ + v: v, + prefix: prefix, + middleware: mw, + } +} + +// Page registers a route on this group. The full route is the group prefix +// concatenated with route. +func (g *Group) Page(route string, initContextFn func(c *Context)) { + fullRoute := g.prefix + route + allMw := make([]Middleware, 0, len(g.v.middleware)+len(g.middleware)) + allMw = append(allMw, g.v.middleware...) + allMw = append(allMw, g.middleware...) + wrapped := chainMiddleware(allMw, initContextFn) + g.v.page(fullRoute, initContextFn, wrapped) +} + +// Group creates a nested sub-group that inherits this group's prefix and +// middleware, then adds its own. +func (g *Group) Group(prefix string, mw ...Middleware) *Group { + combined := make([]Middleware, len(g.middleware), len(g.middleware)+len(mw)) + copy(combined, g.middleware) + combined = append(combined, mw...) + return &Group{ + v: g.v, + prefix: g.prefix + prefix, + middleware: combined, + } +} + +// Use appends middleware to this group's stack. +func (g *Group) Use(mw ...Middleware) { + g.middleware = append(g.middleware, mw...) +} + +// WithMiddleware returns an ActionOption that attaches middleware to an action. +// Action middleware runs after CSRF/rate-limit checks and signal injection. +func WithMiddleware(mw ...Middleware) ActionOption { + return func(e *actionEntry) { + e.middleware = append(e.middleware, mw...) + } +} + +// chainMiddleware wraps handler with the given middleware, outer-first. +func chainMiddleware(mws []Middleware, handler func(*Context)) func(*Context) { + if len(mws) == 0 { + return handler + } + chained := handler + for i := len(mws) - 1; i >= 0; i-- { + mw := mws[i] + next := chained + chained = func(c *Context) { + mw(c, func() { next(c) }) + } + } + return chained +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..513e2e3 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,340 @@ +package via + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/ryanhamamura/via/h" + "github.com/stretchr/testify/assert" +) + +func TestMiddlewareRunsBeforeHandler(t *testing.T) { + var order []string + + v := New() + v.Use(func(c *Context, next func()) { + order = append(order, "mw") + next() + }) + v.Page("/", func(c *Context) { + order = append(order, "handler") + c.View(func() h.H { return h.Div() }) + }) + + // Reset after registration (panic-check runs the raw handler) + order = nil + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil)) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, []string{"mw", "handler"}, order) +} + +func TestMiddlewareAbortSkipsHandler(t *testing.T) { + handlerCalled := false + + v := New() + v.Use(func(c *Context, next func()) { + c.RedirectView("/other") + }) + v.Page("/", func(c *Context) { + handlerCalled = true + c.View(func() h.H { return h.Div() }) + }) + + handlerCalled = false + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil)) + + assert.Equal(t, http.StatusOK, w.Code) + assert.False(t, handlerCalled) +} + +func TestMiddlewareChainOrder(t *testing.T) { + var order []string + + v := New() + for _, label := range []string{"A", "B", "C"} { + l := label + v.Use(func(c *Context, next func()) { + order = append(order, l) + next() + }) + } + v.Page("/", func(c *Context) { + order = append(order, "handler") + c.View(func() h.H { return h.Div() }) + }) + + order = nil + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil)) + + assert.Equal(t, []string{"A", "B", "C", "handler"}, order) +} + +func TestGroupPrefixRouting(t *testing.T) { + v := New() + g := v.Group("/admin") + g.Page("/dashboard", func(c *Context) { + c.View(func() h.H { return h.Div(h.Text("admin dashboard")) }) + }) + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dashboard", nil)) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "admin dashboard") +} + +func TestGroupMiddlewareAppliesToGroupOnly(t *testing.T) { + var groupMwCalled bool + + v := New() + g := v.Group("/admin", func(c *Context, next func()) { + groupMwCalled = true + next() + }) + g.Page("/panel", func(c *Context) { + c.View(func() h.H { return h.Div(h.Text("panel")) }) + }) + v.Page("/public", func(c *Context) { + c.View(func() h.H { return h.Div(h.Text("public")) }) + }) + + // Hit public page — group middleware should NOT run + groupMwCalled = false + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/public", nil)) + assert.False(t, groupMwCalled) + assert.Contains(t, w.Body.String(), "public") + + // Hit group page — group middleware should run + groupMwCalled = false + w = httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/panel", nil)) + assert.True(t, groupMwCalled) + assert.Contains(t, w.Body.String(), "panel") +} + +func TestGlobalMiddlewareAppliesToGroupPages(t *testing.T) { + var globalCalled bool + + v := New() + v.Use(func(c *Context, next func()) { + globalCalled = true + next() + }) + g := v.Group("/admin") + g.Page("/dash", func(c *Context) { + c.View(func() h.H { return h.Div(h.Text("dash")) }) + }) + + globalCalled = false + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dash", nil)) + + assert.True(t, globalCalled) + assert.Contains(t, w.Body.String(), "dash") +} + +func TestNestedGroupInheritsPrefixAndMiddleware(t *testing.T) { + var order []string + + v := New() + admin := v.Group("/admin", func(c *Context, next func()) { + order = append(order, "admin") + next() + }) + superAdmin := admin.Group("/super", func(c *Context, next func()) { + order = append(order, "super") + next() + }) + superAdmin.Page("/secret", func(c *Context) { + order = append(order, "handler") + c.View(func() h.H { return h.Div(h.Text("secret")) }) + }) + + order = nil + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/super/secret", nil)) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, []string{"admin", "super", "handler"}, order) + assert.Contains(t, w.Body.String(), "secret") +} + +func TestGroupUse(t *testing.T) { + var order []string + + v := New() + g := v.Group("/api") + g.Use(func(c *Context, next func()) { + order = append(order, "added-later") + next() + }) + g.Page("/items", func(c *Context) { + order = append(order, "handler") + c.View(func() h.H { return h.Div() }) + }) + + order = nil + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/api/items", nil)) + + assert.Equal(t, []string{"added-later", "handler"}, order) +} + +func TestRedirectViewSetsValidView(t *testing.T) { + v := New() + v.Page("/test", func(c *Context) { + c.RedirectView("/somewhere") + }) + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/test", nil)) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "") +} + +func TestGlobalAndGroupMiddlewareOrder(t *testing.T) { + var order []string + + v := New() + v.Use(func(c *Context, next func()) { + order = append(order, "global") + next() + }) + g := v.Group("/g", func(c *Context, next func()) { + order = append(order, "group") + next() + }) + g.Page("/page", func(c *Context) { + order = append(order, "handler") + c.View(func() h.H { return h.Div() }) + }) + + order = nil + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/g/page", nil)) + + assert.Equal(t, []string{"global", "group", "handler"}, order) +} + +// --- Action middleware tests --- + +func TestActionMiddlewareRunsBeforeAction(t *testing.T) { + var order []string + + v := New() + c := newContext("test", "/", v) + + mw := func(_ *Context, next func()) { + order = append(order, "mw") + next() + } + + trigger := c.Action(func() { + order = append(order, "action") + }, WithMiddleware(mw)) + + entry, err := c.getAction(trigger.id) + assert.NoError(t, err) + + chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c) + + assert.Equal(t, []string{"mw", "action"}, order) +} + +func TestActionMiddlewareAbortSkipsAction(t *testing.T) { + actionCalled := false + + v := New() + c := newContext("test", "/", v) + + mw := func(_ *Context, next func()) { + // don't call next — action should not run + } + + trigger := c.Action(func() { + actionCalled = true + }, WithMiddleware(mw)) + + entry, err := c.getAction(trigger.id) + assert.NoError(t, err) + + chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c) + + assert.False(t, actionCalled) +} + +func TestActionMiddlewareChainOrder(t *testing.T) { + var order []string + + v := New() + c := newContext("test", "/", v) + + var mws []Middleware + for _, label := range []string{"A", "B", "C"} { + l := label + mws = append(mws, func(_ *Context, next func()) { + order = append(order, l) + next() + }) + } + + trigger := c.Action(func() { + order = append(order, "action") + }, WithMiddleware(mws...)) + + entry, err := c.getAction(trigger.id) + assert.NoError(t, err) + + chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c) + + assert.Equal(t, []string{"A", "B", "C", "action"}, order) +} + +func TestActionMiddlewareCombinedWithRateLimit(t *testing.T) { + v := New() + c := newContext("test", "/", v) + + mw := func(_ *Context, next func()) { next() } + + trigger := c.Action(func() {}, WithRateLimit(5, 10), WithMiddleware(mw)) + + entry, err := c.getAction(trigger.id) + assert.NoError(t, err) + assert.NotNil(t, entry.limiter) + assert.Len(t, entry.middleware, 1) +} + +func TestGroupWithEmptyPrefix(t *testing.T) { + var mwCalled bool + + v := New() + g := v.Group("", func(c *Context, next func()) { + mwCalled = true + next() + }) + g.Page("/dashboard", func(c *Context) { + c.View(func() h.H { return h.Div(h.Text("dash")) }) + }) + + mwCalled = false + + w := httptest.NewRecorder() + v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/dashboard", nil)) + + assert.True(t, mwCalled) + assert.Contains(t, w.Body.String(), "dash") +} diff --git a/ratelimit.go b/ratelimit.go index 57dc20a..3bd66ba 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -18,8 +18,9 @@ type RateLimitConfig struct { type ActionOption func(*actionEntry) type actionEntry struct { - fn func() - limiter *rate.Limiter // nil = use context default + fn func() + limiter *rate.Limiter // nil = use context default + middleware []Middleware } // WithRateLimit returns an ActionOption that gives this action its own diff --git a/signal.go b/signal.go index 772f0f0..b4e51b2 100644 --- a/signal.go +++ b/signal.go @@ -81,26 +81,3 @@ func (s *signal) Int() int { return 0 } -// Int64 tries to read the signal value as an int64. -// Returns the value or 0 on failure. -func (s *signal) Int64() int64 { - if n, err := strconv.ParseInt(s.String(), 10, 64); err == nil { - return n - } - return 0 -} - -// Float64 tries to read the signal value as a float64. -// Returns the value or 0.0 on failure. -func (s *signal) Float() float64 { - if n, err := strconv.ParseFloat(s.String(), 64); err == nil { - return n - } - return 0.0 -} - -// Bytes tries to read the signal value as a []byte -// Returns the value or an empty []byte on failure. -func (s *signal) Bytes() []byte { - return []byte(s.String()) -} diff --git a/via.go b/via.go index bb52ab5..4d850b0 100644 --- a/via.go +++ b/via.go @@ -54,6 +54,7 @@ type V struct { datastarContent []byte datastarOnce sync.Once reaperStop chan struct{} + middleware []Middleware } func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event { @@ -170,8 +171,16 @@ func (v *V) AppendToFoot(elements ...h.H) { // }) // }) func (v *V) Page(route string, initContextFn func(c *Context)) { + wrapped := chainMiddleware(v.middleware, initContextFn) + v.page(route, initContextFn, wrapped) +} + +// page registers a route with separate raw and wrapped init functions. +// raw is used for the panic-check at registration time; wrapped includes +// any middleware and is used as the live handler. +func (v *V) page(route string, raw, wrapped func(*Context)) { v.ensureDatastarHandler() - // check for panics + // check for panics using the raw handler (no middleware) func() { defer func() { if err := recover(); err != nil { @@ -180,14 +189,13 @@ func (v *V) Page(route string, initContextFn func(c *Context)) { } }() c := newContext("", "", v) - initContextFn(c) + raw(c) c.view() c.stopAllRoutines() }() - // save page init function allows devmode to restore persisted ctx later if v.cfg.DevMode { - v.devModePageInitFnMap[route] = initContextFn + v.devModePageInitFnMap[route] = wrapped } v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { v.logDebug(nil, "GET %s", r.URL.String()) @@ -201,7 +209,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) { c.reqCtx = r.Context() routeParams := extractParams(route, r.URL.Path) c.injectRouteParams(routeParams) - initContextFn(c) + wrapped(c) v.registerCtx(c) if v.cfg.DevMode { v.devModePersist(c) @@ -226,8 +234,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) { Title: v.cfg.DocumentTitle, Head: headElements, Body: bodyElements, - HTMLAttrs: []h.H{}, - }) + }) _ = view.Render(w) })) } @@ -235,17 +242,9 @@ func (v *V) Page(route string, initContextFn func(c *Context)) { func (v *V) registerCtx(c *Context) { v.contextRegistryMutex.Lock() defer v.contextRegistryMutex.Unlock() - if c == nil { - v.logErr(c, "failed to add nil context to registry") - return - } v.contextRegistry[c.id] = c v.logDebug(c, "new context added to registry") - v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum()) -} - -func (v *V) currSessionNum() int { - return len(v.contextRegistry) + v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry)) } func (v *V) cleanupCtx(c *Context) { @@ -265,7 +264,7 @@ func (v *V) unregisterCtx(c *Context) { defer v.contextRegistryMutex.Unlock() v.logDebug(c, "ctx removed from registry") delete(v.contextRegistry, c.id) - v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum()) + v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry)) } func (v *V) getCtx(id string) (*Context, error) { @@ -355,16 +354,12 @@ func (v *V) Start() { return } - v.shutdown() + v.Shutdown() } // Shutdown gracefully shuts down the server and all contexts. // Safe for programmatic or test use. func (v *V) Shutdown() { - v.shutdown() -} - -func (v *V) shutdown() { if v.reaperStop != nil { close(v.reaperStop) } @@ -614,9 +609,7 @@ func New() *V { c.sseConnected.Store(true) v.logDebug(c, "SSE connection established") - go func() { - c.Sync() - }() + go c.Sync() for { select { @@ -708,7 +701,11 @@ func New() *V { }() c.injectSignals(sigs) - entry.fn() + if len(entry.middleware) > 0 { + chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c) + } else { + entry.fn() + } }) v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) { @@ -732,9 +729,9 @@ func New() *V { } func genRandID() string { - b := make([]byte, 16) + b := make([]byte, 4) rand.Read(b) - return hex.EncodeToString(b)[:8] + return hex.EncodeToString(b) } func genCSRFToken() string { @@ -755,7 +752,7 @@ func extractParams(pattern, path string) map[string]string { key := p[i][1 : len(p[i])-1] // remove {} params[key] = u[i] } else if p[i] != u[i] { - continue + return nil } } return params