package integration import ( "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strings" "time" "rideaware/internal/config" ) type OAuthService struct { repo *Repository } func NewOAuthService() *OAuthService { return &OAuthService{ repo: NewRepository(), } } // GenerateState creates a cryptographically random state token for OAuth CSRF protection. func (s *OAuthService) GenerateState(userID uint, provider string) (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("failed to generate state: %w", err) } stateToken := base64.URLEncoding.EncodeToString(b) state := &OAuthState{ State: stateToken, UserID: userID, Provider: provider, ExpiresAt: time.Now().Add(10 * time.Minute), } if err := s.repo.CreateState(state); err != nil { return "", err } return stateToken, nil } // GenerateStateWithPKCE creates a state token and PKCE code verifier/challenge for Garmin OAuth. func (s *OAuthService) GenerateStateWithPKCE(userID uint, provider string) (stateToken, codeVerifier, codeChallenge string, err error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", "", "", fmt.Errorf("failed to generate state: %w", err) } stateToken = base64.URLEncoding.EncodeToString(b) // Generate code verifier (43-128 chars, base64url-safe) verifierBytes := make([]byte, 32) if _, err := rand.Read(verifierBytes); err != nil { return "", "", "", fmt.Errorf("failed to generate code verifier: %w", err) } codeVerifier = base64.RawURLEncoding.EncodeToString(verifierBytes) // code_challenge = base64url(SHA256(code_verifier)) h := sha256.Sum256([]byte(codeVerifier)) codeChallenge = base64.RawURLEncoding.EncodeToString(h[:]) state := &OAuthState{ State: stateToken, UserID: userID, Provider: provider, CodeVerifier: codeVerifier, ExpiresAt: time.Now().Add(10 * time.Minute), } if err := s.repo.CreateState(state); err != nil { return "", "", "", err } return stateToken, codeVerifier, codeChallenge, nil } // ValidateState validates and consumes an OAuth state token. func (s *OAuthService) ValidateState(stateToken string) (*OAuthState, error) { return s.repo.GetAndDeleteState(stateToken) } // TokenResponse is the standard OAuth2 token response. type TokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` } // ExchangeCode exchanges an authorization code for tokens. func (s *OAuthService) ExchangeCode(tokenURL string, params url.Values) (*TokenResponse, error) { resp, err := http.Post(tokenURL, "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) if err != nil { return nil, fmt.Errorf("token exchange request failed: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read token response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token exchange failed (status %d): %s", resp.StatusCode, string(body)) } var tokenResp TokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, fmt.Errorf("failed to parse token response: %w", err) } return &tokenResp, nil } // RefreshAccessToken refreshes an expired access token. func (s *OAuthService) RefreshAccessToken(tokenURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) { params := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {refreshToken}, "client_id": {clientID}, "client_secret": {clientSecret}, } return s.ExchangeCode(tokenURL, params) } // SaveConnection encrypts tokens and stores the OAuth connection. func (s *OAuthService) SaveConnection(userID uint, provider string, tokenResp *TokenResponse) error { encAccess, err := Encrypt(tokenResp.AccessToken, config.OAuth.EncryptionKey) if err != nil { return fmt.Errorf("failed to encrypt access token: %w", err) } encRefresh := "" if tokenResp.RefreshToken != "" { encRefresh, err = Encrypt(tokenResp.RefreshToken, config.OAuth.EncryptionKey) if err != nil { return fmt.Errorf("failed to encrypt refresh token: %w", err) } } expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) conn := &OAuthConnection{ UserID: userID, Provider: provider, AccessToken: encAccess, RefreshToken: encRefresh, TokenExpiresAt: expiresAt, Scopes: tokenResp.Scope, Status: "active", } return s.repo.UpsertConnection(conn) } // GetValidToken retrieves a connection and ensures the token is valid (refreshing if needed). func (s *OAuthService) GetValidToken(userID uint, provider string) (string, error) { conn, err := s.repo.GetConnection(userID, provider) if err != nil { return "", err } if conn.Status != "active" { return "", fmt.Errorf("%s connection is %s, please reconnect", provider, conn.Status) } accessToken, err := Decrypt(conn.AccessToken, config.OAuth.EncryptionKey) if err != nil { return "", fmt.Errorf("failed to decrypt access token: %w", err) } // Token still valid (with 30s buffer) if time.Now().Before(conn.TokenExpiresAt.Add(-30 * time.Second)) { return accessToken, nil } // Token expired - try refresh if conn.RefreshToken == "" { conn.Status = "expired" s.repo.UpdateConnection(conn) return "", fmt.Errorf("%s token expired and no refresh token available, please reconnect", provider) } refreshToken, err := Decrypt(conn.RefreshToken, config.OAuth.EncryptionKey) if err != nil { return "", fmt.Errorf("failed to decrypt refresh token: %w", err) } var providerConfig config.OAuthProviderConfig switch provider { case "garmin": providerConfig = config.OAuth.Garmin case "wahoo": providerConfig = config.OAuth.Wahoo default: return "", errors.New("unknown provider") } tokenResp, err := s.RefreshAccessToken(providerConfig.TokenURL, providerConfig.ClientID, providerConfig.ClientSecret, refreshToken) if err != nil { conn.Status = "expired" s.repo.UpdateConnection(conn) return "", fmt.Errorf("%s token refresh failed, please reconnect: %w", provider, err) } // Save new tokens if err := s.SaveConnection(userID, provider, tokenResp); err != nil { return "", fmt.Errorf("failed to save refreshed tokens: %w", err) } return tokenResp.AccessToken, nil } // GetConnectionStatus returns the connection status for a user+provider. func (s *OAuthService) GetConnectionStatus(userID uint, provider string) (map[string]interface{}, error) { conn, err := s.repo.GetConnection(userID, provider) if err != nil { return map[string]interface{}{ "connected": false, "provider": provider, }, nil } return map[string]interface{}{ "connected": conn.Status == "active", "provider": conn.Provider, "status": conn.Status, "token_expires_at": conn.TokenExpiresAt, "connected_at": conn.CreatedAt, }, nil } // Disconnect removes an OAuth connection. func (s *OAuthService) Disconnect(userID uint, provider string) error { return s.repo.DeleteConnection(userID, provider) } // Encrypt encrypts plaintext using AES-256-GCM with the given hex-encoded key. func Encrypt(plaintext, hexKey string) (string, error) { if hexKey == "" { // No encryption key configured - store as base64 (development mode) return base64.StdEncoding.EncodeToString([]byte(plaintext)), nil } key, err := hex.DecodeString(hexKey) if err != nil { return "", fmt.Errorf("invalid encryption key: %w", err) } block, err := aes.NewCipher(key) if err != nil { return "", err } gcm, err := cipher.NewGCM(block) if err != nil { return "", err } nonce := make([]byte, gcm.NonceSize()) if _, err := rand.Read(nonce); err != nil { return "", err } ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) return base64.StdEncoding.EncodeToString(ciphertext), nil } // Decrypt decrypts ciphertext that was encrypted with Encrypt. func Decrypt(encoded, hexKey string) (string, error) { if hexKey == "" { // No encryption key - stored as plain base64 decoded, err := base64.StdEncoding.DecodeString(encoded) if err != nil { return "", err } return string(decoded), nil } key, err := hex.DecodeString(hexKey) if err != nil { return "", fmt.Errorf("invalid encryption key: %w", err) } ciphertext, err := base64.StdEncoding.DecodeString(encoded) if err != nil { return "", err } block, err := aes.NewCipher(key) if err != nil { return "", err } gcm, err := cipher.NewGCM(block) if err != nil { return "", err } nonceSize := gcm.NonceSize() if len(ciphertext) < nonceSize { return "", errors.New("ciphertext too short") } nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return "", err } return string(plaintext), nil }