Add cache mechanism to KorAP client
Change-Id: Ie3f3d48611f039904f22a19cf6299a3c43fe8bbe
diff --git a/service/cache.go b/service/cache.go
new file mode 100644
index 0000000..43ec439
--- /dev/null
+++ b/service/cache.go
@@ -0,0 +1,223 @@
+package service
+
+import (
+ "context"
+ "crypto/md5"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ cfg "github.com/korap/korap-mcp/config"
+ "github.com/korap/korap-mcp/logger"
+ "github.com/maypok86/otter"
+ "github.com/rs/zerolog"
+)
+
+// CacheEntry represents a cached API response
+type CacheEntry struct {
+ Data []byte `json:"data"`
+ Timestamp time.Time `json:"timestamp"`
+ TTL time.Duration `json:"ttl"`
+}
+
+// IsExpired checks if the cache entry has expired
+func (ce *CacheEntry) IsExpired() bool {
+ return time.Since(ce.Timestamp) > ce.TTL
+}
+
+// Cache represents the response cache system
+type Cache struct {
+ cache *otter.Cache[string, *CacheEntry]
+ logger zerolog.Logger
+ config CacheConfig
+}
+
+// CacheConfig configures the cache behavior
+type CacheConfig struct {
+ // Enabled controls whether caching is active
+ Enabled bool
+ // DefaultTTL is the default time-to-live for cache entries
+ DefaultTTL time.Duration
+ // SearchTTL is the TTL for search results
+ SearchTTL time.Duration
+ // MetadataTTL is the TTL for metadata and corpus information
+ MetadataTTL time.Duration
+ // MaxSize is the maximum number of cache entries
+ MaxSize int
+}
+
+// DefaultCacheConfig returns a default cache configuration
+func DefaultCacheConfig() CacheConfig {
+ return CacheConfig{
+ Enabled: true,
+ DefaultTTL: 5 * time.Minute,
+ SearchTTL: 2 * time.Minute, // Search results change less frequently
+ MetadataTTL: 15 * time.Minute, // Metadata is more stable
+ MaxSize: 1000,
+ }
+}
+
+// NewCache creates a new cache instance
+func NewCache(config CacheConfig) (*Cache, error) {
+ // Create default logging config for cache
+ logConfig := &cfg.LoggingConfig{
+ Level: "info",
+ Format: "text",
+ }
+
+ if !config.Enabled {
+ return &Cache{
+ cache: nil,
+ logger: logger.GetLogger(logConfig),
+ config: config,
+ }, nil
+ }
+
+ // Create otter cache with specified capacity
+ cache, err := otter.MustBuilder[string, *CacheEntry](config.MaxSize).
+ CollectStats().
+ WithTTL(config.DefaultTTL).
+ Build()
+ if err != nil {
+ return nil, fmt.Errorf("failed to create cache: %w", err)
+ }
+
+ return &Cache{
+ cache: &cache,
+ logger: logger.GetLogger(logConfig),
+ config: config,
+ }, nil
+}
+
+// generateCacheKey creates a unique cache key for a request
+func (c *Cache) generateCacheKey(method, endpoint string, params map[string]any) string {
+ // Create a deterministic key by combining method, endpoint, and parameters
+ var keyParts []string
+ keyParts = append(keyParts, method, endpoint)
+
+ // Add sorted parameters to ensure deterministic cache keys
+ // Note: json.Marshal automatically sorts map keys lexicographically,
+ // providing deterministic JSON output regardless of map iteration order
+ if params != nil {
+ paramsJSON, _ := json.Marshal(params)
+ keyParts = append(keyParts, string(paramsJSON))
+ }
+
+ key := strings.Join(keyParts, "|")
+
+ // Hash the key to keep it reasonable length and provide privacy
+ hash := md5.Sum([]byte(key))
+ return hex.EncodeToString(hash[:])
+}
+
+// Get retrieves a cached response
+func (c *Cache) Get(ctx context.Context, key string) ([]byte, bool) {
+ if !c.config.Enabled || c.cache == nil {
+ return nil, false
+ }
+
+ entry, found := (*c.cache).Get(key)
+ if !found {
+ c.logger.Debug().Str("key", key).Msg("Cache miss")
+ return nil, false
+ }
+
+ // Check if entry has expired
+ if entry.IsExpired() {
+ c.logger.Debug().Str("key", key).Msg("Cache entry expired")
+ (*c.cache).Delete(key)
+ return nil, false
+ }
+
+ c.logger.Debug().Str("key", key).Msg("Cache hit")
+ return entry.Data, true
+}
+
+// Set stores a response in the cache
+func (c *Cache) Set(ctx context.Context, key string, data []byte, ttl time.Duration) {
+ if !c.config.Enabled || c.cache == nil {
+ return
+ }
+
+ entry := &CacheEntry{
+ Data: data,
+ Timestamp: time.Now(),
+ TTL: ttl,
+ }
+
+ (*c.cache).Set(key, entry)
+ c.logger.Debug().Str("key", key).Dur("ttl", ttl).Msg("Cache entry stored")
+}
+
+// Delete removes an entry from the cache
+func (c *Cache) Delete(ctx context.Context, key string) {
+ if !c.config.Enabled || c.cache == nil {
+ return
+ }
+
+ (*c.cache).Delete(key)
+ c.logger.Debug().Str("key", key).Msg("Cache entry deleted")
+}
+
+// Clear removes all entries from the cache
+func (c *Cache) Clear() {
+ if !c.config.Enabled || c.cache == nil {
+ return
+ }
+
+ (*c.cache).Clear()
+ c.logger.Debug().Msg("Cache cleared")
+}
+
+// Stats returns cache statistics
+func (c *Cache) Stats() map[string]interface{} {
+ if !c.config.Enabled || c.cache == nil {
+ return map[string]interface{}{
+ "enabled": false,
+ }
+ }
+
+ stats := (*c.cache).Stats()
+ return map[string]interface{}{
+ "enabled": true,
+ "size": (*c.cache).Size(),
+ "hits": stats.Hits(),
+ "misses": stats.Misses(),
+ "hit_ratio": stats.Ratio(),
+ "evictions": stats.EvictedCount(),
+ "max_size": c.config.MaxSize,
+ "default_ttl": c.config.DefaultTTL.String(),
+ "search_ttl": c.config.SearchTTL.String(),
+ "metadata_ttl": c.config.MetadataTTL.String(),
+ }
+}
+
+// GetTTLForEndpoint returns the appropriate TTL for a given endpoint
+func (c *Cache) GetTTLForEndpoint(endpoint string) time.Duration {
+ endpoint = strings.ToLower(endpoint)
+
+ // Search endpoints get shorter TTL
+ if strings.Contains(endpoint, "search") || strings.Contains(endpoint, "query") {
+ return c.config.SearchTTL
+ }
+
+ // Metadata and corpus endpoints get longer TTL
+ if strings.Contains(endpoint, "corpus") || strings.Contains(endpoint, "metadata") ||
+ strings.Contains(endpoint, "statistics") || strings.Contains(endpoint, "info") {
+ return c.config.MetadataTTL
+ }
+
+ // Default TTL for other endpoints
+ return c.config.DefaultTTL
+}
+
+// Close closes the cache and cleans up resources
+func (c *Cache) Close() error {
+ if c.cache != nil {
+ (*c.cache).Clear()
+ (*c.cache).Close()
+ }
+ return nil
+}