blob: 66b77c8a74c5ba2ca1477b7120b6455aa4f3086f [file] [log] [blame]
package validation
import (
"fmt"
"net/url"
"regexp"
"strings"
"github.com/korap/korap-mcp/service"
"github.com/rs/zerolog"
)
// Validator provides input validation and response schema validation
type Validator struct {
logger zerolog.Logger
}
// New creates a new validator instance
func New(logger zerolog.Logger) *Validator {
return &Validator{
logger: logger.With().Str("component", "validator").Logger(),
}
}
// SearchRequest holds the parameters for a search request validation
type SearchRequest struct {
Query string `json:"query"`
QueryLanguage string `json:"query_language,omitempty"`
Corpus string `json:"corpus,omitempty"`
Count int `json:"count,omitempty"`
}
// MetadataRequest holds the parameters for a metadata request validation
type MetadataRequest struct {
Action string `json:"action"`
Corpus string `json:"corpus,omitempty"`
}
// ValidationError represents a validation error with details
type ValidationError struct {
Field string `json:"field"`
Value string `json:"value"`
Message string `json:"message"`
}
func (e ValidationError) Error() string {
return fmt.Sprintf("validation error for field '%s' (value: '%s'): %s", e.Field, e.Value, e.Message)
}
// ValidationErrors represents multiple validation errors
type ValidationErrors struct {
Errors []ValidationError `json:"errors"`
}
func (e ValidationErrors) Error() string {
if len(e.Errors) == 0 {
return "validation errors occurred"
}
var messages []string
for _, err := range e.Errors {
messages = append(messages, err.Error())
}
return strings.Join(messages, "; ")
}
// Regular expressions for validation
var (
// Query language validation - KorAP supports poliqarp, cosmas2, annis
validQueryLanguages = map[string]bool{
"poliqarp": true,
"cosmas2": true,
"annis": true,
}
// Corpus ID validation - alphanumeric with dots, hyphens, underscores
corpusIDRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
// Action validation for metadata requests
validMetadataActions = map[string]bool{
"list": true,
"statistics": true,
}
)
// ValidateSearchRequest validates a search request
func (v *Validator) ValidateSearchRequest(req SearchRequest) error {
var errors []ValidationError
// Validate query - required and non-empty
if strings.TrimSpace(req.Query) == "" {
errors = append(errors, ValidationError{
Field: "query",
Value: req.Query,
Message: "query is required and cannot be empty",
})
} else {
// Basic query validation - check for potentially dangerous patterns
if err := v.validateQuerySafety(req.Query); err != nil {
errors = append(errors, ValidationError{
Field: "query",
Value: req.Query,
Message: err.Error(),
})
}
}
// Validate query language if provided
if req.QueryLanguage != "" && !validQueryLanguages[req.QueryLanguage] {
var validLangs []string
for lang := range validQueryLanguages {
validLangs = append(validLangs, lang)
}
errors = append(errors, ValidationError{
Field: "query_language",
Value: req.QueryLanguage,
Message: fmt.Sprintf("invalid query language, must be one of: %s", strings.Join(validLangs, ", ")),
})
}
// Validate corpus if provided
if req.Corpus != "" {
if err := v.validateCorpusID(req.Corpus); err != nil {
errors = append(errors, ValidationError{
Field: "corpus",
Value: req.Corpus,
Message: err.Error(),
})
}
}
// Validate count if provided (0 means use default, so only validate non-zero values)
if req.Count < 0 || req.Count > 10000 {
errors = append(errors, ValidationError{
Field: "count",
Value: fmt.Sprintf("%d", req.Count),
Message: "count must be between 0 and 10000 (0 means use default)",
})
}
if len(errors) > 0 {
v.logger.Warn().Interface("errors", errors).Msg("Search request validation failed")
return ValidationErrors{Errors: errors}
}
v.logger.Debug().Interface("request", req).Msg("Search request validation passed")
return nil
}
// ValidateMetadataRequest validates a metadata request
func (v *Validator) ValidateMetadataRequest(req MetadataRequest) error {
var errors []ValidationError
// Validate action - required
if strings.TrimSpace(req.Action) == "" {
errors = append(errors, ValidationError{
Field: "action",
Value: req.Action,
Message: "action is required and cannot be empty",
})
} else if !validMetadataActions[req.Action] {
var validActions []string
for action := range validMetadataActions {
validActions = append(validActions, action)
}
errors = append(errors, ValidationError{
Field: "action",
Value: req.Action,
Message: fmt.Sprintf("invalid action, must be one of: %s", strings.Join(validActions, ", ")),
})
}
// Validate corpus if provided
if req.Corpus != "" {
if err := v.validateCorpusID(req.Corpus); err != nil {
errors = append(errors, ValidationError{
Field: "corpus",
Value: req.Corpus,
Message: err.Error(),
})
}
}
if len(errors) > 0 {
v.logger.Warn().Interface("errors", errors).Msg("Metadata request validation failed")
return ValidationErrors{Errors: errors}
}
v.logger.Debug().Interface("request", req).Msg("Metadata request validation passed")
return nil
}
// ValidateSearchResponse validates a search response
func (v *Validator) ValidateSearchResponse(resp *service.SearchResponse) error {
if resp == nil {
return fmt.Errorf("search response is nil")
}
var errors []ValidationError
// Validate meta structure
if resp.Meta.TotalResults < 0 {
errors = append(errors, ValidationError{
Field: "meta.totalResults",
Value: fmt.Sprintf("%d", resp.Meta.TotalResults),
Message: "totalResults cannot be negative",
})
}
if resp.Meta.Count < 0 {
errors = append(errors, ValidationError{
Field: "meta.count",
Value: fmt.Sprintf("%d", resp.Meta.Count),
Message: "count cannot be negative",
})
}
if resp.Meta.StartIndex < 0 {
errors = append(errors, ValidationError{
Field: "meta.startIndex",
Value: fmt.Sprintf("%d", resp.Meta.StartIndex),
Message: "startIndex cannot be negative",
})
}
if resp.Meta.ItemsPerPage < 0 {
errors = append(errors, ValidationError{
Field: "meta.itemsPerPage",
Value: fmt.Sprintf("%d", resp.Meta.ItemsPerPage),
Message: "itemsPerPage cannot be negative",
})
}
// Validate matches if present
if resp.Matches != nil {
for i, match := range resp.Matches {
if match.MatchID == "" {
errors = append(errors, ValidationError{
Field: fmt.Sprintf("matches[%d].matchID", i),
Value: "",
Message: "match ID is required",
})
}
if match.TextSigle == "" {
errors = append(errors, ValidationError{
Field: fmt.Sprintf("matches[%d].textSigle", i),
Value: "",
Message: "textSigle is required",
})
}
if match.Position < 0 {
errors = append(errors, ValidationError{
Field: fmt.Sprintf("matches[%d].position", i),
Value: fmt.Sprintf("%d", match.Position),
Message: "position cannot be negative",
})
}
}
}
if len(errors) > 0 {
v.logger.Warn().Interface("errors", errors).Msg("Search response validation failed")
return ValidationErrors{Errors: errors}
}
v.logger.Debug().Msg("Search response validation passed")
return nil
}
// ValidateCorpusListResponse validates a corpus list response
func (v *Validator) ValidateCorpusListResponse(resp *service.CorpusListResponse) error {
if resp == nil {
return fmt.Errorf("corpus list response is nil")
}
var errors []ValidationError
// Validate corpus entries
if resp.Corpora != nil {
for i, corpus := range resp.Corpora {
if corpus.ID == "" {
errors = append(errors, ValidationError{
Field: fmt.Sprintf("corpora[%d].id", i),
Value: "",
Message: "corpus ID is required",
})
} else if err := v.validateCorpusID(corpus.ID); err != nil {
errors = append(errors, ValidationError{
Field: fmt.Sprintf("corpora[%d].id", i),
Value: corpus.ID,
Message: err.Error(),
})
}
if corpus.Name == "" {
errors = append(errors, ValidationError{
Field: fmt.Sprintf("corpora[%d].name", i),
Value: "",
Message: "corpus name is required",
})
}
if corpus.Documents < 0 {
errors = append(errors, ValidationError{
Field: fmt.Sprintf("corpora[%d].documents", i),
Value: fmt.Sprintf("%d", corpus.Documents),
Message: "document count cannot be negative",
})
}
if corpus.Tokens < 0 {
errors = append(errors, ValidationError{
Field: fmt.Sprintf("corpora[%d].tokens", i),
Value: fmt.Sprintf("%d", corpus.Tokens),
Message: "token count cannot be negative",
})
}
}
}
if len(errors) > 0 {
v.logger.Warn().Interface("errors", errors).Msg("Corpus list response validation failed")
return ValidationErrors{Errors: errors}
}
v.logger.Debug().Msg("Corpus list response validation passed")
return nil
}
// ValidateStatisticsResponse validates a statistics response
func (v *Validator) ValidateStatisticsResponse(resp *service.StatisticsResponse) error {
if resp == nil {
return fmt.Errorf("statistics response is nil")
}
var errors []ValidationError
if resp.Documents < 0 {
errors = append(errors, ValidationError{
Field: "documents",
Value: fmt.Sprintf("%d", resp.Documents),
Message: "document count cannot be negative",
})
}
if resp.Tokens < 0 {
errors = append(errors, ValidationError{
Field: "tokens",
Value: fmt.Sprintf("%d", resp.Tokens),
Message: "token count cannot be negative",
})
}
if resp.Sentences < 0 {
errors = append(errors, ValidationError{
Field: "sentences",
Value: fmt.Sprintf("%d", resp.Sentences),
Message: "sentence count cannot be negative",
})
}
if resp.Paragraphs < 0 {
errors = append(errors, ValidationError{
Field: "paragraphs",
Value: fmt.Sprintf("%d", resp.Paragraphs),
Message: "paragraph count cannot be negative",
})
}
if len(errors) > 0 {
v.logger.Warn().Interface("errors", errors).Msg("Statistics response validation failed")
return ValidationErrors{Errors: errors}
}
v.logger.Debug().Msg("Statistics response validation passed")
return nil
}
// validateQuerySafety performs basic security validation on queries
func (v *Validator) validateQuerySafety(query string) error {
// Check for extremely long queries that could cause DoS
if len(query) > 10000 {
return fmt.Errorf("query is too long (max 10000 characters)")
}
// Check for potentially dangerous URL patterns
if strings.Contains(query, "://") {
if _, err := url.Parse(query); err == nil {
return fmt.Errorf("query appears to contain a URL which is not allowed")
}
}
// Check for excessive nesting that could cause parser issues
openParens := strings.Count(query, "(")
closeParens := strings.Count(query, ")")
if openParens != closeParens {
return fmt.Errorf("unmatched parentheses in query")
}
if openParens > 100 {
return fmt.Errorf("query has too many nested levels (max 100)")
}
return nil
}
// validateCorpusID validates a corpus identifier
func (v *Validator) validateCorpusID(corpusID string) error {
if len(corpusID) == 0 {
return fmt.Errorf("corpus ID cannot be empty")
}
if len(corpusID) > 100 {
return fmt.Errorf("corpus ID is too long (max 100 characters)")
}
if !corpusIDRegex.MatchString(corpusID) {
return fmt.Errorf("corpus ID contains invalid characters (only alphanumeric, dots, hyphens, underscores allowed)")
}
return nil
}
// SanitizeQuery performs basic sanitization on search queries
func (v *Validator) SanitizeQuery(query string) string {
// Trim whitespace
sanitized := strings.TrimSpace(query)
// Remove any null bytes
sanitized = strings.ReplaceAll(sanitized, "\x00", "")
// Normalize whitespace
sanitized = regexp.MustCompile(`\s+`).ReplaceAllString(sanitized, " ")
v.logger.Debug().
Str("original", query).
Str("sanitized", sanitized).
Msg("Query sanitized")
return sanitized
}
// SanitizeCorpusID performs basic sanitization on corpus IDs
func (v *Validator) SanitizeCorpusID(corpusID string) string {
// Trim whitespace
sanitized := strings.TrimSpace(corpusID)
// Remove any null bytes
sanitized = strings.ReplaceAll(sanitized, "\x00", "")
// Convert to lowercase for consistency
sanitized = strings.ToLower(sanitized)
v.logger.Debug().
Str("original", corpusID).
Str("sanitized", sanitized).
Msg("Corpus ID sanitized")
return sanitized
}