diff --git a/go.mod b/go.mod index fe3c946..8c6de22 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,5 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/sys v0.38.0 // indirect - golang.org/x/time v0.14.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/internal/examples/nats-chatroom/main.go b/internal/examples/nats-chatroom/main.go index e632b83..6abf378 100644 --- a/internal/examples/nats-chatroom/main.go +++ b/internal/examples/nats-chatroom/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "log" "math/rand" "sync" @@ -9,7 +8,6 @@ import ( "github.com/ryanhamamura/via" "github.com/ryanhamamura/via/h" - "github.com/ryanhamamura/via/vianats" ) var ( @@ -36,15 +34,15 @@ func (u *UserInfo) Avatar() h.H { var roomNames = []string{"Go", "Rust", "Python", "JavaScript", "Clojure"} func main() { - ctx := context.Background() + v := via.New() + v.Config(via.Options{ + DevMode: true, + DocumentTitle: "NATS Chat", + LogLevel: via.LogLevelInfo, + ServerAddress: ":7331", + }) - ps, err := vianats.New(ctx, "./data/nats") - if err != nil { - log.Fatalf("Failed to start embedded NATS: %v", err) - } - defer ps.Close() - - err = vianats.EnsureStream(ps, vianats.StreamConfig{ + err := via.EnsureStream(v, via.StreamConfig{ Name: "CHAT", Subjects: []string{"chat.>"}, MaxMsgs: 1000, @@ -54,15 +52,6 @@ func main() { log.Fatalf("Failed to ensure stream: %v", err) } - v := via.New() - v.Config(via.Options{ - DevMode: true, - DocumentTitle: "NATS Chat", - LogLevel: via.LogLevelInfo, - ServerAddress: ":7331", - PubSub: ps, - }) - v.AppendToHead( h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.min.css")), h.StyleEl(h.Raw(` @@ -148,7 +137,7 @@ func main() { subject := "chat.room." + room // Replay history from JetStream - if hist, err := vianats.ReplayHistory[ChatMessage](ps, subject, 50); err == nil { + if hist, err := via.ReplayHistory[ChatMessage](v, subject, 50); err == nil { messages = hist } diff --git a/internal/examples/pubsub-crud/main.go b/internal/examples/pubsub-crud/main.go index df5b1c5..3c6fa57 100644 --- a/internal/examples/pubsub-crud/main.go +++ b/internal/examples/pubsub-crud/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "crypto/rand" "fmt" "html" @@ -11,7 +10,6 @@ import ( "github.com/ryanhamamura/via" "github.com/ryanhamamura/via/h" - "github.com/ryanhamamura/via/vianats" ) var WithSignal = via.WithSignal @@ -49,15 +47,15 @@ func findBookmark(id string) (Bookmark, int) { } func main() { - ctx := context.Background() + v := via.New() + v.Config(via.Options{ + DevMode: true, + DocumentTitle: "Bookmarks", + LogLevel: via.LogLevelInfo, + ServerAddress: ":7331", + }) - ps, err := vianats.New(ctx, "./data/nats") - if err != nil { - log.Fatalf("Failed to start embedded NATS: %v", err) - } - defer ps.Close() - - err = vianats.EnsureStream(ps, vianats.StreamConfig{ + err := via.EnsureStream(v, via.StreamConfig{ Name: "BOOKMARKS", Subjects: []string{"bookmarks.>"}, MaxMsgs: 1000, @@ -67,15 +65,6 @@ func main() { log.Fatalf("Failed to ensure stream: %v", err) } - v := via.New() - v.Config(via.Options{ - DevMode: true, - DocumentTitle: "Bookmarks", - LogLevel: via.LogLevelInfo, - ServerAddress: ":7331", - PubSub: ps, - }) - v.AppendToHead( h.Link(h.Rel("stylesheet"), h.Href("https://cdn.jsdelivr.net/npm/daisyui@4/dist/full.min.css")), h.Script(h.Src("https://cdn.tailwindcss.com")), diff --git a/nats.go b/nats.go new file mode 100644 index 0000000..e22b850 --- /dev/null +++ b/nats.go @@ -0,0 +1,190 @@ +package via + +import ( + "context" + "encoding/json" + "fmt" + "os" + "sync" + "time" + + "github.com/delaneyj/toolbelt/embeddednats" + "github.com/nats-io/nats.go" +) + +// defaultNATS is the process-scoped embedded NATS server. +type defaultNATS struct { + server *embeddednats.Server + nc *nats.Conn + js nats.JetStreamContext + cancel context.CancelFunc + dataDir string +} + +var ( + sharedNATS *defaultNATS + sharedNATSOnce sync.Once + sharedNATSErr error +) + +// getSharedNATS returns a process-level singleton embedded NATS server. +// The server starts once and is reused across all V instances. +func getSharedNATS() (*defaultNATS, error) { + sharedNATSOnce.Do(func() { + sharedNATS, sharedNATSErr = startDefaultNATS() + }) + return sharedNATS, sharedNATSErr +} + +func startDefaultNATS() (dn *defaultNATS, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("nats server panic: %v", r) + } + }() + + dataDir, err := os.MkdirTemp("", "via-nats-*") + if err != nil { + return nil, fmt.Errorf("create temp dir: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + + ns, err := embeddednats.New(ctx, embeddednats.WithDirectory(dataDir)) + if err != nil { + cancel() + os.RemoveAll(dataDir) + return nil, fmt.Errorf("start embedded nats: %w", err) + } + ns.WaitForServer() + + nc, err := ns.Client() + if err != nil { + ns.Close() + cancel() + os.RemoveAll(dataDir) + return nil, fmt.Errorf("connect nats client: %w", err) + } + + js, err := nc.JetStream() + if err != nil { + nc.Close() + ns.Close() + cancel() + os.RemoveAll(dataDir) + return nil, fmt.Errorf("init jetstream: %w", err) + } + + return &defaultNATS{ + server: ns, + nc: nc, + js: js, + cancel: cancel, + dataDir: dataDir, + }, nil +} + +func (n *defaultNATS) Publish(subject string, data []byte) error { + return n.nc.Publish(subject, data) +} + +func (n *defaultNATS) Subscribe(subject string, handler func(data []byte)) (Subscription, error) { + sub, err := n.nc.Subscribe(subject, func(msg *nats.Msg) { + handler(msg.Data) + }) + if err != nil { + return nil, err + } + return sub, nil +} + +// natsRef wraps a shared defaultNATS as a PubSub. Close is a no-op because +// the underlying server is process-scoped and outlives individual V instances. +type natsRef struct { + dn *defaultNATS +} + +func (r *natsRef) Publish(subject string, data []byte) error { + return r.dn.Publish(subject, data) +} + +func (r *natsRef) Subscribe(subject string, handler func(data []byte)) (Subscription, error) { + return r.dn.Subscribe(subject, handler) +} + +func (r *natsRef) Close() error { + return nil +} + +// NATSConn returns the underlying NATS connection from the built-in embedded +// server, or nil if a custom PubSub backend is in use. +func (v *V) NATSConn() *nats.Conn { + if v.defaultNATS != nil { + return v.defaultNATS.nc + } + return nil +} + +// JetStream returns the JetStream context from the built-in embedded server, +// or nil if a custom PubSub backend is in use. +func (v *V) JetStream() nats.JetStreamContext { + if v.defaultNATS != nil { + return v.defaultNATS.js + } + return nil +} + +// StreamConfig holds the parameters for creating or updating a JetStream stream. +type StreamConfig struct { + Name string + Subjects []string + MaxMsgs int64 + MaxAge time.Duration +} + +// EnsureStream creates or updates a JetStream stream matching cfg. +func EnsureStream(v *V, cfg StreamConfig) error { + js := v.JetStream() + if js == nil { + return fmt.Errorf("jetstream not available") + } + _, err := js.AddStream(&nats.StreamConfig{ + Name: cfg.Name, + Subjects: cfg.Subjects, + Retention: nats.LimitsPolicy, + MaxMsgs: cfg.MaxMsgs, + MaxAge: cfg.MaxAge, + }) + return err +} + +// ReplayHistory fetches the last limit messages from subject, +// deserializing each as T. Returns an empty slice if nothing is available. +func ReplayHistory[T any](v *V, subject string, limit int) ([]T, error) { + js := v.JetStream() + if js == nil { + return nil, fmt.Errorf("jetstream not available") + } + sub, err := js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer()) + if err != nil { + return nil, err + } + defer sub.Unsubscribe() + + var msgs []T + for { + raw, err := sub.NextMsg(200 * time.Millisecond) + if err != nil { + break + } + var msg T + if json.Unmarshal(raw.Data, &msg) == nil { + msgs = append(msgs, msg) + } + } + + if limit > 0 && len(msgs) > limit { + msgs = msgs[len(msgs)-limit:] + } + return msgs, nil +} diff --git a/nats_test.go b/nats_test.go index 87f486d..bb44a2f 100644 --- a/nats_test.go +++ b/nats_test.go @@ -2,7 +2,6 @@ package via import ( "sync" - "sync/atomic" "testing" "time" @@ -11,88 +10,36 @@ import ( "github.com/stretchr/testify/require" ) -type mockHandler struct { - id int64 - fn func([]byte) - active atomic.Bool -} - -// mockPubSub implements PubSub for testing without NATS. -type mockPubSub struct { - mu sync.Mutex - subs map[string][]*mockHandler - nextID atomic.Int64 -} - -func newMockPubSub() *mockPubSub { - return &mockPubSub{subs: make(map[string][]*mockHandler)} -} - -func (m *mockPubSub) Publish(subject string, data []byte) error { - m.mu.Lock() - handlers := make([]*mockHandler, len(m.subs[subject])) - copy(handlers, m.subs[subject]) - m.mu.Unlock() - for _, h := range handlers { - if h.active.Load() { - h.fn(data) - } - } - return nil -} - -func (m *mockPubSub) Subscribe(subject string, handler func(data []byte)) (Subscription, error) { - m.mu.Lock() - defer m.mu.Unlock() - mh := &mockHandler{ - id: m.nextID.Add(1), - fn: handler, - } - mh.active.Store(true) - m.subs[subject] = append(m.subs[subject], mh) - return &mockSub{handler: mh}, nil -} - -func (m *mockPubSub) Close() error { return nil } - -type mockSub struct { - handler *mockHandler -} - -func (s *mockSub) Unsubscribe() error { - s.handler.active.Store(false) - return nil -} - func TestPubSub_RoundTrip(t *testing.T) { - ps := newMockPubSub() v := New() - v.Config(Options{PubSub: ps}) + defer v.Shutdown() var received []byte - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) c := newContext("test-ctx", "/", v) c.View(func() h.H { return h.Div() }) _, err := c.Subscribe("test.topic", func(data []byte) { received = data - wg.Done() + close(done) }) require.NoError(t, err) err = c.Publish("test.topic", []byte("hello")) require.NoError(t, err) - wg.Wait() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for message") + } assert.Equal(t, []byte("hello"), received) } func TestPubSub_MultipleSubscribers(t *testing.T) { - ps := newMockPubSub() v := New() - v.Config(Options{PubSub: ps}) + defer v.Shutdown() var mu sync.Mutex var results []string @@ -119,7 +66,17 @@ func TestPubSub_MultipleSubscribers(t *testing.T) { }) c1.Publish("broadcast", []byte("msg")) - wg.Wait() + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for messages") + } assert.Len(t, results, 2) assert.Contains(t, results, "c1:msg") @@ -127,9 +84,8 @@ func TestPubSub_MultipleSubscribers(t *testing.T) { } func TestPubSub_SubscriptionCleanupOnDispose(t *testing.T) { - ps := newMockPubSub() v := New() - v.Config(Options{PubSub: ps}) + defer v.Shutdown() c := newContext("cleanup-ctx", "/", v) c.View(func() h.H { return h.Div() }) @@ -144,9 +100,8 @@ func TestPubSub_SubscriptionCleanupOnDispose(t *testing.T) { } func TestPubSub_ManualUnsubscribe(t *testing.T) { - ps := newMockPubSub() v := New() - v.Config(Options{PubSub: ps}) + defer v.Shutdown() c := newContext("unsub-ctx", "/", v) c.View(func() h.H { return h.Div() }) @@ -160,28 +115,13 @@ func TestPubSub_ManualUnsubscribe(t *testing.T) { sub.Unsubscribe() c.Publish("topic", []byte("ignored")) - time.Sleep(10 * time.Millisecond) + time.Sleep(50 * time.Millisecond) assert.False(t, called) } -func TestPubSub_NoOpWhenNotConfigured(t *testing.T) { - v := New() - - c := newContext("noop-ctx", "/", v) - c.View(func() h.H { return h.Div() }) - - err := c.Publish("topic", []byte("data")) - assert.Error(t, err) - - sub, err := c.Subscribe("topic", func(data []byte) {}) - assert.Error(t, err) - assert.Nil(t, sub) -} - func TestPubSub_NoOpDuringPanicCheck(t *testing.T) { - ps := newMockPubSub() v := New() - v.Config(Options{PubSub: ps}) + defer v.Shutdown() // Panic-check context has id="" c := newContext("", "/", v) diff --git a/pubsub.go b/pubsub.go index c1a3c35..594a3dc 100644 --- a/pubsub.go +++ b/pubsub.go @@ -1,7 +1,8 @@ package via // PubSub is an interface for publish/subscribe messaging backends. -// The vianats sub-package provides an embedded NATS implementation. +// By default, New() starts an embedded NATS server. Supply a custom +// implementation via Config(Options{PubSub: yourBackend}) to override. type PubSub interface { Publish(subject string, data []byte) error Subscribe(subject string, handler func(data []byte)) (Subscription, error) diff --git a/pubsub_helpers_test.go b/pubsub_helpers_test.go index 9a18687..497273d 100644 --- a/pubsub_helpers_test.go +++ b/pubsub_helpers_test.go @@ -1,8 +1,8 @@ package via import ( - "sync" "testing" + "time" "github.com/ryanhamamura/via/h" "github.com/stretchr/testify/assert" @@ -10,9 +10,8 @@ import ( ) func TestPublishSubscribe_RoundTrip(t *testing.T) { - ps := newMockPubSub() v := New() - v.Config(Options{PubSub: ps}) + defer v.Shutdown() type event struct { Name string `json:"name"` @@ -20,30 +19,32 @@ func TestPublishSubscribe_RoundTrip(t *testing.T) { } var got event - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) c := newContext("typed-ctx", "/", v) c.View(func() h.H { return h.Div() }) _, err := Subscribe(c, "events", func(e event) { got = e - wg.Done() + close(done) }) require.NoError(t, err) err = Publish(c, "events", event{Name: "click", Count: 42}) require.NoError(t, err) - wg.Wait() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for message") + } assert.Equal(t, "click", got.Name) assert.Equal(t, 42, got.Count) } func TestSubscribe_SkipsBadJSON(t *testing.T) { - ps := newMockPubSub() v := New() - v.Config(Options{PubSub: ps}) + defer v.Shutdown() type msg struct { Text string `json:"text"` @@ -62,5 +63,6 @@ func TestSubscribe_SkipsBadJSON(t *testing.T) { err = c.Publish("topic", []byte("not json")) require.NoError(t, err) + time.Sleep(50 * time.Millisecond) assert.False(t, called) } diff --git a/via.go b/via.go index 4d850b0..c2c37ed 100644 --- a/via.go +++ b/via.go @@ -49,6 +49,7 @@ type V struct { devModePageInitFnMap map[string]func(*Context) sessionManager *scs.SessionManager pubsub PubSub + defaultNATS *defaultNATS actionRateLimit RateLimitConfig datastarPath string datastarContent []byte @@ -130,6 +131,7 @@ func (v *V) Config(cfg Options) { v.datastarPath = cfg.DatastarPath } if cfg.PubSub != nil { + v.defaultNATS = nil v.pubsub = cfg.PubSub } if cfg.ContextTTL != 0 { @@ -379,6 +381,7 @@ func (v *V) Shutdown() { v.logErr(nil, "pubsub close error: %v", err) } } + v.defaultNATS = nil v.logInfo(nil, "shutdown complete") } @@ -725,6 +728,15 @@ func New() *V { v.logDebug(c, "session close event triggered") v.cleanupCtx(c) }) + + dn, err := getSharedNATS() + if err != nil { + v.logWarn(nil, "embedded NATS unavailable: %v", err) + } else { + v.defaultNATS = dn + v.pubsub = &natsRef{dn: dn} + } + return v } diff --git a/vianats/vianats.go b/vianats/vianats.go deleted file mode 100644 index fde133a..0000000 --- a/vianats/vianats.go +++ /dev/null @@ -1,127 +0,0 @@ -// Package vianats provides an embedded NATS server with JetStream as a -// pub/sub backend for Via applications. -package vianats - -import ( - "context" - "encoding/json" - "fmt" - "time" - - "github.com/delaneyj/toolbelt/embeddednats" - "github.com/nats-io/nats.go" - "github.com/ryanhamamura/via" -) - -// NATS implements via.PubSub using an embedded NATS server with JetStream. -type NATS struct { - server *embeddednats.Server - nc *nats.Conn - js nats.JetStreamContext -} - -// New starts an embedded NATS server with JetStream enabled and returns a -// ready-to-use NATS instance. The server stores data in dataDir and shuts -// down when ctx is cancelled. -func New(ctx context.Context, dataDir string) (*NATS, error) { - ns, err := embeddednats.New(ctx, embeddednats.WithDirectory(dataDir)) - if err != nil { - return nil, fmt.Errorf("vianats: start server: %w", err) - } - ns.WaitForServer() - - nc, err := ns.Client() - if err != nil { - ns.Close() - return nil, fmt.Errorf("vianats: connect client: %w", err) - } - - js, err := nc.JetStream() - if err != nil { - nc.Close() - ns.Close() - return nil, fmt.Errorf("vianats: init jetstream: %w", err) - } - - return &NATS{server: ns, nc: nc, js: js}, nil -} - -// Publish sends data to the given subject using core NATS publish. -// JetStream captures messages automatically if a matching stream exists. -func (n *NATS) Publish(subject string, data []byte) error { - return n.nc.Publish(subject, data) -} - -// Subscribe creates a core NATS subscription for real-time fan-out delivery. -func (n *NATS) Subscribe(subject string, handler func(data []byte)) (via.Subscription, error) { - sub, err := n.nc.Subscribe(subject, func(msg *nats.Msg) { - handler(msg.Data) - }) - if err != nil { - return nil, err - } - return sub, nil -} - -// Close shuts down the client connection and embedded server. -func (n *NATS) Close() error { - n.nc.Close() - return n.server.Close() -} - -// Conn returns the underlying NATS connection for advanced usage. -func (n *NATS) Conn() *nats.Conn { - return n.nc -} - -// JetStream returns the JetStream context for stream configuration and replay. -func (n *NATS) JetStream() nats.JetStreamContext { - return n.js -} - -// StreamConfig holds the parameters for creating or updating a JetStream stream. -type StreamConfig struct { - Name string - Subjects []string - MaxMsgs int64 - MaxAge time.Duration -} - -// EnsureStream creates or updates a JetStream stream matching cfg. -func EnsureStream(n *NATS, cfg StreamConfig) error { - _, err := n.js.AddStream(&nats.StreamConfig{ - Name: cfg.Name, - Subjects: cfg.Subjects, - Retention: nats.LimitsPolicy, - MaxMsgs: cfg.MaxMsgs, - MaxAge: cfg.MaxAge, - }) - return err -} - -// ReplayHistory fetches the last limit messages from subject, -// deserializing each as T. Returns an empty slice if nothing is available. -func ReplayHistory[T any](n *NATS, subject string, limit int) ([]T, error) { - sub, err := n.js.SubscribeSync(subject, nats.DeliverAll(), nats.OrderedConsumer()) - if err != nil { - return nil, err - } - defer sub.Unsubscribe() - - var msgs []T - for { - raw, err := sub.NextMsg(200 * time.Millisecond) - if err != nil { - break - } - var msg T - if json.Unmarshal(raw.Data, &msg) == nil { - msgs = append(msgs, msg) - } - } - - if limit > 0 && len(msgs) > limit { - msgs = msgs[len(msgs)-limit:] - } - return msgs, nil -}