blob: 6ebb61e958c270ee4d97ad6c9c81979ddd7ba844 [file] [log] [blame]
package service
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/korap/korap-mcp/auth"
"github.com/korap/korap-mcp/config"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)
// TestNewClient tests client creation
func TestNewClient(t *testing.T) {
tests := []struct {
name string
opts ClientOptions
wantErr bool
errMsg string
}{
{
name: "valid client without oauth",
opts: ClientOptions{
BaseURL: "https://example.com",
Timeout: 10 * time.Second,
},
wantErr: false,
},
{
name: "valid client with oauth",
opts: ClientOptions{
BaseURL: "https://example.com",
OAuthConfig: &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
},
},
wantErr: false,
},
{
name: "invalid client - no base URL",
opts: ClientOptions{
Timeout: 10 * time.Second,
},
wantErr: true,
errMsg: "base URL is required",
},
{
name: "invalid oauth config",
opts: ClientOptions{
BaseURL: "https://example.com",
OAuthConfig: &config.OAuthConfig{
Enabled: true, // Missing required fields
},
},
wantErr: true,
errMsg: "failed to create OAuth client",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewClient(tt.opts)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
return
}
assert.NoError(t, err)
assert.NotNil(t, client)
// Check base URL normalization
assert.True(t, strings.HasSuffix(client.baseURL, "/"), "Base URL should end with /")
// Check default timeout
if tt.opts.Timeout == 0 {
assert.Equal(t, 30*time.Second, client.httpClient.Timeout, "Default timeout should be 30 seconds")
}
})
}
}
// TestClientMethods tests basic client HTTP methods
func TestClientMethods(t *testing.T) {
// Create a mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/test":
if r.Method == http.MethodGet {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "success"})
} else if r.Method == http.MethodPost {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"created": "true"})
}
case "/error":
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(ErrorResponse{
Error: "bad_request",
ErrorDescription: "Invalid request parameter",
})
case "/":
// Root endpoint for ping test
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client, err := NewClient(ClientOptions{
BaseURL: server.URL,
Timeout: 5 * time.Second,
})
assert.NoError(t, err)
ctx := context.Background()
t.Run("GET request", func(t *testing.T) {
resp, err := client.Get(ctx, "/test")
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("POST request", func(t *testing.T) {
body := map[string]string{"key": "value"}
resp, err := client.Post(ctx, "/test", body)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusCreated, resp.StatusCode)
})
t.Run("GetJSON", func(t *testing.T) {
var result map[string]string
err := client.GetJSON(ctx, "/test", &result)
assert.NoError(t, err)
assert.Equal(t, "success", result["message"])
})
t.Run("PostJSON", func(t *testing.T) {
body := map[string]string{"key": "value"}
var result map[string]string
err := client.PostJSON(ctx, "/test", body, &result)
assert.NoError(t, err)
assert.Equal(t, "true", result["created"])
})
t.Run("Error handling", func(t *testing.T) {
var result map[string]any
err := client.GetJSON(ctx, "/error", &result)
assert.Error(t, err)
apiErr, ok := err.(*APIError)
assert.True(t, ok, "Expected APIError")
assert.Equal(t, http.StatusBadRequest, apiErr.StatusCode)
assert.Equal(t, "bad_request", apiErr.Message)
})
t.Run("Ping", func(t *testing.T) {
err := client.Ping(ctx)
assert.NoError(t, err)
})
}
// TestClientAuthentication tests OAuth2 integration
func TestClientAuthentication(t *testing.T) {
// Create a mock OAuth2 server
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" {
// Mock token endpoint
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "test-access-token",
"token_type": "Bearer",
"expires_in": 3600,
})
}
}))
defer authServer.Close()
// Create a mock API server that checks authentication
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != "Bearer test-access-token" {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(ErrorResponse{
Error: "unauthorized",
ErrorDescription: "Invalid or missing access token",
})
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"authenticated": "true"})
}))
defer apiServer.Close()
// Create client with OAuth2 configuration
oauthConfig := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenURL: authServer.URL + "/token",
Enabled: true,
}
client, err := NewClient(ClientOptions{
BaseURL: apiServer.URL,
OAuthConfig: oauthConfig,
})
assert.NoError(t, err)
ctx := context.Background()
t.Run("Unauthenticated request", func(t *testing.T) {
var result map[string]any
err := client.GetJSON(ctx, "/", &result)
assert.Error(t, err)
apiErr, ok := err.(*APIError)
assert.True(t, ok, "Expected APIError")
assert.Equal(t, http.StatusUnauthorized, apiErr.StatusCode)
})
t.Run("Client credentials authentication", func(t *testing.T) {
err := client.AuthenticateWithClientCredentials(ctx)
assert.NoError(t, err)
assert.True(t, client.IsAuthenticated(), "Client should be authenticated")
})
t.Run("Authenticated request", func(t *testing.T) {
var result map[string]any
err := client.GetJSON(ctx, "/", &result)
assert.NoError(t, err)
assert.Equal(t, "true", result["authenticated"])
})
}
// TestURLBuilding tests URL construction logic
func TestURLBuilding(t *testing.T) {
client, err := NewClient(ClientOptions{
BaseURL: "https://example.com/api/v1",
})
assert.NoError(t, err)
tests := []struct {
endpoint string
expected string
}{
{"/search", "https://example.com/api/v1/search"},
{"search", "https://example.com/api/v1/search"},
{"/corpus/info", "https://example.com/api/v1/corpus/info"},
{"", "https://example.com/api/v1/"},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("endpoint_%s", tt.endpoint), func(t *testing.T) {
url, err := client.buildURL(tt.endpoint)
assert.NoError(t, err)
assert.Equal(t, tt.expected, url)
})
}
}
// TestOAuthClientOperations tests OAuth client operations
func TestOAuthClientOperations(t *testing.T) {
client, err := NewClient(ClientOptions{
BaseURL: "https://example.com",
})
assert.NoError(t, err)
// Test setting OAuth client
oauthConfig := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenURL: "https://example.com/token",
Enabled: true,
}
oauthClient, err := auth.NewOAuthClient(oauthConfig)
assert.NoError(t, err)
// Set a mock token
token := &oauth2.Token{
AccessToken: "test-token",
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
}
oauthClient.SetToken(token)
client.SetOAuthClient(oauthClient)
assert.True(t, client.IsAuthenticated(), "Client should be authenticated after setting valid token")
assert.Equal(t, "https://example.com/", client.GetBaseURL())
}
// TestClientCaching tests client caching functionality
func TestClientCaching(t *testing.T) {
requestCount := 0
// Create a mock server that counts requests
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
// Use URL.Path to get clean path without query parameters
path := r.URL.Path
switch path {
case "/cached":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"data": "cached response",
"request_count": requestCount,
})
case "/search":
w.Header().Set("Content-Type", "application/json")
query := r.URL.Query().Get("q")
json.NewEncoder(w).Encode(map[string]interface{}{
"query": query,
"results": []string{"result1", "result2"},
"request_count": requestCount,
})
default:
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{
"error": fmt.Sprintf("Not found: %s", path),
})
}
}))
defer server.Close()
t.Run("cache enabled", func(t *testing.T) {
cacheConfig := DefaultCacheConfig()
cacheConfig.DefaultTTL = 1 * time.Minute
client, err := NewClient(ClientOptions{
BaseURL: server.URL,
CacheConfig: &cacheConfig,
})
assert.NoError(t, err)
defer client.Close()
ctx := context.Background()
requestCount = 0
// First request - should hit the server
var result1 map[string]interface{}
err = client.GetJSON(ctx, "/cached", &result1)
assert.NoError(t, err)
assert.Equal(t, float64(1), result1["request_count"])
// Second request - should hit the cache
var result2 map[string]interface{}
err = client.GetJSON(ctx, "/cached", &result2)
assert.NoError(t, err)
assert.Equal(t, float64(1), result2["request_count"]) // Same as first request
// Verify cache statistics
cache := client.GetCache()
assert.NotNil(t, cache)
stats := cache.Stats()
assert.True(t, stats["enabled"].(bool))
assert.Equal(t, 1, stats["size"].(int))
assert.True(t, stats["hits"].(int64) > 0)
})
t.Run("cache disabled", func(t *testing.T) {
cacheConfig := CacheConfig{Enabled: false}
client, err := NewClient(ClientOptions{
BaseURL: server.URL,
CacheConfig: &cacheConfig,
})
assert.NoError(t, err)
defer client.Close()
ctx := context.Background()
requestCount = 0
// Both requests should hit the server
var result1 map[string]interface{}
err = client.GetJSON(ctx, "/cached", &result1)
assert.NoError(t, err)
assert.Equal(t, float64(1), result1["request_count"])
var result2 map[string]interface{}
err = client.GetJSON(ctx, "/cached", &result2)
assert.NoError(t, err)
assert.Equal(t, float64(2), result2["request_count"]) // Different from first request
// Verify cache is disabled
cache := client.GetCache()
assert.NotNil(t, cache)
stats := cache.Stats()
assert.False(t, stats["enabled"].(bool))
})
t.Run("cache expiry", func(t *testing.T) {
cacheConfig := CacheConfig{
Enabled: true,
DefaultTTL: 50 * time.Millisecond,
SearchTTL: 50 * time.Millisecond,
MetadataTTL: 50 * time.Millisecond,
MaxSize: 100,
}
client, err := NewClient(ClientOptions{
BaseURL: server.URL,
CacheConfig: &cacheConfig,
})
assert.NoError(t, err)
defer client.Close()
ctx := context.Background()
requestCount = 0
// First request
var result1 map[string]interface{}
err = client.GetJSON(ctx, "/cached", &result1)
assert.NoError(t, err)
assert.Equal(t, float64(1), result1["request_count"])
// Wait for cache to expire
time.Sleep(100 * time.Millisecond)
// Second request should hit server again
var result2 map[string]interface{}
err = client.GetJSON(ctx, "/cached", &result2)
assert.NoError(t, err)
assert.Equal(t, float64(2), result2["request_count"])
})
}
// TestClientDefaultCache tests that clients get default cache when no config is provided
func TestClientDefaultCache(t *testing.T) {
client, err := NewClient(ClientOptions{
BaseURL: "https://example.com",
})
assert.NoError(t, err)
defer client.Close()
cache := client.GetCache()
assert.NotNil(t, cache)
stats := cache.Stats()
assert.True(t, stats["enabled"].(bool))
assert.Equal(t, 1000, stats["max_size"].(int))
}