diff --git a/features/auth/handlers.go b/features/auth/handlers.go index 3ca7f6b..8a581a3 100644 --- a/features/auth/handlers.go +++ b/features/auth/handlers.go @@ -3,10 +3,10 @@ package auth import ( "database/sql" "net/http" + "net/url" "github.com/alexedwards/scs/v2" "github.com/google/uuid" - "github.com/starfederation/datastar-go/datastar" "github.com/ryanhamamura/games/auth" "github.com/ryanhamamura/games/db/repository" @@ -14,20 +14,15 @@ import ( appsessions "github.com/ryanhamamura/games/sessions" ) -type LoginSignals struct { - Username string `json:"username"` - Password string `json:"password"` //nolint:gosec // form input, not stored -} - -type RegisterSignals struct { - Username string `json:"username"` - Password string `json:"password"` //nolint:gosec // form input, not stored - Confirm string `json:"confirm"` -} - -func HandleLoginPage() http.HandlerFunc { +func HandleLoginPage(sessions *scs.SessionManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if err := pages.LoginPage().Render(r.Context(), w); err != nil { + // Capture return_url so we can redirect back after login + if returnURL := r.URL.Query().Get("return_url"); returnURL != "" { + sessions.Put(r.Context(), "return_url", returnURL) + } + + errorMsg := r.URL.Query().Get("error") + if err := pages.LoginPage(errorMsg).Render(r.Context(), w); err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } } @@ -35,7 +30,8 @@ func HandleLoginPage() http.HandlerFunc { func HandleRegisterPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if err := pages.RegisterPage().Render(r.Context(), w); err != nil { + errorMsg := r.URL.Query().Get("error") + if err := pages.RegisterPage(errorMsg).Render(r.Context(), w); err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } } @@ -43,25 +39,20 @@ func HandleRegisterPage() http.HandlerFunc { func HandleLogin(queries *repository.Queries, sessions *scs.SessionManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - var signals LoginSignals - if err := datastar.ReadSignals(r, &signals); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } + username := r.FormValue("username") + password := r.FormValue("password") - sse := datastar.NewSSE(w, r) - - user, err := queries.GetUserByUsername(r.Context(), signals.Username) + user, err := queries.GetUserByUsername(r.Context(), username) if err == sql.ErrNoRows { - sse.MarshalAndPatchSignals(map[string]any{"error": "Invalid username or password"}) //nolint:errcheck + http.Redirect(w, r, "/login?error="+url.QueryEscape("Invalid username or password"), http.StatusSeeOther) return } if err != nil { - sse.MarshalAndPatchSignals(map[string]any{"error": "An error occurred"}) //nolint:errcheck + http.Redirect(w, r, "/login?error="+url.QueryEscape("An error occurred"), http.StatusSeeOther) return } - if !auth.CheckPassword(signals.Password, user.PasswordHash) { - sse.MarshalAndPatchSignals(map[string]any{"error": "Invalid username or password"}) //nolint:errcheck + if !auth.CheckPassword(password, user.PasswordHash) { + http.Redirect(w, r, "/login?error="+url.QueryEscape("Invalid username or password"), http.StatusSeeOther) return } @@ -76,46 +67,42 @@ func HandleLogin(queries *repository.Queries, sessions *scs.SessionManager) http redirectURL = returnURL } - sse.Redirect(redirectURL) //nolint:errcheck + http.Redirect(w, r, redirectURL, http.StatusSeeOther) } } func HandleRegister(queries *repository.Queries, sessions *scs.SessionManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - var signals RegisterSignals - if err := datastar.ReadSignals(r, &signals); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + username := r.FormValue("username") + password := r.FormValue("password") + confirm := r.FormValue("confirm") + + if err := auth.ValidateUsername(username); err != nil { + http.Redirect(w, r, "/register?error="+url.QueryEscape(err.Error()), http.StatusSeeOther) + return + } + if err := auth.ValidatePassword(password); err != nil { + http.Redirect(w, r, "/register?error="+url.QueryEscape(err.Error()), http.StatusSeeOther) + return + } + if password != confirm { + http.Redirect(w, r, "/register?error="+url.QueryEscape("Passwords do not match"), http.StatusSeeOther) return } - sse := datastar.NewSSE(w, r) - - if err := auth.ValidateUsername(signals.Username); err != nil { - sse.MarshalAndPatchSignals(map[string]any{"error": err.Error()}) //nolint:errcheck - return - } - if err := auth.ValidatePassword(signals.Password); err != nil { - sse.MarshalAndPatchSignals(map[string]any{"error": err.Error()}) //nolint:errcheck - return - } - if signals.Password != signals.Confirm { - sse.MarshalAndPatchSignals(map[string]any{"error": "Passwords do not match"}) //nolint:errcheck - return - } - - hash, err := auth.HashPassword(signals.Password) + hash, err := auth.HashPassword(password) if err != nil { - sse.MarshalAndPatchSignals(map[string]any{"error": "An error occurred"}) //nolint:errcheck + http.Redirect(w, r, "/register?error="+url.QueryEscape("An error occurred"), http.StatusSeeOther) return } user, err := queries.CreateUser(r.Context(), repository.CreateUserParams{ ID: uuid.New().String(), - Username: signals.Username, + Username: username, PasswordHash: hash, }) if err != nil { - sse.MarshalAndPatchSignals(map[string]any{"error": "Username already taken"}) //nolint:errcheck + http.Redirect(w, r, "/register?error="+url.QueryEscape("Username already taken"), http.StatusSeeOther) return } @@ -130,6 +117,6 @@ func HandleRegister(queries *repository.Queries, sessions *scs.SessionManager) h redirectURL = returnURL } - sse.Redirect(redirectURL) //nolint:errcheck + http.Redirect(w, r, redirectURL, http.StatusSeeOther) } } diff --git a/features/auth/handlers_test.go b/features/auth/handlers_test.go new file mode 100644 index 0000000..149b1cd --- /dev/null +++ b/features/auth/handlers_test.go @@ -0,0 +1,351 @@ +package auth_test + +import ( + "context" + "database/sql" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/alexedwards/scs/v2" + "github.com/google/uuid" + + "github.com/ryanhamamura/games/auth" + "github.com/ryanhamamura/games/db/repository" + featauth "github.com/ryanhamamura/games/features/auth" + "github.com/ryanhamamura/games/features/lobby" + appsessions "github.com/ryanhamamura/games/sessions" + "github.com/ryanhamamura/games/testutil" +) + +// sessionCookieName is the default SCS cookie name used in tests. +const sessionCookieName = "session" + +type testSetup struct { + db *sql.DB + queries *repository.Queries + sm *scs.SessionManager +} + +func (s *testSetup) ctx() context.Context { + return context.Background() +} + +func newTestSetup(t *testing.T) *testSetup { + t.Helper() + db, queries := testutil.NewTestDB(t) + sm := testutil.NewTestSessionManager(t, db) + return &testSetup{db: db, queries: queries, sm: sm} +} + +// createTestUser inserts a user into the test database and returns the user ID. +func createTestUser(t *testing.T, setup *testSetup, username, password string) string { + t.Helper() + hash, err := auth.HashPassword(password) + if err != nil { + t.Fatalf("hashing password: %v", err) + } + id := uuid.New().String() + _, err = setup.queries.CreateUser(setup.ctx(), repository.CreateUserParams{ + ID: id, + Username: username, + PasswordHash: hash, + }) + if err != nil { + t.Fatalf("creating test user: %v", err) + } + return id +} + +// postForm sends a POST request with form-encoded body through the session middleware, +// forwarding any cookies from a previous response. +func postForm(handler http.Handler, path string, values url.Values, cookies []*http.Cookie) *httptest.ResponseRecorder { + body := strings.NewReader(values.Encode()) + req := httptest.NewRequest(http.MethodPost, path, body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + return rec +} + +// getPage sends a GET request through the session middleware, forwarding cookies. +func getPage(handler http.Handler, path string, cookies []*http.Cookie) *httptest.ResponseRecorder { + req := httptest.NewRequest(http.MethodGet, path, nil) + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + return rec +} + +// extractSessionValue makes a GET request with the given cookies to a test endpoint +// that reads a session value, verifying the session was persisted correctly. +func extractSessionValue(t *testing.T, setup *testSetup, cookies []*http.Cookie, key string) string { + t.Helper() + var value string + handler := setup.sm.LoadAndSave(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + value = setup.sm.GetString(r.Context(), key) + })) + req := httptest.NewRequest(http.MethodGet, "/check-session", nil) + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("session check returned %d", rec.Code) + } + return value +} + +func TestHandleLogin_Success(t *testing.T) { + setup := newTestSetup(t) + createTestUser(t, setup, "alice", "password123") + + handler := setup.sm.LoadAndSave(featauth.HandleLogin(setup.queries, setup.sm)) + rec := postForm(handler, "/auth/login", url.Values{ + "username": {"alice"}, + "password": {"password123"}, + }, nil) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/" { + t.Errorf("expected redirect to /, got %q", loc) + } + + // Verify the response sets a session cookie + cookies := rec.Result().Cookies() + if !hasCookie(cookies, sessionCookieName) { + t.Fatal("response did not set a session cookie") + } + + // Verify session contains user data by reading it back + userID := extractSessionValue(t, setup, cookies, appsessions.KeyUserID) + if userID == "" { + t.Error("session does not contain user_id after login") + } + nickname := extractSessionValue(t, setup, cookies, appsessions.KeyNickname) + if nickname != "alice" { + t.Errorf("expected nickname %q, got %q", "alice", nickname) + } +} + +func TestHandleLogin_InvalidPassword(t *testing.T) { + setup := newTestSetup(t) + createTestUser(t, setup, "alice", "password123") + + handler := setup.sm.LoadAndSave(featauth.HandleLogin(setup.queries, setup.sm)) + rec := postForm(handler, "/auth/login", url.Values{ + "username": {"alice"}, + "password": {"wrongpassword"}, + }, nil) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.HasPrefix(loc, "/login?error=") { + t.Errorf("expected redirect to /login?error=..., got %q", loc) + } +} + +func TestHandleLogin_UnknownUser(t *testing.T) { + setup := newTestSetup(t) + + handler := setup.sm.LoadAndSave(featauth.HandleLogin(setup.queries, setup.sm)) + rec := postForm(handler, "/auth/login", url.Values{ + "username": {"nonexistent"}, + "password": {"password123"}, + }, nil) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.HasPrefix(loc, "/login?error=") { + t.Errorf("expected redirect to /login?error=..., got %q", loc) + } +} + +func TestHandleLogin_ReturnURL(t *testing.T) { + setup := newTestSetup(t) + createTestUser(t, setup, "alice", "password123") + + // First, visit the login page with a return_url to store it in the session + loginPageHandler := setup.sm.LoadAndSave(featauth.HandleLoginPage(setup.sm)) + pageRec := getPage(loginPageHandler, "/login?return_url=/games/abc", nil) + cookies := pageRec.Result().Cookies() + + // Now log in with those cookies so the handler can read return_url from session + loginHandler := setup.sm.LoadAndSave(featauth.HandleLogin(setup.queries, setup.sm)) + rec := postForm(loginHandler, "/auth/login", url.Values{ + "username": {"alice"}, + "password": {"password123"}, + }, cookies) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/games/abc" { + t.Errorf("expected redirect to /games/abc, got %q", loc) + } +} + +func TestHandleRegister_Success(t *testing.T) { + setup := newTestSetup(t) + + handler := setup.sm.LoadAndSave(featauth.HandleRegister(setup.queries, setup.sm)) + rec := postForm(handler, "/auth/register", url.Values{ + "username": {"newuser"}, + "password": {"password123"}, + "confirm": {"password123"}, + }, nil) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + if loc := rec.Header().Get("Location"); loc != "/" { + t.Errorf("expected redirect to /, got %q", loc) + } + + cookies := rec.Result().Cookies() + if !hasCookie(cookies, sessionCookieName) { + t.Fatal("response did not set a session cookie") + } + + userID := extractSessionValue(t, setup, cookies, appsessions.KeyUserID) + if userID == "" { + t.Error("session does not contain user_id after registration") + } +} + +func TestHandleRegister_PasswordMismatch(t *testing.T) { + setup := newTestSetup(t) + + handler := setup.sm.LoadAndSave(featauth.HandleRegister(setup.queries, setup.sm)) + rec := postForm(handler, "/auth/register", url.Values{ + "username": {"newuser"}, + "password": {"password123"}, + "confirm": {"differentpassword"}, + }, nil) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.Contains(loc, "Passwords+do+not+match") { + t.Errorf("expected error about password mismatch, got %q", loc) + } +} + +func TestHandleRegister_InvalidUsername(t *testing.T) { + setup := newTestSetup(t) + + handler := setup.sm.LoadAndSave(featauth.HandleRegister(setup.queries, setup.sm)) + rec := postForm(handler, "/auth/register", url.Values{ + "username": {"ab"}, // too short + "password": {"password123"}, + "confirm": {"password123"}, + }, nil) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.HasPrefix(loc, "/register?error=") { + t.Errorf("expected redirect to /register?error=..., got %q", loc) + } +} + +func TestHandleRegister_ShortPassword(t *testing.T) { + setup := newTestSetup(t) + + handler := setup.sm.LoadAndSave(featauth.HandleRegister(setup.queries, setup.sm)) + rec := postForm(handler, "/auth/register", url.Values{ + "username": {"validuser"}, + "password": {"short"}, + "confirm": {"short"}, + }, nil) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.HasPrefix(loc, "/register?error=") { + t.Errorf("expected redirect to /register?error=..., got %q", loc) + } +} + +func TestHandleRegister_DuplicateUsername(t *testing.T) { + setup := newTestSetup(t) + createTestUser(t, setup, "taken", "password123") + + handler := setup.sm.LoadAndSave(featauth.HandleRegister(setup.queries, setup.sm)) + rec := postForm(handler, "/auth/register", url.Values{ + "username": {"taken"}, + "password": {"password123"}, + "confirm": {"password123"}, + }, nil) + + if rec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rec.Code) + } + loc := rec.Header().Get("Location") + if !strings.Contains(loc, "Username+already+taken") { + t.Errorf("expected error about duplicate username, got %q", loc) + } +} + +func TestHandleLogout(t *testing.T) { + setup := newTestSetup(t) + createTestUser(t, setup, "alice", "password123") + + // Log in first to establish a session + loginHandler := setup.sm.LoadAndSave(featauth.HandleLogin(setup.queries, setup.sm)) + loginRec := postForm(loginHandler, "/auth/login", url.Values{ + "username": {"alice"}, + "password": {"password123"}, + }, nil) + cookies := loginRec.Result().Cookies() + + // Verify we're logged in + userID := extractSessionValue(t, setup, cookies, appsessions.KeyUserID) + if userID == "" { + t.Fatal("expected to be logged in before testing logout") + } + + // Now log out + logoutHandler := setup.sm.LoadAndSave(lobby.HandleLogout(setup.sm)) + logoutRec := postForm(logoutHandler, "/logout", nil, cookies) + + if logoutRec.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, logoutRec.Code) + } + if loc := logoutRec.Header().Get("Location"); loc != "/" { + t.Errorf("expected redirect to /, got %q", loc) + } + + // Verify session is cleared — use the cookies from the logout response + logoutCookies := logoutRec.Result().Cookies() + userID = extractSessionValue(t, setup, logoutCookies, appsessions.KeyUserID) + if userID != "" { + t.Errorf("expected empty user_id after logout, got %q", userID) + } +} + +func hasCookie(cookies []*http.Cookie, name string) bool { + for _, c := range cookies { + if c.Name == name { + return true + } + } + return false +} diff --git a/features/auth/pages/login.templ b/features/auth/pages/login.templ index b159492..2be459a 100644 --- a/features/auth/pages/login.templ +++ b/features/auth/pages/login.templ @@ -1,45 +1,39 @@ package pages -import ( - "github.com/ryanhamamura/games/features/common/layouts" - "github.com/starfederation/datastar-go/datastar" -) +import "github.com/ryanhamamura/games/features/common/layouts" -templ LoginPage() { +templ LoginPage(errorMsg string) { @layouts.Base("Login") { -
+

Login

Sign in to your account

-
-
+ if errorMsg != "" { +
{ errorMsg }
+ } +
+ autofocus + />
- -
+

Don't have an account? Register

diff --git a/features/auth/pages/register.templ b/features/auth/pages/register.templ index dab31e1..5e6a40b 100644 --- a/features/auth/pages/register.templ +++ b/features/auth/pages/register.templ @@ -1,54 +1,47 @@ package pages -import ( - "github.com/ryanhamamura/games/features/common/layouts" - "github.com/starfederation/datastar-go/datastar" -) +import "github.com/ryanhamamura/games/features/common/layouts" -templ RegisterPage() { +templ RegisterPage(errorMsg string) { @layouts.Base("Register") { -
+

Register

Create a new account

-
-
+ if errorMsg != "" { +
{ errorMsg }
+ } +
+ autofocus + /> + />
- -
+

Already have an account? Login

diff --git a/features/auth/routes.go b/features/auth/routes.go index bb39e44..1726269 100644 --- a/features/auth/routes.go +++ b/features/auth/routes.go @@ -9,7 +9,7 @@ import ( ) func SetupRoutes(router chi.Router, queries *repository.Queries, sessions *scs.SessionManager) { - router.Get("/login", HandleLoginPage()) + router.Get("/login", HandleLoginPage(sessions)) router.Get("/register", HandleRegisterPage()) router.Post("/auth/login", HandleLogin(queries, sessions)) router.Post("/auth/register", HandleRegister(queries, sessions)) diff --git a/features/lobby/handlers.go b/features/lobby/handlers.go index 074b0aa..f2b5a10 100644 --- a/features/lobby/handlers.go +++ b/features/lobby/handlers.go @@ -171,7 +171,6 @@ func HandleLogout(sessions *scs.SessionManager) http.HandlerFunc { return } - sse := datastar.NewSSE(w, r) - sse.ExecuteScript("window.location.href='/'") //nolint:errcheck + http.Redirect(w, r, "/", http.StatusSeeOther) } } diff --git a/features/lobby/pages/lobby.templ b/features/lobby/pages/lobby.templ index b6186df..43d0230 100644 --- a/features/lobby/pages/lobby.templ +++ b/features/lobby/pages/lobby.templ @@ -20,13 +20,11 @@ templ LobbyPage(data LobbyData) { if data.IsLoggedIn {
Logged in as { data.Username } - +
+ +
} else {