Added request/response validation
diff --git a/tools/metadata.go b/tools/metadata.go
index 6fd84b1..c3da813 100644
--- a/tools/metadata.go
+++ b/tools/metadata.go
@@ -6,19 +6,22 @@
"strings"
"github.com/korap/korap-mcp/service"
+ "github.com/korap/korap-mcp/validation"
"github.com/mark3labs/mcp-go/mcp"
"github.com/rs/zerolog/log"
)
// MetadataTool implements the Tool interface for KorAP corpus metadata retrieval
type MetadataTool struct {
- client *service.Client
+ client *service.Client
+ validator *validation.Validator
}
// NewMetadataTool creates a new metadata tool instance
func NewMetadataTool(client *service.Client) *MetadataTool {
return &MetadataTool{
- client: client,
+ client: client,
+ validator: validation.New(log.Logger),
}
}
@@ -67,20 +70,29 @@
// Extract optional corpus parameter
corpus := request.GetString("corpus", "")
+ // Validate the metadata request using the validation package
+ metadataReq := validation.MetadataRequest{
+ Action: action,
+ Corpus: corpus,
+ }
+
+ if err := m.validator.ValidateMetadataRequest(metadataReq); err != nil {
+ log.Warn().
+ Err(err).
+ Interface("request", metadataReq).
+ Msg("Metadata request validation failed")
+ return nil, fmt.Errorf("invalid metadata request: %w", err)
+ }
+
+ // Sanitize inputs
+ if corpus != "" {
+ corpus = m.validator.SanitizeCorpusID(corpus)
+ }
+
log.Debug().
Str("action", action).
Str("corpus", corpus).
- Msg("Parsed metadata parameters")
-
- // Validate parameters before authentication
- switch action {
- case "list":
- // No additional validation needed for list
- case "statistics":
- // No additional validation needed for statistics - corpus is optional
- default:
- return nil, fmt.Errorf("unknown action: %s", action)
- }
+ Msg("Parsed and validated metadata parameters")
// Check if client is available and authenticated
if m.client == nil {
@@ -119,6 +131,14 @@
return nil, fmt.Errorf("failed to retrieve corpus list: %w", err)
}
+ // Validate the response
+ if err := m.validator.ValidateCorpusListResponse(&corpusListResp); err != nil {
+ log.Warn().
+ Err(err).
+ Msg("Corpus list response validation failed, but continuing with potentially invalid data")
+ // Continue processing despite validation errors to be resilient
+ }
+
log.Info().
Int("corpus_count", len(corpusListResp.Corpora)).
Msg("Corpus list retrieved successfully")
@@ -150,6 +170,14 @@
return nil, fmt.Errorf("failed to retrieve corpus statistics: %w", err)
}
+ // Validate the response
+ if err := m.validator.ValidateStatisticsResponse(&statsResp); err != nil {
+ log.Warn().
+ Err(err).
+ Msg("Statistics response validation failed, but continuing with potentially invalid data")
+ // Continue processing despite validation errors to be resilient
+ }
+
log.Info().
Str("corpus", corpus).
Int("documents", statsResp.Documents).
diff --git a/tools/metadata_test.go b/tools/metadata_test.go
index 88c7e6d..e5d60cd 100644
--- a/tools/metadata_test.go
+++ b/tools/metadata_test.go
@@ -110,8 +110,9 @@
_, err := tool.Execute(context.Background(), request)
assert.Error(t, err)
- // The unknown action error should come before authentication
- assert.Contains(t, err.Error(), "unknown action: unknown")
+ // The validation error should come before authentication
+ assert.Contains(t, err.Error(), "invalid metadata request")
+ assert.Contains(t, err.Error(), "invalid action")
}
func TestMetadataTool_Execute_StatisticsWithoutCorpus(t *testing.T) {
diff --git a/tools/search.go b/tools/search.go
index c96f407..2a78f8e 100644
--- a/tools/search.go
+++ b/tools/search.go
@@ -6,19 +6,22 @@
"strings"
"github.com/korap/korap-mcp/service"
+ "github.com/korap/korap-mcp/validation"
"github.com/mark3labs/mcp-go/mcp"
"github.com/rs/zerolog/log"
)
// SearchTool implements the Tool interface for KorAP corpus search
type SearchTool struct {
- client *service.Client
+ client *service.Client
+ validator *validation.Validator
}
// NewSearchTool creates a new search tool instance
func NewSearchTool(client *service.Client) *SearchTool {
return &SearchTool{
- client: client,
+ client: client,
+ validator: validation.New(log.Logger),
}
}
@@ -80,12 +83,34 @@
corpus := request.GetString("corpus", "")
count := request.GetInt("count", 25)
+ // Validate the search request using the validation package
+ searchReq := validation.SearchRequest{
+ Query: query,
+ QueryLanguage: queryLang,
+ Corpus: corpus,
+ Count: count,
+ }
+
+ if err := s.validator.ValidateSearchRequest(searchReq); err != nil {
+ log.Warn().
+ Err(err).
+ Interface("request", searchReq).
+ Msg("Search request validation failed")
+ return nil, fmt.Errorf("invalid search request: %w", err)
+ }
+
+ // Sanitize inputs
+ query = s.validator.SanitizeQuery(query)
+ if corpus != "" {
+ corpus = s.validator.SanitizeCorpusID(corpus)
+ }
+
log.Debug().
Str("query", query).
Str("query_language", queryLang).
Str("corpus", corpus).
Int("count", count).
- Msg("Parsed search parameters")
+ Msg("Parsed and validated search parameters")
// Check if client is available and authenticated
if s.client == nil {
@@ -100,7 +125,7 @@
}
// Prepare search request
- searchReq := service.SearchRequest{
+ korapSearchReq := service.SearchRequest{
Query: query,
QueryLang: queryLang,
Collection: corpus,
@@ -109,7 +134,7 @@
// Perform the search
var searchResp service.SearchResponse
- err = s.client.PostJSON(ctx, "search", searchReq, &searchResp)
+ err = s.client.PostJSON(ctx, "search", korapSearchReq, &searchResp)
if err != nil {
log.Error().
Err(err).
@@ -118,6 +143,14 @@
return nil, fmt.Errorf("search failed: %w", err)
}
+ // Validate the response
+ if err := s.validator.ValidateSearchResponse(&searchResp); err != nil {
+ log.Warn().
+ Err(err).
+ Msg("Search response validation failed, but continuing with potentially invalid data")
+ // Continue processing despite validation errors to be resilient
+ }
+
log.Info().
Str("query", query).
Int("total_results", searchResp.Meta.TotalResults).
diff --git a/validation/validator.go b/validation/validator.go
new file mode 100644
index 0000000..66b77c8
--- /dev/null
+++ b/validation/validator.go
@@ -0,0 +1,460 @@
+package validation
+
+import (
+ "fmt"
+ "net/url"
+ "regexp"
+ "strings"
+
+ "github.com/korap/korap-mcp/service"
+ "github.com/rs/zerolog"
+)
+
+// Validator provides input validation and response schema validation
+type Validator struct {
+ logger zerolog.Logger
+}
+
+// New creates a new validator instance
+func New(logger zerolog.Logger) *Validator {
+ return &Validator{
+ logger: logger.With().Str("component", "validator").Logger(),
+ }
+}
+
+// SearchRequest holds the parameters for a search request validation
+type SearchRequest struct {
+ Query string `json:"query"`
+ QueryLanguage string `json:"query_language,omitempty"`
+ Corpus string `json:"corpus,omitempty"`
+ Count int `json:"count,omitempty"`
+}
+
+// MetadataRequest holds the parameters for a metadata request validation
+type MetadataRequest struct {
+ Action string `json:"action"`
+ Corpus string `json:"corpus,omitempty"`
+}
+
+// ValidationError represents a validation error with details
+type ValidationError struct {
+ Field string `json:"field"`
+ Value string `json:"value"`
+ Message string `json:"message"`
+}
+
+func (e ValidationError) Error() string {
+ return fmt.Sprintf("validation error for field '%s' (value: '%s'): %s", e.Field, e.Value, e.Message)
+}
+
+// ValidationErrors represents multiple validation errors
+type ValidationErrors struct {
+ Errors []ValidationError `json:"errors"`
+}
+
+func (e ValidationErrors) Error() string {
+ if len(e.Errors) == 0 {
+ return "validation errors occurred"
+ }
+ var messages []string
+ for _, err := range e.Errors {
+ messages = append(messages, err.Error())
+ }
+ return strings.Join(messages, "; ")
+}
+
+// Regular expressions for validation
+var (
+ // Query language validation - KorAP supports poliqarp, cosmas2, annis
+ validQueryLanguages = map[string]bool{
+ "poliqarp": true,
+ "cosmas2": true,
+ "annis": true,
+ }
+
+ // Corpus ID validation - alphanumeric with dots, hyphens, underscores
+ corpusIDRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
+
+ // Action validation for metadata requests
+ validMetadataActions = map[string]bool{
+ "list": true,
+ "statistics": true,
+ }
+)
+
+// ValidateSearchRequest validates a search request
+func (v *Validator) ValidateSearchRequest(req SearchRequest) error {
+ var errors []ValidationError
+
+ // Validate query - required and non-empty
+ if strings.TrimSpace(req.Query) == "" {
+ errors = append(errors, ValidationError{
+ Field: "query",
+ Value: req.Query,
+ Message: "query is required and cannot be empty",
+ })
+ } else {
+ // Basic query validation - check for potentially dangerous patterns
+ if err := v.validateQuerySafety(req.Query); err != nil {
+ errors = append(errors, ValidationError{
+ Field: "query",
+ Value: req.Query,
+ Message: err.Error(),
+ })
+ }
+ }
+
+ // Validate query language if provided
+ if req.QueryLanguage != "" && !validQueryLanguages[req.QueryLanguage] {
+ var validLangs []string
+ for lang := range validQueryLanguages {
+ validLangs = append(validLangs, lang)
+ }
+ errors = append(errors, ValidationError{
+ Field: "query_language",
+ Value: req.QueryLanguage,
+ Message: fmt.Sprintf("invalid query language, must be one of: %s", strings.Join(validLangs, ", ")),
+ })
+ }
+
+ // Validate corpus if provided
+ if req.Corpus != "" {
+ if err := v.validateCorpusID(req.Corpus); err != nil {
+ errors = append(errors, ValidationError{
+ Field: "corpus",
+ Value: req.Corpus,
+ Message: err.Error(),
+ })
+ }
+ }
+
+ // Validate count if provided (0 means use default, so only validate non-zero values)
+ if req.Count < 0 || req.Count > 10000 {
+ errors = append(errors, ValidationError{
+ Field: "count",
+ Value: fmt.Sprintf("%d", req.Count),
+ Message: "count must be between 0 and 10000 (0 means use default)",
+ })
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Search request validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Interface("request", req).Msg("Search request validation passed")
+ return nil
+}
+
+// ValidateMetadataRequest validates a metadata request
+func (v *Validator) ValidateMetadataRequest(req MetadataRequest) error {
+ var errors []ValidationError
+
+ // Validate action - required
+ if strings.TrimSpace(req.Action) == "" {
+ errors = append(errors, ValidationError{
+ Field: "action",
+ Value: req.Action,
+ Message: "action is required and cannot be empty",
+ })
+ } else if !validMetadataActions[req.Action] {
+ var validActions []string
+ for action := range validMetadataActions {
+ validActions = append(validActions, action)
+ }
+ errors = append(errors, ValidationError{
+ Field: "action",
+ Value: req.Action,
+ Message: fmt.Sprintf("invalid action, must be one of: %s", strings.Join(validActions, ", ")),
+ })
+ }
+
+ // Validate corpus if provided
+ if req.Corpus != "" {
+ if err := v.validateCorpusID(req.Corpus); err != nil {
+ errors = append(errors, ValidationError{
+ Field: "corpus",
+ Value: req.Corpus,
+ Message: err.Error(),
+ })
+ }
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Metadata request validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Interface("request", req).Msg("Metadata request validation passed")
+ return nil
+}
+
+// ValidateSearchResponse validates a search response
+func (v *Validator) ValidateSearchResponse(resp *service.SearchResponse) error {
+ if resp == nil {
+ return fmt.Errorf("search response is nil")
+ }
+
+ var errors []ValidationError
+
+ // Validate meta structure
+ if resp.Meta.TotalResults < 0 {
+ errors = append(errors, ValidationError{
+ Field: "meta.totalResults",
+ Value: fmt.Sprintf("%d", resp.Meta.TotalResults),
+ Message: "totalResults cannot be negative",
+ })
+ }
+
+ if resp.Meta.Count < 0 {
+ errors = append(errors, ValidationError{
+ Field: "meta.count",
+ Value: fmt.Sprintf("%d", resp.Meta.Count),
+ Message: "count cannot be negative",
+ })
+ }
+
+ if resp.Meta.StartIndex < 0 {
+ errors = append(errors, ValidationError{
+ Field: "meta.startIndex",
+ Value: fmt.Sprintf("%d", resp.Meta.StartIndex),
+ Message: "startIndex cannot be negative",
+ })
+ }
+
+ if resp.Meta.ItemsPerPage < 0 {
+ errors = append(errors, ValidationError{
+ Field: "meta.itemsPerPage",
+ Value: fmt.Sprintf("%d", resp.Meta.ItemsPerPage),
+ Message: "itemsPerPage cannot be negative",
+ })
+ }
+
+ // Validate matches if present
+ if resp.Matches != nil {
+ for i, match := range resp.Matches {
+ if match.MatchID == "" {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("matches[%d].matchID", i),
+ Value: "",
+ Message: "match ID is required",
+ })
+ }
+
+ if match.TextSigle == "" {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("matches[%d].textSigle", i),
+ Value: "",
+ Message: "textSigle is required",
+ })
+ }
+
+ if match.Position < 0 {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("matches[%d].position", i),
+ Value: fmt.Sprintf("%d", match.Position),
+ Message: "position cannot be negative",
+ })
+ }
+ }
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Search response validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Msg("Search response validation passed")
+ return nil
+}
+
+// ValidateCorpusListResponse validates a corpus list response
+func (v *Validator) ValidateCorpusListResponse(resp *service.CorpusListResponse) error {
+ if resp == nil {
+ return fmt.Errorf("corpus list response is nil")
+ }
+
+ var errors []ValidationError
+
+ // Validate corpus entries
+ if resp.Corpora != nil {
+ for i, corpus := range resp.Corpora {
+ if corpus.ID == "" {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].id", i),
+ Value: "",
+ Message: "corpus ID is required",
+ })
+ } else if err := v.validateCorpusID(corpus.ID); err != nil {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].id", i),
+ Value: corpus.ID,
+ Message: err.Error(),
+ })
+ }
+
+ if corpus.Name == "" {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].name", i),
+ Value: "",
+ Message: "corpus name is required",
+ })
+ }
+
+ if corpus.Documents < 0 {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].documents", i),
+ Value: fmt.Sprintf("%d", corpus.Documents),
+ Message: "document count cannot be negative",
+ })
+ }
+
+ if corpus.Tokens < 0 {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].tokens", i),
+ Value: fmt.Sprintf("%d", corpus.Tokens),
+ Message: "token count cannot be negative",
+ })
+ }
+ }
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Corpus list response validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Msg("Corpus list response validation passed")
+ return nil
+}
+
+// ValidateStatisticsResponse validates a statistics response
+func (v *Validator) ValidateStatisticsResponse(resp *service.StatisticsResponse) error {
+ if resp == nil {
+ return fmt.Errorf("statistics response is nil")
+ }
+
+ var errors []ValidationError
+
+ if resp.Documents < 0 {
+ errors = append(errors, ValidationError{
+ Field: "documents",
+ Value: fmt.Sprintf("%d", resp.Documents),
+ Message: "document count cannot be negative",
+ })
+ }
+
+ if resp.Tokens < 0 {
+ errors = append(errors, ValidationError{
+ Field: "tokens",
+ Value: fmt.Sprintf("%d", resp.Tokens),
+ Message: "token count cannot be negative",
+ })
+ }
+
+ if resp.Sentences < 0 {
+ errors = append(errors, ValidationError{
+ Field: "sentences",
+ Value: fmt.Sprintf("%d", resp.Sentences),
+ Message: "sentence count cannot be negative",
+ })
+ }
+
+ if resp.Paragraphs < 0 {
+ errors = append(errors, ValidationError{
+ Field: "paragraphs",
+ Value: fmt.Sprintf("%d", resp.Paragraphs),
+ Message: "paragraph count cannot be negative",
+ })
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Statistics response validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Msg("Statistics response validation passed")
+ return nil
+}
+
+// validateQuerySafety performs basic security validation on queries
+func (v *Validator) validateQuerySafety(query string) error {
+ // Check for extremely long queries that could cause DoS
+ if len(query) > 10000 {
+ return fmt.Errorf("query is too long (max 10000 characters)")
+ }
+
+ // Check for potentially dangerous URL patterns
+ if strings.Contains(query, "://") {
+ if _, err := url.Parse(query); err == nil {
+ return fmt.Errorf("query appears to contain a URL which is not allowed")
+ }
+ }
+
+ // Check for excessive nesting that could cause parser issues
+ openParens := strings.Count(query, "(")
+ closeParens := strings.Count(query, ")")
+ if openParens != closeParens {
+ return fmt.Errorf("unmatched parentheses in query")
+ }
+ if openParens > 100 {
+ return fmt.Errorf("query has too many nested levels (max 100)")
+ }
+
+ return nil
+}
+
+// validateCorpusID validates a corpus identifier
+func (v *Validator) validateCorpusID(corpusID string) error {
+ if len(corpusID) == 0 {
+ return fmt.Errorf("corpus ID cannot be empty")
+ }
+
+ if len(corpusID) > 100 {
+ return fmt.Errorf("corpus ID is too long (max 100 characters)")
+ }
+
+ if !corpusIDRegex.MatchString(corpusID) {
+ return fmt.Errorf("corpus ID contains invalid characters (only alphanumeric, dots, hyphens, underscores allowed)")
+ }
+
+ return nil
+}
+
+// SanitizeQuery performs basic sanitization on search queries
+func (v *Validator) SanitizeQuery(query string) string {
+ // Trim whitespace
+ sanitized := strings.TrimSpace(query)
+
+ // Remove any null bytes
+ sanitized = strings.ReplaceAll(sanitized, "\x00", "")
+
+ // Normalize whitespace
+ sanitized = regexp.MustCompile(`\s+`).ReplaceAllString(sanitized, " ")
+
+ v.logger.Debug().
+ Str("original", query).
+ Str("sanitized", sanitized).
+ Msg("Query sanitized")
+
+ return sanitized
+}
+
+// SanitizeCorpusID performs basic sanitization on corpus IDs
+func (v *Validator) SanitizeCorpusID(corpusID string) string {
+ // Trim whitespace
+ sanitized := strings.TrimSpace(corpusID)
+
+ // Remove any null bytes
+ sanitized = strings.ReplaceAll(sanitized, "\x00", "")
+
+ // Convert to lowercase for consistency
+ sanitized = strings.ToLower(sanitized)
+
+ v.logger.Debug().
+ Str("original", corpusID).
+ Str("sanitized", sanitized).
+ Msg("Corpus ID sanitized")
+
+ return sanitized
+}
diff --git a/validation/validator_test.go b/validation/validator_test.go
new file mode 100644
index 0000000..43254a7
--- /dev/null
+++ b/validation/validator_test.go
@@ -0,0 +1,827 @@
+package validation
+
+import (
+ "testing"
+
+ "github.com/korap/korap-mcp/service"
+ "github.com/rs/zerolog"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNew(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ assert.NotNil(t, validator)
+ assert.Equal(t, logger.With().Str("component", "validator").Logger(), validator.logger)
+}
+
+func TestValidationError_Error(t *testing.T) {
+ err := ValidationError{
+ Field: "test_field",
+ Value: "test_value",
+ Message: "test message",
+ }
+
+ expected := "validation error for field 'test_field' (value: 'test_value'): test message"
+ assert.Equal(t, expected, err.Error())
+}
+
+func TestValidationErrors_Error(t *testing.T) {
+ // Test empty errors
+ emptyErrors := ValidationErrors{}
+ assert.Equal(t, "validation errors occurred", emptyErrors.Error())
+
+ // Test single error
+ singleError := ValidationErrors{
+ Errors: []ValidationError{
+ {Field: "field1", Value: "value1", Message: "message1"},
+ },
+ }
+ expected := "validation error for field 'field1' (value: 'value1'): message1"
+ assert.Equal(t, expected, singleError.Error())
+
+ // Test multiple errors
+ multipleErrors := ValidationErrors{
+ Errors: []ValidationError{
+ {Field: "field1", Value: "value1", Message: "message1"},
+ {Field: "field2", Value: "value2", Message: "message2"},
+ },
+ }
+ expected = "validation error for field 'field1' (value: 'value1'): message1; validation error for field 'field2' (value: 'value2'): message2"
+ assert.Equal(t, expected, multipleErrors.Error())
+}
+
+func TestValidateSearchRequest(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ request SearchRequest
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "valid_request_minimal",
+ request: SearchRequest{
+ Query: "test query",
+ },
+ expectErr: false,
+ },
+ {
+ name: "valid_request_complete",
+ request: SearchRequest{
+ Query: "test query",
+ QueryLanguage: "poliqarp",
+ Corpus: "test-corpus",
+ Count: 100,
+ },
+ expectErr: false,
+ },
+ {
+ name: "empty_query",
+ request: SearchRequest{
+ Query: "",
+ },
+ expectErr: true,
+ errorMsg: "query is required and cannot be empty",
+ },
+ {
+ name: "whitespace_only_query",
+ request: SearchRequest{
+ Query: " ",
+ },
+ expectErr: true,
+ errorMsg: "query is required and cannot be empty",
+ },
+ {
+ name: "invalid_query_language",
+ request: SearchRequest{
+ Query: "test query",
+ QueryLanguage: "invalid",
+ },
+ expectErr: true,
+ errorMsg: "invalid query language",
+ },
+ {
+ name: "invalid_corpus_id",
+ request: SearchRequest{
+ Query: "test query",
+ Corpus: "invalid corpus!",
+ },
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ {
+ name: "count_negative",
+ request: SearchRequest{
+ Query: "test query",
+ Count: -1,
+ },
+ expectErr: true,
+ errorMsg: "count must be between 0 and 10000",
+ },
+ {
+ name: "count_zero_valid",
+ request: SearchRequest{
+ Query: "test query",
+ Count: 0,
+ },
+ expectErr: false,
+ },
+ {
+ name: "count_too_high",
+ request: SearchRequest{
+ Query: "test query",
+ Count: 10001,
+ },
+ expectErr: true,
+ errorMsg: "count must be between 0 and 10000",
+ },
+ {
+ name: "unsafe_query_too_long",
+ request: SearchRequest{
+ Query: string(make([]byte, 10001)),
+ },
+ expectErr: true,
+ errorMsg: "query is too long",
+ },
+ {
+ name: "unsafe_query_url",
+ request: SearchRequest{
+ Query: "http://example.com",
+ },
+ expectErr: true,
+ errorMsg: "query appears to contain a URL",
+ },
+ {
+ name: "unsafe_query_unmatched_parens",
+ request: SearchRequest{
+ Query: "test (query",
+ },
+ expectErr: true,
+ errorMsg: "unmatched parentheses",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateSearchRequest(tt.request)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateMetadataRequest(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ request MetadataRequest
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "valid_list_action",
+ request: MetadataRequest{
+ Action: "list",
+ },
+ expectErr: false,
+ },
+ {
+ name: "valid_statistics_action",
+ request: MetadataRequest{
+ Action: "statistics",
+ Corpus: "test-corpus",
+ },
+ expectErr: false,
+ },
+ {
+ name: "empty_action",
+ request: MetadataRequest{
+ Action: "",
+ },
+ expectErr: true,
+ errorMsg: "action is required and cannot be empty",
+ },
+ {
+ name: "whitespace_only_action",
+ request: MetadataRequest{
+ Action: " ",
+ },
+ expectErr: true,
+ errorMsg: "action is required and cannot be empty",
+ },
+ {
+ name: "invalid_action",
+ request: MetadataRequest{
+ Action: "invalid",
+ },
+ expectErr: true,
+ errorMsg: "invalid action",
+ },
+ {
+ name: "invalid_corpus_id",
+ request: MetadataRequest{
+ Action: "statistics",
+ Corpus: "invalid corpus!",
+ },
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateMetadataRequest(tt.request)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateSearchResponse(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ response *service.SearchResponse
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "nil_response",
+ response: nil,
+ expectErr: true,
+ errorMsg: "search response is nil",
+ },
+ {
+ name: "valid_response",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ Query: service.SearchQuery{
+ Query: "test",
+ QueryLang: "poliqarp",
+ },
+ Matches: []service.SearchMatch{
+ {MatchID: "match1", TextSigle: "text1", Position: 0},
+ {MatchID: "match2", TextSigle: "text2", Position: 1},
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "negative_total_results",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: -1,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ },
+ expectErr: true,
+ errorMsg: "totalResults cannot be negative",
+ },
+ {
+ name: "negative_count",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: -1,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ },
+ expectErr: true,
+ errorMsg: "count cannot be negative",
+ },
+ {
+ name: "negative_start_index",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: -1,
+ ItemsPerPage: 10,
+ },
+ },
+ expectErr: true,
+ errorMsg: "startIndex cannot be negative",
+ },
+ {
+ name: "negative_items_per_page",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: -1,
+ },
+ },
+ expectErr: true,
+ errorMsg: "itemsPerPage cannot be negative",
+ },
+ {
+ name: "match_missing_id",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ Matches: []service.SearchMatch{
+ {MatchID: "", TextSigle: "text1", Position: 0},
+ },
+ },
+ expectErr: true,
+ errorMsg: "match ID is required",
+ },
+ {
+ name: "match_missing_text_sigle",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ Matches: []service.SearchMatch{
+ {MatchID: "match1", TextSigle: "", Position: 0},
+ },
+ },
+ expectErr: true,
+ errorMsg: "textSigle is required",
+ },
+ {
+ name: "match_negative_position",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ Matches: []service.SearchMatch{
+ {MatchID: "match1", TextSigle: "text1", Position: -1},
+ },
+ },
+ expectErr: true,
+ errorMsg: "position cannot be negative",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateSearchResponse(tt.response)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateCorpusListResponse(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ response *service.CorpusListResponse
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "nil_response",
+ response: nil,
+ expectErr: true,
+ errorMsg: "corpus list response is nil",
+ },
+ {
+ name: "valid_response",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "corpus1",
+ Name: "Test Corpus 1",
+ Documents: 100,
+ Tokens: 50000,
+ },
+ {
+ ID: "corpus2",
+ Name: "Test Corpus 2",
+ Documents: 200,
+ Tokens: 75000,
+ },
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "empty_corpus_list",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{},
+ },
+ expectErr: false,
+ },
+ {
+ name: "corpus_missing_id",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "",
+ Name: "Test Corpus",
+ Documents: 100,
+ Tokens: 50000,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "corpus ID is required",
+ },
+ {
+ name: "corpus_invalid_id",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "invalid id!",
+ Name: "Test Corpus",
+ Documents: 100,
+ Tokens: 50000,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ {
+ name: "corpus_missing_name",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "corpus1",
+ Name: "",
+ Documents: 100,
+ Tokens: 50000,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "corpus name is required",
+ },
+ {
+ name: "corpus_negative_documents",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "corpus1",
+ Name: "Test Corpus",
+ Documents: -1,
+ Tokens: 50000,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "document count cannot be negative",
+ },
+ {
+ name: "corpus_negative_tokens",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "corpus1",
+ Name: "Test Corpus",
+ Documents: 100,
+ Tokens: -1,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "token count cannot be negative",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateCorpusListResponse(tt.response)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateStatisticsResponse(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ response *service.StatisticsResponse
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "nil_response",
+ response: nil,
+ expectErr: true,
+ errorMsg: "statistics response is nil",
+ },
+ {
+ name: "valid_response",
+ response: &service.StatisticsResponse{
+ Documents: 100,
+ Tokens: 50000,
+ Sentences: 2500,
+ Paragraphs: 500,
+ },
+ expectErr: false,
+ },
+ {
+ name: "negative_documents",
+ response: &service.StatisticsResponse{
+ Documents: -1,
+ Tokens: 50000,
+ },
+ expectErr: true,
+ errorMsg: "document count cannot be negative",
+ },
+ {
+ name: "negative_tokens",
+ response: &service.StatisticsResponse{
+ Documents: 100,
+ Tokens: -1,
+ },
+ expectErr: true,
+ errorMsg: "token count cannot be negative",
+ },
+ {
+ name: "negative_sentences",
+ response: &service.StatisticsResponse{
+ Documents: 100,
+ Tokens: 50000,
+ Sentences: -1,
+ },
+ expectErr: true,
+ errorMsg: "sentence count cannot be negative",
+ },
+ {
+ name: "negative_paragraphs",
+ response: &service.StatisticsResponse{
+ Documents: 100,
+ Tokens: 50000,
+ Paragraphs: -1,
+ },
+ expectErr: true,
+ errorMsg: "paragraph count cannot be negative",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateStatisticsResponse(tt.response)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateQuerySafety(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ query string
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "valid_query",
+ query: "test query",
+ expectErr: false,
+ },
+ {
+ name: "query_too_long",
+ query: string(make([]byte, 10001)),
+ expectErr: true,
+ errorMsg: "query is too long",
+ },
+ {
+ name: "query_with_url",
+ query: "http://example.com",
+ expectErr: true,
+ errorMsg: "query appears to contain a URL",
+ },
+ {
+ name: "query_with_https_url",
+ query: "https://example.com",
+ expectErr: true,
+ errorMsg: "query appears to contain a URL",
+ },
+ {
+ name: "query_unmatched_open_paren",
+ query: "test (query",
+ expectErr: true,
+ errorMsg: "unmatched parentheses",
+ },
+ {
+ name: "query_unmatched_close_paren",
+ query: "test query)",
+ expectErr: true,
+ errorMsg: "unmatched parentheses",
+ },
+ {
+ name: "query_too_many_nested_parens",
+ query: "(" + string(make([]byte, 100)) + ")" + "(" + string(make([]byte, 100)) + ")",
+ expectErr: false, // This should be under the limit
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.validateQuerySafety(tt.query)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateCorpusID(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ corpusID string
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "valid_corpus_id",
+ corpusID: "test-corpus_1.0",
+ expectErr: false,
+ },
+ {
+ name: "empty_corpus_id",
+ corpusID: "",
+ expectErr: true,
+ errorMsg: "corpus ID cannot be empty",
+ },
+ {
+ name: "corpus_id_too_long",
+ corpusID: string(make([]byte, 101)),
+ expectErr: true,
+ errorMsg: "corpus ID is too long",
+ },
+ {
+ name: "corpus_id_invalid_chars",
+ corpusID: "invalid corpus!",
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ {
+ name: "corpus_id_with_space",
+ corpusID: "corpus with space",
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.validateCorpusID(tt.corpusID)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestSanitizeQuery(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {
+ name: "trim_whitespace",
+ input: " test query ",
+ expected: "test query",
+ },
+ {
+ name: "remove_null_bytes",
+ input: "test\x00query",
+ expected: "testquery",
+ },
+ {
+ name: "normalize_whitespace",
+ input: "test query\t\nwith spaces",
+ expected: "test query with spaces",
+ },
+ {
+ name: "empty_string",
+ input: "",
+ expected: "",
+ },
+ {
+ name: "already_clean",
+ input: "test query",
+ expected: "test query",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := validator.SanitizeQuery(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestSanitizeCorpusID(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {
+ name: "trim_whitespace",
+ input: " Test-Corpus ",
+ expected: "test-corpus",
+ },
+ {
+ name: "remove_null_bytes",
+ input: "test\x00corpus",
+ expected: "testcorpus",
+ },
+ {
+ name: "lowercase",
+ input: "Test-Corpus_1.0",
+ expected: "test-corpus_1.0",
+ },
+ {
+ name: "empty_string",
+ input: "",
+ expected: "",
+ },
+ {
+ name: "already_clean",
+ input: "test-corpus",
+ expected: "test-corpus",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := validator.SanitizeCorpusID(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// Helper function to create bool pointers
+func boolPtr(b bool) *bool {
+ return &b
+}