blob: 9116f086a7ca26a1f55c0ab628b6aa21019639d3 [file] [log] [blame]
package tools
import (
"context"
"fmt"
"sync"
"github.com/mark3labs/mcp-go/mcp"
"github.com/rs/zerolog/log"
)
// Tool represents an MCP tool that can be registered and executed
type Tool interface {
// Name returns the unique name of the tool
Name() string
// Description returns a human-readable description of the tool
Description() string
// InputSchema returns the JSON schema for the tool's input parameters
InputSchema() map[string]interface{}
// Execute runs the tool with the given arguments and returns the result
Execute(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
}
// Registry manages the collection of available MCP tools
type Registry struct {
mu sync.RWMutex
tools map[string]Tool
}
// NewRegistry creates a new tools registry
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]Tool),
}
}
// Register adds a tool to the registry
func (r *Registry) Register(tool Tool) error {
if tool == nil {
return fmt.Errorf("cannot register nil tool")
}
name := tool.Name()
if name == "" {
return fmt.Errorf("tool name cannot be empty")
}
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.tools[name]; exists {
return fmt.Errorf("tool with name %q already registered", name)
}
r.tools[name] = tool
log.Debug().Str("tool", name).Msg("Tool registered successfully")
return nil
}
// Unregister removes a tool from the registry
func (r *Registry) Unregister(name string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.tools[name]; !exists {
return fmt.Errorf("tool with name %q not found", name)
}
delete(r.tools, name)
log.Debug().Str("tool", name).Msg("Tool unregistered successfully")
return nil
}
// Get retrieves a tool by name
func (r *Registry) Get(name string) (Tool, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
tool, exists := r.tools[name]
return tool, exists
}
// List returns all registered tools as MCP tool definitions
func (r *Registry) List() []mcp.Tool {
r.mu.RLock()
defer r.mu.RUnlock()
mcpTools := make([]mcp.Tool, 0, len(r.tools))
for _, tool := range r.tools {
mcpTool := mcp.NewTool(tool.Name(), mcp.WithDescription(tool.Description()))
mcpTools = append(mcpTools, mcpTool)
}
return mcpTools
}
// Names returns a slice of all registered tool names
func (r *Registry) Names() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
return names
}
// Count returns the number of registered tools
func (r *Registry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.tools)
}
// Execute runs a tool with the given arguments
func (r *Registry) Execute(ctx context.Context, name string, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
tool, exists := r.Get(name)
if !exists {
return nil, fmt.Errorf("tool %q not found", name)
}
log.Debug().
Str("tool", name).
Interface("request", request).
Msg("Executing tool")
result, err := tool.Execute(ctx, request)
if err != nil {
log.Error().
Err(err).
Str("tool", name).
Msg("Tool execution failed")
return nil, fmt.Errorf("tool %q execution failed: %w", name, err)
}
log.Debug().
Str("tool", name).
Msg("Tool execution completed successfully")
return result, nil
}
// Clear removes all tools from the registry
func (r *Registry) Clear() {
r.mu.Lock()
defer r.mu.Unlock()
count := len(r.tools)
r.tools = make(map[string]Tool)
log.Debug().Int("count", count).Msg("Registry cleared")
}