diff --git a/configuration.go b/configuration.go index 2358c7a..eac0bcc 100644 --- a/configuration.go +++ b/configuration.go @@ -1,6 +1,8 @@ package via import ( + "time" + "github.com/alexedwards/scs/v2" "github.com/rs/zerolog" ) @@ -54,4 +56,9 @@ type Options struct { // PubSub enables publish/subscribe messaging. Use vianats.New() for an // embedded NATS backend, or supply any PubSub implementation. PubSub PubSub + + // ContextTTL is the maximum time a context may exist without an SSE + // connection before the background reaper disposes it. + // Default: 30s. Negative value disables the reaper. + ContextTTL time.Duration } diff --git a/context.go b/context.go index 102ecd5..a1a0acc 100644 --- a/context.go +++ b/context.go @@ -8,6 +8,7 @@ import ( "maps" "reflect" "sync" + "sync/atomic" "time" "github.com/ryanhamamura/via/h" @@ -33,6 +34,8 @@ type Context struct { subscriptions []Subscription subsMu sync.Mutex disposeOnce sync.Once + createdAt time.Time + sseConnected atomic.Bool } // View defines the UI rendered by this context. @@ -481,5 +484,6 @@ func newContext(id string, route string, v *V) *Context { signals: new(sync.Map), patchChan: make(chan patch, 1), ctxDisposedChan: make(chan struct{}, 1), + createdAt: time.Now(), } } diff --git a/via.go b/via.go index b3af7d7..f960249 100644 --- a/via.go +++ b/via.go @@ -36,20 +36,21 @@ var datastarJS []byte // V is the root application. // It manages page routing, user sessions, and SSE connections for live updates. type V struct { - cfg Options - mux *http.ServeMux - server *http.Server - logger zerolog.Logger - contextRegistry map[string]*Context - contextRegistryMutex sync.RWMutex - documentHeadIncludes []h.H - documentFootIncludes []h.H - devModePageInitFnMap map[string]func(*Context) - sessionManager *scs.SessionManager - pubsub PubSub - datastarPath string - datastarContent []byte - datastarOnce sync.Once + cfg Options + mux *http.ServeMux + server *http.Server + logger zerolog.Logger + contextRegistry map[string]*Context + contextRegistryMutex sync.RWMutex + documentHeadIncludes []h.H + documentFootIncludes []h.H + devModePageInitFnMap map[string]func(*Context) + sessionManager *scs.SessionManager + pubsub PubSub + datastarPath string + datastarContent []byte + datastarOnce sync.Once + reaperStop chan struct{} } func (v *V) logEvent(evt *zerolog.Event, c *Context) *zerolog.Event { @@ -127,6 +128,9 @@ func (v *V) Config(cfg Options) { if cfg.PubSub != nil { v.pubsub = cfg.PubSub } + if cfg.ContextTTL != 0 { + v.cfg.ContextTTL = cfg.ContextTTL + } } // AppendToHead appends the given h.H nodes to the head of the base HTML document. @@ -238,6 +242,14 @@ func (v *V) currSessionNum() int { return len(v.contextRegistry) } +func (v *V) cleanupCtx(c *Context) { + c.dispose() + if v.cfg.DevMode { + v.devModeRemovePersisted(c) + } + v.unregisterCtx(c) +} + func (v *V) unregisterCtx(c *Context) { if c.id == "" { v.logErr(c, "unregister ctx failed: ctx contains empty id") @@ -259,6 +271,50 @@ func (v *V) getCtx(id string) (*Context, error) { return nil, fmt.Errorf("ctx '%s' not found", id) } +func (v *V) startReaper() { + ttl := v.cfg.ContextTTL + if ttl < 0 { + return + } + if ttl == 0 { + ttl = 30 * time.Second + } + interval := ttl / 3 + if interval < 5*time.Second { + interval = 5 * time.Second + } + v.reaperStop = make(chan struct{}) + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-v.reaperStop: + return + case <-ticker.C: + v.reapOrphanedContexts(ttl) + } + } + }() +} + +func (v *V) reapOrphanedContexts(ttl time.Duration) { + now := time.Now() + v.contextRegistryMutex.RLock() + var orphans []*Context + for _, c := range v.contextRegistry { + if !c.sseConnected.Load() && now.Sub(c.createdAt) > ttl { + orphans = append(orphans, c) + } + } + v.contextRegistryMutex.RUnlock() + + for _, c := range orphans { + v.logInfo(c, "reaping orphaned context (no SSE connection after %s)", ttl) + v.cleanupCtx(c) + } +} + // Start starts the Via HTTP server and blocks until a SIGINT or SIGTERM // signal is received, then performs a graceful shutdown. func (v *V) Start() { @@ -271,6 +327,8 @@ func (v *V) Start() { Handler: handler, } + v.startReaper() + errCh := make(chan error, 1) go func() { errCh <- v.server.ListenAndServe() @@ -301,6 +359,9 @@ func (v *V) Shutdown() { } func (v *V) shutdown() { + if v.reaperStop != nil { + close(v.reaperStop) + } v.logInfo(nil, "draining all contexts") v.drainAllContexts() @@ -400,10 +461,7 @@ func (v *V) devModeRemovePersisted(c *Context) { } file.Close() - // remove ctx to persisted list - if _, ok := ctxRegMap[c.id]; !ok { - delete(ctxRegMap, c.id) - } + delete(ctxRegMap, c.id) // write persisted list to file file, err = os.Create(p) @@ -507,6 +565,7 @@ func New() *V { // use last-event-id to tell if request is a sse reconnect sse.Send(datastar.EventTypePatchElements, []string{}, datastar.WithSSEEventId("via")) + c.sseConnected.Store(true) v.logDebug(c, "SSE connection established") go func() { @@ -517,6 +576,7 @@ func New() *V { select { case <-sse.Context().Done(): v.logDebug(c, "SSE connection ended") + v.cleanupCtx(c) return case <-c.ctxDisposedChan: v.logDebug(c, "context disposed, closing SSE") @@ -603,12 +663,8 @@ func New() *V { v.logErr(c, "failed to handle session close: %v", err) return } - c.dispose() v.logDebug(c, "session close event triggered") - if v.cfg.DevMode { - v.devModeRemovePersisted(c) - } - v.unregisterCtx(c) + v.cleanupCtx(c) }) return v } diff --git a/via_test.go b/via_test.go index 4566b1e..4891295 100644 --- a/via_test.go +++ b/via_test.go @@ -1,9 +1,13 @@ package via import ( + "encoding/json" "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" + "time" "github.com/ryanhamamura/via/h" "github.com/stretchr/testify/assert" @@ -235,3 +239,93 @@ func TestPage_PanicsOnNoView(t *testing.T) { v.Page("/", func(c *Context) {}) }) } + +func TestReaperCleansOrphanedContexts(t *testing.T) { + v := New() + c := newContext("orphan-1", "/", v) + c.createdAt = time.Now().Add(-time.Minute) // created 1 min ago + v.registerCtx(c) + + _, err := v.getCtx("orphan-1") + assert.NoError(t, err) + + v.reapOrphanedContexts(10 * time.Second) + + _, err = v.getCtx("orphan-1") + assert.Error(t, err, "orphaned context should have been reaped") +} + +func TestReaperIgnoresConnectedContexts(t *testing.T) { + v := New() + c := newContext("connected-1", "/", v) + c.createdAt = time.Now().Add(-time.Minute) + c.sseConnected.Store(true) + v.registerCtx(c) + + v.reapOrphanedContexts(10 * time.Second) + + _, err := v.getCtx("connected-1") + assert.NoError(t, err, "connected context should survive reaping") +} + +func TestReaperDisabledWithNegativeTTL(t *testing.T) { + v := New() + v.cfg.ContextTTL = -1 + v.startReaper() + assert.Nil(t, v.reaperStop, "reaper should not start with negative TTL") +} + +func TestCleanupCtxIdempotent(t *testing.T) { + v := New() + c := newContext("idempotent-1", "/", v) + v.registerCtx(c) + + assert.NotPanics(t, func() { + v.cleanupCtx(c) + v.cleanupCtx(c) + }) + + _, err := v.getCtx("idempotent-1") + assert.Error(t, err, "context should be removed after cleanup") +} + +func TestDevModeRemovePersistedFix(t *testing.T) { + v := New() + v.cfg.DevMode = true + + dir := filepath.Join(t.TempDir(), ".via", "devmode") + p := filepath.Join(dir, "ctx.json") + assert.NoError(t, os.MkdirAll(dir, 0755)) + + // Write a persisted context + ctxRegMap := map[string]string{"test-ctx-1": "/"} + f, err := os.Create(p) + assert.NoError(t, err) + assert.NoError(t, json.NewEncoder(f).Encode(ctxRegMap)) + f.Close() + + // Patch devModeRemovePersisted to use our temp path by calling it + // directly — we need to override the path. Instead, test via the + // actual function by temporarily changing the working dir. + origDir, _ := os.Getwd() + assert.NoError(t, os.Chdir(t.TempDir())) + defer os.Chdir(origDir) + + // Re-create the structure in the temp dir + assert.NoError(t, os.MkdirAll(filepath.Join(".via", "devmode"), 0755)) + p2 := filepath.Join(".via", "devmode", "ctx.json") + f2, _ := os.Create(p2) + json.NewEncoder(f2).Encode(map[string]string{"test-ctx-1": "/"}) + f2.Close() + + c := newContext("test-ctx-1", "/", v) + v.devModeRemovePersisted(c) + + // Read back and verify + f3, err := os.Open(p2) + assert.NoError(t, err) + defer f3.Close() + var result map[string]string + assert.NoError(t, json.NewDecoder(f3).Decode(&result)) + assert.Empty(t, result, "persisted context should be removed") +}