blob: 43ec439ae6384f7958f5e197d5ee77525611a04f [file] [log] [blame]
package service
import (
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"time"
cfg "github.com/korap/korap-mcp/config"
"github.com/korap/korap-mcp/logger"
"github.com/maypok86/otter"
"github.com/rs/zerolog"
)
// CacheEntry represents a cached API response
type CacheEntry struct {
Data []byte `json:"data"`
Timestamp time.Time `json:"timestamp"`
TTL time.Duration `json:"ttl"`
}
// IsExpired checks if the cache entry has expired
func (ce *CacheEntry) IsExpired() bool {
return time.Since(ce.Timestamp) > ce.TTL
}
// Cache represents the response cache system
type Cache struct {
cache *otter.Cache[string, *CacheEntry]
logger zerolog.Logger
config CacheConfig
}
// CacheConfig configures the cache behavior
type CacheConfig struct {
// Enabled controls whether caching is active
Enabled bool
// DefaultTTL is the default time-to-live for cache entries
DefaultTTL time.Duration
// SearchTTL is the TTL for search results
SearchTTL time.Duration
// MetadataTTL is the TTL for metadata and corpus information
MetadataTTL time.Duration
// MaxSize is the maximum number of cache entries
MaxSize int
}
// DefaultCacheConfig returns a default cache configuration
func DefaultCacheConfig() CacheConfig {
return CacheConfig{
Enabled: true,
DefaultTTL: 5 * time.Minute,
SearchTTL: 2 * time.Minute, // Search results change less frequently
MetadataTTL: 15 * time.Minute, // Metadata is more stable
MaxSize: 1000,
}
}
// NewCache creates a new cache instance
func NewCache(config CacheConfig) (*Cache, error) {
// Create default logging config for cache
logConfig := &cfg.LoggingConfig{
Level: "info",
Format: "text",
}
if !config.Enabled {
return &Cache{
cache: nil,
logger: logger.GetLogger(logConfig),
config: config,
}, nil
}
// Create otter cache with specified capacity
cache, err := otter.MustBuilder[string, *CacheEntry](config.MaxSize).
CollectStats().
WithTTL(config.DefaultTTL).
Build()
if err != nil {
return nil, fmt.Errorf("failed to create cache: %w", err)
}
return &Cache{
cache: &cache,
logger: logger.GetLogger(logConfig),
config: config,
}, nil
}
// generateCacheKey creates a unique cache key for a request
func (c *Cache) generateCacheKey(method, endpoint string, params map[string]any) string {
// Create a deterministic key by combining method, endpoint, and parameters
var keyParts []string
keyParts = append(keyParts, method, endpoint)
// Add sorted parameters to ensure deterministic cache keys
// Note: json.Marshal automatically sorts map keys lexicographically,
// providing deterministic JSON output regardless of map iteration order
if params != nil {
paramsJSON, _ := json.Marshal(params)
keyParts = append(keyParts, string(paramsJSON))
}
key := strings.Join(keyParts, "|")
// Hash the key to keep it reasonable length and provide privacy
hash := md5.Sum([]byte(key))
return hex.EncodeToString(hash[:])
}
// Get retrieves a cached response
func (c *Cache) Get(ctx context.Context, key string) ([]byte, bool) {
if !c.config.Enabled || c.cache == nil {
return nil, false
}
entry, found := (*c.cache).Get(key)
if !found {
c.logger.Debug().Str("key", key).Msg("Cache miss")
return nil, false
}
// Check if entry has expired
if entry.IsExpired() {
c.logger.Debug().Str("key", key).Msg("Cache entry expired")
(*c.cache).Delete(key)
return nil, false
}
c.logger.Debug().Str("key", key).Msg("Cache hit")
return entry.Data, true
}
// Set stores a response in the cache
func (c *Cache) Set(ctx context.Context, key string, data []byte, ttl time.Duration) {
if !c.config.Enabled || c.cache == nil {
return
}
entry := &CacheEntry{
Data: data,
Timestamp: time.Now(),
TTL: ttl,
}
(*c.cache).Set(key, entry)
c.logger.Debug().Str("key", key).Dur("ttl", ttl).Msg("Cache entry stored")
}
// Delete removes an entry from the cache
func (c *Cache) Delete(ctx context.Context, key string) {
if !c.config.Enabled || c.cache == nil {
return
}
(*c.cache).Delete(key)
c.logger.Debug().Str("key", key).Msg("Cache entry deleted")
}
// Clear removes all entries from the cache
func (c *Cache) Clear() {
if !c.config.Enabled || c.cache == nil {
return
}
(*c.cache).Clear()
c.logger.Debug().Msg("Cache cleared")
}
// Stats returns cache statistics
func (c *Cache) Stats() map[string]interface{} {
if !c.config.Enabled || c.cache == nil {
return map[string]interface{}{
"enabled": false,
}
}
stats := (*c.cache).Stats()
return map[string]interface{}{
"enabled": true,
"size": (*c.cache).Size(),
"hits": stats.Hits(),
"misses": stats.Misses(),
"hit_ratio": stats.Ratio(),
"evictions": stats.EvictedCount(),
"max_size": c.config.MaxSize,
"default_ttl": c.config.DefaultTTL.String(),
"search_ttl": c.config.SearchTTL.String(),
"metadata_ttl": c.config.MetadataTTL.String(),
}
}
// GetTTLForEndpoint returns the appropriate TTL for a given endpoint
func (c *Cache) GetTTLForEndpoint(endpoint string) time.Duration {
endpoint = strings.ToLower(endpoint)
// Search endpoints get shorter TTL
if strings.Contains(endpoint, "search") || strings.Contains(endpoint, "query") {
return c.config.SearchTTL
}
// Metadata and corpus endpoints get longer TTL
if strings.Contains(endpoint, "corpus") || strings.Contains(endpoint, "metadata") ||
strings.Contains(endpoint, "statistics") || strings.Contains(endpoint, "info") {
return c.config.MetadataTTL
}
// Default TTL for other endpoints
return c.config.DefaultTTL
}
// Close closes the cache and cleans up resources
func (c *Cache) Close() error {
if c.cache != nil {
(*c.cache).Clear()
(*c.cache).Close()
}
return nil
}