| package service |
| |
| import ( |
| "context" |
| "encoding/json" |
| "fmt" |
| "net/http" |
| "net/http/httptest" |
| "strings" |
| "testing" |
| "time" |
| |
| "github.com/korap/korap-mcp/auth" |
| "github.com/korap/korap-mcp/config" |
| "github.com/stretchr/testify/assert" |
| "golang.org/x/oauth2" |
| ) |
| |
| // TestNewClient tests client creation |
| func TestNewClient(t *testing.T) { |
| tests := []struct { |
| name string |
| opts ClientOptions |
| wantErr bool |
| errMsg string |
| }{ |
| { |
| name: "valid client without oauth", |
| opts: ClientOptions{ |
| BaseURL: "https://example.com", |
| Timeout: 10 * time.Second, |
| }, |
| wantErr: false, |
| }, |
| { |
| name: "valid client with oauth", |
| opts: ClientOptions{ |
| BaseURL: "https://example.com", |
| OAuthConfig: &config.OAuthConfig{ |
| ClientID: "test-client", |
| ClientSecret: "test-secret", |
| AuthURL: "https://example.com/auth", |
| TokenURL: "https://example.com/token", |
| Enabled: true, |
| }, |
| }, |
| wantErr: false, |
| }, |
| { |
| name: "invalid client - no base URL", |
| opts: ClientOptions{ |
| Timeout: 10 * time.Second, |
| }, |
| wantErr: true, |
| errMsg: "base URL is required", |
| }, |
| { |
| name: "invalid oauth config", |
| opts: ClientOptions{ |
| BaseURL: "https://example.com", |
| OAuthConfig: &config.OAuthConfig{ |
| Enabled: true, // Missing required fields |
| }, |
| }, |
| wantErr: true, |
| errMsg: "failed to create OAuth client", |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| client, err := NewClient(tt.opts) |
| |
| if tt.wantErr { |
| assert.Error(t, err) |
| assert.Contains(t, err.Error(), tt.errMsg) |
| return |
| } |
| |
| assert.NoError(t, err) |
| assert.NotNil(t, client) |
| |
| // Check base URL normalization |
| assert.True(t, strings.HasSuffix(client.baseURL, "/"), "Base URL should end with /") |
| |
| // Check default timeout |
| if tt.opts.Timeout == 0 { |
| assert.Equal(t, 30*time.Second, client.httpClient.Timeout, "Default timeout should be 30 seconds") |
| } |
| }) |
| } |
| } |
| |
| // TestClientMethods tests basic client HTTP methods |
| func TestClientMethods(t *testing.T) { |
| // Create a mock server |
| server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| switch r.URL.Path { |
| case "/test": |
| if r.Method == http.MethodGet { |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(map[string]string{"message": "success"}) |
| } else if r.Method == http.MethodPost { |
| w.Header().Set("Content-Type", "application/json") |
| w.WriteHeader(http.StatusCreated) |
| json.NewEncoder(w).Encode(map[string]string{"created": "true"}) |
| } |
| case "/error": |
| w.WriteHeader(http.StatusBadRequest) |
| json.NewEncoder(w).Encode(ErrorResponse{ |
| Error: "bad_request", |
| ErrorDescription: "Invalid request parameter", |
| }) |
| case "/": |
| // Root endpoint for ping test |
| w.WriteHeader(http.StatusOK) |
| w.Write([]byte("OK")) |
| default: |
| w.WriteHeader(http.StatusNotFound) |
| } |
| })) |
| defer server.Close() |
| |
| client, err := NewClient(ClientOptions{ |
| BaseURL: server.URL, |
| Timeout: 5 * time.Second, |
| }) |
| assert.NoError(t, err) |
| |
| ctx := context.Background() |
| |
| t.Run("GET request", func(t *testing.T) { |
| resp, err := client.Get(ctx, "/test") |
| assert.NoError(t, err) |
| defer resp.Body.Close() |
| |
| assert.Equal(t, http.StatusOK, resp.StatusCode) |
| }) |
| |
| t.Run("POST request", func(t *testing.T) { |
| body := map[string]string{"key": "value"} |
| resp, err := client.Post(ctx, "/test", body) |
| assert.NoError(t, err) |
| defer resp.Body.Close() |
| |
| assert.Equal(t, http.StatusCreated, resp.StatusCode) |
| }) |
| |
| t.Run("GetJSON", func(t *testing.T) { |
| var result map[string]string |
| err := client.GetJSON(ctx, "/test", &result) |
| assert.NoError(t, err) |
| |
| assert.Equal(t, "success", result["message"]) |
| }) |
| |
| t.Run("PostJSON", func(t *testing.T) { |
| body := map[string]string{"key": "value"} |
| var result map[string]string |
| err := client.PostJSON(ctx, "/test", body, &result) |
| assert.NoError(t, err) |
| |
| assert.Equal(t, "true", result["created"]) |
| }) |
| |
| t.Run("Error handling", func(t *testing.T) { |
| var result map[string]any |
| err := client.GetJSON(ctx, "/error", &result) |
| |
| assert.Error(t, err) |
| |
| apiErr, ok := err.(*APIError) |
| assert.True(t, ok, "Expected APIError") |
| |
| assert.Equal(t, http.StatusBadRequest, apiErr.StatusCode) |
| assert.Equal(t, "bad_request", apiErr.Message) |
| }) |
| |
| t.Run("Ping", func(t *testing.T) { |
| err := client.Ping(ctx) |
| assert.NoError(t, err) |
| }) |
| } |
| |
| // TestClientAuthentication tests OAuth2 integration |
| func TestClientAuthentication(t *testing.T) { |
| // Create a mock OAuth2 server |
| authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if r.URL.Path == "/token" { |
| // Mock token endpoint |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(map[string]any{ |
| "access_token": "test-access-token", |
| "token_type": "Bearer", |
| "expires_in": 3600, |
| }) |
| } |
| })) |
| defer authServer.Close() |
| |
| // Create a mock API server that checks authentication |
| apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| auth := r.Header.Get("Authorization") |
| if auth != "Bearer test-access-token" { |
| w.WriteHeader(http.StatusUnauthorized) |
| json.NewEncoder(w).Encode(ErrorResponse{ |
| Error: "unauthorized", |
| ErrorDescription: "Invalid or missing access token", |
| }) |
| return |
| } |
| |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(map[string]string{"authenticated": "true"}) |
| })) |
| defer apiServer.Close() |
| |
| // Create client with OAuth2 configuration |
| oauthConfig := &config.OAuthConfig{ |
| ClientID: "test-client", |
| ClientSecret: "test-secret", |
| TokenURL: authServer.URL + "/token", |
| Enabled: true, |
| } |
| |
| client, err := NewClient(ClientOptions{ |
| BaseURL: apiServer.URL, |
| OAuthConfig: oauthConfig, |
| }) |
| assert.NoError(t, err) |
| |
| ctx := context.Background() |
| |
| t.Run("Unauthenticated request", func(t *testing.T) { |
| var result map[string]any |
| err := client.GetJSON(ctx, "/", &result) |
| |
| assert.Error(t, err) |
| |
| apiErr, ok := err.(*APIError) |
| assert.True(t, ok, "Expected APIError") |
| |
| assert.Equal(t, http.StatusUnauthorized, apiErr.StatusCode) |
| }) |
| |
| t.Run("Client credentials authentication", func(t *testing.T) { |
| err := client.AuthenticateWithClientCredentials(ctx) |
| assert.NoError(t, err) |
| |
| assert.True(t, client.IsAuthenticated(), "Client should be authenticated") |
| }) |
| |
| t.Run("Authenticated request", func(t *testing.T) { |
| var result map[string]any |
| err := client.GetJSON(ctx, "/", &result) |
| assert.NoError(t, err) |
| |
| assert.Equal(t, "true", result["authenticated"]) |
| }) |
| } |
| |
| // TestURLBuilding tests URL construction logic |
| func TestURLBuilding(t *testing.T) { |
| client, err := NewClient(ClientOptions{ |
| BaseURL: "https://example.com/api/v1", |
| }) |
| assert.NoError(t, err) |
| |
| tests := []struct { |
| endpoint string |
| expected string |
| }{ |
| {"/search", "https://example.com/api/v1/search"}, |
| {"search", "https://example.com/api/v1/search"}, |
| {"/corpus/info", "https://example.com/api/v1/corpus/info"}, |
| {"", "https://example.com/api/v1/"}, |
| } |
| |
| for _, tt := range tests { |
| t.Run(fmt.Sprintf("endpoint_%s", tt.endpoint), func(t *testing.T) { |
| url, err := client.buildURL(tt.endpoint) |
| assert.NoError(t, err) |
| |
| assert.Equal(t, tt.expected, url) |
| }) |
| } |
| } |
| |
| // TestOAuthClientOperations tests OAuth client operations |
| func TestOAuthClientOperations(t *testing.T) { |
| client, err := NewClient(ClientOptions{ |
| BaseURL: "https://example.com", |
| }) |
| assert.NoError(t, err) |
| |
| // Test setting OAuth client |
| oauthConfig := &config.OAuthConfig{ |
| ClientID: "test-client", |
| ClientSecret: "test-secret", |
| TokenURL: "https://example.com/token", |
| Enabled: true, |
| } |
| |
| oauthClient, err := auth.NewOAuthClient(oauthConfig) |
| assert.NoError(t, err) |
| |
| // Set a mock token |
| token := &oauth2.Token{ |
| AccessToken: "test-token", |
| TokenType: "Bearer", |
| Expiry: time.Now().Add(time.Hour), |
| } |
| oauthClient.SetToken(token) |
| |
| client.SetOAuthClient(oauthClient) |
| |
| assert.True(t, client.IsAuthenticated(), "Client should be authenticated after setting valid token") |
| |
| assert.Equal(t, "https://example.com/", client.GetBaseURL()) |
| } |
| |
| // TestClientCaching tests client caching functionality |
| func TestClientCaching(t *testing.T) { |
| requestCount := 0 |
| |
| // Create a mock server that counts requests |
| server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| requestCount++ |
| // Use URL.Path to get clean path without query parameters |
| path := r.URL.Path |
| switch path { |
| case "/cached": |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(map[string]interface{}{ |
| "data": "cached response", |
| "request_count": requestCount, |
| }) |
| case "/search": |
| w.Header().Set("Content-Type", "application/json") |
| query := r.URL.Query().Get("q") |
| json.NewEncoder(w).Encode(map[string]interface{}{ |
| "query": query, |
| "results": []string{"result1", "result2"}, |
| "request_count": requestCount, |
| }) |
| default: |
| w.WriteHeader(http.StatusNotFound) |
| json.NewEncoder(w).Encode(map[string]string{ |
| "error": fmt.Sprintf("Not found: %s", path), |
| }) |
| } |
| })) |
| defer server.Close() |
| |
| t.Run("cache enabled", func(t *testing.T) { |
| cacheConfig := DefaultCacheConfig() |
| cacheConfig.DefaultTTL = 1 * time.Minute |
| |
| client, err := NewClient(ClientOptions{ |
| BaseURL: server.URL, |
| CacheConfig: &cacheConfig, |
| }) |
| assert.NoError(t, err) |
| defer client.Close() |
| |
| ctx := context.Background() |
| requestCount = 0 |
| |
| // First request - should hit the server |
| var result1 map[string]interface{} |
| err = client.GetJSON(ctx, "/cached", &result1) |
| assert.NoError(t, err) |
| assert.Equal(t, float64(1), result1["request_count"]) |
| |
| // Second request - should hit the cache |
| var result2 map[string]interface{} |
| err = client.GetJSON(ctx, "/cached", &result2) |
| assert.NoError(t, err) |
| assert.Equal(t, float64(1), result2["request_count"]) // Same as first request |
| |
| // Verify cache statistics |
| cache := client.GetCache() |
| assert.NotNil(t, cache) |
| stats := cache.Stats() |
| assert.True(t, stats["enabled"].(bool)) |
| assert.Equal(t, 1, stats["size"].(int)) |
| assert.True(t, stats["hits"].(int64) > 0) |
| }) |
| |
| t.Run("cache disabled", func(t *testing.T) { |
| cacheConfig := CacheConfig{Enabled: false} |
| |
| client, err := NewClient(ClientOptions{ |
| BaseURL: server.URL, |
| CacheConfig: &cacheConfig, |
| }) |
| assert.NoError(t, err) |
| defer client.Close() |
| |
| ctx := context.Background() |
| requestCount = 0 |
| |
| // Both requests should hit the server |
| var result1 map[string]interface{} |
| err = client.GetJSON(ctx, "/cached", &result1) |
| assert.NoError(t, err) |
| assert.Equal(t, float64(1), result1["request_count"]) |
| |
| var result2 map[string]interface{} |
| err = client.GetJSON(ctx, "/cached", &result2) |
| assert.NoError(t, err) |
| assert.Equal(t, float64(2), result2["request_count"]) // Different from first request |
| |
| // Verify cache is disabled |
| cache := client.GetCache() |
| assert.NotNil(t, cache) |
| stats := cache.Stats() |
| assert.False(t, stats["enabled"].(bool)) |
| }) |
| |
| t.Run("cache expiry", func(t *testing.T) { |
| cacheConfig := CacheConfig{ |
| Enabled: true, |
| DefaultTTL: 50 * time.Millisecond, |
| SearchTTL: 50 * time.Millisecond, |
| MetadataTTL: 50 * time.Millisecond, |
| MaxSize: 100, |
| } |
| |
| client, err := NewClient(ClientOptions{ |
| BaseURL: server.URL, |
| CacheConfig: &cacheConfig, |
| }) |
| assert.NoError(t, err) |
| defer client.Close() |
| |
| ctx := context.Background() |
| requestCount = 0 |
| |
| // First request |
| var result1 map[string]interface{} |
| err = client.GetJSON(ctx, "/cached", &result1) |
| assert.NoError(t, err) |
| assert.Equal(t, float64(1), result1["request_count"]) |
| |
| // Wait for cache to expire |
| time.Sleep(100 * time.Millisecond) |
| |
| // Second request should hit server again |
| var result2 map[string]interface{} |
| err = client.GetJSON(ctx, "/cached", &result2) |
| assert.NoError(t, err) |
| assert.Equal(t, float64(2), result2["request_count"]) |
| }) |
| } |
| |
| // TestClientDefaultCache tests that clients get default cache when no config is provided |
| func TestClientDefaultCache(t *testing.T) { |
| client, err := NewClient(ClientOptions{ |
| BaseURL: "https://example.com", |
| }) |
| assert.NoError(t, err) |
| defer client.Close() |
| |
| cache := client.GetCache() |
| assert.NotNil(t, cache) |
| |
| stats := cache.Stats() |
| assert.True(t, stats["enabled"].(bool)) |
| assert.Equal(t, 1000, stats["max_size"].(int)) |
| } |