1 Commits

Author SHA1 Message Date
Ryan Hamamura
b7acfa6302 feat: add automatic CSRF protection for action calls
Generate a per-context CSRF token (128-bit, crypto/rand) and inject it
as a Datastar signal (via-csrf) alongside via-ctx. Validate with
constant-time comparison on /_action/{id} before executing, returning
403 on mismatch. Transparent to users since Datastar sends all signals
automatically.

Closes #9
2026-02-06 11:17:41 -10:00
2 changed files with 16 additions and 1 deletions

View File

@@ -20,6 +20,7 @@ import (
type Context struct {
id string
route string
csrfToken string
app *V
view func() h.H
routeParams map[string]string
@@ -477,6 +478,7 @@ func newContext(id string, route string, v *V) *Context {
return &Context{
id: id,
route: route,
csrfToken: genCSRFToken(),
routeParams: make(map[string]string),
app: v,
componentRegistry: make(map[string]*Context),

15
via.go
View File

@@ -10,6 +10,7 @@ import (
"context"
"crypto/rand"
_ "embed"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
@@ -203,7 +204,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) {
headElements := []h.H{h.Script(h.Type("module"), h.Src(v.datastarPath))}
headElements = append(headElements, v.documentHeadIncludes...)
headElements = append(headElements,
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s'}", id))),
h.Meta(h.Data("signals", fmt.Sprintf("{'via-ctx':'%s','via-csrf':'%s'}", id, c.csrfToken))),
h.Meta(h.Data("init", "@get('/_sse')")),
h.Meta(h.Data("init", fmt.Sprintf(`window.addEventListener('beforeunload', (evt) => {
navigator.sendBeacon('/_session/close', '%s');});`, c.id))),
@@ -632,6 +633,12 @@ func New() *V {
v.logErr(nil, "action '%s' failed: %v", actionID, err)
return
}
csrfToken, _ := sigs["via-csrf"].(string)
if subtle.ConstantTimeCompare([]byte(csrfToken), []byte(c.csrfToken)) != 1 {
v.logWarn(c, "action '%s' rejected: invalid CSRF token", actionID)
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
c.reqCtx = r.Context()
actionFn, err := c.getActionFn(actionID)
if err != nil {
@@ -675,6 +682,12 @@ func genRandID() string {
return hex.EncodeToString(b)[:8]
}
func genCSRFToken() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
func extractParams(pattern, path string) map[string]string {
p := strings.Split(strings.Trim(pattern, "/"), "/")
u := strings.Split(strings.Trim(path, "/"), "/")