diff --git a/internal/web/csrf.go b/internal/web/csrf.go new file mode 100644 index 0000000..dd1993d --- /dev/null +++ b/internal/web/csrf.go @@ -0,0 +1,41 @@ +package web + +import ( + "crypto/rand" + "encoding/hex" + "net/http" +) + +const csrfCookieName = "_csrf" + +// ensureCSRFToken returns the current CSRF token from the cookie, generating +// and setting a new one if the cookie is absent. +func ensureCSRFToken(w http.ResponseWriter, r *http.Request, secure bool) string { + if cookie, err := r.Cookie(csrfCookieName); err == nil && cookie.Value != "" { + return cookie.Value + } + raw := make([]byte, 32) + _, _ = rand.Read(raw) + token := hex.EncodeToString(raw) + http.SetCookie(w, &http.Cookie{ + Name: csrfCookieName, + Value: token, + Path: "/", + MaxAge: 86400, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteStrictMode, + }) + return token +} + +// validateCSRF returns true when the form's csrf_token field matches the +// _csrf cookie. Call after r.ParseForm(). +func validateCSRF(r *http.Request) bool { + cookie, err := r.Cookie(csrfCookieName) + if err != nil || cookie.Value == "" { + return false + } + formToken := r.FormValue("csrf_token") + return formToken != "" && formToken == cookie.Value +} diff --git a/internal/web/handler.go b/internal/web/handler.go index 477a8f6..15d80e0 100644 --- a/internal/web/handler.go +++ b/internal/web/handler.go @@ -4,6 +4,7 @@ import ( "database/sql" "embed" "errors" + "fmt" "html/template" "io" "log/slog" @@ -126,6 +127,7 @@ type dashboardData struct { Subscription *subscriptionRow Invoices []invoiceRow Flash string + CSRFToken string } // ---- DB helpers ---- @@ -181,37 +183,16 @@ func loadRecentInvoices(db *sql.DB, customerID int64) ([]invoiceRow, error) { return result, rows.Err() } +func validEmail(s string) bool { + at := strings.Index(s, "@") + return at > 0 && at < len(s)-1 && strings.Contains(s[at+1:], ".") +} + func formatCurrency(dollars, cents int64, currency string) string { if currency == "USD" || currency == "" { - return "$" + itoa(dollars) + "." + pad2(cents) + return fmt.Sprintf("$%d.%02d", dollars, cents) } - return itoa(dollars) + "." + pad2(cents) + " " + currency -} - -func itoa(n int64) string { - if n == 0 { - return "0" - } - s := "" - neg := n < 0 - if neg { - n = -n - } - for n > 0 { - s = string(rune('0'+n%10)) + s - n /= 10 - } - if neg { - s = "-" + s - } - return s -} - -func pad2(n int64) string { - if n < 10 { - return "0" + itoa(n) - } - return itoa(n) + return fmt.Sprintf("%d.%02d %s", dollars, cents, currency) } // ---- session cookie helpers ---- @@ -265,8 +246,11 @@ func (h *Handler) IndexHandler(w http.ResponseWriter, r *http.Request) { } func (h *Handler) LoginGET(w http.ResponseWriter, r *http.Request) { + token := ensureCSRFToken(w, r, sessionSecure()) h.ts.render(w, "login.html", map[string]any{ - "Error": r.URL.Query().Get("error"), + "Error": r.URL.Query().Get("error"), + "reset": r.URL.Query().Get("reset"), + "CSRFToken": token, }) } @@ -275,6 +259,10 @@ func (h *Handler) LoginPOST(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad request", http.StatusBadRequest) return } + if !validateCSRF(r) { + http.Error(w, "invalid request", http.StatusForbidden) + return + } email := strings.TrimSpace(strings.ToLower(r.FormValue("email"))) password := r.FormValue("password") @@ -313,8 +301,10 @@ func (h *Handler) LoginPOST(w http.ResponseWriter, r *http.Request) { } func (h *Handler) RegisterGET(w http.ResponseWriter, r *http.Request) { + token := ensureCSRFToken(w, r, sessionSecure()) h.ts.render(w, "register.html", map[string]any{ - "Error": r.URL.Query().Get("error"), + "Error": r.URL.Query().Get("error"), + "CSRFToken": token, }) } @@ -323,6 +313,10 @@ func (h *Handler) RegisterPOST(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad request", http.StatusBadRequest) return } + if !validateCSRF(r) { + http.Error(w, "invalid request", http.StatusForbidden) + return + } firstName := strings.TrimSpace(r.FormValue("first_name")) lastName := strings.TrimSpace(r.FormValue("last_name")) @@ -334,6 +328,10 @@ func (h *Handler) RegisterPOST(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/register?error=missing_fields", http.StatusSeeOther) return } + if !validEmail(email) { + http.Redirect(w, r, "/register?error=invalid_email", http.StatusSeeOther) + return + } if password != confirm { http.Redirect(w, r, "/register?error=password_mismatch", http.StatusSeeOther) return @@ -416,22 +414,41 @@ func (h *Handler) DashboardGET(w http.ResponseWriter, r *http.Request) { case "cancelled": flash = "Checkout was cancelled. No charge was made." } - if r.URL.Query().Get("cancelled") == "1" { + switch r.URL.Query().Get("cancelled") { + case "1": flash = "Your subscription has been cancelled and will not renew. You retain access until the end of the current billing period." } - if r.URL.Query().Get("error") == "cancel_failed" { + switch r.URL.Query().Get("error") { + case "cancel_failed": flash = "Could not cancel subscription. Please contact support." + case "already_cancelling": + flash = "Your subscription is already scheduled for cancellation." + case "no_subscription": + flash = "No active subscription found." + case "already_subscribed": + flash = "You already have an active subscription." } + csrfToken := ensureCSRFToken(w, r, sessionSecure()) + h.ts.render(w, "dashboard.html", dashboardData{ Customer: c, Subscription: sub, Invoices: invoices, Flash: flash, + CSRFToken: csrfToken, }) } func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + if !validateCSRF(r) { + http.Error(w, "invalid request", http.StatusForbidden) + return + } cookie, err := r.Cookie("session") if err == nil { _ = auth.DeleteSession(h.DB, cookie.Value) @@ -441,9 +458,11 @@ func (h *Handler) LogoutPOST(w http.ResponseWriter, r *http.Request) { } func (h *Handler) ResetGET(w http.ResponseWriter, r *http.Request) { + token := ensureCSRFToken(w, r, sessionSecure()) h.ts.render(w, "reset-request.html", map[string]any{ - "Sent": r.URL.Query().Get("sent"), - "Error": r.URL.Query().Get("error"), + "Sent": r.URL.Query().Get("sent"), + "Error": r.URL.Query().Get("error"), + "CSRFToken": token, }) } @@ -452,6 +471,10 @@ func (h *Handler) ResetPOST(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad request", http.StatusBadRequest) return } + if !validateCSRF(r) { + http.Error(w, "invalid request", http.StatusForbidden) + return + } email := strings.TrimSpace(strings.ToLower(r.FormValue("email"))) if email == "" { @@ -480,10 +503,12 @@ func (h *Handler) ResetPOST(w http.ResponseWriter, r *http.Request) { } func (h *Handler) ResetConfirmGET(w http.ResponseWriter, r *http.Request) { - token := r.PathValue("token") + pathToken := r.PathValue("token") + csrfToken := ensureCSRFToken(w, r, sessionSecure()) h.ts.render(w, "reset-confirm.html", map[string]any{ - "Token": token, - "Error": r.URL.Query().Get("error"), + "Token": pathToken, + "Error": r.URL.Query().Get("error"), + "CSRFToken": csrfToken, }) } @@ -494,6 +519,10 @@ func (h *Handler) ResetConfirmPOST(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad request", http.StatusBadRequest) return } + if !validateCSRF(r) { + http.Error(w, "invalid request", http.StatusForbidden) + return + } password := r.FormValue("password") confirm := r.FormValue("confirm_password") @@ -552,6 +581,30 @@ func (h *Handler) CheckoutGET(w http.ResponseWriter, r *http.Request) { return } + // Reject price IDs not in the configured set. + validPrice := false + for _, id := range h.Stripe.PriceIDs { + if id == priceID { + validPrice = true + break + } + } + if !validPrice { + http.Error(w, "invalid plan", http.StatusBadRequest) + return + } + + // Block customers who already have an active or cancelling subscription. + var existingCount int + _ = h.DB.QueryRow( + `SELECT COUNT(*) FROM subscriptions WHERE customer_id = ? AND status IN ('active', 'cancelling')`, + customerID, + ).Scan(&existingCount) + if existingCount > 0 { + http.Redirect(w, r, "/dashboard?error=already_subscribed", http.StatusSeeOther) + return + } + var stripeCustomerID string err := h.DB.QueryRow( `SELECT stripe_customer_id FROM customers WHERE id = ?`, customerID, @@ -620,20 +673,33 @@ func (h *Handler) WebhookPOST(w http.ResponseWriter, r *http.Request) { } func (h *Handler) CancelPOST(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + if !validateCSRF(r) { + http.Error(w, "invalid request", http.StatusForbidden) + return + } + customerID := auth.CustomerIDFromContext(r.Context()) - var stripeSubID string + var stripeSubID, subStatus string err := h.DB.QueryRow( - `SELECT stripe_subscription_id FROM subscriptions - WHERE customer_id = ? AND status = 'active' + `SELECT stripe_subscription_id, status FROM subscriptions + WHERE customer_id = ? AND status IN ('active', 'cancelling') ORDER BY created_at DESC LIMIT 1`, customerID, - ).Scan(&stripeSubID) + ).Scan(&stripeSubID, &subStatus) if err != nil { slog.Error("cancel: find subscription", "err", err) http.Redirect(w, r, "/dashboard?error=no_subscription", http.StatusSeeOther) return } + if subStatus == "cancelling" { + http.Redirect(w, r, "/dashboard?error=already_cancelling", http.StatusSeeOther) + return + } if err := payments.CancelSubscription(stripeSubID); err != nil { slog.Error("cancel: stripe cancel", "err", err) diff --git a/internal/web/templates/dashboard.html b/internal/web/templates/dashboard.html index facd561..da35838 100644 --- a/internal/web/templates/dashboard.html +++ b/internal/web/templates/dashboard.html @@ -4,6 +4,7 @@ {{define "nav-actions"}}
+
{{end}} @@ -60,6 +61,7 @@
+
@@ -109,8 +111,8 @@ {{range .Invoices}} - {{.CreatedAt | slice 0 10}} - {{.PeriodStart | slice 0 10}} – {{.PeriodEnd | slice 0 10}} + {{.CreatedAt | fmtDate}} + {{.PeriodStart | fmtDate}} – {{.PeriodEnd | fmtDate}} {{.AmountDisplay}} {{.Status}} diff --git a/internal/web/templates/login.html b/internal/web/templates/login.html index 056e170..4e7753d 100644 --- a/internal/web/templates/login.html +++ b/internal/web/templates/login.html @@ -25,6 +25,7 @@ {{end}}
+
{{end}} +
diff --git a/internal/web/templates/reset-confirm.html b/internal/web/templates/reset-confirm.html index cff8720..3d8495f 100644 --- a/internal/web/templates/reset-confirm.html +++ b/internal/web/templates/reset-confirm.html @@ -23,6 +23,7 @@ {{end}} +
+