2 Commits

Author SHA1 Message Date
Ryan Hamamura
b7acfa6302 feat: add automatic CSRF protection for action calls
Generate a per-context CSRF token (128-bit, crypto/rand) and inject it
as a Datastar signal (via-csrf) alongside via-ctx. Validate with
constant-time comparison on /_action/{id} before executing, returning
403 on mismatch. Transparent to users since Datastar sends all signals
automatically.

Closes #9
2026-02-06 11:17:41 -10:00
Ryan Hamamura
8aa91c577c feat: add event types OnSubmit, OnInput, OnFocus, OnBlur, OnMouseEnter, OnMouseLeave, OnScroll, OnDblClick 2026-02-06 10:54:27 -10:00
4 changed files with 118 additions and 1 deletions

View File

@@ -107,6 +107,54 @@ func (a *actionTrigger) OnChange(options ...ActionTriggerOption) h.H {
return h.Data("on:change__debounce.200ms", buildOnExpr(actionURL(a.id), &opts)) return h.Data("on:change__debounce.200ms", buildOnExpr(actionURL(a.id), &opts))
} }
// OnSubmit returns a via.h DOM attribute that triggers on form submit.
func (a *actionTrigger) OnSubmit(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:submit", buildOnExpr(actionURL(a.id), &opts))
}
// OnInput returns a via.h DOM attribute that triggers on input (without debounce).
func (a *actionTrigger) OnInput(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:input", buildOnExpr(actionURL(a.id), &opts))
}
// OnFocus returns a via.h DOM attribute that triggers when the element gains focus.
func (a *actionTrigger) OnFocus(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:focus", buildOnExpr(actionURL(a.id), &opts))
}
// OnBlur returns a via.h DOM attribute that triggers when the element loses focus.
func (a *actionTrigger) OnBlur(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:blur", buildOnExpr(actionURL(a.id), &opts))
}
// OnMouseEnter returns a via.h DOM attribute that triggers when the mouse enters the element.
func (a *actionTrigger) OnMouseEnter(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:mouseenter", buildOnExpr(actionURL(a.id), &opts))
}
// OnMouseLeave returns a via.h DOM attribute that triggers when the mouse leaves the element.
func (a *actionTrigger) OnMouseLeave(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:mouseleave", buildOnExpr(actionURL(a.id), &opts))
}
// OnScroll returns a via.h DOM attribute that triggers on scroll.
func (a *actionTrigger) OnScroll(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:scroll", buildOnExpr(actionURL(a.id), &opts))
}
// OnDblClick returns a via.h DOM attribute that triggers on double click.
func (a *actionTrigger) OnDblClick(options ...ActionTriggerOption) h.H {
opts := applyOptions(options...)
return h.Data("on:dblclick", buildOnExpr(actionURL(a.id), &opts))
}
// OnKeyDown returns a via.h DOM attribute that triggers when a key is pressed. // OnKeyDown returns a via.h DOM attribute that triggers when a key is pressed.
// key: optional, see https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key // key: optional, see https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key
// Example: OnKeyDown("Enter") // Example: OnKeyDown("Enter")

View File

