Initial minimal mcp server for KorAP
diff --git a/auth/oauth.go b/auth/oauth.go
new file mode 100644
index 0000000..acbc978
--- /dev/null
+++ b/auth/oauth.go
@@ -0,0 +1,152 @@
+package auth
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "time"
+
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/clientcredentials"
+
+ "github.com/korap/korap-mcp/config"
+)
+
+// OAuthClient handles OAuth2 authentication for KorAP API
+type OAuthClient struct {
+ config *config.OAuthConfig
+ oauth2Config *oauth2.Config
+ token *oauth2.Token
+ httpClient *http.Client
+}
+
+// NewOAuthClient creates a new OAuth2 client
+func NewOAuthClient(cfg *config.OAuthConfig) (*OAuthClient, error) {
+ if cfg == nil {
+ return nil, fmt.Errorf("oauth config cannot be nil")
+ }
+
+ if err := cfg.Validate(); err != nil {
+ return nil, fmt.Errorf("invalid oauth config: %w", err)
+ }
+
+ client := &OAuthClient{
+ config: cfg,
+ oauth2Config: cfg.ToOAuth2Config(),
+ }
+
+ return client, nil
+}
+
+// GetAuthURL returns the authorization URL for the OAuth2 flow
+func (c *OAuthClient) GetAuthURL(state string) string {
+ if c.oauth2Config == nil {
+ return ""
+ }
+ return c.oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline)
+}
+
+// ExchangeCode exchanges an authorization code for an access token
+func (c *OAuthClient) ExchangeCode(ctx context.Context, code string) error {
+ if c.oauth2Config == nil {
+ return fmt.Errorf("oauth2 not configured")
+ }
+
+ token, err := c.oauth2Config.Exchange(ctx, code)
+ if err != nil {
+ return fmt.Errorf("failed to exchange code for token: %w", err)
+ }
+
+ c.token = token
+ c.httpClient = c.oauth2Config.Client(ctx, token)
+ return nil
+}
+
+// SetToken sets the OAuth2 token directly
+func (c *OAuthClient) SetToken(token *oauth2.Token) {
+ c.token = token
+ if c.oauth2Config != nil {
+ c.httpClient = c.oauth2Config.Client(context.Background(), token)
+ }
+}
+
+// GetToken returns the current OAuth2 token
+func (c *OAuthClient) GetToken() *oauth2.Token {
+ return c.token
+}
+
+// GetHTTPClient returns an HTTP client with OAuth2 authentication
+func (c *OAuthClient) GetHTTPClient() *http.Client {
+ if c.httpClient != nil {
+ return c.httpClient
+ }
+
+ // Return default client if not authenticated
+ return &http.Client{
+ Timeout: time.Second * 30,
+ }
+}
+
+// IsAuthenticated checks if the client has a valid token
+func (c *OAuthClient) IsAuthenticated() bool {
+ if c.token == nil {
+ return false
+ }
+
+ // Check if token is expired (with 5 minute buffer)
+ return c.token.Valid() && c.token.Expiry.After(time.Now().Add(5*time.Minute))
+}
+
+// ClientCredentialsFlow performs client credentials OAuth2 flow
+func (c *OAuthClient) ClientCredentialsFlow(ctx context.Context) error {
+ if c.config == nil || !c.config.Enabled {
+ return fmt.Errorf("oauth2 not configured")
+ }
+
+ ccConfig := &clientcredentials.Config{
+ ClientID: c.config.ClientID,
+ ClientSecret: c.config.ClientSecret,
+ TokenURL: c.config.TokenURL,
+ Scopes: c.config.Scopes,
+ }
+
+ token, err := ccConfig.Token(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get client credentials token: %w", err)
+ }
+
+ c.token = token
+ c.httpClient = ccConfig.Client(ctx)
+ return nil
+}
+
+// RefreshToken refreshes the OAuth2 token if possible
+func (c *OAuthClient) RefreshToken(ctx context.Context) error {
+ if c.oauth2Config == nil {
+ return fmt.Errorf("oauth2 not configured")
+ }
+
+ if c.token == nil {
+ return fmt.Errorf("no token to refresh")
+ }
+
+ tokenSource := c.oauth2Config.TokenSource(ctx, c.token)
+ newToken, err := tokenSource.Token()
+ if err != nil {
+ return fmt.Errorf("failed to refresh token: %w", err)
+ }
+
+ c.token = newToken
+ c.httpClient = c.oauth2Config.Client(ctx, newToken)
+ return nil
+}
+
+// AddAuthHeader adds authentication header to an HTTP request
+func (c *OAuthClient) AddAuthHeader(req *http.Request) error {
+ if c.token == nil {
+ return fmt.Errorf("no authentication token available")
+ }
+
+ c.token.SetAuthHeader(req)
+ return nil
+}
diff --git a/auth/oauth_test.go b/auth/oauth_test.go
new file mode 100644
index 0000000..9c6b0cc
--- /dev/null
+++ b/auth/oauth_test.go
@@ -0,0 +1,403 @@
+package auth
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/korap/korap-mcp/config"
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/oauth2"
+)
+
+func TestNewOAuthClient(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: "https://example.com/auth",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+ assert.NotNil(t, client)
+ assert.Equal(t, cfg, client.config)
+}
+
+func TestNewOAuthClient_InvalidConfig(t *testing.T) {
+ tests := []struct {
+ name string
+ config *config.OAuthConfig
+ }{
+ {
+ name: "nil config",
+ config: nil,
+ },
+ {
+ name: "empty client ID",
+ config: &config.OAuthConfig{
+ ClientSecret: "test-secret",
+ AuthURL: "https://example.com/auth",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ },
+ },
+ {
+ name: "empty client secret",
+ config: &config.OAuthConfig{
+ ClientID: "test-client",
+ AuthURL: "https://example.com/auth",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ client, err := NewOAuthClient(tt.config)
+ assert.Error(t, err)
+ assert.Nil(t, client)
+ })
+ }
+}
+
+func TestOAuthClient_GetAuthURL(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: "https://example.com/auth",
+ TokenURL: "https://example.com/token",
+ RedirectURL: "https://example.com/callback",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ authURL := client.GetAuthURL("test-state")
+ assert.NotEmpty(t, authURL)
+
+ // Check if the URL contains expected parameters
+ assert.Contains(t, authURL, "client_id=test-client")
+ assert.Contains(t, authURL, "state=test-state")
+}
+
+func TestOAuthClient_IsAuthenticated(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: "https://example.com/auth",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ // Should not be authenticated initially
+ assert.False(t, client.IsAuthenticated(), "client should not be authenticated initially")
+
+ // Set a valid token
+ validToken := &oauth2.Token{
+ AccessToken: "test-token",
+ Expiry: time.Now().Add(time.Hour),
+ }
+ client.SetToken(validToken)
+
+ assert.True(t, client.IsAuthenticated(), "client should be authenticated with valid token")
+
+ // Set an expired token
+ expiredToken := &oauth2.Token{
+ AccessToken: "test-token",
+ Expiry: time.Now().Add(-time.Hour),
+ }
+ client.SetToken(expiredToken)
+
+ assert.False(t, client.IsAuthenticated(), "client should not be authenticated with expired token")
+}
+
+func TestOAuthClient_ExchangeCode(t *testing.T) {
+ // Create mock OAuth2 server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/token" {
+ w.Header().Set("Content-Type", "application/json")
+ response := map[string]interface{}{
+ "access_token": "test-access-token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ }
+ json.NewEncoder(w).Encode(response)
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer server.Close()
+
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: server.URL + "/auth",
+ TokenURL: server.URL + "/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ err = client.ExchangeCode(context.Background(), "test-code")
+ assert.NoError(t, err)
+
+ token := client.GetToken()
+ assert.NotNil(t, token)
+ assert.Equal(t, "test-access-token", token.AccessToken)
+}
+
+func TestOAuthClient_ClientCredentialsFlow(t *testing.T) {
+ // Create mock OAuth2 server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/token" {
+ w.Header().Set("Content-Type", "application/json")
+ response := map[string]interface{}{
+ "access_token": "client-credentials-token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ }
+ json.NewEncoder(w).Encode(response)
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer server.Close()
+
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ TokenURL: server.URL + "/token",
+ Scopes: []string{"read"},
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ err = client.ClientCredentialsFlow(context.Background())
+ assert.NoError(t, err)
+
+ token := client.GetToken()
+ assert.NotNil(t, token)
+ assert.Equal(t, "client-credentials-token", token.AccessToken)
+}
+
+func TestOAuthClient_GetHTTPClient(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: "https://example.com/auth",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ httpClient := client.GetHTTPClient()
+ assert.NotNil(t, httpClient)
+
+ // Should return default client when not authenticated
+ assert.Equal(t, 30*time.Second, httpClient.Timeout)
+}
+
+func TestOAuthClient_AddAuthHeader(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: "https://example.com/auth",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ req, _ := http.NewRequest("GET", "https://example.com/api", nil)
+
+ // Should fail when no token is set
+ err = client.AddAuthHeader(req)
+ assert.Error(t, err, "expected error when no token is available")
+
+ // Set a token and try again
+ token := &oauth2.Token{
+ AccessToken: "test-token",
+ TokenType: "Bearer",
+ }
+ client.SetToken(token)
+
+ err = client.AddAuthHeader(req)
+ assert.NoError(t, err)
+
+ authHeader := req.Header.Get("Authorization")
+ assert.Equal(t, "Bearer test-token", authHeader)
+}
+
+func TestOAuthClient_RefreshToken(t *testing.T) {
+ // Create mock OAuth2 server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/token" {
+ w.Header().Set("Content-Type", "application/json")
+ response := map[string]interface{}{
+ "access_token": "refreshed-access-token",
+ "refresh_token": "new-refresh-token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ }
+ json.NewEncoder(w).Encode(response)
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer server.Close()
+
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: server.URL + "/auth",
+ TokenURL: server.URL + "/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ // Set an initial token with refresh token
+ initialToken := &oauth2.Token{
+ AccessToken: "initial-token",
+ RefreshToken: "refresh-token",
+ Expiry: time.Now().Add(-time.Hour), // Expired
+ }
+ client.SetToken(initialToken)
+
+ err = client.RefreshToken(context.Background())
+ assert.NoError(t, err)
+
+ token := client.GetToken()
+ assert.NotNil(t, token)
+ assert.Equal(t, "refreshed-access-token", token.AccessToken)
+}
+
+func TestOAuthClient_RefreshToken_Errors(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: "https://example.com/auth",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ // Test refresh without token
+ err = client.RefreshToken(context.Background())
+ assert.Error(t, err, "expected error when refreshing without token")
+
+ // Test refresh with unconfigured OAuth2
+ client.oauth2Config = nil
+ client.SetToken(&oauth2.Token{AccessToken: "test"})
+ err = client.RefreshToken(context.Background())
+ assert.Error(t, err, "expected error when OAuth2 not configured")
+}
+
+func TestOAuthClient_TokenExpiration(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ // Test with token expiring soon (within 5 minute buffer)
+ soonExpiredToken := &oauth2.Token{
+ AccessToken: "test-token",
+ Expiry: time.Now().Add(2 * time.Minute), // Expires in 2 minutes
+ }
+ client.SetToken(soonExpiredToken)
+
+ assert.False(t, client.IsAuthenticated(), "client should not be authenticated with soon-to-expire token")
+
+ // Test with token expiring later (outside 5 minute buffer)
+ validToken := &oauth2.Token{
+ AccessToken: "test-token",
+ Expiry: time.Now().Add(10 * time.Minute), // Expires in 10 minutes
+ }
+ client.SetToken(validToken)
+
+ assert.True(t, client.IsAuthenticated(), "client should be authenticated with valid token")
+}
+
+func TestOAuthClient_ErrorHandling(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthURL: "https://invalid-server.example.com/auth",
+ TokenURL: "https://invalid-server.example.com/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ // Test client credentials flow with invalid server
+ err = client.ClientCredentialsFlow(context.Background())
+ assert.Error(t, err, "expected error when connecting to invalid server")
+
+ // Test code exchange with invalid server
+ err = client.ExchangeCode(context.Background(), "test-code")
+ assert.Error(t, err, "expected error when connecting to invalid server")
+}
+
+func TestOAuthClient_DisabledOAuth(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ Enabled: false,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ httpClient := client.GetHTTPClient()
+ assert.NotNil(t, httpClient, "HTTP client should not be nil even when OAuth is disabled")
+
+ authURL := client.GetAuthURL("test-state")
+ assert.Empty(t, authURL, "auth URL should be empty when OAuth is disabled")
+
+ err = client.ClientCredentialsFlow(context.Background())
+ assert.Error(t, err, "expected error when trying client credentials flow with disabled OAuth")
+}
+
+func TestOAuthClient_ContextCancellation(t *testing.T) {
+ cfg := &config.OAuthConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ TokenURL: "https://example.com/token",
+ Enabled: true,
+ }
+
+ client, err := NewOAuthClient(cfg)
+ assert.NoError(t, err)
+
+ // Create a cancelled context
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ err = client.ClientCredentialsFlow(ctx)
+ assert.Error(t, err, "expected error with cancelled context")
+}