Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10b4838f8d | ||
|
|
5362614c3e | ||
|
|
e636970f7b | ||
|
|
f5158b866c |
73
context.go
73
context.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -25,8 +24,7 @@ type Context struct {
|
|||||||
app *V
|
app *V
|
||||||
view func() h.H
|
view func() h.H
|
||||||
routeParams map[string]string
|
routeParams map[string]string
|
||||||
componentRegistry map[string]*Context
|
parentPageCtx *Context
|
||||||
parentPageCtx *Context
|
|
||||||
patchChan chan patch
|
patchChan chan patch
|
||||||
actionLimiter *rate.Limiter
|
actionLimiter *rate.Limiter
|
||||||
actionRegistry map[string]actionEntry
|
actionRegistry map[string]actionEntry
|
||||||
@@ -34,6 +32,7 @@ type Context struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ctxDisposedChan chan struct{}
|
ctxDisposedChan chan struct{}
|
||||||
reqCtx context.Context
|
reqCtx context.Context
|
||||||
|
fields []*Field
|
||||||
subscriptions []Subscription
|
subscriptions []Subscription
|
||||||
subsMu sync.Mutex
|
subsMu sync.Mutex
|
||||||
disposeOnce sync.Once
|
disposeOnce sync.Once
|
||||||
@@ -81,7 +80,6 @@ func (c *Context) Component(initCtx func(c *Context)) func() h.H {
|
|||||||
compCtx.parentPageCtx = c
|
compCtx.parentPageCtx = c
|
||||||
}
|
}
|
||||||
initCtx(compCtx)
|
initCtx(compCtx)
|
||||||
c.componentRegistry[id] = compCtx
|
|
||||||
return compCtx.view
|
return compCtx.view
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,14 +206,14 @@ func (c *Context) injectSignals(sigs map[string]any) {
|
|||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
for sigID, val := range sigs {
|
for sigID, val := range sigs {
|
||||||
if _, ok := c.signals.Load(sigID); !ok {
|
item, ok := c.signals.Load(sigID)
|
||||||
|
if !ok {
|
||||||
c.signals.Store(sigID, &signal{
|
c.signals.Store(sigID, &signal{
|
||||||
id: sigID,
|
id: sigID,
|
||||||
val: val,
|
val: val,
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
item, _ := c.signals.Load(sigID)
|
|
||||||
if sig, ok := item.(*signal); ok {
|
if sig, ok := item.(*signal); ok {
|
||||||
sig.val = val
|
sig.val = val
|
||||||
sig.changed = false
|
sig.changed = false
|
||||||
@@ -266,7 +264,7 @@ func (c *Context) sendPatch(p patch) {
|
|||||||
// Sync pushes the current view state and signal changes to the browser immediately
|
// Sync pushes the current view state and signal changes to the browser immediately
|
||||||
// over the live SSE event stream.
|
// over the live SSE event stream.
|
||||||
func (c *Context) Sync() {
|
func (c *Context) Sync() {
|
||||||
elemsPatch := bytes.NewBuffer(make([]byte, 0))
|
elemsPatch := new(bytes.Buffer)
|
||||||
if err := c.view().Render(elemsPatch); err != nil {
|
if err := c.view().Render(elemsPatch); err != nil {
|
||||||
c.app.logErr(c, "sync view failed: %v", err)
|
c.app.logErr(c, "sync view failed: %v", err)
|
||||||
return
|
return
|
||||||
@@ -331,6 +329,15 @@ func (c *Context) ExecScript(s string) {
|
|||||||
c.sendPatch(patch{patchTypeScript, s})
|
c.sendPatch(patch{patchTypeScript, s})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RedirectView sets a view that redirects the browser to the given URL.
|
||||||
|
// Use this in middleware to abort the chain and redirect in one step.
|
||||||
|
func (c *Context) RedirectView(url string) {
|
||||||
|
c.View(func() h.H {
|
||||||
|
c.Redirect(url)
|
||||||
|
return h.Div()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Redirect navigates the browser to the given URL.
|
// Redirect navigates the browser to the given URL.
|
||||||
// This triggers a full page navigation - the current context will be disposed
|
// This triggers a full page navigation - the current context will be disposed
|
||||||
// and a new context created at the destination URL.
|
// and a new context created at the destination URL.
|
||||||
@@ -386,12 +393,9 @@ func (c *Context) injectRouteParams(params map[string]string) {
|
|||||||
if params == nil {
|
if params == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m := make(map[string]string)
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
maps.Copy(m, params)
|
c.routeParams = params
|
||||||
c.routeParams = m
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPathParam retrieves the value from the page request URL for the given parameter name
|
// GetPathParam retrieves the value from the page request URL for the given parameter name
|
||||||
@@ -477,6 +481,50 @@ func (c *Context) unsubscribeAll() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Field creates a signal with validation rules attached.
|
||||||
|
// The initial value seeds both the signal and the reset target.
|
||||||
|
// The field is tracked on the context so ValidateAll/ResetFields
|
||||||
|
// can operate on all fields by default.
|
||||||
|
func (c *Context) Field(initial any, rules ...Rule) *Field {
|
||||||
|
f := &Field{
|
||||||
|
signal: c.Signal(initial),
|
||||||
|
rules: rules,
|
||||||
|
initialVal: initial,
|
||||||
|
}
|
||||||
|
target := c
|
||||||
|
if c.isComponent() {
|
||||||
|
target = c.parentPageCtx
|
||||||
|
}
|
||||||
|
target.fields = append(target.fields, f)
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAll runs Validate on each field, returning true only if all pass.
|
||||||
|
// With no arguments it validates every field tracked on this context.
|
||||||
|
func (c *Context) ValidateAll(fields ...*Field) bool {
|
||||||
|
if len(fields) == 0 {
|
||||||
|
fields = c.fields
|
||||||
|
}
|
||||||
|
ok := true
|
||||||
|
for _, f := range fields {
|
||||||
|
if !f.Validate() {
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetFields resets each field to its initial value and clears errors.
|
||||||
|
// With no arguments it resets every field tracked on this context.
|
||||||
|
func (c *Context) ResetFields(fields ...*Field) {
|
||||||
|
if len(fields) == 0 {
|
||||||
|
fields = c.fields
|
||||||
|
}
|
||||||
|
for _, f := range fields {
|
||||||
|
f.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newContext(id string, route string, v *V) *Context {
|
func newContext(id string, route string, v *V) *Context {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
panic("create context failed: app pointer is nil")
|
panic("create context failed: app pointer is nil")
|
||||||
@@ -488,8 +536,7 @@ func newContext(id string, route string, v *V) *Context {
|
|||||||
csrfToken: genCSRFToken(),
|
csrfToken: genCSRFToken(),
|
||||||
routeParams: make(map[string]string),
|
routeParams: make(map[string]string),
|
||||||
app: v,
|
app: v,
|
||||||
componentRegistry: make(map[string]*Context),
|
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
|
||||||
actionLimiter: newLimiter(v.actionRateLimit, defaultActionRate, defaultActionBurst),
|
|
||||||
actionRegistry: make(map[string]actionEntry),
|
actionRegistry: make(map[string]actionEntry),
|
||||||
signals: new(sync.Map),
|
signals: new(sync.Map),
|
||||||
patchChan: make(chan patch, 1),
|
patchChan: make(chan patch, 1),
|
||||||
|
|||||||
58
field.go
Normal file
58
field.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
// Field is a signal with built-in validation rules and error state.
|
||||||
|
// It embeds *signal, so all signal methods (Bind, String, Int, Bool, SetValue, Text, ID)
|
||||||
|
// work transparently.
|
||||||
|
type Field struct {
|
||||||
|
*signal
|
||||||
|
rules []Rule
|
||||||
|
errors []string
|
||||||
|
initialVal any
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate runs all rules against the current value.
|
||||||
|
// Clears previous errors, populates new ones, returns true if all rules pass.
|
||||||
|
func (f *Field) Validate() bool {
|
||||||
|
f.errors = nil
|
||||||
|
val := f.String()
|
||||||
|
for _, r := range f.rules {
|
||||||
|
if err := r.validate(val); err != nil {
|
||||||
|
f.errors = append(f.errors, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(f.errors) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasError returns true if this field has any validation errors.
|
||||||
|
func (f *Field) HasError() bool {
|
||||||
|
return len(f.errors) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstError returns the first validation error message, or "" if valid.
|
||||||
|
func (f *Field) FirstError() string {
|
||||||
|
if len(f.errors) > 0 {
|
||||||
|
return f.errors[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errors returns all current validation error messages.
|
||||||
|
func (f *Field) Errors() []string {
|
||||||
|
return f.errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddError manually adds an error message (useful for server-side or cross-field validation).
|
||||||
|
func (f *Field) AddError(msg string) {
|
||||||
|
f.errors = append(f.errors, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearErrors removes all validation errors from this field.
|
||||||
|
func (f *Field) ClearErrors() {
|
||||||
|
f.errors = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset restores the field value to its initial value and clears all errors.
|
||||||
|
func (f *Field) Reset() {
|
||||||
|
f.SetValue(f.initialVal)
|
||||||
|
f.errors = nil
|
||||||
|
}
|
||||||
206
field_test.go
Normal file
206
field_test.go
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ryanhamamura/via/h"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestField(initial any, rules ...Rule) *Field {
|
||||||
|
v := New()
|
||||||
|
var f *Field
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
f = c.Field(initial, rules...)
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldCreation(t *testing.T) {
|
||||||
|
f := newTestField("hello", Required())
|
||||||
|
assert.Equal(t, "hello", f.String())
|
||||||
|
assert.NotEmpty(t, f.ID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldSignalDelegation(t *testing.T) {
|
||||||
|
f := newTestField(42)
|
||||||
|
assert.Equal(t, "42", f.String())
|
||||||
|
assert.Equal(t, 42, f.Int())
|
||||||
|
|
||||||
|
f.SetValue("new")
|
||||||
|
assert.Equal(t, "new", f.String())
|
||||||
|
|
||||||
|
// Bind returns an h.H element
|
||||||
|
assert.NotNil(t, f.Bind())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldValidateSingleRule(t *testing.T) {
|
||||||
|
f := newTestField("", Required())
|
||||||
|
assert.False(t, f.Validate())
|
||||||
|
assert.True(t, f.HasError())
|
||||||
|
assert.Equal(t, "This field is required", f.FirstError())
|
||||||
|
|
||||||
|
f.SetValue("ok")
|
||||||
|
assert.True(t, f.Validate())
|
||||||
|
assert.False(t, f.HasError())
|
||||||
|
assert.Equal(t, "", f.FirstError())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldValidateMultipleRules(t *testing.T) {
|
||||||
|
f := newTestField("ab", Required(), MinLen(3))
|
||||||
|
assert.False(t, f.Validate())
|
||||||
|
errs := f.Errors()
|
||||||
|
assert.Len(t, errs, 1)
|
||||||
|
assert.Equal(t, "Must be at least 3 characters", errs[0])
|
||||||
|
|
||||||
|
f.SetValue("")
|
||||||
|
assert.False(t, f.Validate())
|
||||||
|
errs = f.Errors()
|
||||||
|
assert.Len(t, errs, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldErrors(t *testing.T) {
|
||||||
|
f := newTestField("")
|
||||||
|
assert.Nil(t, f.Errors())
|
||||||
|
assert.False(t, f.HasError())
|
||||||
|
assert.Equal(t, "", f.FirstError())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldAddError(t *testing.T) {
|
||||||
|
f := newTestField("ok")
|
||||||
|
f.AddError("username taken")
|
||||||
|
assert.True(t, f.HasError())
|
||||||
|
assert.Equal(t, "username taken", f.FirstError())
|
||||||
|
assert.Len(t, f.Errors(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldClearErrors(t *testing.T) {
|
||||||
|
f := newTestField("", Required())
|
||||||
|
f.Validate()
|
||||||
|
assert.True(t, f.HasError())
|
||||||
|
f.ClearErrors()
|
||||||
|
assert.False(t, f.HasError())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldReset(t *testing.T) {
|
||||||
|
f := newTestField("initial", Required(), MinLen(3))
|
||||||
|
f.SetValue("changed")
|
||||||
|
f.AddError("some error")
|
||||||
|
|
||||||
|
f.Reset()
|
||||||
|
assert.Equal(t, "initial", f.String())
|
||||||
|
assert.False(t, f.HasError())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAll(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
c.Field("", Required(), MinLen(3))
|
||||||
|
c.Field("", Required(), Email())
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
|
||||||
|
// both empty → both fail
|
||||||
|
assert.False(t, c.ValidateAll())
|
||||||
|
})
|
||||||
|
|
||||||
|
v2 := New()
|
||||||
|
v2.Page("/", func(c *Context) {
|
||||||
|
c.Field("joe", Required(), MinLen(3))
|
||||||
|
c.Field("joe@x.com", Required(), Email())
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
|
||||||
|
assert.True(t, c.ValidateAll())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAllPartialFailure(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
good := c.Field("valid", Required())
|
||||||
|
bad := c.Field("", Required())
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
|
||||||
|
ok := c.ValidateAll()
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.False(t, good.HasError())
|
||||||
|
assert.True(t, bad.HasError())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAllSelectiveArgs(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
a := c.Field("", Required())
|
||||||
|
b := c.Field("ok", Required())
|
||||||
|
cField := c.Field("", Required())
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
|
||||||
|
// only validate a and b — cField should be untouched
|
||||||
|
ok := c.ValidateAll(a, b)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.True(t, a.HasError())
|
||||||
|
assert.False(t, b.HasError())
|
||||||
|
assert.False(t, cField.HasError(), "unselected field should not be validated")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetFields(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
a := c.Field("a", Required())
|
||||||
|
b := c.Field("b", Required())
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
|
||||||
|
a.SetValue("changed-a")
|
||||||
|
b.SetValue("changed-b")
|
||||||
|
a.AddError("err")
|
||||||
|
|
||||||
|
c.ResetFields()
|
||||||
|
assert.Equal(t, "a", a.String())
|
||||||
|
assert.Equal(t, "b", b.String())
|
||||||
|
assert.False(t, a.HasError())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResetFieldsSelectiveArgs(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
a := c.Field("a")
|
||||||
|
b := c.Field("b")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
|
||||||
|
a.SetValue("changed-a")
|
||||||
|
b.SetValue("changed-b")
|
||||||
|
|
||||||
|
// only reset a
|
||||||
|
c.ResetFields(a)
|
||||||
|
assert.Equal(t, "a", a.String())
|
||||||
|
assert.Equal(t, "changed-b", b.String(), "unselected field should not be reset")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldValidateClearsPreviousErrors(t *testing.T) {
|
||||||
|
f := newTestField("", Required())
|
||||||
|
f.Validate()
|
||||||
|
assert.True(t, f.HasError())
|
||||||
|
|
||||||
|
f.SetValue("ok")
|
||||||
|
f.Validate()
|
||||||
|
assert.False(t, f.HasError())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldCustomValidator(t *testing.T) {
|
||||||
|
f := newTestField("bad", Custom(func(val string) error {
|
||||||
|
if val == "bad" {
|
||||||
|
return fmt.Errorf("no bad words")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
assert.False(t, f.Validate())
|
||||||
|
assert.Equal(t, "no bad words", f.FirstError())
|
||||||
|
|
||||||
|
f.SetValue("good")
|
||||||
|
assert.True(t, f.Validate())
|
||||||
|
}
|
||||||
151
internal/examples/middleware/main.go
Normal file
151
internal/examples/middleware/main.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ryanhamamura/via"
|
||||||
|
"github.com/ryanhamamura/via/h"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
v := via.New()
|
||||||
|
v.Config(via.Options{
|
||||||
|
ServerAddress: ":8080",
|
||||||
|
DocumentTitle: "Middleware Example",
|
||||||
|
})
|
||||||
|
|
||||||
|
// --- Middleware definitions ---
|
||||||
|
|
||||||
|
// requestLogger logs every page request to stdout.
|
||||||
|
requestLogger := func(c *via.Context, next func()) {
|
||||||
|
fmt.Printf("[%s] request\n", time.Now().Format("15:04:05"))
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// authRequired redirects unauthenticated users to /login.
|
||||||
|
authRequired := func(c *via.Context, next func()) {
|
||||||
|
if c.Session().GetString("role") == "" {
|
||||||
|
c.RedirectView("/login")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// auditLog prints the authenticated username to stdout.
|
||||||
|
auditLog := func(c *via.Context, next func()) {
|
||||||
|
fmt.Printf("[audit] user=%s\n", c.Session().GetString("username"))
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// superAdminOnly rejects non-superadmin users with a forbidden view.
|
||||||
|
superAdminOnly := func(c *via.Context, next func()) {
|
||||||
|
if c.Session().GetString("role") != "superadmin" {
|
||||||
|
c.View(func() h.H {
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Text("Forbidden")),
|
||||||
|
h.P(h.Text("Super-admin access required.")),
|
||||||
|
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Route registration ---
|
||||||
|
|
||||||
|
v.Use(requestLogger) // global middleware
|
||||||
|
|
||||||
|
admin := v.Group("/admin", authRequired) // prefixed group
|
||||||
|
admin.Use(auditLog) // Group.Use()
|
||||||
|
superAdmin := admin.Group("/super", superAdminOnly) // nested group
|
||||||
|
|
||||||
|
// Public: redirect root to login
|
||||||
|
v.Page("/", func(c *via.Context) {
|
||||||
|
c.View(func() h.H {
|
||||||
|
c.Redirect("/login")
|
||||||
|
return h.Div()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Public: login page with role-selection buttons
|
||||||
|
v.Page("/login", func(c *via.Context) {
|
||||||
|
loginAdmin := c.Action(func() {
|
||||||
|
c.Session().Set("role", "admin")
|
||||||
|
c.Session().Set("username", "alice")
|
||||||
|
c.Session().RenewToken()
|
||||||
|
c.Redirect("/admin/dashboard")
|
||||||
|
})
|
||||||
|
|
||||||
|
loginSuper := c.Action(func() {
|
||||||
|
c.Session().Set("role", "superadmin")
|
||||||
|
c.Session().Set("username", "bob")
|
||||||
|
c.Session().RenewToken()
|
||||||
|
c.Redirect("/admin/dashboard")
|
||||||
|
})
|
||||||
|
|
||||||
|
c.View(func() h.H {
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Text("Login")),
|
||||||
|
h.P(h.Text("Choose a role:")),
|
||||||
|
h.Button(h.Text("Login as Admin"), loginAdmin.OnClick()),
|
||||||
|
h.Raw(" "),
|
||||||
|
h.Button(h.Text("Login as Super Admin"), loginSuper.OnClick()),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Per-action middleware: only superadmins can invoke this action.
|
||||||
|
requireSuperAdmin := func(c *via.Context, next func()) {
|
||||||
|
if c.Session().GetString("role") != "superadmin" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin: dashboard (requires authRequired + auditLog)
|
||||||
|
admin.Page("/dashboard", func(c *via.Context) {
|
||||||
|
logout := c.Action(func() {
|
||||||
|
c.Session().Delete("role")
|
||||||
|
c.Session().Delete("username")
|
||||||
|
c.Redirect("/login")
|
||||||
|
})
|
||||||
|
|
||||||
|
dangerAction := c.Action(func() {
|
||||||
|
fmt.Printf("[danger] executed by %s\n", c.Session().GetString("username"))
|
||||||
|
c.Sync()
|
||||||
|
}, via.WithMiddleware(requireSuperAdmin))
|
||||||
|
|
||||||
|
c.View(func() h.H {
|
||||||
|
username := c.Session().GetString("username")
|
||||||
|
role := c.Session().GetString("role")
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Textf("Dashboard — %s (%s)", username, role)),
|
||||||
|
h.Ul(
|
||||||
|
h.Li(h.A(h.Href("/admin/super/settings"), h.Text("Super Admin Settings"))),
|
||||||
|
),
|
||||||
|
h.H2(h.Text("Danger Zone")),
|
||||||
|
h.P(h.Text("This action is protected by per-action middleware (superadmin only):")),
|
||||||
|
h.Button(h.Text("Delete Everything"), dangerAction.OnClick()),
|
||||||
|
h.Br(),
|
||||||
|
h.Br(),
|
||||||
|
h.Button(h.Text("Logout"), logout.OnClick()),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Super-admin: settings (requires authRequired + auditLog + superAdminOnly)
|
||||||
|
superAdmin.Page("/settings", func(c *via.Context) {
|
||||||
|
c.View(func() h.H {
|
||||||
|
username := c.Session().GetString("username")
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Textf("Super Admin Settings — %s", username)),
|
||||||
|
h.P(h.Text("Only super-admins can see this page.")),
|
||||||
|
h.A(h.Href("/admin/dashboard"), h.Text("Back to dashboard")),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
v.Start()
|
||||||
|
}
|
||||||
@@ -29,7 +29,17 @@ func main() {
|
|||||||
SessionManager: sm,
|
SessionManager: sm,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Login page
|
// Auth middleware — redirects unauthenticated users to /login
|
||||||
|
authRequired := func(c *via.Context, next func()) {
|
||||||
|
if c.Session().GetString("username") == "" {
|
||||||
|
c.Session().Set("flash", "Please log in first")
|
||||||
|
c.RedirectView("/login")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login page (public)
|
||||||
v.Page("/login", func(c *via.Context) {
|
v.Page("/login", func(c *via.Context) {
|
||||||
flash := c.Session().PopString("flash")
|
flash := c.Session().PopString("flash")
|
||||||
usernameInput := c.Signal("")
|
usernameInput := c.Signal("")
|
||||||
@@ -64,8 +74,10 @@ func main() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Dashboard page (protected)
|
// Protected pages
|
||||||
v.Page("/dashboard", func(c *via.Context) {
|
protected := v.Group("", authRequired)
|
||||||
|
|
||||||
|
protected.Page("/dashboard", func(c *via.Context) {
|
||||||
logout := c.Action(func() {
|
logout := c.Action(func() {
|
||||||
c.Session().Set("flash", "Goodbye!")
|
c.Session().Set("flash", "Goodbye!")
|
||||||
c.Session().Delete("username")
|
c.Session().Delete("username")
|
||||||
@@ -74,14 +86,6 @@ func main() {
|
|||||||
|
|
||||||
c.View(func() h.H {
|
c.View(func() h.H {
|
||||||
username := c.Session().GetString("username")
|
username := c.Session().GetString("username")
|
||||||
|
|
||||||
// Not logged in? Redirect to login
|
|
||||||
if username == "" {
|
|
||||||
c.Session().Set("flash", "Please log in first")
|
|
||||||
c.Redirect("/login")
|
|
||||||
return h.Div()
|
|
||||||
}
|
|
||||||
|
|
||||||
flash := c.Session().PopString("flash")
|
flash := c.Session().PopString("flash")
|
||||||
var flashMsg h.H
|
var flashMsg h.H
|
||||||
if flash != "" {
|
if flash != "" {
|
||||||
|
|||||||
87
internal/examples/signup/main.go
Normal file
87
internal/examples/signup/main.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ryanhamamura/via"
|
||||||
|
"github.com/ryanhamamura/via/h"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
v := via.New()
|
||||||
|
v.Config(via.Options{
|
||||||
|
DocumentTitle: "Signup",
|
||||||
|
ServerAddress: ":8080",
|
||||||
|
})
|
||||||
|
|
||||||
|
v.AppendToHead(h.StyleEl(h.Raw(`
|
||||||
|
body { font-family: system-ui, sans-serif; max-width: 420px; margin: 2rem auto; padding: 0 1rem; }
|
||||||
|
label { display: block; font-weight: 600; margin-top: 1rem; }
|
||||||
|
input { display: block; width: 100%; padding: 0.4rem; margin-top: 0.25rem; box-sizing: border-box; }
|
||||||
|
.error { color: #c00; font-size: 0.85rem; margin-top: 0.2rem; }
|
||||||
|
.success { color: #080; margin-top: 1rem; }
|
||||||
|
.actions { margin-top: 1.5rem; display: flex; gap: 0.5rem; }
|
||||||
|
`)))
|
||||||
|
|
||||||
|
v.Page("/", func(c *via.Context) {
|
||||||
|
username := c.Field("", via.Required(), via.MinLen(3), via.MaxLen(20))
|
||||||
|
email := c.Field("", via.Required(), via.Email())
|
||||||
|
age := c.Field("", via.Required(), via.Min(13), via.Max(120))
|
||||||
|
// Optional field — only validated when non-empty
|
||||||
|
website := c.Field("", via.Pattern(`^$|^https?://\S+$`, "Must be a valid URL"))
|
||||||
|
|
||||||
|
var success string
|
||||||
|
|
||||||
|
signup := c.Action(func() {
|
||||||
|
success = ""
|
||||||
|
if !c.ValidateAll() {
|
||||||
|
c.Sync()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Server-side check
|
||||||
|
if username.String() == "admin" {
|
||||||
|
username.AddError("Username is already taken")
|
||||||
|
c.Sync()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
success = "Account created for " + username.String() + "!"
|
||||||
|
c.ResetFields()
|
||||||
|
c.Sync()
|
||||||
|
})
|
||||||
|
|
||||||
|
reset := c.Action(func() {
|
||||||
|
success = ""
|
||||||
|
c.ResetFields()
|
||||||
|
c.Sync()
|
||||||
|
})
|
||||||
|
|
||||||
|
c.View(func() h.H {
|
||||||
|
return h.Div(
|
||||||
|
h.H1(h.Text("Sign Up")),
|
||||||
|
|
||||||
|
h.Label(h.Text("Username")),
|
||||||
|
h.Input(h.Type("text"), h.Placeholder("pick a username"), username.Bind()),
|
||||||
|
h.If(username.HasError(), h.Div(h.Class("error"), h.Text(username.FirstError()))),
|
||||||
|
|
||||||
|
h.Label(h.Text("Email")),
|
||||||
|
h.Input(h.Type("email"), h.Placeholder("you@example.com"), email.Bind()),
|
||||||
|
h.If(email.HasError(), h.Div(h.Class("error"), h.Text(email.FirstError()))),
|
||||||
|
|
||||||
|
h.Label(h.Text("Age")),
|
||||||
|
h.Input(h.Type("number"), h.Placeholder("your age"), age.Bind()),
|
||||||
|
h.If(age.HasError(), h.Div(h.Class("error"), h.Text(age.FirstError()))),
|
||||||
|
|
||||||
|
h.Label(h.Text("Website (optional)")),
|
||||||
|
h.Input(h.Type("url"), h.Placeholder("https://example.com"), website.Bind()),
|
||||||
|
h.If(website.HasError(), h.Div(h.Class("error"), h.Text(website.FirstError()))),
|
||||||
|
|
||||||
|
h.Div(h.Class("actions"),
|
||||||
|
h.Button(h.Text("Sign Up"), signup.OnClick()),
|
||||||
|
h.Button(h.Text("Reset"), reset.OnClick()),
|
||||||
|
),
|
||||||
|
|
||||||
|
h.If(success != "", h.P(h.Class("success"), h.Text(success))),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
v.Start()
|
||||||
|
}
|
||||||
82
middleware.go
Normal file
82
middleware.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
// Middleware wraps a page init function. Call next to continue the chain;
|
||||||
|
// return without calling next to abort (set a view first, e.g. RedirectView).
|
||||||
|
type Middleware func(c *Context, next func())
|
||||||
|
|
||||||
|
// Group is a route group with a shared prefix and middleware stack.
|
||||||
|
type Group struct {
|
||||||
|
v *V
|
||||||
|
prefix string
|
||||||
|
middleware []Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use appends middleware to the global stack.
|
||||||
|
// Global middleware runs before every page handler.
|
||||||
|
func (v *V) Use(mw ...Middleware) {
|
||||||
|
v.middleware = append(v.middleware, mw...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group creates a route group with the given path prefix and middleware.
|
||||||
|
// Routes registered on the group are prefixed and run the group's middleware
|
||||||
|
// after any global middleware.
|
||||||
|
func (v *V) Group(prefix string, mw ...Middleware) *Group {
|
||||||
|
return &Group{
|
||||||
|
v: v,
|
||||||
|
prefix: prefix,
|
||||||
|
middleware: mw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Page registers a route on this group. The full route is the group prefix
|
||||||
|
// concatenated with route.
|
||||||
|
func (g *Group) Page(route string, initContextFn func(c *Context)) {
|
||||||
|
fullRoute := g.prefix + route
|
||||||
|
allMw := make([]Middleware, 0, len(g.v.middleware)+len(g.middleware))
|
||||||
|
allMw = append(allMw, g.v.middleware...)
|
||||||
|
allMw = append(allMw, g.middleware...)
|
||||||
|
wrapped := chainMiddleware(allMw, initContextFn)
|
||||||
|
g.v.page(fullRoute, initContextFn, wrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group creates a nested sub-group that inherits this group's prefix and
|
||||||
|
// middleware, then adds its own.
|
||||||
|
func (g *Group) Group(prefix string, mw ...Middleware) *Group {
|
||||||
|
combined := make([]Middleware, len(g.middleware), len(g.middleware)+len(mw))
|
||||||
|
copy(combined, g.middleware)
|
||||||
|
combined = append(combined, mw...)
|
||||||
|
return &Group{
|
||||||
|
v: g.v,
|
||||||
|
prefix: g.prefix + prefix,
|
||||||
|
middleware: combined,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use appends middleware to this group's stack.
|
||||||
|
func (g *Group) Use(mw ...Middleware) {
|
||||||
|
g.middleware = append(g.middleware, mw...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMiddleware returns an ActionOption that attaches middleware to an action.
|
||||||
|
// Action middleware runs after CSRF/rate-limit checks and signal injection.
|
||||||
|
func WithMiddleware(mw ...Middleware) ActionOption {
|
||||||
|
return func(e *actionEntry) {
|
||||||
|
e.middleware = append(e.middleware, mw...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chainMiddleware wraps handler with the given middleware, outer-first.
|
||||||
|
func chainMiddleware(mws []Middleware, handler func(*Context)) func(*Context) {
|
||||||
|
if len(mws) == 0 {
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
chained := handler
|
||||||
|
for i := len(mws) - 1; i >= 0; i-- {
|
||||||
|
mw := mws[i]
|
||||||
|
next := chained
|
||||||
|
chained = func(c *Context) {
|
||||||
|
mw(c, func() { next(c) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return chained
|
||||||
|
}
|
||||||
340
middleware_test.go
Normal file
340
middleware_test.go
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ryanhamamura/via/h"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMiddlewareRunsBeforeHandler(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
order = append(order, "mw")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
// Reset after registration (panic-check runs the raw handler)
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, []string{"mw", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareAbortSkipsHandler(t *testing.T) {
|
||||||
|
handlerCalled := false
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
c.RedirectView("/other")
|
||||||
|
})
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
handlerCalled = true
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
handlerCalled = false
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.False(t, handlerCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareChainOrder(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
for _, label := range []string{"A", "B", "C"} {
|
||||||
|
l := label
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
order = append(order, l)
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
v.Page("/", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"A", "B", "C", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupPrefixRouting(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
g := v.Group("/admin")
|
||||||
|
g.Page("/dashboard", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("admin dashboard")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dashboard", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "admin dashboard")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupMiddlewareAppliesToGroupOnly(t *testing.T) {
|
||||||
|
var groupMwCalled bool
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
g := v.Group("/admin", func(c *Context, next func()) {
|
||||||
|
groupMwCalled = true
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g.Page("/panel", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("panel")) })
|
||||||
|
})
|
||||||
|
v.Page("/public", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("public")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hit public page — group middleware should NOT run
|
||||||
|
groupMwCalled = false
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/public", nil))
|
||||||
|
assert.False(t, groupMwCalled)
|
||||||
|
assert.Contains(t, w.Body.String(), "public")
|
||||||
|
|
||||||
|
// Hit group page — group middleware should run
|
||||||
|
groupMwCalled = false
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/panel", nil))
|
||||||
|
assert.True(t, groupMwCalled)
|
||||||
|
assert.Contains(t, w.Body.String(), "panel")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalMiddlewareAppliesToGroupPages(t *testing.T) {
|
||||||
|
var globalCalled bool
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
globalCalled = true
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g := v.Group("/admin")
|
||||||
|
g.Page("/dash", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("dash")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
globalCalled = false
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/dash", nil))
|
||||||
|
|
||||||
|
assert.True(t, globalCalled)
|
||||||
|
assert.Contains(t, w.Body.String(), "dash")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNestedGroupInheritsPrefixAndMiddleware(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
admin := v.Group("/admin", func(c *Context, next func()) {
|
||||||
|
order = append(order, "admin")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
superAdmin := admin.Group("/super", func(c *Context, next func()) {
|
||||||
|
order = append(order, "super")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
superAdmin.Page("/secret", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("secret")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/admin/super/secret", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, []string{"admin", "super", "handler"}, order)
|
||||||
|
assert.Contains(t, w.Body.String(), "secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupUse(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
g := v.Group("/api")
|
||||||
|
g.Use(func(c *Context, next func()) {
|
||||||
|
order = append(order, "added-later")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g.Page("/items", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/api/items", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"added-later", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedirectViewSetsValidView(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
v.Page("/test", func(c *Context) {
|
||||||
|
c.RedirectView("/somewhere")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/test", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "<!doctype html>")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalAndGroupMiddlewareOrder(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Use(func(c *Context, next func()) {
|
||||||
|
order = append(order, "global")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g := v.Group("/g", func(c *Context, next func()) {
|
||||||
|
order = append(order, "group")
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g.Page("/page", func(c *Context) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
c.View(func() h.H { return h.Div() })
|
||||||
|
})
|
||||||
|
|
||||||
|
order = nil
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/g/page", nil))
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"global", "group", "handler"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Action middleware tests ---
|
||||||
|
|
||||||
|
func TestActionMiddlewareRunsBeforeAction(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
c := newContext("test", "/", v)
|
||||||
|
|
||||||
|
mw := func(_ *Context, next func()) {
|
||||||
|
order = append(order, "mw")
|
||||||
|
next()
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger := c.Action(func() {
|
||||||
|
order = append(order, "action")
|
||||||
|
}, WithMiddleware(mw))
|
||||||
|
|
||||||
|
entry, err := c.getAction(trigger.id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"mw", "action"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestActionMiddlewareAbortSkipsAction(t *testing.T) {
|
||||||
|
actionCalled := false
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
c := newContext("test", "/", v)
|
||||||
|
|
||||||
|
mw := func(_ *Context, next func()) {
|
||||||
|
// don't call next — action should not run
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger := c.Action(func() {
|
||||||
|
actionCalled = true
|
||||||
|
}, WithMiddleware(mw))
|
||||||
|
|
||||||
|
entry, err := c.getAction(trigger.id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||||
|
|
||||||
|
assert.False(t, actionCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestActionMiddlewareChainOrder(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
c := newContext("test", "/", v)
|
||||||
|
|
||||||
|
var mws []Middleware
|
||||||
|
for _, label := range []string{"A", "B", "C"} {
|
||||||
|
l := label
|
||||||
|
mws = append(mws, func(_ *Context, next func()) {
|
||||||
|
order = append(order, l)
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger := c.Action(func() {
|
||||||
|
order = append(order, "action")
|
||||||
|
}, WithMiddleware(mws...))
|
||||||
|
|
||||||
|
entry, err := c.getAction(trigger.id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"A", "B", "C", "action"}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestActionMiddlewareCombinedWithRateLimit(t *testing.T) {
|
||||||
|
v := New()
|
||||||
|
c := newContext("test", "/", v)
|
||||||
|
|
||||||
|
mw := func(_ *Context, next func()) { next() }
|
||||||
|
|
||||||
|
trigger := c.Action(func() {}, WithRateLimit(5, 10), WithMiddleware(mw))
|
||||||
|
|
||||||
|
entry, err := c.getAction(trigger.id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, entry.limiter)
|
||||||
|
assert.Len(t, entry.middleware, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupWithEmptyPrefix(t *testing.T) {
|
||||||
|
var mwCalled bool
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
g := v.Group("", func(c *Context, next func()) {
|
||||||
|
mwCalled = true
|
||||||
|
next()
|
||||||
|
})
|
||||||
|
g.Page("/dashboard", func(c *Context) {
|
||||||
|
c.View(func() h.H { return h.Div(h.Text("dash")) })
|
||||||
|
})
|
||||||
|
|
||||||
|
mwCalled = false
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
v.mux.ServeHTTP(w, httptest.NewRequest("GET", "/dashboard", nil))
|
||||||
|
|
||||||
|
assert.True(t, mwCalled)
|
||||||
|
assert.Contains(t, w.Body.String(), "dash")
|
||||||
|
}
|
||||||
@@ -18,8 +18,9 @@ type RateLimitConfig struct {
|
|||||||
type ActionOption func(*actionEntry)
|
type ActionOption func(*actionEntry)
|
||||||
|
|
||||||
type actionEntry struct {
|
type actionEntry struct {
|
||||||
fn func()
|
fn func()
|
||||||
limiter *rate.Limiter // nil = use context default
|
limiter *rate.Limiter // nil = use context default
|
||||||
|
middleware []Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithRateLimit returns an ActionOption that gives this action its own
|
// WithRateLimit returns an ActionOption that gives this action its own
|
||||||
|
|||||||
130
rule.go
Normal file
130
rule.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rule defines a single validation check for a Field.
|
||||||
|
type Rule struct {
|
||||||
|
validate func(val string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required rejects empty or whitespace-only values.
|
||||||
|
func Required(msg ...string) Rule {
|
||||||
|
m := "This field is required"
|
||||||
|
if len(msg) > 0 {
|
||||||
|
m = msg[0]
|
||||||
|
}
|
||||||
|
return Rule{func(val string) error {
|
||||||
|
if strings.TrimSpace(val) == "" {
|
||||||
|
return errors.New(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinLen rejects values shorter than n characters.
|
||||||
|
func MinLen(n int, msg ...string) Rule {
|
||||||
|
m := fmt.Sprintf("Must be at least %d characters", n)
|
||||||
|
if len(msg) > 0 {
|
||||||
|
m = msg[0]
|
||||||
|
}
|
||||||
|
return Rule{func(val string) error {
|
||||||
|
if utf8.RuneCountInString(val) < n {
|
||||||
|
return errors.New(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxLen rejects values longer than n characters.
|
||||||
|
func MaxLen(n int, msg ...string) Rule {
|
||||||
|
m := fmt.Sprintf("Must be at most %d characters", n)
|
||||||
|
if len(msg) > 0 {
|
||||||
|
m = msg[0]
|
||||||
|
}
|
||||||
|
return Rule{func(val string) error {
|
||||||
|
if utf8.RuneCountInString(val) > n {
|
||||||
|
return errors.New(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min parses the value as an integer and rejects values less than n.
|
||||||
|
func Min(n int, msg ...string) Rule {
|
||||||
|
m := fmt.Sprintf("Must be at least %d", n)
|
||||||
|
if len(msg) > 0 {
|
||||||
|
m = msg[0]
|
||||||
|
}
|
||||||
|
return Rule{func(val string) error {
|
||||||
|
v, err := strconv.Atoi(val)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("Must be a valid number")
|
||||||
|
}
|
||||||
|
if v < n {
|
||||||
|
return errors.New(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max parses the value as an integer and rejects values greater than n.
|
||||||
|
func Max(n int, msg ...string) Rule {
|
||||||
|
m := fmt.Sprintf("Must be at most %d", n)
|
||||||
|
if len(msg) > 0 {
|
||||||
|
m = msg[0]
|
||||||
|
}
|
||||||
|
return Rule{func(val string) error {
|
||||||
|
v, err := strconv.Atoi(val)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("Must be a valid number")
|
||||||
|
}
|
||||||
|
if v > n {
|
||||||
|
return errors.New(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern rejects values that don't match the regular expression re.
|
||||||
|
func Pattern(re string, msg ...string) Rule {
|
||||||
|
m := "Invalid format"
|
||||||
|
if len(msg) > 0 {
|
||||||
|
m = msg[0]
|
||||||
|
}
|
||||||
|
compiled := regexp.MustCompile(re)
|
||||||
|
return Rule{func(val string) error {
|
||||||
|
if !compiled.MatchString(val) {
|
||||||
|
return errors.New(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
var emailRegexp = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
||||||
|
|
||||||
|
// Email rejects values that don't look like an email address.
|
||||||
|
func Email(msg ...string) Rule {
|
||||||
|
m := "Invalid email address"
|
||||||
|
if len(msg) > 0 {
|
||||||
|
m = msg[0]
|
||||||
|
}
|
||||||
|
return Rule{func(val string) error {
|
||||||
|
if !emailRegexp.MatchString(val) {
|
||||||
|
return errors.New(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom creates a rule from a user-provided validation function.
|
||||||
|
// The function should return nil for valid input and an error for invalid input.
|
||||||
|
func Custom(fn func(string) error) Rule {
|
||||||
|
return Rule{validate: fn}
|
||||||
|
}
|
||||||
116
rule_test.go
Normal file
116
rule_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequired(t *testing.T) {
|
||||||
|
r := Required()
|
||||||
|
assert.NoError(t, r.validate("hello"))
|
||||||
|
assert.Error(t, r.validate(""))
|
||||||
|
assert.Error(t, r.validate(" "))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequiredCustomMessage(t *testing.T) {
|
||||||
|
r := Required("name needed")
|
||||||
|
err := r.validate("")
|
||||||
|
assert.EqualError(t, err, "name needed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMinLen(t *testing.T) {
|
||||||
|
r := MinLen(3)
|
||||||
|
assert.NoError(t, r.validate("abc"))
|
||||||
|
assert.NoError(t, r.validate("abcd"))
|
||||||
|
assert.Error(t, r.validate("ab"))
|
||||||
|
assert.Error(t, r.validate(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMinLenCustomMessage(t *testing.T) {
|
||||||
|
r := MinLen(5, "too short")
|
||||||
|
err := r.validate("ab")
|
||||||
|
assert.EqualError(t, err, "too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxLen(t *testing.T) {
|
||||||
|
r := MaxLen(5)
|
||||||
|
assert.NoError(t, r.validate("abc"))
|
||||||
|
assert.NoError(t, r.validate("abcde"))
|
||||||
|
assert.Error(t, r.validate("abcdef"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxLenCustomMessage(t *testing.T) {
|
||||||
|
r := MaxLen(2, "too long")
|
||||||
|
err := r.validate("abc")
|
||||||
|
assert.EqualError(t, err, "too long")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMin(t *testing.T) {
|
||||||
|
r := Min(5)
|
||||||
|
assert.NoError(t, r.validate("5"))
|
||||||
|
assert.NoError(t, r.validate("10"))
|
||||||
|
assert.Error(t, r.validate("4"))
|
||||||
|
assert.Error(t, r.validate("abc"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMinCustomMessage(t *testing.T) {
|
||||||
|
r := Min(10, "need 10+")
|
||||||
|
err := r.validate("3")
|
||||||
|
assert.EqualError(t, err, "need 10+")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMax(t *testing.T) {
|
||||||
|
r := Max(10)
|
||||||
|
assert.NoError(t, r.validate("10"))
|
||||||
|
assert.NoError(t, r.validate("5"))
|
||||||
|
assert.Error(t, r.validate("11"))
|
||||||
|
assert.Error(t, r.validate("abc"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxCustomMessage(t *testing.T) {
|
||||||
|
r := Max(5, "too big")
|
||||||
|
err := r.validate("6")
|
||||||
|
assert.EqualError(t, err, "too big")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPattern(t *testing.T) {
|
||||||
|
r := Pattern(`^\d{3}$`)
|
||||||
|
assert.NoError(t, r.validate("123"))
|
||||||
|
assert.Error(t, r.validate("12"))
|
||||||
|
assert.Error(t, r.validate("abcd"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternCustomMessage(t *testing.T) {
|
||||||
|
r := Pattern(`^\d+$`, "digits only")
|
||||||
|
err := r.validate("abc")
|
||||||
|
assert.EqualError(t, err, "digits only")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmail(t *testing.T) {
|
||||||
|
r := Email()
|
||||||
|
assert.NoError(t, r.validate("user@example.com"))
|
||||||
|
assert.NoError(t, r.validate("a.b+c@foo.co"))
|
||||||
|
assert.Error(t, r.validate("notanemail"))
|
||||||
|
assert.Error(t, r.validate("@example.com"))
|
||||||
|
assert.Error(t, r.validate("user@"))
|
||||||
|
assert.Error(t, r.validate(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmailCustomMessage(t *testing.T) {
|
||||||
|
r := Email("bad email")
|
||||||
|
err := r.validate("nope")
|
||||||
|
assert.EqualError(t, err, "bad email")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustom(t *testing.T) {
|
||||||
|
r := Custom(func(val string) error {
|
||||||
|
if val != "magic" {
|
||||||
|
return fmt.Errorf("must be magic")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, r.validate("magic"))
|
||||||
|
assert.EqualError(t, r.validate("other"), "must be magic")
|
||||||
|
}
|
||||||
23
signal.go
23
signal.go
@@ -81,26 +81,3 @@ func (s *signal) Int() int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Int64 tries to read the signal value as an int64.
|
|
||||||
// Returns the value or 0 on failure.
|
|
||||||
func (s *signal) Int64() int64 {
|
|
||||||
if n, err := strconv.ParseInt(s.String(), 10, 64); err == nil {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Float64 tries to read the signal value as a float64.
|
|
||||||
// Returns the value or 0.0 on failure.
|
|
||||||
func (s *signal) Float() float64 {
|
|
||||||
if n, err := strconv.ParseFloat(s.String(), 64); err == nil {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
return 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytes tries to read the signal value as a []byte
|
|
||||||
// Returns the value or an empty []byte on failure.
|
|
||||||
func (s *signal) Bytes() []byte {
|
|
||||||
return []byte(s.String())
|
|
||||||
}
|
|
||||||
|
|||||||
143
static_test.go
Normal file
143
static_test.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package via
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStatic(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
os.MkdirAll(filepath.Join(dir, "sub"), 0755)
|
||||||
|
os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello world"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(dir, "sub", "nested.txt"), []byte("nested"), 0644)
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Static("/assets/", dir)
|
||||||
|
|
||||||
|
t.Run("serves file", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/hello.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "hello world", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("serves nested file", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/sub/nested.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "nested", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("directory listing returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("subdirectory listing returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/sub/", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing file returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/assets/nope.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticAutoSlash(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
os.WriteFile(filepath.Join(dir, "ok.txt"), []byte("ok"), 0644)
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.Static("/files", dir) // no trailing slash
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/files/ok.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "ok", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticFS(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"style.css": {Data: []byte("body{}")},
|
||||||
|
"js/app.js": {Data: []byte("console.log('hi')")},
|
||||||
|
}
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.StaticFS("/static/", fsys)
|
||||||
|
|
||||||
|
t.Run("serves file", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/static/style.css", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "body{}", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("serves nested file", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/static/js/app.js", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "console.log('hi')", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("directory listing returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/static/", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing file returns 404", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/static/nope.css", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticFSAutoSlash(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"ok.txt": {Data: []byte("ok")},
|
||||||
|
}
|
||||||
|
|
||||||
|
v := New()
|
||||||
|
v.StaticFS("/embed", fsys) // no trailing slash
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/embed/ok.txt", nil)
|
||||||
|
v.mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "ok", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify StaticFS accepts the fs.FS interface (compile-time check).
|
||||||
|
var _ fs.FS = fstest.MapFS{}
|
||||||
98
via.go
98
via.go
@@ -9,12 +9,13 @@ package via
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
_ "embed"
|
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
_ "embed"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -53,6 +54,7 @@ type V struct {
|
|||||||
datastarContent []byte
|
datastarContent []byte
|
||||||
datastarOnce sync.Once
|
datastarOnce sync.Once
|
||||||
reaperStop chan struct{}
|
reaperStop chan struct{}
|
||||||
|
middleware []Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
|
func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event {
|
||||||
@@ -169,8 +171,16 @@ func (v *V) AppendToFoot(elements ...h.H) {
|
|||||||
// })
|
// })
|
||||||
// })
|
// })
|
||||||
func (v *V) Page(route string, initContextFn func(c *Context)) {
|
func (v *V) Page(route string, initContextFn func(c *Context)) {
|
||||||
|
wrapped := chainMiddleware(v.middleware, initContextFn)
|
||||||
|
v.page(route, initContextFn, wrapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
// page registers a route with separate raw and wrapped init functions.
|
||||||
|
// raw is used for the panic-check at registration time; wrapped includes
|
||||||
|
// any middleware and is used as the live handler.
|
||||||
|
func (v *V) page(route string, raw, wrapped func(*Context)) {
|
||||||
v.ensureDatastarHandler()
|
v.ensureDatastarHandler()
|
||||||
// check for panics
|
// check for panics using the raw handler (no middleware)
|
||||||
func() {
|
func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
@@ -179,14 +189,13 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
c := newContext("", "", v)
|
c := newContext("", "", v)
|
||||||
initContextFn(c)
|
raw(c)
|
||||||
c.view()
|
c.view()
|
||||||
c.stopAllRoutines()
|
c.stopAllRoutines()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// save page init function allows devmode to restore persisted ctx later
|
|
||||||
if v.cfg.DevMode {
|
if v.cfg.DevMode {
|
||||||
v.devModePageInitFnMap[route] = initContextFn
|
v.devModePageInitFnMap[route] = wrapped
|
||||||
}
|
}
|
||||||
v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
v.mux.HandleFunc("GET "+route, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
v.logDebug(nil, "GET %s", r.URL.String())
|
v.logDebug(nil, "GET %s", r.URL.String())
|
||||||
@@ -200,7 +209,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
c.reqCtx = r.Context()
|
c.reqCtx = r.Context()
|
||||||
routeParams := extractParams(route, r.URL.Path)
|
routeParams := extractParams(route, r.URL.Path)
|
||||||
c.injectRouteParams(routeParams)
|
c.injectRouteParams(routeParams)
|
||||||
initContextFn(c)
|
wrapped(c)
|
||||||
v.registerCtx(c)
|
v.registerCtx(c)
|
||||||
if v.cfg.DevMode {
|
if v.cfg.DevMode {
|
||||||
v.devModePersist(c)
|
v.devModePersist(c)
|
||||||
@@ -225,8 +234,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
Title: v.cfg.DocumentTitle,
|
Title: v.cfg.DocumentTitle,
|
||||||
Head: headElements,
|
Head: headElements,
|
||||||
Body: bodyElements,
|
Body: bodyElements,
|
||||||
HTMLAttrs: []h.H{},
|
})
|
||||||
})
|
|
||||||
_ = view.Render(w)
|
_ = view.Render(w)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@@ -234,17 +242,9 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
|
|||||||
func (v *V) registerCtx(c *Context) {
|
func (v *V) registerCtx(c *Context) {
|
||||||
v.contextRegistryMutex.Lock()
|
v.contextRegistryMutex.Lock()
|
||||||
defer v.contextRegistryMutex.Unlock()
|
defer v.contextRegistryMutex.Unlock()
|
||||||
if c == nil {
|
|
||||||
v.logErr(c, "failed to add nil context to registry")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
v.contextRegistry[c.id] = c
|
v.contextRegistry[c.id] = c
|
||||||
v.logDebug(c, "new context added to registry")
|
v.logDebug(c, "new context added to registry")
|
||||||
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
|
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
|
||||||
}
|
|
||||||
|
|
||||||
func (v *V) currSessionNum() int {
|
|
||||||
return len(v.contextRegistry)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *V) cleanupCtx(c *Context) {
|
func (v *V) cleanupCtx(c *Context) {
|
||||||
@@ -264,7 +264,7 @@ func (v *V) unregisterCtx(c *Context) {
|
|||||||
defer v.contextRegistryMutex.Unlock()
|
defer v.contextRegistryMutex.Unlock()
|
||||||
v.logDebug(c, "ctx removed from registry")
|
v.logDebug(c, "ctx removed from registry")
|
||||||
delete(v.contextRegistry, c.id)
|
delete(v.contextRegistry, c.id)
|
||||||
v.logDebug(nil, "number of sessions in registry: %d", v.currSessionNum())
|
v.logDebug(nil, "number of sessions in registry: %d", len(v.contextRegistry))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *V) getCtx(id string) (*Context, error) {
|
func (v *V) getCtx(id string) (*Context, error) {
|
||||||
@@ -354,16 +354,12 @@ func (v *V) Start() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
v.shutdown()
|
v.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown gracefully shuts down the server and all contexts.
|
// Shutdown gracefully shuts down the server and all contexts.
|
||||||
// Safe for programmatic or test use.
|
// Safe for programmatic or test use.
|
||||||
func (v *V) Shutdown() {
|
func (v *V) Shutdown() {
|
||||||
v.shutdown()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *V) shutdown() {
|
|
||||||
if v.reaperStop != nil {
|
if v.reaperStop != nil {
|
||||||
close(v.reaperStop)
|
close(v.reaperStop)
|
||||||
}
|
}
|
||||||
@@ -412,6 +408,46 @@ func (v *V) HTTPServeMux() *http.ServeMux {
|
|||||||
return v.mux
|
return v.mux
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Static serves files from a filesystem directory at the given URL prefix.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// v.Static("/assets/", "./public")
|
||||||
|
func (v *V) Static(urlPrefix, dir string) {
|
||||||
|
if !strings.HasSuffix(urlPrefix, "/") {
|
||||||
|
urlPrefix += "/"
|
||||||
|
}
|
||||||
|
fileServer := http.StripPrefix(urlPrefix, http.FileServer(http.Dir(dir)))
|
||||||
|
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaticFS serves files from an [fs.FS] at the given URL prefix.
|
||||||
|
// This is useful with //go:embed filesystems.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// //go:embed static
|
||||||
|
// var staticFiles embed.FS
|
||||||
|
// v.StaticFS("/assets/", staticFiles)
|
||||||
|
func (v *V) StaticFS(urlPrefix string, fsys fs.FS) {
|
||||||
|
if !strings.HasSuffix(urlPrefix, "/") {
|
||||||
|
urlPrefix += "/"
|
||||||
|
}
|
||||||
|
fileServer := http.StripPrefix(urlPrefix, http.FileServerFS(fsys))
|
||||||
|
v.mux.Handle("GET "+urlPrefix, noDirListing(fileServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// noDirListing wraps a file server handler to return 404 for directory requests.
|
||||||
|
func noDirListing(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.HasSuffix(r.URL.Path, "/") {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (v *V) ensureDatastarHandler() {
|
func (v *V) ensureDatastarHandler() {
|
||||||
v.datastarOnce.Do(func() {
|
v.datastarOnce.Do(func() {
|
||||||
v.mux.HandleFunc("GET "+v.datastarPath, func(w http.ResponseWriter, r *http.Request) {
|
v.mux.HandleFunc("GET "+v.datastarPath, func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -573,9 +609,7 @@ func New() *V {
|
|||||||
c.sseConnected.Store(true)
|
c.sseConnected.Store(true)
|
||||||
v.logDebug(c, "SSE connection established")
|
v.logDebug(c, "SSE connection established")
|
||||||
|
|
||||||
go func() {
|
go c.Sync()
|
||||||
c.Sync()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -667,7 +701,11 @@ func New() *V {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
c.injectSignals(sigs)
|
c.injectSignals(sigs)
|
||||||
entry.fn()
|
if len(entry.middleware) > 0 {
|
||||||
|
chainMiddleware(entry.middleware, func(_ *Context) { entry.fn() })(c)
|
||||||
|
} else {
|
||||||
|
entry.fn()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
|
v.mux.HandleFunc("POST /_session/close", func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -691,9 +729,9 @@ func New() *V {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func genRandID() string {
|
func genRandID() string {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 4)
|
||||||
rand.Read(b)
|
rand.Read(b)
|
||||||
return hex.EncodeToString(b)[:8]
|
return hex.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func genCSRFToken() string {
|
func genCSRFToken() string {
|
||||||
@@ -714,7 +752,7 @@ func extractParams(pattern, path string) map[string]string {
|
|||||||
key := p[i][1 : len(p[i])-1] // remove {}
|
key := p[i][1 : len(p[i])-1] // remove {}
|
||||||
params[key] = u[i]
|
params[key] = u[i]
|
||||||
} else if p[i] != u[i] {
|
} else if p[i] != u[i] {
|
||||||
continue
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return params
|
return params
|
||||||
|
|||||||
Reference in New Issue
Block a user