@@ -20,6 +20,7 @@ import (
type Context struct { type Context struct {
id string id string
route string route string
csrfToken string
app *V app *V
view func() h.H view func() h.H
routeParams map[string]string routeParams map[string]string
@@ -477,6 +478,7 @@ func newContext(id string, route string, v *V) *Context {
return &Context{ return &Context{
id: id, id: id,
route: route, route: route,
csrfToken: genCSRFToken(),
routeParams: make(map[string]string), routeParams: make(map[string]string),
app: v, app: v,
componentRegistry: make(map[string]*Context), componentRegistry: make(map[string]*Context),

15
via.go
View File

@@ -10,6 +10,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
_ "embed" _ "embed"
"crypto/subtle"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -203,7 +204,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
headElements := []h.H{h.Script(h.Type("module"), h.Src(v.datastarPath))} headElements := []h.H{h.Script(h.Type("module"), h.Src(v.datastarPath))}
headElements = append(headElements, v.documentHeadIncludes...) headElements = append(headElements, v.documentHeadIncludes...)
headElements = append(headElements, headElements = append(headElements,
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s'}", id))), h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s','via-csrf':'%s'}", id, c.csrfToken))),
h.Meta(h.Data("init", "@get('/_sse')")), h.Meta(h.Data("init", "@get('/_sse')")),
h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => { h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => {
navigator.sendBeacon('/_session/close', '%s');});`, c.id))), navigator.sendBeacon('/_session/close', '%s');});`, c.id))),
@@ -632,6 +633,12 @@ func New() *V {
v.logErr(nil, "action '%s' failed: %v", actionID, err) v.logErr(nil, "action '%s' failed: %v", actionID, err)
return return
} }
csrfToken, _ := sigs["via-csrf"].(string)
if subtle.ConstantTimeCompare([]byte(csrfToken), []byte(c.csrfToken)) != 1 {
v.logWarn(c, "action '%s' rejected: invalid CSRF token", actionID)
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
c.reqCtx = r.Context() c.reqCtx = r.Context()
actionFn, err := c.getActionFn(actionID) actionFn, err := c.getActionFn(actionID)
if err != nil { if err != nil {
@@ -675,6 +682,12 @@ func genRandID() string {
return hex.EncodeToString(b)[:8] return hex.EncodeToString(b)[:8]
} }
func genCSRFToken() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
func extractParams(pattern, path string) map[string]string { func extractParams(pattern, path string) map[string]string {
p := strings.Split(strings.Trim(pattern, "/"), "/") p := strings.Split(strings.Trim(pattern, "/"), "/")
u := strings.Split(strings.Trim(path, "/"), "/") u := strings.Split(strings.Trim(path, "/"), "/")

View File

@@ -132,6 +132,60 @@ func TestAction(t *testing.T) {
assert.Contains(t, body, "/_action/") assert.Contains(t, body, "/_action/")
} }
func TestEventTypes(t *testing.T) {
tests := []struct {
name string
attr string
buildEl func(trigger *actionTrigger) h.H
}{
{"OnSubmit", "data-on:submit", func(tr *actionTrigger) h.H { return h.Form(tr.OnSubmit()) }},
{"OnInput", "data-on:input", func(tr *actionTrigger) h.H { return h.Input(tr.OnInput()) }},
{"OnFocus", "data-on:focus", func(tr *actionTrigger) h.H { return h.Input(tr.OnFocus()) }},
{"OnBlur", "data-on:blur", func(tr *actionTrigger) h.H { return h.Input(tr.OnBlur()) }},
{"OnMouseEnter", "data-on:mouseenter", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseEnter()) }},
{"OnMouseLeave", "data-on:mouseleave", func(tr *actionTrigger) h.H { return h.Div(tr.OnMouseLeave()) }},
{"OnScroll", "data-on:scroll", func(tr *actionTrigger) h.H { return h.Div(tr.OnScroll()) }},
{"OnDblClick", "data-on:dblclick", func(tr *actionTrigger) h.H { return h.Div(tr.OnDblClick()) }},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var trigger *actionTrigger
v := New()
v.Page("/", func(c *Context) {
trigger = c.Action(func() {})
c.View(func() h.H { return tt.buildEl(trigger) })
})
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, req)
body := w.Body.String()
assert.Contains(t, body, tt.attr)
assert.Contains(t, body, "/_action/"+trigger.id)
})
}
t.Run("WithSignal", func(t *testing.T) {
var trigger *actionTrigger
var sig *signal
v := New()
v.Page("/", func(c *Context) {
trigger = c.Action(func() {})
sig = c.Signal("val")
c.View(func() h.H {
return h.Div(trigger.OnDblClick(WithSignal(sig, "x")))
})
})
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
v.mux.ServeHTTP(w, req)
body := w.Body.String()
assert.Contains(t, body, "data-on:dblclick")
assert.Contains(t, body, "$"+sig.ID()+"='x'")
})
}
func TestOnKeyDownWithWindow(t *testing.T) { func TestOnKeyDownWithWindow(t *testing.T) {
var trigger *actionTrigger var trigger *actionTrigger
v := New() v := New()