feat: migrate Flask API to Go with JWT auth
This commit is contained in:
153
internal/auth/handler.go
Normal file
153
internal/auth/handler.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"rideaware/internal/config"
|
||||
"rideaware/internal/user"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
userService *user.Service
|
||||
}
|
||||
|
||||
func NewHandler() *Handler {
|
||||
return &Handler{
|
||||
userService: user.NewService(),
|
||||
}
|
||||
}
|
||||
|
||||
type SignupRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
UserID uint `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (h *Handler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
var req SignupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.userService.CreateUser(req.Username, req.Password, req.Email, req.FirstName, req.LastName)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, _ := config.GenerateAccessToken(newUser.ID, newUser.Email, newUser.Username)
|
||||
refreshToken, _ := config.GenerateRefreshToken(newUser.ID, newUser.Email, newUser.Username)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: 900,
|
||||
UserID: newUser.ID,
|
||||
Username: newUser.Username,
|
||||
Email: newUser.Email,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.VerifyUser(req.Username, req.Password)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, _ := config.GenerateAccessToken(user.ID, user.Email, user.Username)
|
||||
refreshToken, _ := config.GenerateRefreshToken(user.ID, user.Email, user.Username)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: 900,
|
||||
UserID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
h.userService.RequestPasswordReset(req.Email)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "If email exists, reset link has been sent",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) ConfirmPasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ResetPassword(req.Token, req.NewPassword); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "Password reset successful",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "Logout successful"})
|
||||
}
|
||||
97
internal/config/jwt.go
Normal file
97
internal/config/jwt.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type JWTConfig struct {
|
||||
SecretKey string
|
||||
AccessTokenDuration time.Duration
|
||||
RefreshTokenDuration time.Duration
|
||||
ResetTokenDuration time.Duration
|
||||
}
|
||||
|
||||
var JWT *JWTConfig
|
||||
|
||||
func InitJWT() {
|
||||
JWT = &JWTConfig{
|
||||
SecretKey: os.Getenv("JWT_SECRET_KEY"),
|
||||
AccessTokenDuration: 15 * time.Minute,
|
||||
RefreshTokenDuration: 7 * 24 * time.Hour,
|
||||
ResetTokenDuration: 1 * time.Hour,
|
||||
}
|
||||
|
||||
if JWT.SecretKey == "" {
|
||||
panic("JWT_SECRET_KEY not set in environment")
|
||||
}
|
||||
}
|
||||
|
||||
type CustomClaims struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
TokenType string `json:"token_type"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func GenerateAccessToken(userID uint, email, username string) (string, error) {
|
||||
claims := CustomClaims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Username: username,
|
||||
TokenType: "access",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(JWT.AccessTokenDuration)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: "rideaware",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(JWT.SecretKey))
|
||||
}
|
||||
|
||||
func GenerateRefreshToken(userID uint, email, username string) (string, error) {
|
||||
claims := CustomClaims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Username: username,
|
||||
TokenType: "refresh",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(JWT.RefreshTokenDuration)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: "rideaware",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(JWT.SecretKey))
|
||||
}
|
||||
|
||||
func VerifyToken(tokenString string) (*CustomClaims, error) {
|
||||
claims := &CustomClaims{}
|
||||
token, err := jwt.ParseWithClaims(
|
||||
tokenString,
|
||||
claims,
|
||||
func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(JWT.SecretKey), nil
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
82
internal/email/service.go
Normal file
82
internal/email/service.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/resend/resend-go/v2"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
client *resend.Client
|
||||
from string
|
||||
}
|
||||
|
||||
func NewService() *Service {
|
||||
senderEmail := os.Getenv("SENDER_EMAIL")
|
||||
if senderEmail == "" {
|
||||
senderEmail = "noreply@rideaware.app"
|
||||
}
|
||||
|
||||
apiKey := os.Getenv("RESEND_API_KEY")
|
||||
if apiKey == "" {
|
||||
apiKey = "re_test"
|
||||
}
|
||||
|
||||
return &Service{
|
||||
client: resend.NewClient(apiKey),
|
||||
from: senderEmail,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SendPasswordResetEmail(email, username, resetLink string) error {
|
||||
params := &resend.SendEmailRequest{
|
||||
From: s.from,
|
||||
To: []string{email},
|
||||
Subject: "Reset Your RideAware Password",
|
||||
Html: fmt.Sprintf(`
|
||||
<h2>Password Reset Request</h2>
|
||||
<p>Hi %s,</p>
|
||||
<p>We received a request to reset your password. Click the link below to create a new password:</p>
|
||||
<p><a href="%s">Reset Password</a></p>
|
||||
<p>This link will expire in 1 hour.</p>
|
||||
<p>If you didn't request this, you can ignore this email.</p>
|
||||
`, username, resetLink),
|
||||
}
|
||||
|
||||
sent, err := s.client.Emails.Send(params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
|
||||
if sent.Id == "" {
|
||||
return fmt.Errorf("failed to send email")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) SendWelcomeEmail(email, username string) error {
|
||||
params := &resend.SendEmailRequest{
|
||||
From: s.from,
|
||||
To: []string{email},
|
||||
Subject: "Welcome to RideAware",
|
||||
Html: fmt.Sprintf(`
|
||||
<h2>Welcome to RideAware</h2>
|
||||
<p>Hi %s,</p>
|
||||
<p>Your account has been created successfully!</p>
|
||||
<p>Start tracking your rides and improve your performance.</p>
|
||||
`, username),
|
||||
}
|
||||
|
||||
sent, err := s.client.Emails.Send(params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
|
||||
if sent.Id == "" {
|
||||
return fmt.Errorf("failed to send email")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
65
internal/middleware/auth.go
Normal file
65
internal/middleware/auth.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"rideaware/internal/config"
|
||||
)
|
||||
|
||||
const UserContextKey = "user"
|
||||
|
||||
type AuthMiddleware struct{}
|
||||
|
||||
func NewAuthMiddleware() *AuthMiddleware {
|
||||
return &AuthMiddleware{}
|
||||
}
|
||||
|
||||
func (am *AuthMiddleware) ProtectedRoute(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "missing authorization header",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid authorization header format",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
claims, err := config.VerifyToken(token)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid or expired token",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if claims.TokenType != "access" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "refresh token cannot be used for access",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), UserContextKey, claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
29
internal/profile/model.go
Normal file
29
internal/profile/model.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package profile
|
||||
|
||||
import "time"
|
||||
|
||||
type Equipment struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"not null;index" json:"user_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Type string `gorm:"not null" json:"type"` // "bike", "shoes", "helmet", etc.
|
||||
Brand string `gorm:"default:''" json:"brand"`
|
||||
Model string `gorm:"default:''" json:"model"`
|
||||
Weight float64 `gorm:"default:0" json:"weight"` // grams
|
||||
Notes string `gorm:"default:''" json:"notes"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"not null;uniqueIndex" json:"user_id"`
|
||||
TotalRides int `gorm:"default:0" json:"total_rides"`
|
||||
TotalDistance float64 `gorm:"default:0" json:"total_distance"`
|
||||
TotalTime int `gorm:"default:0" json:"total_time"`
|
||||
AverageSpeed float64 `gorm:"default:0" json:"average_speed"`
|
||||
MaxSpeed float64 `gorm:"default:0" json:"max_speed"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
93
internal/user/handler.go
Normal file
93
internal/user/handler.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"rideaware/internal/config"
|
||||
"rideaware/internal/middleware"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewHandler() *Handler {
|
||||
return &Handler{
|
||||
service: NewService(),
|
||||
}
|
||||
}
|
||||
|
||||
type GetProfileResponse struct {
|
||||
User *User `json:"user"`
|
||||
Profile *Profile `json:"profile"`
|
||||
}
|
||||
|
||||
func (h *Handler) GetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
claims := r.Context().Value(middleware.UserContextKey).(*config.CustomClaims)
|
||||
|
||||
user, err := h.service.repo.GetUserByID(claims.UserID)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "user not found"})
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(GetProfileResponse{
|
||||
User: user,
|
||||
Profile: user.Profile,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
|
||||
claims := r.Context().Value(middleware.UserContextKey).(*config.CustomClaims)
|
||||
|
||||
var req struct {
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
Bio string `json:"bio"`
|
||||
FTP int `json:"ftp"`
|
||||
MaxHR int `json:"max_hr"`
|
||||
Weight float64 `json:"weight"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.service.repo.GetUserByID(claims.UserID)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "user not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update profile
|
||||
if user.Profile != nil {
|
||||
user.Profile.FirstName = req.FirstName
|
||||
user.Profile.LastName = req.LastName
|
||||
user.Profile.Bio = req.Bio
|
||||
user.Profile.FTP = req.FTP
|
||||
user.Profile.MaxHR = req.MaxHR
|
||||
user.Profile.Weight = req.Weight
|
||||
|
||||
if err := h.service.repo.UpdateUser(user); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "failed to update profile"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(GetProfileResponse{
|
||||
User: user,
|
||||
Profile: user.Profile,
|
||||
})
|
||||
}
|
||||
105
internal/user/model.go
Normal file
105
internal/user/model.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Username string `gorm:"uniqueIndex;not null" json:"username"`
|
||||
Email string `gorm:"uniqueIndex;not null" json:"email"`
|
||||
Password string `gorm:"not null" json:"-"`
|
||||
IsActive bool `gorm:"default:true" json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
Profile *Profile `gorm:"foreignKey:UserID;constraint:OnDelete:Cascade" json:"profile,omitempty"`
|
||||
PasswordResets []PasswordReset `gorm:"foreignKey:UserID;constraint:OnDelete:Cascade" json:"password_resets,omitempty"`
|
||||
Sessions []Session `gorm:"foreignKey:UserID;constraint:OnDelete:Cascade" json:"sessions,omitempty"`
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"not null;uniqueIndex" json:"user_id"`
|
||||
FirstName string `gorm:"default:''" json:"first_name"`
|
||||
LastName string `gorm:"default:''" json:"last_name"`
|
||||
Bio string `gorm:"default:''" json:"bio"`
|
||||
ProfilePicture string `gorm:"default:''" json:"profile_picture"`
|
||||
RestingHR int `gorm:"default:0" json:"resting_hr"`
|
||||
MaxHR int `gorm:"default:0" json:"max_hr"`
|
||||
FTP int `gorm:"default:0" json:"ftp"`
|
||||
Weight float64 `gorm:"default:0" json:"weight"`
|
||||
TotalRides int `gorm:"default:0" json:"total_rides"`
|
||||
TotalDistance float64 `gorm:"default:0" json:"total_distance"`
|
||||
TotalTime int `gorm:"default:0" json:"total_time"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type PasswordReset struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"not null" json:"user_id"`
|
||||
Token string `gorm:"uniqueIndex;not null" json:"token"`
|
||||
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"not null;index" json:"user_id"`
|
||||
Token string `gorm:"uniqueIndex;not null" json:"token"`
|
||||
ExpiresAt time.Time `gorm:"not null;index" json:"expires_at"`
|
||||
DeviceName string `gorm:"default:''" json:"device_name"`
|
||||
UserAgent string `gorm:"default:''" json:"user_agent"`
|
||||
IPAddress string `gorm:"default:''" json:"ip_address"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ===== Methods =====
|
||||
|
||||
// SetPassword hashes and sets the password
|
||||
func (u *User) SetPassword(rawPassword string) error {
|
||||
if len(rawPassword) < 8 {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword(
|
||||
[]byte(rawPassword),
|
||||
bcrypt.DefaultCost,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Password = string(hashedPassword)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPassword verifies the password
|
||||
func (u *User) CheckPassword(password string) bool {
|
||||
return bcrypt.CompareHashAndPassword(
|
||||
[]byte(u.Password),
|
||||
[]byte(password),
|
||||
) == nil
|
||||
}
|
||||
|
||||
// AfterCreate hook: automatically create profile after user insert
|
||||
func (u *User) AfterCreate(tx *gorm.DB) error {
|
||||
profile := &Profile{
|
||||
UserID: u.ID,
|
||||
}
|
||||
return tx.Create(profile).Error
|
||||
}
|
||||
|
||||
// IsPasswordResetTokenValid checks if token exists and is not expired
|
||||
func (prt *PasswordReset) IsValid() bool {
|
||||
return prt.UsedAt == nil && time.Now().Before(prt.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsSessionValid checks if session is not expired
|
||||
func (s *Session) IsValid() bool {
|
||||
return time.Now().Before(s.ExpiresAt)
|
||||
}
|
||||
62
internal/user/repository.go
Normal file
62
internal/user/repository.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"rideaware/pkg/database"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Repository struct{}
|
||||
|
||||
func NewRepository() *Repository {
|
||||
return &Repository{}
|
||||
}
|
||||
|
||||
func (r *Repository) CreateUser(user *User) error {
|
||||
return database.DB.Create(user).Error
|
||||
}
|
||||
|
||||
func (r *Repository) GetUserByUsername(username string) (*User, error) {
|
||||
var user User
|
||||
if err := database.DB.Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *Repository) GetUserByEmail(email string) (*User, error) {
|
||||
var user User
|
||||
if err := database.DB.Where("email = ?", email).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *Repository) GetUserByID(id uint) (*User, error) {
|
||||
var user User
|
||||
if err := database.DB.Preload("Profile").Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *Repository) UpdateUser(user *User) error {
|
||||
return database.DB.Save(user).Error
|
||||
}
|
||||
|
||||
func (r *Repository) UserExists(username, email string) (bool, error) {
|
||||
var count int64
|
||||
err := database.DB.Model(&User{}).
|
||||
Where("username = ? OR email = ?", username, email).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
159
internal/user/service.go
Normal file
159
internal/user/service.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"rideaware/internal/config"
|
||||
"rideaware/internal/email"
|
||||
"rideaware/pkg/database"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
repo *Repository
|
||||
email *email.Service
|
||||
}
|
||||
|
||||
func NewService() *Service {
|
||||
return &Service{
|
||||
repo: NewRepository(),
|
||||
email: email.NewService(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) CreateUser(username, password, email, firstName, lastName string) (*User, error) {
|
||||
if username == "" || password == "" {
|
||||
return nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
if email != "" {
|
||||
if !isValidEmail(email) {
|
||||
return nil, errors.New("invalid email format")
|
||||
}
|
||||
}
|
||||
|
||||
exists, err := s.repo.UserExists(username, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("username or email already exists")
|
||||
}
|
||||
|
||||
user := &User{
|
||||
Username: username,
|
||||
Email: email,
|
||||
}
|
||||
|
||||
if err := user.SetPassword(password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.repo.CreateUser(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = s.email.SendWelcomeEmail(email, username)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Service) VerifyUser(username, password string) (*User, error) {
|
||||
user, err := s.repo.GetUserByUsername(username)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid username or password")
|
||||
}
|
||||
|
||||
if !user.CheckPassword(password) {
|
||||
return nil, errors.New("invalid username or password")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Service) RequestPasswordReset(email string) error {
|
||||
user, err := s.repo.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
// Don't leak if email exists
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := generateSecureToken(32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resetToken := &PasswordReset{
|
||||
UserID: user.ID,
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(config.JWT.ResetTokenDuration),
|
||||
}
|
||||
|
||||
if err := database.DB.Create(resetToken).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resetLink := "https://rideaware.app/reset-password?token=" + token
|
||||
return s.email.SendPasswordResetEmail(user.Email, user.Username, resetLink)
|
||||
}
|
||||
|
||||
func (s *Service) ResetPassword(token, newPassword string) error {
|
||||
if len(newPassword) < 8 {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
|
||||
var resetToken PasswordReset
|
||||
if err := database.DB.Where("token = ?", token).First(&resetToken).Error; err != nil {
|
||||
return errors.New("invalid or expired reset token")
|
||||
}
|
||||
|
||||
if !resetToken.IsValid() {
|
||||
return errors.New("reset token has expired")
|
||||
}
|
||||
|
||||
user, err := s.repo.GetUserByID(resetToken.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := user.SetPassword(newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
tx := database.DB.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := tx.Model(user).Update("password", user.Password).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(&resetToken).Update("used_at", now).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func isValidEmail(email string) bool {
|
||||
regex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
return regex.MatchString(email)
|
||||
}
|
||||
|
||||
func generateSecureToken(length int) (string, error) {
|
||||
b := make([]byte, length)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
Reference in New Issue
Block a user