125 lines
4.1 KiB
Go
125 lines
4.1 KiB
Go
package auth
|
|
|
|
import (
|
|
"database/sql"
|
|
"net/http"
|
|
"net/url"
|
|
|
|
"github.com/alexedwards/scs/v2"
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/ryanhamamura/games/auth"
|
|
"github.com/ryanhamamura/games/db/repository"
|
|
"github.com/ryanhamamura/games/features/auth/pages"
|
|
appsessions "github.com/ryanhamamura/games/sessions"
|
|
)
|
|
|
|
func HandleLoginPage(sessions *scs.SessionManager) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Capture return_url so we can redirect back after login
|
|
if returnURL := r.URL.Query().Get("return_url"); returnURL != "" {
|
|
sessions.Put(r.Context(), "return_url", returnURL)
|
|
}
|
|
|
|
errorMsg := r.URL.Query().Get("error")
|
|
if err := pages.LoginPage(errorMsg).Render(r.Context(), w); err != nil {
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
}
|
|
|
|
func HandleRegisterPage() http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
errorMsg := r.URL.Query().Get("error")
|
|
if err := pages.RegisterPage(errorMsg).Render(r.Context(), w); err != nil {
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
}
|
|
|
|
func HandleLogin(queries *repository.Queries, sessions *scs.SessionManager) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
r.Body = http.MaxBytesReader(w, r.Body, 1024)
|
|
username := r.FormValue("username")
|
|
password := r.FormValue("password")
|
|
|
|
user, err := queries.GetUserByUsername(r.Context(), username)
|
|
if err == sql.ErrNoRows {
|
|
http.Redirect(w, r, "/login?error="+url.QueryEscape("Invalid username or password"), http.StatusSeeOther)
|
|
return
|
|
}
|
|
if err != nil {
|
|
http.Redirect(w, r, "/login?error="+url.QueryEscape("An error occurred"), http.StatusSeeOther)
|
|
return
|
|
}
|
|
if !auth.CheckPassword(password, user.PasswordHash) {
|
|
http.Redirect(w, r, "/login?error="+url.QueryEscape("Invalid username or password"), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
sessions.RenewToken(r.Context()) //nolint:errcheck
|
|
sessions.Put(r.Context(), appsessions.KeyUserID, user.ID)
|
|
sessions.Put(r.Context(), "username", user.Username)
|
|
sessions.Put(r.Context(), appsessions.KeyNickname, user.Username)
|
|
|
|
redirectURL := "/"
|
|
if returnURL := sessions.GetString(r.Context(), "return_url"); returnURL != "" {
|
|
sessions.Put(r.Context(), "return_url", "")
|
|
redirectURL = returnURL
|
|
}
|
|
|
|
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
|
}
|
|
}
|
|
|
|
func HandleRegister(queries *repository.Queries, sessions *scs.SessionManager) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
r.Body = http.MaxBytesReader(w, r.Body, 1024)
|
|
username := r.FormValue("username")
|
|
password := r.FormValue("password")
|
|
confirm := r.FormValue("confirm")
|
|
|
|
if err := auth.ValidateUsername(username); err != nil {
|
|
http.Redirect(w, r, "/register?error="+url.QueryEscape(err.Error()), http.StatusSeeOther)
|
|
return
|
|
}
|
|
if err := auth.ValidatePassword(password); err != nil {
|
|
http.Redirect(w, r, "/register?error="+url.QueryEscape(err.Error()), http.StatusSeeOther)
|
|
return
|
|
}
|
|
if password != confirm {
|
|
http.Redirect(w, r, "/register?error="+url.QueryEscape("Passwords do not match"), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
hash, err := auth.HashPassword(password)
|
|
if err != nil {
|
|
http.Redirect(w, r, "/register?error="+url.QueryEscape("An error occurred"), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
user, err := queries.CreateUser(r.Context(), repository.CreateUserParams{
|
|
ID: uuid.New().String(),
|
|
Username: username,
|
|
PasswordHash: hash,
|
|
})
|
|
if err != nil {
|
|
http.Redirect(w, r, "/register?error="+url.QueryEscape("Username already taken"), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
sessions.RenewToken(r.Context()) //nolint:errcheck
|
|
sessions.Put(r.Context(), appsessions.KeyUserID, user.ID)
|
|
sessions.Put(r.Context(), "username", user.Username)
|
|
sessions.Put(r.Context(), appsessions.KeyNickname, user.Username)
|
|
|
|
redirectURL := "/"
|
|
if returnURL := sessions.GetString(r.Context(), "return_url"); returnURL != "" {
|
|
sessions.Put(r.Context(), "return_url", "")
|
|
redirectURL = returnURL
|
|
}
|
|
|
|
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
|
}
|
|
}
|