blob: 9c6b0cc3d5548cc21d2fb084924bc0b721b99978 [file] [log] [blame]
package auth
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/korap/korap-mcp/config"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)
func TestNewOAuthClient(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.Equal(t, cfg, client.config)
}
func TestNewOAuthClient_InvalidConfig(t *testing.T) {
tests := []struct {
name string
config *config.OAuthConfig
}{
{
name: "nil config",
config: nil,
},
{
name: "empty client ID",
config: &config.OAuthConfig{
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
},
},
{
name: "empty client secret",
config: &config.OAuthConfig{
ClientID: "test-client",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewOAuthClient(tt.config)
assert.Error(t, err)
assert.Nil(t, client)
})
}
}
func TestOAuthClient_GetAuthURL(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
RedirectURL: "https://example.com/callback",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
authURL := client.GetAuthURL("test-state")
assert.NotEmpty(t, authURL)
// Check if the URL contains expected parameters
assert.Contains(t, authURL, "client_id=test-client")
assert.Contains(t, authURL, "state=test-state")
}
func TestOAuthClient_IsAuthenticated(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
// Should not be authenticated initially
assert.False(t, client.IsAuthenticated(), "client should not be authenticated initially")
// Set a valid token
validToken := &oauth2.Token{
AccessToken: "test-token",
Expiry: time.Now().Add(time.Hour),
}
client.SetToken(validToken)
assert.True(t, client.IsAuthenticated(), "client should be authenticated with valid token")
// Set an expired token
expiredToken := &oauth2.Token{
AccessToken: "test-token",
Expiry: time.Now().Add(-time.Hour),
}
client.SetToken(expiredToken)
assert.False(t, client.IsAuthenticated(), "client should not be authenticated with expired token")
}
func TestOAuthClient_ExchangeCode(t *testing.T) {
// Create mock OAuth2 server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" {
w.Header().Set("Content-Type", "application/json")
response := map[string]interface{}{
"access_token": "test-access-token",
"token_type": "Bearer",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(response)
return
}
http.NotFound(w, r)
}))
defer server.Close()
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
err = client.ExchangeCode(context.Background(), "test-code")
assert.NoError(t, err)
token := client.GetToken()
assert.NotNil(t, token)
assert.Equal(t, "test-access-token", token.AccessToken)
}
func TestOAuthClient_ClientCredentialsFlow(t *testing.T) {
// Create mock OAuth2 server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" {
w.Header().Set("Content-Type", "application/json")
response := map[string]interface{}{
"access_token": "client-credentials-token",
"token_type": "Bearer",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(response)
return
}
http.NotFound(w, r)
}))
defer server.Close()
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenURL: server.URL + "/token",
Scopes: []string{"read"},
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
err = client.ClientCredentialsFlow(context.Background())
assert.NoError(t, err)
token := client.GetToken()
assert.NotNil(t, token)
assert.Equal(t, "client-credentials-token", token.AccessToken)
}
func TestOAuthClient_GetHTTPClient(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
httpClient := client.GetHTTPClient()
assert.NotNil(t, httpClient)
// Should return default client when not authenticated
assert.Equal(t, 30*time.Second, httpClient.Timeout)
}
func TestOAuthClient_AddAuthHeader(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
req, _ := http.NewRequest("GET", "https://example.com/api", nil)
// Should fail when no token is set
err = client.AddAuthHeader(req)
assert.Error(t, err, "expected error when no token is available")
// Set a token and try again
token := &oauth2.Token{
AccessToken: "test-token",
TokenType: "Bearer",
}
client.SetToken(token)
err = client.AddAuthHeader(req)
assert.NoError(t, err)
authHeader := req.Header.Get("Authorization")
assert.Equal(t, "Bearer test-token", authHeader)
}
func TestOAuthClient_RefreshToken(t *testing.T) {
// Create mock OAuth2 server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" {
w.Header().Set("Content-Type", "application/json")
response := map[string]interface{}{
"access_token": "refreshed-access-token",
"refresh_token": "new-refresh-token",
"token_type": "Bearer",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(response)
return
}
http.NotFound(w, r)
}))
defer server.Close()
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
// Set an initial token with refresh token
initialToken := &oauth2.Token{
AccessToken: "initial-token",
RefreshToken: "refresh-token",
Expiry: time.Now().Add(-time.Hour), // Expired
}
client.SetToken(initialToken)
err = client.RefreshToken(context.Background())
assert.NoError(t, err)
token := client.GetToken()
assert.NotNil(t, token)
assert.Equal(t, "refreshed-access-token", token.AccessToken)
}
func TestOAuthClient_RefreshToken_Errors(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
// Test refresh without token
err = client.RefreshToken(context.Background())
assert.Error(t, err, "expected error when refreshing without token")
// Test refresh with unconfigured OAuth2
client.oauth2Config = nil
client.SetToken(&oauth2.Token{AccessToken: "test"})
err = client.RefreshToken(context.Background())
assert.Error(t, err, "expected error when OAuth2 not configured")
}
func TestOAuthClient_TokenExpiration(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenURL: "https://example.com/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
// Test with token expiring soon (within 5 minute buffer)
soonExpiredToken := &oauth2.Token{
AccessToken: "test-token",
Expiry: time.Now().Add(2 * time.Minute), // Expires in 2 minutes
}
client.SetToken(soonExpiredToken)
assert.False(t, client.IsAuthenticated(), "client should not be authenticated with soon-to-expire token")
// Test with token expiring later (outside 5 minute buffer)
validToken := &oauth2.Token{
AccessToken: "test-token",
Expiry: time.Now().Add(10 * time.Minute), // Expires in 10 minutes
}
client.SetToken(validToken)
assert.True(t, client.IsAuthenticated(), "client should be authenticated with valid token")
}
func TestOAuthClient_ErrorHandling(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://invalid-server.example.com/auth",
TokenURL: "https://invalid-server.example.com/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
// Test client credentials flow with invalid server
err = client.ClientCredentialsFlow(context.Background())
assert.Error(t, err, "expected error when connecting to invalid server")
// Test code exchange with invalid server
err = client.ExchangeCode(context.Background(), "test-code")
assert.Error(t, err, "expected error when connecting to invalid server")
}
func TestOAuthClient_DisabledOAuth(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
Enabled: false,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
httpClient := client.GetHTTPClient()
assert.NotNil(t, httpClient, "HTTP client should not be nil even when OAuth is disabled")
authURL := client.GetAuthURL("test-state")
assert.Empty(t, authURL, "auth URL should be empty when OAuth is disabled")
err = client.ClientCredentialsFlow(context.Background())
assert.Error(t, err, "expected error when trying client credentials flow with disabled OAuth")
}
func TestOAuthClient_ContextCancellation(t *testing.T) {
cfg := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenURL: "https://example.com/token",
Enabled: true,
}
client, err := NewOAuthClient(cfg)
assert.NoError(t, err)
// Create a cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = client.ClientCredentialsFlow(ctx)
assert.Error(t, err, "expected error with cancelled context")
}