94 lines
2.7 KiB
Go
94 lines
2.7 KiB
Go
package integration
|
|
|
|
import (
|
|
"errors"
|
|
"rideaware/pkg/database"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type Repository struct{}
|
|
|
|
func NewRepository() *Repository {
|
|
return &Repository{}
|
|
}
|
|
|
|
// UpsertConnection creates or updates an OAuth connection for a user+provider pair.
|
|
func (r *Repository) UpsertConnection(conn *OAuthConnection) error {
|
|
var existing OAuthConnection
|
|
err := database.DB.Where("user_id = ? AND provider = ?", conn.UserID, conn.Provider).
|
|
First(&existing).Error
|
|
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return database.DB.Create(conn).Error
|
|
}
|
|
return err
|
|
}
|
|
|
|
existing.AccessToken = conn.AccessToken
|
|
existing.RefreshToken = conn.RefreshToken
|
|
existing.TokenExpiresAt = conn.TokenExpiresAt
|
|
existing.ProviderUserID = conn.ProviderUserID
|
|
existing.Scopes = conn.Scopes
|
|
existing.Status = "active"
|
|
|
|
conn.ID = existing.ID
|
|
return database.DB.Save(&existing).Error
|
|
}
|
|
|
|
// GetConnection retrieves an active OAuth connection for a user+provider pair.
|
|
func (r *Repository) GetConnection(userID uint, provider string) (*OAuthConnection, error) {
|
|
var conn OAuthConnection
|
|
if err := database.DB.Where("user_id = ? AND provider = ?", userID, provider).
|
|
First(&conn).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, errors.New("no " + provider + " connection found")
|
|
}
|
|
return nil, err
|
|
}
|
|
return &conn, nil
|
|
}
|
|
|
|
// UpdateConnection updates an existing OAuth connection.
|
|
func (r *Repository) UpdateConnection(conn *OAuthConnection) error {
|
|
return database.DB.Save(conn).Error
|
|
}
|
|
|
|
// DeleteConnection removes an OAuth connection.
|
|
func (r *Repository) DeleteConnection(userID uint, provider string) error {
|
|
return database.DB.Where("user_id = ? AND provider = ?", userID, provider).
|
|
Delete(&OAuthConnection{}).Error
|
|
}
|
|
|
|
// CreateState stores an OAuth state token for CSRF protection.
|
|
func (r *Repository) CreateState(state *OAuthState) error {
|
|
return database.DB.Create(state).Error
|
|
}
|
|
|
|
// GetAndDeleteState retrieves and deletes an OAuth state token. Returns error if expired or not found.
|
|
func (r *Repository) GetAndDeleteState(stateToken string) (*OAuthState, error) {
|
|
var state OAuthState
|
|
if err := database.DB.Where("state = ?", stateToken).First(&state).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, errors.New("invalid or expired state token")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Delete the state immediately (single-use)
|
|
database.DB.Delete(&state)
|
|
|
|
if time.Now().After(state.ExpiresAt) {
|
|
return nil, errors.New("state token has expired")
|
|
}
|
|
|
|
return &state, nil
|
|
}
|
|
|
|
// CleanupExpiredStates removes expired OAuth state tokens.
|
|
func (r *Repository) CleanupExpiredStates() error {
|
|
return database.DB.Where("expires_at < ?", time.Now()).Delete(&OAuthState{}).Error
|
|
}
|