| package auth |
| |
| import ( |
| "context" |
| "fmt" |
| "net/http" |
| "time" |
| |
| "golang.org/x/oauth2" |
| "golang.org/x/oauth2/clientcredentials" |
| |
| "github.com/korap/korap-mcp/config" |
| ) |
| |
| // OAuthClient handles OAuth2 authentication for KorAP API |
| type OAuthClient struct { |
| config *config.OAuthConfig |
| oauth2Config *oauth2.Config |
| token *oauth2.Token |
| httpClient *http.Client |
| } |
| |
| // NewOAuthClient creates a new OAuth2 client |
| func NewOAuthClient(cfg *config.OAuthConfig) (*OAuthClient, error) { |
| if cfg == nil { |
| return nil, fmt.Errorf("oauth config cannot be nil") |
| } |
| |
| if err := cfg.Validate(); err != nil { |
| return nil, fmt.Errorf("invalid oauth config: %w", err) |
| } |
| |
| client := &OAuthClient{ |
| config: cfg, |
| oauth2Config: cfg.ToOAuth2Config(), |
| } |
| |
| return client, nil |
| } |
| |
| // GetAuthURL returns the authorization URL for the OAuth2 flow |
| func (c *OAuthClient) GetAuthURL(state string) string { |
| if c.oauth2Config == nil { |
| return "" |
| } |
| return c.oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline) |
| } |
| |
| // ExchangeCode exchanges an authorization code for an access token |
| func (c *OAuthClient) ExchangeCode(ctx context.Context, code string) error { |
| if c.oauth2Config == nil { |
| return fmt.Errorf("oauth2 not configured") |
| } |
| |
| token, err := c.oauth2Config.Exchange(ctx, code) |
| if err != nil { |
| return fmt.Errorf("failed to exchange code for token: %w", err) |
| } |
| |
| c.token = token |
| c.httpClient = c.oauth2Config.Client(ctx, token) |
| return nil |
| } |
| |
| // SetToken sets the OAuth2 token directly |
| func (c *OAuthClient) SetToken(token *oauth2.Token) { |
| c.token = token |
| if c.oauth2Config != nil { |
| c.httpClient = c.oauth2Config.Client(context.Background(), token) |
| } |
| } |
| |
| // GetToken returns the current OAuth2 token |
| func (c *OAuthClient) GetToken() *oauth2.Token { |
| return c.token |
| } |
| |
| // GetHTTPClient returns an HTTP client with OAuth2 authentication |
| func (c *OAuthClient) GetHTTPClient() *http.Client { |
| if c.httpClient != nil { |
| return c.httpClient |
| } |
| |
| // Return default client if not authenticated |
| return &http.Client{ |
| Timeout: time.Second * 30, |
| } |
| } |
| |
| // IsAuthenticated checks if the client has a valid token |
| func (c *OAuthClient) IsAuthenticated() bool { |
| if c.token == nil { |
| return false |
| } |
| |
| // Check if token is expired (with 5 minute buffer) |
| return c.token.Valid() && c.token.Expiry.After(time.Now().Add(5*time.Minute)) |
| } |
| |
| // ClientCredentialsFlow performs client credentials OAuth2 flow |
| func (c *OAuthClient) ClientCredentialsFlow(ctx context.Context) error { |
| if c.config == nil || !c.config.Enabled { |
| return fmt.Errorf("oauth2 not configured") |
| } |
| |
| ccConfig := &clientcredentials.Config{ |
| ClientID: c.config.ClientID, |
| ClientSecret: c.config.ClientSecret, |
| TokenURL: c.config.TokenURL, |
| Scopes: c.config.Scopes, |
| } |
| |
| token, err := ccConfig.Token(ctx) |
| if err != nil { |
| return fmt.Errorf("failed to get client credentials token: %w", err) |
| } |
| |
| c.token = token |
| c.httpClient = ccConfig.Client(ctx) |
| return nil |
| } |
| |
| // RefreshToken refreshes the OAuth2 token if possible |
| func (c *OAuthClient) RefreshToken(ctx context.Context) error { |
| if c.oauth2Config == nil { |
| return fmt.Errorf("oauth2 not configured") |
| } |
| |
| if c.token == nil { |
| return fmt.Errorf("no token to refresh") |
| } |
| |
| tokenSource := c.oauth2Config.TokenSource(ctx, c.token) |
| newToken, err := tokenSource.Token() |
| if err != nil { |
| return fmt.Errorf("failed to refresh token: %w", err) |
| } |
| |
| c.token = newToken |
| c.httpClient = c.oauth2Config.Client(ctx, newToken) |
| return nil |
| } |
| |
| // AddAuthHeader adds authentication header to an HTTP request |
| func (c *OAuthClient) AddAuthHeader(req *http.Request) error { |
| if c.token == nil { |
| return fmt.Errorf("no authentication token available") |
| } |
| |
| c.token.SetAuthHeader(req) |
| return nil |
| } |