Added request/response validation
diff --git a/validation/validator.go b/validation/validator.go
new file mode 100644
index 0000000..66b77c8
--- /dev/null
+++ b/validation/validator.go
@@ -0,0 +1,460 @@
+package validation
+
+import (
+ "fmt"
+ "net/url"
+ "regexp"
+ "strings"
+
+ "github.com/korap/korap-mcp/service"
+ "github.com/rs/zerolog"
+)
+
+// Validator provides input validation and response schema validation
+type Validator struct {
+ logger zerolog.Logger
+}
+
+// New creates a new validator instance
+func New(logger zerolog.Logger) *Validator {
+ return &Validator{
+ logger: logger.With().Str("component", "validator").Logger(),
+ }
+}
+
+// SearchRequest holds the parameters for a search request validation
+type SearchRequest struct {
+ Query string `json:"query"`
+ QueryLanguage string `json:"query_language,omitempty"`
+ Corpus string `json:"corpus,omitempty"`
+ Count int `json:"count,omitempty"`
+}
+
+// MetadataRequest holds the parameters for a metadata request validation
+type MetadataRequest struct {
+ Action string `json:"action"`
+ Corpus string `json:"corpus,omitempty"`
+}
+
+// ValidationError represents a validation error with details
+type ValidationError struct {
+ Field string `json:"field"`
+ Value string `json:"value"`
+ Message string `json:"message"`
+}
+
+func (e ValidationError) Error() string {
+ return fmt.Sprintf("validation error for field '%s' (value: '%s'): %s", e.Field, e.Value, e.Message)
+}
+
+// ValidationErrors represents multiple validation errors
+type ValidationErrors struct {
+ Errors []ValidationError `json:"errors"`
+}
+
+func (e ValidationErrors) Error() string {
+ if len(e.Errors) == 0 {
+ return "validation errors occurred"
+ }
+ var messages []string
+ for _, err := range e.Errors {
+ messages = append(messages, err.Error())
+ }
+ return strings.Join(messages, "; ")
+}
+
+// Regular expressions for validation
+var (
+ // Query language validation - KorAP supports poliqarp, cosmas2, annis
+ validQueryLanguages = map[string]bool{
+ "poliqarp": true,
+ "cosmas2": true,
+ "annis": true,
+ }
+
+ // Corpus ID validation - alphanumeric with dots, hyphens, underscores
+ corpusIDRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
+
+ // Action validation for metadata requests
+ validMetadataActions = map[string]bool{
+ "list": true,
+ "statistics": true,
+ }
+)
+
+// ValidateSearchRequest validates a search request
+func (v *Validator) ValidateSearchRequest(req SearchRequest) error {
+ var errors []ValidationError
+
+ // Validate query - required and non-empty
+ if strings.TrimSpace(req.Query) == "" {
+ errors = append(errors, ValidationError{
+ Field: "query",
+ Value: req.Query,
+ Message: "query is required and cannot be empty",
+ })
+ } else {
+ // Basic query validation - check for potentially dangerous patterns
+ if err := v.validateQuerySafety(req.Query); err != nil {
+ errors = append(errors, ValidationError{
+ Field: "query",
+ Value: req.Query,
+ Message: err.Error(),
+ })
+ }
+ }
+
+ // Validate query language if provided
+ if req.QueryLanguage != "" && !validQueryLanguages[req.QueryLanguage] {
+ var validLangs []string
+ for lang := range validQueryLanguages {
+ validLangs = append(validLangs, lang)
+ }
+ errors = append(errors, ValidationError{
+ Field: "query_language",
+ Value: req.QueryLanguage,
+ Message: fmt.Sprintf("invalid query language, must be one of: %s", strings.Join(validLangs, ", ")),
+ })
+ }
+
+ // Validate corpus if provided
+ if req.Corpus != "" {
+ if err := v.validateCorpusID(req.Corpus); err != nil {
+ errors = append(errors, ValidationError{
+ Field: "corpus",
+ Value: req.Corpus,
+ Message: err.Error(),
+ })
+ }
+ }
+
+ // Validate count if provided (0 means use default, so only validate non-zero values)
+ if req.Count < 0 || req.Count > 10000 {
+ errors = append(errors, ValidationError{
+ Field: "count",
+ Value: fmt.Sprintf("%d", req.Count),
+ Message: "count must be between 0 and 10000 (0 means use default)",
+ })
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Search request validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Interface("request", req).Msg("Search request validation passed")
+ return nil
+}
+
+// ValidateMetadataRequest validates a metadata request
+func (v *Validator) ValidateMetadataRequest(req MetadataRequest) error {
+ var errors []ValidationError
+
+ // Validate action - required
+ if strings.TrimSpace(req.Action) == "" {
+ errors = append(errors, ValidationError{
+ Field: "action",
+ Value: req.Action,
+ Message: "action is required and cannot be empty",
+ })
+ } else if !validMetadataActions[req.Action] {
+ var validActions []string
+ for action := range validMetadataActions {
+ validActions = append(validActions, action)
+ }
+ errors = append(errors, ValidationError{
+ Field: "action",
+ Value: req.Action,
+ Message: fmt.Sprintf("invalid action, must be one of: %s", strings.Join(validActions, ", ")),
+ })
+ }
+
+ // Validate corpus if provided
+ if req.Corpus != "" {
+ if err := v.validateCorpusID(req.Corpus); err != nil {
+ errors = append(errors, ValidationError{
+ Field: "corpus",
+ Value: req.Corpus,
+ Message: err.Error(),
+ })
+ }
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Metadata request validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Interface("request", req).Msg("Metadata request validation passed")
+ return nil
+}
+
+// ValidateSearchResponse validates a search response
+func (v *Validator) ValidateSearchResponse(resp *service.SearchResponse) error {
+ if resp == nil {
+ return fmt.Errorf("search response is nil")
+ }
+
+ var errors []ValidationError
+
+ // Validate meta structure
+ if resp.Meta.TotalResults < 0 {
+ errors = append(errors, ValidationError{
+ Field: "meta.totalResults",
+ Value: fmt.Sprintf("%d", resp.Meta.TotalResults),
+ Message: "totalResults cannot be negative",
+ })
+ }
+
+ if resp.Meta.Count < 0 {
+ errors = append(errors, ValidationError{
+ Field: "meta.count",
+ Value: fmt.Sprintf("%d", resp.Meta.Count),
+ Message: "count cannot be negative",
+ })
+ }
+
+ if resp.Meta.StartIndex < 0 {
+ errors = append(errors, ValidationError{
+ Field: "meta.startIndex",
+ Value: fmt.Sprintf("%d", resp.Meta.StartIndex),
+ Message: "startIndex cannot be negative",
+ })
+ }
+
+ if resp.Meta.ItemsPerPage < 0 {
+ errors = append(errors, ValidationError{
+ Field: "meta.itemsPerPage",
+ Value: fmt.Sprintf("%d", resp.Meta.ItemsPerPage),
+ Message: "itemsPerPage cannot be negative",
+ })
+ }
+
+ // Validate matches if present
+ if resp.Matches != nil {
+ for i, match := range resp.Matches {
+ if match.MatchID == "" {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("matches[%d].matchID", i),
+ Value: "",
+ Message: "match ID is required",
+ })
+ }
+
+ if match.TextSigle == "" {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("matches[%d].textSigle", i),
+ Value: "",
+ Message: "textSigle is required",
+ })
+ }
+
+ if match.Position < 0 {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("matches[%d].position", i),
+ Value: fmt.Sprintf("%d", match.Position),
+ Message: "position cannot be negative",
+ })
+ }
+ }
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Search response validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Msg("Search response validation passed")
+ return nil
+}
+
+// ValidateCorpusListResponse validates a corpus list response
+func (v *Validator) ValidateCorpusListResponse(resp *service.CorpusListResponse) error {
+ if resp == nil {
+ return fmt.Errorf("corpus list response is nil")
+ }
+
+ var errors []ValidationError
+
+ // Validate corpus entries
+ if resp.Corpora != nil {
+ for i, corpus := range resp.Corpora {
+ if corpus.ID == "" {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].id", i),
+ Value: "",
+ Message: "corpus ID is required",
+ })
+ } else if err := v.validateCorpusID(corpus.ID); err != nil {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].id", i),
+ Value: corpus.ID,
+ Message: err.Error(),
+ })
+ }
+
+ if corpus.Name == "" {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].name", i),
+ Value: "",
+ Message: "corpus name is required",
+ })
+ }
+
+ if corpus.Documents < 0 {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].documents", i),
+ Value: fmt.Sprintf("%d", corpus.Documents),
+ Message: "document count cannot be negative",
+ })
+ }
+
+ if corpus.Tokens < 0 {
+ errors = append(errors, ValidationError{
+ Field: fmt.Sprintf("corpora[%d].tokens", i),
+ Value: fmt.Sprintf("%d", corpus.Tokens),
+ Message: "token count cannot be negative",
+ })
+ }
+ }
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Corpus list response validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Msg("Corpus list response validation passed")
+ return nil
+}
+
+// ValidateStatisticsResponse validates a statistics response
+func (v *Validator) ValidateStatisticsResponse(resp *service.StatisticsResponse) error {
+ if resp == nil {
+ return fmt.Errorf("statistics response is nil")
+ }
+
+ var errors []ValidationError
+
+ if resp.Documents < 0 {
+ errors = append(errors, ValidationError{
+ Field: "documents",
+ Value: fmt.Sprintf("%d", resp.Documents),
+ Message: "document count cannot be negative",
+ })
+ }
+
+ if resp.Tokens < 0 {
+ errors = append(errors, ValidationError{
+ Field: "tokens",
+ Value: fmt.Sprintf("%d", resp.Tokens),
+ Message: "token count cannot be negative",
+ })
+ }
+
+ if resp.Sentences < 0 {
+ errors = append(errors, ValidationError{
+ Field: "sentences",
+ Value: fmt.Sprintf("%d", resp.Sentences),
+ Message: "sentence count cannot be negative",
+ })
+ }
+
+ if resp.Paragraphs < 0 {
+ errors = append(errors, ValidationError{
+ Field: "paragraphs",
+ Value: fmt.Sprintf("%d", resp.Paragraphs),
+ Message: "paragraph count cannot be negative",
+ })
+ }
+
+ if len(errors) > 0 {
+ v.logger.Warn().Interface("errors", errors).Msg("Statistics response validation failed")
+ return ValidationErrors{Errors: errors}
+ }
+
+ v.logger.Debug().Msg("Statistics response validation passed")
+ return nil
+}
+
+// validateQuerySafety performs basic security validation on queries
+func (v *Validator) validateQuerySafety(query string) error {
+ // Check for extremely long queries that could cause DoS
+ if len(query) > 10000 {
+ return fmt.Errorf("query is too long (max 10000 characters)")
+ }
+
+ // Check for potentially dangerous URL patterns
+ if strings.Contains(query, "://") {
+ if _, err := url.Parse(query); err == nil {
+ return fmt.Errorf("query appears to contain a URL which is not allowed")
+ }
+ }
+
+ // Check for excessive nesting that could cause parser issues
+ openParens := strings.Count(query, "(")
+ closeParens := strings.Count(query, ")")
+ if openParens != closeParens {
+ return fmt.Errorf("unmatched parentheses in query")
+ }
+ if openParens > 100 {
+ return fmt.Errorf("query has too many nested levels (max 100)")
+ }
+
+ return nil
+}
+
+// validateCorpusID validates a corpus identifier
+func (v *Validator) validateCorpusID(corpusID string) error {
+ if len(corpusID) == 0 {
+ return fmt.Errorf("corpus ID cannot be empty")
+ }
+
+ if len(corpusID) > 100 {
+ return fmt.Errorf("corpus ID is too long (max 100 characters)")
+ }
+
+ if !corpusIDRegex.MatchString(corpusID) {
+ return fmt.Errorf("corpus ID contains invalid characters (only alphanumeric, dots, hyphens, underscores allowed)")
+ }
+
+ return nil
+}
+
+// SanitizeQuery performs basic sanitization on search queries
+func (v *Validator) SanitizeQuery(query string) string {
+ // Trim whitespace
+ sanitized := strings.TrimSpace(query)
+
+ // Remove any null bytes
+ sanitized = strings.ReplaceAll(sanitized, "\x00", "")
+
+ // Normalize whitespace
+ sanitized = regexp.MustCompile(`\s+`).ReplaceAllString(sanitized, " ")
+
+ v.logger.Debug().
+ Str("original", query).
+ Str("sanitized", sanitized).
+ Msg("Query sanitized")
+
+ return sanitized
+}
+
+// SanitizeCorpusID performs basic sanitization on corpus IDs
+func (v *Validator) SanitizeCorpusID(corpusID string) string {
+ // Trim whitespace
+ sanitized := strings.TrimSpace(corpusID)
+
+ // Remove any null bytes
+ sanitized = strings.ReplaceAll(sanitized, "\x00", "")
+
+ // Convert to lowercase for consistency
+ sanitized = strings.ToLower(sanitized)
+
+ v.logger.Debug().
+ Str("original", corpusID).
+ Str("sanitized", sanitized).
+ Msg("Corpus ID sanitized")
+
+ return sanitized
+}