Add cache mechanism to KorAP client
Change-Id: Ie3f3d48611f039904f22a19cf6299a3c43fe8bbe
diff --git a/service/cache.go b/service/cache.go
new file mode 100644
index 0000000..43ec439
--- /dev/null
+++ b/service/cache.go
@@ -0,0 +1,223 @@
+package service
+
+import (
+ "context"
+ "crypto/md5"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ cfg "github.com/korap/korap-mcp/config"
+ "github.com/korap/korap-mcp/logger"
+ "github.com/maypok86/otter"
+ "github.com/rs/zerolog"
+)
+
+// CacheEntry represents a cached API response
+type CacheEntry struct {
+ Data []byte `json:"data"`
+ Timestamp time.Time `json:"timestamp"`
+ TTL time.Duration `json:"ttl"`
+}
+
+// IsExpired checks if the cache entry has expired
+func (ce *CacheEntry) IsExpired() bool {
+ return time.Since(ce.Timestamp) > ce.TTL
+}
+
+// Cache represents the response cache system
+type Cache struct {
+ cache *otter.Cache[string, *CacheEntry]
+ logger zerolog.Logger
+ config CacheConfig
+}
+
+// CacheConfig configures the cache behavior
+type CacheConfig struct {
+ // Enabled controls whether caching is active
+ Enabled bool
+ // DefaultTTL is the default time-to-live for cache entries
+ DefaultTTL time.Duration
+ // SearchTTL is the TTL for search results
+ SearchTTL time.Duration
+ // MetadataTTL is the TTL for metadata and corpus information
+ MetadataTTL time.Duration
+ // MaxSize is the maximum number of cache entries
+ MaxSize int
+}
+
+// DefaultCacheConfig returns a default cache configuration
+func DefaultCacheConfig() CacheConfig {
+ return CacheConfig{
+ Enabled: true,
+ DefaultTTL: 5 * time.Minute,
+ SearchTTL: 2 * time.Minute, // Search results change less frequently
+ MetadataTTL: 15 * time.Minute, // Metadata is more stable
+ MaxSize: 1000,
+ }
+}
+
+// NewCache creates a new cache instance
+func NewCache(config CacheConfig) (*Cache, error) {
+ // Create default logging config for cache
+ logConfig := &cfg.LoggingConfig{
+ Level: "info",
+ Format: "text",
+ }
+
+ if !config.Enabled {
+ return &Cache{
+ cache: nil,
+ logger: logger.GetLogger(logConfig),
+ config: config,
+ }, nil
+ }
+
+ // Create otter cache with specified capacity
+ cache, err := otter.MustBuilder[string, *CacheEntry](config.MaxSize).
+ CollectStats().
+ WithTTL(config.DefaultTTL).
+ Build()
+ if err != nil {
+ return nil, fmt.Errorf("failed to create cache: %w", err)
+ }
+
+ return &Cache{
+ cache: &cache,
+ logger: logger.GetLogger(logConfig),
+ config: config,
+ }, nil
+}
+
+// generateCacheKey creates a unique cache key for a request
+func (c *Cache) generateCacheKey(method, endpoint string, params map[string]any) string {
+ // Create a deterministic key by combining method, endpoint, and parameters
+ var keyParts []string
+ keyParts = append(keyParts, method, endpoint)
+
+ // Add sorted parameters to ensure deterministic cache keys
+ // Note: json.Marshal automatically sorts map keys lexicographically,
+ // providing deterministic JSON output regardless of map iteration order
+ if params != nil {
+ paramsJSON, _ := json.Marshal(params)
+ keyParts = append(keyParts, string(paramsJSON))
+ }
+
+ key := strings.Join(keyParts, "|")
+
+ // Hash the key to keep it reasonable length and provide privacy
+ hash := md5.Sum([]byte(key))
+ return hex.EncodeToString(hash[:])
+}
+
+// Get retrieves a cached response
+func (c *Cache) Get(ctx context.Context, key string) ([]byte, bool) {
+ if !c.config.Enabled || c.cache == nil {
+ return nil, false
+ }
+
+ entry, found := (*c.cache).Get(key)
+ if !found {
+ c.logger.Debug().Str("key", key).Msg("Cache miss")
+ return nil, false
+ }
+
+ // Check if entry has expired
+ if entry.IsExpired() {
+ c.logger.Debug().Str("key", key).Msg("Cache entry expired")
+ (*c.cache).Delete(key)
+ return nil, false
+ }
+
+ c.logger.Debug().Str("key", key).Msg("Cache hit")
+ return entry.Data, true
+}
+
+// Set stores a response in the cache
+func (c *Cache) Set(ctx context.Context, key string, data []byte, ttl time.Duration) {
+ if !c.config.Enabled || c.cache == nil {
+ return
+ }
+
+ entry := &CacheEntry{
+ Data: data,
+ Timestamp: time.Now(),
+ TTL: ttl,
+ }
+
+ (*c.cache).Set(key, entry)
+ c.logger.Debug().Str("key", key).Dur("ttl", ttl).Msg("Cache entry stored")
+}
+
+// Delete removes an entry from the cache
+func (c *Cache) Delete(ctx context.Context, key string) {
+ if !c.config.Enabled || c.cache == nil {
+ return
+ }
+
+ (*c.cache).Delete(key)
+ c.logger.Debug().Str("key", key).Msg("Cache entry deleted")
+}
+
+// Clear removes all entries from the cache
+func (c *Cache) Clear() {
+ if !c.config.Enabled || c.cache == nil {
+ return
+ }
+
+ (*c.cache).Clear()
+ c.logger.Debug().Msg("Cache cleared")
+}
+
+// Stats returns cache statistics
+func (c *Cache) Stats() map[string]interface{} {
+ if !c.config.Enabled || c.cache == nil {
+ return map[string]interface{}{
+ "enabled": false,
+ }
+ }
+
+ stats := (*c.cache).Stats()
+ return map[string]interface{}{
+ "enabled": true,
+ "size": (*c.cache).Size(),
+ "hits": stats.Hits(),
+ "misses": stats.Misses(),
+ "hit_ratio": stats.Ratio(),
+ "evictions": stats.EvictedCount(),
+ "max_size": c.config.MaxSize,
+ "default_ttl": c.config.DefaultTTL.String(),
+ "search_ttl": c.config.SearchTTL.String(),
+ "metadata_ttl": c.config.MetadataTTL.String(),
+ }
+}
+
+// GetTTLForEndpoint returns the appropriate TTL for a given endpoint
+func (c *Cache) GetTTLForEndpoint(endpoint string) time.Duration {
+ endpoint = strings.ToLower(endpoint)
+
+ // Search endpoints get shorter TTL
+ if strings.Contains(endpoint, "search") || strings.Contains(endpoint, "query") {
+ return c.config.SearchTTL
+ }
+
+ // Metadata and corpus endpoints get longer TTL
+ if strings.Contains(endpoint, "corpus") || strings.Contains(endpoint, "metadata") ||
+ strings.Contains(endpoint, "statistics") || strings.Contains(endpoint, "info") {
+ return c.config.MetadataTTL
+ }
+
+ // Default TTL for other endpoints
+ return c.config.DefaultTTL
+}
+
+// Close closes the cache and cleans up resources
+func (c *Cache) Close() error {
+ if c.cache != nil {
+ (*c.cache).Clear()
+ (*c.cache).Close()
+ }
+ return nil
+}
diff --git a/service/cache_test.go b/service/cache_test.go
new file mode 100644
index 0000000..037f0b8
--- /dev/null
+++ b/service/cache_test.go
@@ -0,0 +1,443 @@
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDefaultCacheConfig(t *testing.T) {
+ config := DefaultCacheConfig()
+
+ assert.True(t, config.Enabled)
+ assert.Equal(t, 5*time.Minute, config.DefaultTTL)
+ assert.Equal(t, 2*time.Minute, config.SearchTTL)
+ assert.Equal(t, 15*time.Minute, config.MetadataTTL)
+ assert.Equal(t, 1000, config.MaxSize)
+}
+
+func TestNewCache(t *testing.T) {
+ tests := []struct {
+ name string
+ config CacheConfig
+ expectError bool
+ expectNilCache bool
+ }{
+ {
+ name: "enabled cache",
+ config: DefaultCacheConfig(),
+ expectError: false,
+ expectNilCache: false,
+ },
+ {
+ name: "disabled cache",
+ config: CacheConfig{
+ Enabled: false,
+ },
+ expectError: false,
+ expectNilCache: true,
+ },
+ {
+ name: "custom configuration",
+ config: CacheConfig{
+ Enabled: true,
+ DefaultTTL: 1 * time.Minute,
+ SearchTTL: 30 * time.Second,
+ MetadataTTL: 5 * time.Minute,
+ MaxSize: 500,
+ },
+ expectError: false,
+ expectNilCache: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cache, err := NewCache(tt.config)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ return
+ }
+
+ require.NoError(t, err)
+ assert.NotNil(t, cache)
+
+ if tt.expectNilCache {
+ assert.Nil(t, cache.cache)
+ } else {
+ assert.NotNil(t, cache.cache)
+ }
+
+ assert.Equal(t, tt.config, cache.config)
+ })
+ }
+}
+
+func TestCacheGetSet(t *testing.T) {
+ cache, err := NewCache(DefaultCacheConfig())
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ key := "test-key"
+ data := []byte("test data")
+ ttl := 1 * time.Minute
+
+ // Test cache miss
+ result, found := cache.Get(ctx, key)
+ assert.False(t, found)
+ assert.Nil(t, result)
+
+ // Test cache set and hit
+ cache.Set(ctx, key, data, ttl)
+ result, found = cache.Get(ctx, key)
+ assert.True(t, found)
+ assert.Equal(t, data, result)
+}
+
+func TestCacheExpiry(t *testing.T) {
+ cache, err := NewCache(DefaultCacheConfig())
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ key := "test-key"
+ data := []byte("test data")
+ ttl := 50 * time.Millisecond
+
+ // Set cache entry with short TTL
+ cache.Set(ctx, key, data, ttl)
+
+ // Should hit immediately
+ result, found := cache.Get(ctx, key)
+ assert.True(t, found)
+ assert.Equal(t, data, result)
+
+ // Wait for expiry
+ time.Sleep(100 * time.Millisecond)
+
+ // Should miss after expiry
+ result, found = cache.Get(ctx, key)
+ assert.False(t, found)
+ assert.Nil(t, result)
+}
+
+func TestCacheDelete(t *testing.T) {
+ cache, err := NewCache(DefaultCacheConfig())
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ key := "test-key"
+ data := []byte("test data")
+ ttl := 1 * time.Minute
+
+ // Set and verify
+ cache.Set(ctx, key, data, ttl)
+ result, found := cache.Get(ctx, key)
+ assert.True(t, found)
+ assert.Equal(t, data, result)
+
+ // Delete and verify
+ cache.Delete(ctx, key)
+ result, found = cache.Get(ctx, key)
+ assert.False(t, found)
+ assert.Nil(t, result)
+}
+
+func TestCacheClear(t *testing.T) {
+ cache, err := NewCache(DefaultCacheConfig())
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ ttl := 1 * time.Minute
+
+ // Set multiple entries
+ cache.Set(ctx, "key1", []byte("data1"), ttl)
+ cache.Set(ctx, "key2", []byte("data2"), ttl)
+ cache.Set(ctx, "key3", []byte("data3"), ttl)
+
+ // Verify all entries exist
+ _, found1 := cache.Get(ctx, "key1")
+ _, found2 := cache.Get(ctx, "key2")
+ _, found3 := cache.Get(ctx, "key3")
+ assert.True(t, found1)
+ assert.True(t, found2)
+ assert.True(t, found3)
+
+ // Clear cache
+ cache.Clear()
+
+ // Verify all entries are gone
+ _, found1 = cache.Get(ctx, "key1")
+ _, found2 = cache.Get(ctx, "key2")
+ _, found3 = cache.Get(ctx, "key3")
+ assert.False(t, found1)
+ assert.False(t, found2)
+ assert.False(t, found3)
+}
+
+func TestCacheStats(t *testing.T) {
+ cache, err := NewCache(DefaultCacheConfig())
+ require.NoError(t, err)
+
+ ctx := context.Background()
+
+ // Test stats with empty cache
+ stats := cache.Stats()
+ assert.True(t, stats["enabled"].(bool))
+ assert.Equal(t, 0, stats["size"].(int))
+
+ // Add some data and check stats
+ cache.Set(ctx, "key1", []byte("data1"), 1*time.Minute)
+ cache.Set(ctx, "key2", []byte("data2"), 1*time.Minute)
+
+ stats = cache.Stats()
+ assert.True(t, stats["enabled"].(bool))
+ assert.Equal(t, 2, stats["size"].(int))
+ assert.Equal(t, 1000, stats["max_size"].(int))
+ assert.Contains(t, stats, "hits")
+ assert.Contains(t, stats, "misses")
+ assert.Contains(t, stats, "hit_ratio")
+}
+
+func TestCacheStatsDisabled(t *testing.T) {
+ config := CacheConfig{Enabled: false}
+ cache, err := NewCache(config)
+ require.NoError(t, err)
+
+ stats := cache.Stats()
+ assert.False(t, stats["enabled"].(bool))
+ assert.Len(t, stats, 1) // Only "enabled" key should be present
+}
+
+func TestCacheDisabled(t *testing.T) {
+ config := CacheConfig{Enabled: false}
+ cache, err := NewCache(config)
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ key := "test-key"
+ data := []byte("test data")
+ ttl := 1 * time.Minute
+
+ // All operations should be no-ops
+ cache.Set(ctx, key, data, ttl)
+ result, found := cache.Get(ctx, key)
+ assert.False(t, found)
+ assert.Nil(t, result)
+
+ cache.Delete(ctx, key)
+ cache.Clear()
+
+ // Should not panic
+ assert.NotPanics(t, func() {
+ cache.Close()
+ })
+}
+
+func TestGenerateCacheKey(t *testing.T) {
+ cache, err := NewCache(DefaultCacheConfig())
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ method string
+ endpoint string
+ params map[string]any
+ wantSame bool
+ }{
+ {
+ name: "same parameters",
+ method: "GET",
+ endpoint: "/search",
+ params: map[string]any{"q": "test", "count": 10},
+ wantSame: true,
+ },
+ {
+ name: "different method",
+ method: "POST",
+ endpoint: "/search",
+ params: map[string]any{"q": "test", "count": 10},
+ wantSame: false,
+ },
+ {
+ name: "different endpoint",
+ method: "GET",
+ endpoint: "/corpus",
+ params: map[string]any{"q": "test", "count": 10},
+ wantSame: false,
+ },
+ {
+ name: "different parameters",
+ method: "GET",
+ endpoint: "/search",
+ params: map[string]any{"q": "different", "count": 10},
+ wantSame: false,
+ },
+ {
+ name: "nil parameters",
+ method: "GET",
+ endpoint: "/search",
+ params: nil,
+ wantSame: false,
+ },
+ }
+
+ baseKey := cache.generateCacheKey("GET", "/search", map[string]any{"q": "test", "count": 10})
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ key := cache.generateCacheKey(tt.method, tt.endpoint, tt.params)
+
+ assert.NotEmpty(t, key)
+ assert.Len(t, key, 32) // MD5 hash length
+
+ if tt.wantSame {
+ assert.Equal(t, baseKey, key)
+ } else {
+ assert.NotEqual(t, baseKey, key)
+ }
+ })
+ }
+}
+
+func TestCacheKeyDeterministic(t *testing.T) {
+ cache, err := NewCache(DefaultCacheConfig())
+ require.NoError(t, err)
+
+ // Test that map parameter order doesn't affect cache key generation
+ // This verifies that json.Marshal provides deterministic ordering
+ params1 := map[string]any{
+ "query": "test search",
+ "count": 50,
+ "offset": 0,
+ "corpus": "news-corpus",
+ "language": "poliqarp",
+ }
+
+ params2 := map[string]any{
+ "language": "poliqarp",
+ "corpus": "news-corpus",
+ "offset": 0,
+ "count": 50,
+ "query": "test search",
+ }
+
+ // Same parameters in different map creation order should produce same cache key
+ key1 := cache.generateCacheKey("GET", "/search", params1)
+ key2 := cache.generateCacheKey("GET", "/search", params2)
+
+ assert.Equal(t, key1, key2, "Cache keys should be identical regardless of map parameter order")
+ assert.Len(t, key1, 32, "Cache key should be MD5 hash length")
+ assert.Len(t, key2, 32, "Cache key should be MD5 hash length")
+
+ // Generate keys multiple times to ensure consistency
+ for i := 0; i < 10; i++ {
+ keyN := cache.generateCacheKey("GET", "/search", params1)
+ assert.Equal(t, key1, keyN, "Cache key should be consistent across multiple generations")
+ }
+}
+
+func TestGetTTLForEndpoint(t *testing.T) {
+ config := DefaultCacheConfig()
+ cache, err := NewCache(config)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ endpoint string
+ expected time.Duration
+ }{
+ {
+ name: "search endpoint",
+ endpoint: "/api/v1/search",
+ expected: config.SearchTTL,
+ },
+ {
+ name: "query endpoint",
+ endpoint: "/query",
+ expected: config.SearchTTL,
+ },
+ {
+ name: "corpus endpoint",
+ endpoint: "/corpus",
+ expected: config.MetadataTTL,
+ },
+ {
+ name: "metadata endpoint",
+ endpoint: "/metadata",
+ expected: config.MetadataTTL,
+ },
+ {
+ name: "statistics endpoint",
+ endpoint: "/statistics",
+ expected: config.MetadataTTL,
+ },
+ {
+ name: "info endpoint",
+ endpoint: "/info",
+ expected: config.MetadataTTL,
+ },
+ {
+ name: "other endpoint",
+ endpoint: "/other",
+ expected: config.DefaultTTL,
+ },
+ {
+ name: "case insensitive",
+ endpoint: "/API/V1/SEARCH",
+ expected: config.SearchTTL,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ttl := cache.GetTTLForEndpoint(tt.endpoint)
+ assert.Equal(t, tt.expected, ttl)
+ })
+ }
+}
+
+func TestCacheEntry(t *testing.T) {
+ t.Run("not expired", func(t *testing.T) {
+ entry := &CacheEntry{
+ Data: []byte("test"),
+ Timestamp: time.Now(),
+ TTL: 1 * time.Minute,
+ }
+
+ assert.False(t, entry.IsExpired())
+ })
+
+ t.Run("expired", func(t *testing.T) {
+ entry := &CacheEntry{
+ Data: []byte("test"),
+ Timestamp: time.Now().Add(-2 * time.Minute),
+ TTL: 1 * time.Minute,
+ }
+
+ assert.True(t, entry.IsExpired())
+ })
+}
+
+func TestCacheClose(t *testing.T) {
+ cache, err := NewCache(DefaultCacheConfig())
+ require.NoError(t, err)
+
+ // Add some data
+ ctx := context.Background()
+ cache.Set(ctx, "key1", []byte("data1"), 1*time.Minute)
+
+ // Close should not error
+ err = cache.Close()
+ assert.NoError(t, err)
+
+ // Disabled cache close should also not error
+ disabledCache, err := NewCache(CacheConfig{Enabled: false})
+ require.NoError(t, err)
+
+ err = disabledCache.Close()
+ assert.NoError(t, err)
+}
diff --git a/service/client.go b/service/client.go
index 0ea0b94..a7c532c 100644
--- a/service/client.go
+++ b/service/client.go
@@ -19,6 +19,7 @@
baseURL string
httpClient *http.Client
oauthClient *auth.OAuthClient
+ cache *Cache
}
// ClientOptions configures the KorAP client
@@ -26,6 +27,7 @@
BaseURL string
Timeout time.Duration
OAuthConfig *config.OAuthConfig
+ CacheConfig *CacheConfig
}
// NewClient creates a new KorAP API client
@@ -59,6 +61,23 @@
client.oauthClient = oauthClient
}
+ // Initialize cache if configuration is provided
+ if opts.CacheConfig != nil {
+ cache, err := NewCache(*opts.CacheConfig)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create cache: %w", err)
+ }
+ client.cache = cache
+ } else {
+ // Use default cache configuration
+ defaultConfig := DefaultCacheConfig()
+ cache, err := NewCache(defaultConfig)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create default cache: %w", err)
+ }
+ client.cache = cache
+ }
+
return client, nil
}
@@ -152,8 +171,31 @@
return c.doRequest(ctx, http.MethodPost, endpoint, bodyReader)
}
-// GetJSON performs a GET request and unmarshals the JSON response
+// GetJSON performs a GET request and unmarshals the JSON response with caching
func (c *Client) GetJSON(ctx context.Context, endpoint string, target any) error {
+ // Generate cache key for GET requests
+ cacheKey := ""
+ if c.cache != nil {
+ // For GET requests, we can cache based on endpoint and query parameters
+ // Extract query parameters from endpoint if any
+ endpointURL, _ := url.Parse(endpoint)
+ params := make(map[string]any)
+ for key, values := range endpointURL.Query() {
+ if len(values) > 0 {
+ params[key] = values[0]
+ }
+ }
+ cacheKey = c.cache.generateCacheKey("GET", endpointURL.Path, params)
+
+ // Try to get from cache first
+ if cachedData, found := c.cache.Get(ctx, cacheKey); found {
+ if err := json.Unmarshal(cachedData, target); err == nil {
+ return nil
+ }
+ // If unmarshal fails, continue with API call
+ }
+ }
+
resp, err := c.Get(ctx, endpoint)
if err != nil {
return err
@@ -164,10 +206,23 @@
return c.handleErrorResponse(resp)
}
- if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ // Unmarshal JSON
+ if err := json.Unmarshal(body, target); err != nil {
return fmt.Errorf("failed to decode JSON response: %w", err)
}
+ // Cache the response for GET requests
+ if c.cache != nil && cacheKey != "" {
+ ttl := c.cache.GetTTLForEndpoint(endpoint)
+ c.cache.Set(ctx, cacheKey, body, ttl)
+ }
+
return nil
}
@@ -236,3 +291,16 @@
return nil
}
+
+// GetCache returns the cache instance (for testing and monitoring)
+func (c *Client) GetCache() *Cache {
+ return c.cache
+}
+
+// Close closes the client and cleans up resources
+func (c *Client) Close() error {
+ if c.cache != nil {
+ return c.cache.Close()
+ }
+ return nil
+}
diff --git a/service/client_test.go b/service/client_test.go
index 0009d71..6ebb61e 100644
--- a/service/client_test.go
+++ b/service/client_test.go
@@ -318,3 +318,154 @@
assert.Equal(t, "https://example.com/", client.GetBaseURL())
}
+
+// TestClientCaching tests client caching functionality
+func TestClientCaching(t *testing.T) {
+ requestCount := 0
+
+ // Create a mock server that counts requests
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestCount++
+ // Use URL.Path to get clean path without query parameters
+ path := r.URL.Path
+ switch path {
+ case "/cached":
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "data": "cached response",
+ "request_count": requestCount,
+ })
+ case "/search":
+ w.Header().Set("Content-Type", "application/json")
+ query := r.URL.Query().Get("q")
+ json.NewEncoder(w).Encode(map[string]interface{}{
+ "query": query,
+ "results": []string{"result1", "result2"},
+ "request_count": requestCount,
+ })
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ json.NewEncoder(w).Encode(map[string]string{
+ "error": fmt.Sprintf("Not found: %s", path),
+ })
+ }
+ }))
+ defer server.Close()
+
+ t.Run("cache enabled", func(t *testing.T) {
+ cacheConfig := DefaultCacheConfig()
+ cacheConfig.DefaultTTL = 1 * time.Minute
+
+ client, err := NewClient(ClientOptions{
+ BaseURL: server.URL,
+ CacheConfig: &cacheConfig,
+ })
+ assert.NoError(t, err)
+ defer client.Close()
+
+ ctx := context.Background()
+ requestCount = 0
+
+ // First request - should hit the server
+ var result1 map[string]interface{}
+ err = client.GetJSON(ctx, "/cached", &result1)
+ assert.NoError(t, err)
+ assert.Equal(t, float64(1), result1["request_count"])
+
+ // Second request - should hit the cache
+ var result2 map[string]interface{}
+ err = client.GetJSON(ctx, "/cached", &result2)
+ assert.NoError(t, err)
+ assert.Equal(t, float64(1), result2["request_count"]) // Same as first request
+
+ // Verify cache statistics
+ cache := client.GetCache()
+ assert.NotNil(t, cache)
+ stats := cache.Stats()
+ assert.True(t, stats["enabled"].(bool))
+ assert.Equal(t, 1, stats["size"].(int))
+ assert.True(t, stats["hits"].(int64) > 0)
+ })
+
+ t.Run("cache disabled", func(t *testing.T) {
+ cacheConfig := CacheConfig{Enabled: false}
+
+ client, err := NewClient(ClientOptions{
+ BaseURL: server.URL,
+ CacheConfig: &cacheConfig,
+ })
+ assert.NoError(t, err)
+ defer client.Close()
+
+ ctx := context.Background()
+ requestCount = 0
+
+ // Both requests should hit the server
+ var result1 map[string]interface{}
+ err = client.GetJSON(ctx, "/cached", &result1)
+ assert.NoError(t, err)
+ assert.Equal(t, float64(1), result1["request_count"])
+
+ var result2 map[string]interface{}
+ err = client.GetJSON(ctx, "/cached", &result2)
+ assert.NoError(t, err)
+ assert.Equal(t, float64(2), result2["request_count"]) // Different from first request
+
+ // Verify cache is disabled
+ cache := client.GetCache()
+ assert.NotNil(t, cache)
+ stats := cache.Stats()
+ assert.False(t, stats["enabled"].(bool))
+ })
+
+ t.Run("cache expiry", func(t *testing.T) {
+ cacheConfig := CacheConfig{
+ Enabled: true,
+ DefaultTTL: 50 * time.Millisecond,
+ SearchTTL: 50 * time.Millisecond,
+ MetadataTTL: 50 * time.Millisecond,
+ MaxSize: 100,
+ }
+
+ client, err := NewClient(ClientOptions{
+ BaseURL: server.URL,
+ CacheConfig: &cacheConfig,
+ })
+ assert.NoError(t, err)
+ defer client.Close()
+
+ ctx := context.Background()
+ requestCount = 0
+
+ // First request
+ var result1 map[string]interface{}
+ err = client.GetJSON(ctx, "/cached", &result1)
+ assert.NoError(t, err)
+ assert.Equal(t, float64(1), result1["request_count"])
+
+ // Wait for cache to expire
+ time.Sleep(100 * time.Millisecond)
+
+ // Second request should hit server again
+ var result2 map[string]interface{}
+ err = client.GetJSON(ctx, "/cached", &result2)
+ assert.NoError(t, err)
+ assert.Equal(t, float64(2), result2["request_count"])
+ })
+}
+
+// TestClientDefaultCache tests that clients get default cache when no config is provided
+func TestClientDefaultCache(t *testing.T) {
+ client, err := NewClient(ClientOptions{
+ BaseURL: "https://example.com",
+ })
+ assert.NoError(t, err)
+ defer client.Close()
+
+ cache := client.GetCache()
+ assert.NotNil(t, cache)
+
+ stats := cache.Stats()
+ assert.True(t, stats["enabled"].(bool))
+ assert.Equal(t, 1000, stats["max_size"].(int))
+}