Added request/response validation
diff --git a/tools/search.go b/tools/search.go
index c96f407..2a78f8e 100644
--- a/tools/search.go
+++ b/tools/search.go
@@ -6,19 +6,22 @@
"strings"
"github.com/korap/korap-mcp/service"
+ "github.com/korap/korap-mcp/validation"
"github.com/mark3labs/mcp-go/mcp"
"github.com/rs/zerolog/log"
)
// SearchTool implements the Tool interface for KorAP corpus search
type SearchTool struct {
- client *service.Client
+ client *service.Client
+ validator *validation.Validator
}
// NewSearchTool creates a new search tool instance
func NewSearchTool(client *service.Client) *SearchTool {
return &SearchTool{
- client: client,
+ client: client,
+ validator: validation.New(log.Logger),
}
}
@@ -80,12 +83,34 @@
corpus := request.GetString("corpus", "")
count := request.GetInt("count", 25)
+ // Validate the search request using the validation package
+ searchReq := validation.SearchRequest{
+ Query: query,
+ QueryLanguage: queryLang,
+ Corpus: corpus,
+ Count: count,
+ }
+
+ if err := s.validator.ValidateSearchRequest(searchReq); err != nil {
+ log.Warn().
+ Err(err).
+ Interface("request", searchReq).
+ Msg("Search request validation failed")
+ return nil, fmt.Errorf("invalid search request: %w", err)
+ }
+
+ // Sanitize inputs
+ query = s.validator.SanitizeQuery(query)
+ if corpus != "" {
+ corpus = s.validator.SanitizeCorpusID(corpus)
+ }
+
log.Debug().
Str("query", query).
Str("query_language", queryLang).
Str("corpus", corpus).
Int("count", count).
- Msg("Parsed search parameters")
+ Msg("Parsed and validated search parameters")
// Check if client is available and authenticated
if s.client == nil {
@@ -100,7 +125,7 @@
}
// Prepare search request
- searchReq := service.SearchRequest{
+ korapSearchReq := service.SearchRequest{
Query: query,
QueryLang: queryLang,
Collection: corpus,
@@ -109,7 +134,7 @@
// Perform the search
var searchResp service.SearchResponse
- err = s.client.PostJSON(ctx, "search", searchReq, &searchResp)
+ err = s.client.PostJSON(ctx, "search", korapSearchReq, &searchResp)
if err != nil {
log.Error().
Err(err).
@@ -118,6 +143,14 @@
return nil, fmt.Errorf("search failed: %w", err)
}
+ // Validate the response
+ if err := s.validator.ValidateSearchResponse(&searchResp); err != nil {
+ log.Warn().
+ Err(err).
+ Msg("Search response validation failed, but continuing with potentially invalid data")
+ // Continue processing despite validation errors to be resilient
+ }
+
log.Info().
Str("query", query).
Int("total_results", searchResp.Meta.TotalResults).