blob: acbc978424628a5742d1b0f422682b0707fa653e [file] [log] [blame]
package auth
import (
"context"
"fmt"
"net/http"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"github.com/korap/korap-mcp/config"
)
// OAuthClient handles OAuth2 authentication for KorAP API
type OAuthClient struct {
config *config.OAuthConfig
oauth2Config *oauth2.Config
token *oauth2.Token
httpClient *http.Client
}
// NewOAuthClient creates a new OAuth2 client
func NewOAuthClient(cfg *config.OAuthConfig) (*OAuthClient, error) {
if cfg == nil {
return nil, fmt.Errorf("oauth config cannot be nil")
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid oauth config: %w", err)
}
client := &OAuthClient{
config: cfg,
oauth2Config: cfg.ToOAuth2Config(),
}
return client, nil
}
// GetAuthURL returns the authorization URL for the OAuth2 flow
func (c *OAuthClient) GetAuthURL(state string) string {
if c.oauth2Config == nil {
return ""
}
return c.oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline)
}
// ExchangeCode exchanges an authorization code for an access token
func (c *OAuthClient) ExchangeCode(ctx context.Context, code string) error {
if c.oauth2Config == nil {
return fmt.Errorf("oauth2 not configured")
}
token, err := c.oauth2Config.Exchange(ctx, code)
if err != nil {
return fmt.Errorf("failed to exchange code for token: %w", err)
}
c.token = token
c.httpClient = c.oauth2Config.Client(ctx, token)
return nil
}
// SetToken sets the OAuth2 token directly
func (c *OAuthClient) SetToken(token *oauth2.Token) {
c.token = token
if c.oauth2Config != nil {
c.httpClient = c.oauth2Config.Client(context.Background(), token)
}
}
// GetToken returns the current OAuth2 token
func (c *OAuthClient) GetToken() *oauth2.Token {
return c.token
}
// GetHTTPClient returns an HTTP client with OAuth2 authentication
func (c *OAuthClient) GetHTTPClient() *http.Client {
if c.httpClient != nil {
return c.httpClient
}
// Return default client if not authenticated
return &http.Client{
Timeout: time.Second * 30,
}
}
// IsAuthenticated checks if the client has a valid token
func (c *OAuthClient) IsAuthenticated() bool {
if c.token == nil {
return false
}
// Check if token is expired (with 5 minute buffer)
return c.token.Valid() && c.token.Expiry.After(time.Now().Add(5*time.Minute))
}
// ClientCredentialsFlow performs client credentials OAuth2 flow
func (c *OAuthClient) ClientCredentialsFlow(ctx context.Context) error {
if c.config == nil || !c.config.Enabled {
return fmt.Errorf("oauth2 not configured")
}
ccConfig := &clientcredentials.Config{
ClientID: c.config.ClientID,
ClientSecret: c.config.ClientSecret,
TokenURL: c.config.TokenURL,
Scopes: c.config.Scopes,
}
token, err := ccConfig.Token(ctx)
if err != nil {
return fmt.Errorf("failed to get client credentials token: %w", err)
}
c.token = token
c.httpClient = ccConfig.Client(ctx)
return nil
}
// RefreshToken refreshes the OAuth2 token if possible
func (c *OAuthClient) RefreshToken(ctx context.Context) error {
if c.oauth2Config == nil {
return fmt.Errorf("oauth2 not configured")
}
if c.token == nil {
return fmt.Errorf("no token to refresh")
}
tokenSource := c.oauth2Config.TokenSource(ctx, c.token)
newToken, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to refresh token: %w", err)
}
c.token = newToken
c.httpClient = c.oauth2Config.Client(ctx, newToken)
return nil
}
// AddAuthHeader adds authentication header to an HTTP request
func (c *OAuthClient) AddAuthHeader(req *http.Request) error {
if c.token == nil {
return fmt.Errorf("no authentication token available")
}
c.token.SetAuthHeader(req)
return nil
}