Files
rideaware-api/internal/integration/repository.go

94 lines
2.7 KiB
Go

package integration
import (
"errors"
"rideaware/pkg/database"
"time"
"gorm.io/gorm"
)
type Repository struct{}
func NewRepository() *Repository {
return &Repository{}
}
// UpsertConnection creates or updates an OAuth connection for a user+provider pair.
func (r *Repository) UpsertConnection(conn *OAuthConnection) error {
var existing OAuthConnection
err := database.DB.Where("user_id = ? AND provider = ?", conn.UserID, conn.Provider).
First(&existing).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return database.DB.Create(conn).Error
}
return err
}
existing.AccessToken = conn.AccessToken
existing.RefreshToken = conn.RefreshToken
existing.TokenExpiresAt = conn.TokenExpiresAt
existing.ProviderUserID = conn.ProviderUserID
existing.Scopes = conn.Scopes
existing.Status = "active"
conn.ID = existing.ID
return database.DB.Save(&existing).Error
}
// GetConnection retrieves an active OAuth connection for a user+provider pair.
func (r *Repository) GetConnection(userID uint, provider string) (*OAuthConnection, error) {
var conn OAuthConnection
if err := database.DB.Where("user_id = ? AND provider = ?", userID, provider).
First(&conn).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("no " + provider + " connection found")
}
return nil, err
}
return &conn, nil
}
// UpdateConnection updates an existing OAuth connection.
func (r *Repository) UpdateConnection(conn *OAuthConnection) error {
return database.DB.Save(conn).Error
}
// DeleteConnection removes an OAuth connection.
func (r *Repository) DeleteConnection(userID uint, provider string) error {
return database.DB.Where("user_id = ? AND provider = ?", userID, provider).
Delete(&OAuthConnection{}).Error
}
// CreateState stores an OAuth state token for CSRF protection.
func (r *Repository) CreateState(state *OAuthState) error {
return database.DB.Create(state).Error
}
// GetAndDeleteState retrieves and deletes an OAuth state token. Returns error if expired or not found.
func (r *Repository) GetAndDeleteState(stateToken string) (*OAuthState, error) {
var state OAuthState
if err := database.DB.Where("state = ?", stateToken).First(&state).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("invalid or expired state token")
}
return nil, err
}
// Delete the state immediately (single-use)
database.DB.Delete(&state)
if time.Now().After(state.ExpiresAt) {
return nil, errors.New("state token has expired")
}
return &state, nil
}
// CleanupExpiredStates removes expired OAuth state tokens.
func (r *Repository) CleanupExpiredStates() error {
return database.DB.Where("expires_at < ?", time.Now()).Delete(&OAuthState{}).Error
}