From 9a2318897312e5b2b6344afbb4bf0840a77fca98 Mon Sep 17 00:00:00 2001 From: ryanhamamura <58859899+ryanhamamura@users.noreply.github.com> Date: Fri, 9 Jan 2026 06:59:26 -1000 Subject: [PATCH] feat: add cookie-based session support using alexedwards/scs (#1) - Add Session wrapper with typed getters (GetString, GetInt, GetBool, etc.) - Add flash message support via Pop methods (PopString, PopInt, etc.) - Add session utilities: Exists, Keys, ID, Clear, Destroy, RenewToken - Create default session manager in New() for zero-config usage - Allow custom session manager via Options.SessionManager - Wrap mux with scs LoadAndSave middleware in Start() - Add session example demonstrating login/logout with flash messages --- configuration.go | 7 ++ context.go | 12 ++ go.mod | 1 + go.sum | 2 + internal/examples/session/main.go | 57 +++++++++ session.go | 191 ++++++++++++++++++++++++++++++ via.go | 15 ++- 7 files changed, 284 insertions(+), 1 deletion(-) create mode 100644 internal/examples/session/main.go create mode 100644 session.go diff --git a/configuration.go b/configuration.go index 8316344..eb9e4ac 100644 --- a/configuration.go +++ b/configuration.go @@ -1,5 +1,7 @@ package via +import "github.com/alexedwards/scs/v2" + type LogLevel int const ( @@ -30,4 +32,9 @@ type Options struct { // Plugins to extend the capabilities of the `Via` application. Plugins []Plugin + + // SessionManager enables cookie-based sessions. If set, Via wraps handlers + // with scs LoadAndSave middleware. Configure the session manager before + // passing it (lifetime, cookie settings, store, etc). + SessionManager *scs.SessionManager } diff --git a/context.go b/context.go index 498d9e0..d558f8d 100644 --- a/context.go +++ b/context.go @@ -2,6 +2,7 @@ package via import ( "bytes" + "context" "encoding/json" "fmt" "log" @@ -29,6 +30,7 @@ type Context struct { signals *sync.Map mu sync.RWMutex ctxDisposedChan chan struct{} + reqCtx context.Context } // View defines the UI rendered by this context. @@ -360,6 +362,16 @@ func (c *Context) GetPathParam(param string) string { return "" } +// Session returns the session for this context. +// Session data persists across page views for the same browser. +// Returns a no-op session if no SessionManager is configured. +func (c *Context) Session() *Session { + return &Session{ + ctx: c.reqCtx, + manager: c.app.sessionManager, + } +} + func newContext(id string, route string, v *V) *Context { if v == nil { log.Fatal("create context failed: app pointer is nil") diff --git a/go.mod b/go.mod index 28f2ab6..771232e 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( require ( github.com/CAFxX/httpcompression v0.0.9 // indirect + github.com/alexedwards/scs/v2 v2.9.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/klauspost/compress v1.18.0 // indirect diff --git a/go.sum b/go.sum index fb49a09..5913295 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/CAFxX/httpcompression v0.0.9 h1:0ue2X8dOLEpxTm8tt+OdHcgA+gbDge0OqFQWG github.com/CAFxX/httpcompression v0.0.9/go.mod h1:XX8oPZA+4IDcfZ0A71Hz0mZsv/YJOgYygkFhizVPilM= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/alexedwards/scs/v2 v2.9.0 h1:xa05mVpwTBm1iLeTMNFfAWpKUm4fXAW7CeAViqBVS90= +github.com/alexedwards/scs/v2 v2.9.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= diff --git a/internal/examples/session/main.go b/internal/examples/session/main.go new file mode 100644 index 0000000..82e5c0a --- /dev/null +++ b/internal/examples/session/main.go @@ -0,0 +1,57 @@ +package main + +import ( + "github.com/go-via/via" + "github.com/go-via/via/h" +) + +func main() { + v := via.New() + + v.Page("/", func(c *via.Context) { + username := c.Session().GetString("username") + flash := c.Session().PopString("flash") + + usernameInput := c.Signal("") + + login := c.Action(func() { + name := usernameInput.String() + if name != "" { + c.Session().Set("username", name) + c.Session().Set("flash", "Welcome, "+name+"!") + c.Session().RenewToken() + } + c.Sync() + }) + + logout := c.Action(func() { + c.Session().Set("flash", "Goodbye!") + c.Session().Delete("username") + c.Sync() + }) + + c.View(func() h.H { + var flashMsg h.H + if flash != "" { + flashMsg = h.P(h.Text(flash), h.Style("color: green")) + } + + if username == "" { + return h.Div( + flashMsg, + h.H1(h.Text("Login")), + h.Input(h.Type("text"), h.Placeholder("Username"), usernameInput.Bind()), + h.Button(h.Text("Login"), login.OnClick()), + ) + } + return h.Div( + flashMsg, + h.H1(h.Textf("Hello, %s!", username)), + h.P(h.Text("Your session persists across page refreshes.")), + h.Button(h.Text("Logout"), logout.OnClick()), + ) + }) + }) + + v.Start() +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..2fa1608 --- /dev/null +++ b/session.go @@ -0,0 +1,191 @@ +package via + +import ( + "context" + "time" + + "github.com/alexedwards/scs/v2" +) + +// Session provides access to the user's session data. +// Session data persists across page views for the same browser. +type Session struct { + ctx context.Context + manager *scs.SessionManager +} + +// Get retrieves a value from the session. +func (s *Session) Get(key string) any { + if s.manager == nil || s.ctx == nil { + return nil + } + return s.manager.Get(s.ctx, key) +} + +// GetString retrieves a string value from the session. +func (s *Session) GetString(key string) string { + if s.manager == nil || s.ctx == nil { + return "" + } + return s.manager.GetString(s.ctx, key) +} + +// GetInt retrieves an int value from the session. +func (s *Session) GetInt(key string) int { + if s.manager == nil || s.ctx == nil { + return 0 + } + return s.manager.GetInt(s.ctx, key) +} + +// GetBool retrieves a bool value from the session. +func (s *Session) GetBool(key string) bool { + if s.manager == nil || s.ctx == nil { + return false + } + return s.manager.GetBool(s.ctx, key) +} + +// Set stores a value in the session. +func (s *Session) Set(key string, val any) { + if s.manager == nil || s.ctx == nil { + return + } + s.manager.Put(s.ctx, key, val) +} + +// Delete removes a value from the session. +func (s *Session) Delete(key string) { + if s.manager == nil || s.ctx == nil { + return + } + s.manager.Remove(s.ctx, key) +} + +// Clear removes all data from the session. +func (s *Session) Clear() error { + if s.manager == nil || s.ctx == nil { + return nil + } + return s.manager.Clear(s.ctx) +} + +// Destroy destroys the session entirely (use for logout). +func (s *Session) Destroy() error { + if s.manager == nil || s.ctx == nil { + return nil + } + return s.manager.Destroy(s.ctx) +} + +// RenewToken regenerates the session token (use after login to prevent session fixation). +func (s *Session) RenewToken() error { + if s.manager == nil || s.ctx == nil { + return nil + } + return s.manager.RenewToken(s.ctx) +} + +// Exists returns true if the key exists in the session. +func (s *Session) Exists(key string) bool { + if s.manager == nil || s.ctx == nil { + return false + } + return s.manager.Exists(s.ctx, key) +} + +// Keys returns all keys in the session. +func (s *Session) Keys() []string { + if s.manager == nil || s.ctx == nil { + return nil + } + return s.manager.Keys(s.ctx) +} + +// ID returns the session token (cookie value). +func (s *Session) ID() string { + if s.manager == nil || s.ctx == nil { + return "" + } + return s.manager.Token(s.ctx) +} + +// Pop retrieves a value and deletes it from the session (flash message pattern). +func (s *Session) Pop(key string) any { + if s.manager == nil || s.ctx == nil { + return nil + } + return s.manager.Pop(s.ctx, key) +} + +// PopString retrieves a string value and deletes it from the session. +func (s *Session) PopString(key string) string { + if s.manager == nil || s.ctx == nil { + return "" + } + return s.manager.PopString(s.ctx, key) +} + +// PopInt retrieves an int value and deletes it from the session. +func (s *Session) PopInt(key string) int { + if s.manager == nil || s.ctx == nil { + return 0 + } + return s.manager.PopInt(s.ctx, key) +} + +// PopBool retrieves a bool value and deletes it from the session. +func (s *Session) PopBool(key string) bool { + if s.manager == nil || s.ctx == nil { + return false + } + return s.manager.PopBool(s.ctx, key) +} + +// GetFloat64 retrieves a float64 value from the session. +func (s *Session) GetFloat64(key string) float64 { + if s.manager == nil || s.ctx == nil { + return 0 + } + return s.manager.GetFloat(s.ctx, key) +} + +// PopFloat64 retrieves a float64 value and deletes it from the session. +func (s *Session) PopFloat64(key string) float64 { + if s.manager == nil || s.ctx == nil { + return 0 + } + return s.manager.PopFloat(s.ctx, key) +} + +// GetTime retrieves a time.Time value from the session. +func (s *Session) GetTime(key string) time.Time { + if s.manager == nil || s.ctx == nil { + return time.Time{} + } + return s.manager.GetTime(s.ctx, key) +} + +// PopTime retrieves a time.Time value and deletes it from the session. +func (s *Session) PopTime(key string) time.Time { + if s.manager == nil || s.ctx == nil { + return time.Time{} + } + return s.manager.PopTime(s.ctx, key) +} + +// GetBytes retrieves a []byte value from the session. +func (s *Session) GetBytes(key string) []byte { + if s.manager == nil || s.ctx == nil { + return nil + } + return s.manager.GetBytes(s.ctx, key) +} + +// PopBytes retrieves a []byte value and deletes it from the session. +func (s *Session) PopBytes(key string) []byte { + if s.manager == nil || s.ctx == nil { + return nil + } + return s.manager.PopBytes(s.ctx, key) +} diff --git a/via.go b/via.go index caf46fe..5517a8d 100644 --- a/via.go +++ b/via.go @@ -20,6 +20,7 @@ import ( "strings" "sync" + "github.com/alexedwards/scs/v2" "github.com/go-via/via/h" "github.com/starfederation/datastar-go/datastar" ) @@ -37,6 +38,7 @@ type V struct { documentHeadIncludes []h.H documentFootIncludes []h.H devModePageInitFnMap map[string]func(*Context) + sessionManager *scs.SessionManager } func (v *V) logFatal(format string, a ...any) { @@ -102,6 +104,9 @@ func (v *V) Config(cfg Options) { if cfg.ServerAddress != "" { v.cfg.ServerAddress = cfg.ServerAddress } + if cfg.SessionManager != nil { + v.sessionManager = cfg.SessionManager + } } // AppendToHead appends the given h.H nodes to the head of the base HTML document. @@ -162,6 +167,7 @@ func (v *V) Page(route string, initContextFn func(c *Context)) { } id := fmt.Sprintf("%s_/%s", route, genRandID()) c := newContext(id, route, v) + c.reqCtx = r.Context() routeParams := extractParams(route, r.URL.Path) c.injectRouteParams(routeParams) initContextFn(c) @@ -235,7 +241,11 @@ func (v *V) getCtx(id string) (*Context, error) { // Start starts the Via HTTP server on the given address. func (v *V) Start() { v.logInfo(nil, "via started at [%s]", v.cfg.ServerAddress) - log.Fatalf("[fatal] %v", http.ListenAndServe(v.cfg.ServerAddress, v.mux)) + handler := http.Handler(v.mux) + if v.sessionManager != nil { + handler = v.sessionManager.LoadAndSave(v.mux) + } + log.Fatalf("[fatal] %v", http.ListenAndServe(v.cfg.ServerAddress, handler)) } // HTTPServeMux returns the underlying HTTP request multiplexer to enable user extentions, middleware and @@ -364,6 +374,7 @@ func New() *V { mux: mux, contextRegistry: make(map[string]*Context), devModePageInitFnMap: make(map[string]func(*Context)), + sessionManager: scs.New(), cfg: Options{ DevMode: false, ServerAddress: ":3000", @@ -396,6 +407,7 @@ func New() *V { v.logErr(nil, "sse stream failed to start: %v", err) return } + c.reqCtx = r.Context() sse := datastar.NewSSE(w, r, datastar.WithCompression(datastar.WithBrotli(datastar.WithBrotliLevel(5)))) @@ -456,6 +468,7 @@ func New() *V { v.logErr(nil, "action '%s' failed: %v", actionID, err) return } + c.reqCtx = r.Context() actionFn, err := c.getActionFn(actionID) if err != nil { v.logDebug(c, "action '%s' failed: %v", actionID, err)