diff --git a/context.go b/context.go index 934aca9..102ecd5 100644 --- a/context.go +++ b/context.go @@ -32,6 +32,7 @@ type Context struct { reqCtx context.Context subscriptions []Subscription subsMu sync.Mutex + disposeOnce sync.Once } // View defines the UI rendered by this context. @@ -350,11 +351,23 @@ func (c *Context) ReplaceURLf(format string, a ...any) { c.ReplaceURL(fmt.Sprintf(format, a...)) } -// stopAllRoutines stops all go routines tied to this Context preventing goroutine leaks. +// dispose idempotently tears down this context: unsubscribes all pubsub +// subscriptions and closes ctxDisposedChan to stop routines and exit the SSE loop. +func (c *Context) dispose() { + c.disposeOnce.Do(func() { + c.unsubscribeAll() + c.stopAllRoutines() + }) +} + +// stopAllRoutines closes ctxDisposedChan, broadcasting to all listening +// goroutines (OnIntervalRoutine, SSE loop) that this context is done. func (c *Context) stopAllRoutines() { select { - case c.ctxDisposedChan <- struct{}{}: + case <-c.ctxDisposedChan: + // already closed default: + close(c.ctxDisposedChan) } } diff --git a/via.go b/via.go index ee488e3..b3af7d7 100644 --- a/via.go +++ b/via.go @@ -7,6 +7,7 @@ package via import ( + "context" "crypto/rand" _ "embed" "encoding/hex" @@ -16,9 +17,12 @@ import ( "net/http" "net/url" "os" + ossignal "os/signal" "path/filepath" "strings" "sync" + "syscall" + "time" "github.com/alexedwards/scs/v2" "github.com/rs/zerolog" @@ -34,6 +38,7 @@ var datastarJS []byte type V struct { cfg Options mux *http.ServeMux + server *http.Server logger zerolog.Logger contextRegistry map[string]*Context contextRegistryMutex sync.RWMutex @@ -254,14 +259,82 @@ func (v *V) getCtx(id string) (*Context, error) { return nil, fmt.Errorf("ctx '%s' not found", id) } -// Start starts the Via HTTP server on the given address. +// Start starts the Via HTTP server and blocks until a SIGINT or SIGTERM +// signal is received, then performs a graceful shutdown. func (v *V) Start() { - v.logInfo(nil, "via started at [%s]", v.cfg.ServerAddress) handler := http.Handler(v.mux) if v.sessionManager != nil { handler = v.sessionManager.LoadAndSave(v.mux) } - v.logger.Fatal().Err(http.ListenAndServe(v.cfg.ServerAddress, handler)).Msg("http server failed") + v.server = &http.Server{ + Addr: v.cfg.ServerAddress, + Handler: handler, + } + + errCh := make(chan error, 1) + go func() { + errCh <- v.server.ListenAndServe() + }() + + v.logInfo(nil, "via started at [%s]", v.cfg.ServerAddress) + + sigCh := make(chan os.Signal, 1) + ossignal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + select { + case sig := <-sigCh: + v.logInfo(nil, "received signal %v, shutting down", sig) + case err := <-errCh: + if err != nil && err != http.ErrServerClosed { + v.logger.Fatal().Err(err).Msg("http server failed") + } + return + } + + v.shutdown() +} + +// Shutdown gracefully shuts down the server and all contexts. +// Safe for programmatic or test use. +func (v *V) Shutdown() { + v.shutdown() +} + +func (v *V) shutdown() { + v.logInfo(nil, "draining all contexts") + v.drainAllContexts() + + if v.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := v.server.Shutdown(ctx); err != nil { + v.logErr(nil, "http server shutdown error: %v", err) + } + } + + if v.pubsub != nil { + if err := v.pubsub.Close(); err != nil { + v.logErr(nil, "pubsub close error: %v", err) + } + } + + v.logInfo(nil, "shutdown complete") +} + +func (v *V) drainAllContexts() { + v.contextRegistryMutex.Lock() + contexts := make([]*Context, 0, len(v.contextRegistry)) + for _, c := range v.contextRegistry { + contexts = append(contexts, c) + } + v.contextRegistry = make(map[string]*Context) + v.contextRegistryMutex.Unlock() + + for _, c := range contexts { + v.logDebug(c, "disposing context") + c.dispose() + } + v.logInfo(nil, "drained %d context(s)", len(contexts)) } // HTTPServeMux returns the underlying HTTP request multiplexer to enable user extentions, middleware and @@ -445,10 +518,10 @@ func New() *V { case <-sse.Context().Done(): v.logDebug(c, "SSE connection ended") return - case patch, ok := <-c.patchChan: - if !ok { - continue - } + case <-c.ctxDisposedChan: + v.logDebug(c, "context disposed, closing SSE") + return + case patch := <-c.patchChan: switch patch.typ { case patchTypeElements: if err := sse.PatchElements(patch.content); err != nil { @@ -530,8 +603,7 @@ func New() *V { v.logErr(c, "failed to handle session close: %v", err) return } - c.unsubscribeAll() - c.stopAllRoutines() + c.dispose() v.logDebug(c, "session close event triggered") if v.cfg.DevMode { v.devModeRemovePersisted(c)