blob: 343e8206cb84cacdee359bf11e37a4cc8e6e78a1 [file] [log] [blame]
Akron81f709c2025-06-12 17:30:55 +02001package validation
2
3import (
4 "fmt"
5 "net/url"
6 "regexp"
7 "strings"
8
9 "github.com/korap/korap-mcp/service"
10 "github.com/rs/zerolog"
11)
12
13// Validator provides input validation and response schema validation
14type Validator struct {
15 logger zerolog.Logger
16}
17
18// New creates a new validator instance
19func New(logger zerolog.Logger) *Validator {
20 return &Validator{
21 logger: logger.With().Str("component", "validator").Logger(),
22 }
23}
24
25// SearchRequest holds the parameters for a search request validation
26type SearchRequest struct {
27 Query string `json:"query"`
28 QueryLanguage string `json:"query_language,omitempty"`
29 Corpus string `json:"corpus,omitempty"`
30 Count int `json:"count,omitempty"`
31}
32
33// MetadataRequest holds the parameters for a metadata request validation
34type MetadataRequest struct {
35 Action string `json:"action"`
36 Corpus string `json:"corpus,omitempty"`
37}
38
39// ValidationError represents a validation error with details
40type ValidationError struct {
41 Field string `json:"field"`
42 Value string `json:"value"`
43 Message string `json:"message"`
44}
45
46func (e ValidationError) Error() string {
47 return fmt.Sprintf("validation error for field '%s' (value: '%s'): %s", e.Field, e.Value, e.Message)
48}
49
50// ValidationErrors represents multiple validation errors
51type ValidationErrors struct {
52 Errors []ValidationError `json:"errors"`
53}
54
55func (e ValidationErrors) Error() string {
56 if len(e.Errors) == 0 {
57 return "validation errors occurred"
58 }
59 var messages []string
60 for _, err := range e.Errors {
61 messages = append(messages, err.Error())
62 }
63 return strings.Join(messages, "; ")
64}
65
66// Regular expressions for validation
67var (
Akron8db31c32025-06-17 12:22:41 +020068 // Query language validation
69 validQueryLanguages = []string{"poliqarp", "poliqarpplus", "cosmas2", "annis", "cql", "cqp", "fcsql"}
Akron81f709c2025-06-12 17:30:55 +020070
Akron8db31c32025-06-17 12:22:41 +020071 // Corpus ID validation - KorAP collection queries with metadata fields, operators, and regex
72 corpusIDRegex = regexp.MustCompile(`^[a-zA-Z0-9._\-\s&|!=<>()/*"']+$`)
Akron81f709c2025-06-12 17:30:55 +020073
74 // Action validation for metadata requests
75 validMetadataActions = map[string]bool{
76 "list": true,
77 "statistics": true,
78 }
79)
80
81// ValidateSearchRequest validates a search request
82func (v *Validator) ValidateSearchRequest(req SearchRequest) error {
83 var errors []ValidationError
84
85 // Validate query - required and non-empty
86 if strings.TrimSpace(req.Query) == "" {
87 errors = append(errors, ValidationError{
88 Field: "query",
89 Value: req.Query,
90 Message: "query is required and cannot be empty",
91 })
92 } else {
93 // Basic query validation - check for potentially dangerous patterns
94 if err := v.validateQuerySafety(req.Query); err != nil {
95 errors = append(errors, ValidationError{
96 Field: "query",
97 Value: req.Query,
98 Message: err.Error(),
99 })
100 }
101 }
102
103 // Validate query language if provided
Akron8db31c32025-06-17 12:22:41 +0200104 if req.QueryLanguage != "" && !contains(validQueryLanguages, req.QueryLanguage) {
Akron81f709c2025-06-12 17:30:55 +0200105 errors = append(errors, ValidationError{
106 Field: "query_language",
107 Value: req.QueryLanguage,
Akron8db31c32025-06-17 12:22:41 +0200108 Message: fmt.Sprintf("invalid query language, must be one of: %s", strings.Join(validQueryLanguages, ", ")),
Akron81f709c2025-06-12 17:30:55 +0200109 })
110 }
111
112 // Validate corpus if provided
113 if req.Corpus != "" {
114 if err := v.validateCorpusID(req.Corpus); err != nil {
115 errors = append(errors, ValidationError{
116 Field: "corpus",
117 Value: req.Corpus,
118 Message: err.Error(),
119 })
120 }
121 }
122
123 // Validate count if provided (0 means use default, so only validate non-zero values)
124 if req.Count < 0 || req.Count > 10000 {
125 errors = append(errors, ValidationError{
126 Field: "count",
127 Value: fmt.Sprintf("%d", req.Count),
128 Message: "count must be between 0 and 10000 (0 means use default)",
129 })
130 }
131
132 if len(errors) > 0 {
133 v.logger.Warn().Interface("errors", errors).Msg("Search request validation failed")
134 return ValidationErrors{Errors: errors}
135 }
136
137 v.logger.Debug().Interface("request", req).Msg("Search request validation passed")
138 return nil
139}
140
141// ValidateMetadataRequest validates a metadata request
142func (v *Validator) ValidateMetadataRequest(req MetadataRequest) error {
143 var errors []ValidationError
144
145 // Validate action - required
146 if strings.TrimSpace(req.Action) == "" {
147 errors = append(errors, ValidationError{
148 Field: "action",
149 Value: req.Action,
150 Message: "action is required and cannot be empty",
151 })
152 } else if !validMetadataActions[req.Action] {
153 var validActions []string
154 for action := range validMetadataActions {
155 validActions = append(validActions, action)
156 }
157 errors = append(errors, ValidationError{
158 Field: "action",
159 Value: req.Action,
160 Message: fmt.Sprintf("invalid action, must be one of: %s", strings.Join(validActions, ", ")),
161 })
162 }
163
164 // Validate corpus if provided
165 if req.Corpus != "" {
166 if err := v.validateCorpusID(req.Corpus); err != nil {
167 errors = append(errors, ValidationError{
168 Field: "corpus",
169 Value: req.Corpus,
170 Message: err.Error(),
171 })
172 }
173 }
174
175 if len(errors) > 0 {
176 v.logger.Warn().Interface("errors", errors).Msg("Metadata request validation failed")
177 return ValidationErrors{Errors: errors}
178 }
179
180 v.logger.Debug().Interface("request", req).Msg("Metadata request validation passed")
181 return nil
182}
183
184// ValidateSearchResponse validates a search response
185func (v *Validator) ValidateSearchResponse(resp *service.SearchResponse) error {
186 if resp == nil {
187 return fmt.Errorf("search response is nil")
188 }
189
190 var errors []ValidationError
191
192 // Validate meta structure
193 if resp.Meta.TotalResults < 0 {
194 errors = append(errors, ValidationError{
195 Field: "meta.totalResults",
196 Value: fmt.Sprintf("%d", resp.Meta.TotalResults),
197 Message: "totalResults cannot be negative",
198 })
199 }
200
201 if resp.Meta.Count < 0 {
202 errors = append(errors, ValidationError{
203 Field: "meta.count",
204 Value: fmt.Sprintf("%d", resp.Meta.Count),
205 Message: "count cannot be negative",
206 })
207 }
208
209 if resp.Meta.StartIndex < 0 {
210 errors = append(errors, ValidationError{
211 Field: "meta.startIndex",
212 Value: fmt.Sprintf("%d", resp.Meta.StartIndex),
213 Message: "startIndex cannot be negative",
214 })
215 }
216
217 if resp.Meta.ItemsPerPage < 0 {
218 errors = append(errors, ValidationError{
219 Field: "meta.itemsPerPage",
220 Value: fmt.Sprintf("%d", resp.Meta.ItemsPerPage),
221 Message: "itemsPerPage cannot be negative",
222 })
223 }
224
225 // Validate matches if present
226 if resp.Matches != nil {
227 for i, match := range resp.Matches {
228 if match.MatchID == "" {
229 errors = append(errors, ValidationError{
230 Field: fmt.Sprintf("matches[%d].matchID", i),
231 Value: "",
232 Message: "match ID is required",
233 })
234 }
235
236 if match.TextSigle == "" {
237 errors = append(errors, ValidationError{
238 Field: fmt.Sprintf("matches[%d].textSigle", i),
239 Value: "",
240 Message: "textSigle is required",
241 })
242 }
243
244 if match.Position < 0 {
245 errors = append(errors, ValidationError{
246 Field: fmt.Sprintf("matches[%d].position", i),
247 Value: fmt.Sprintf("%d", match.Position),
248 Message: "position cannot be negative",
249 })
250 }
251 }
252 }
253
254 if len(errors) > 0 {
255 v.logger.Warn().Interface("errors", errors).Msg("Search response validation failed")
256 return ValidationErrors{Errors: errors}
257 }
258
259 v.logger.Debug().Msg("Search response validation passed")
260 return nil
261}
262
263// ValidateCorpusListResponse validates a corpus list response
264func (v *Validator) ValidateCorpusListResponse(resp *service.CorpusListResponse) error {
265 if resp == nil {
266 return fmt.Errorf("corpus list response is nil")
267 }
268
269 var errors []ValidationError
270
271 // Validate corpus entries
272 if resp.Corpora != nil {
273 for i, corpus := range resp.Corpora {
274 if corpus.ID == "" {
275 errors = append(errors, ValidationError{
276 Field: fmt.Sprintf("corpora[%d].id", i),
277 Value: "",
278 Message: "corpus ID is required",
279 })
280 } else if err := v.validateCorpusID(corpus.ID); err != nil {
281 errors = append(errors, ValidationError{
282 Field: fmt.Sprintf("corpora[%d].id", i),
283 Value: corpus.ID,
284 Message: err.Error(),
285 })
286 }
287
288 if corpus.Name == "" {
289 errors = append(errors, ValidationError{
290 Field: fmt.Sprintf("corpora[%d].name", i),
291 Value: "",
292 Message: "corpus name is required",
293 })
294 }
295
296 if corpus.Documents < 0 {
297 errors = append(errors, ValidationError{
298 Field: fmt.Sprintf("corpora[%d].documents", i),
299 Value: fmt.Sprintf("%d", corpus.Documents),
300 Message: "document count cannot be negative",
301 })
302 }
303
304 if corpus.Tokens < 0 {
305 errors = append(errors, ValidationError{
306 Field: fmt.Sprintf("corpora[%d].tokens", i),
307 Value: fmt.Sprintf("%d", corpus.Tokens),
308 Message: "token count cannot be negative",
309 })
310 }
311 }
312 }
313
314 if len(errors) > 0 {
315 v.logger.Warn().Interface("errors", errors).Msg("Corpus list response validation failed")
316 return ValidationErrors{Errors: errors}
317 }
318
319 v.logger.Debug().Msg("Corpus list response validation passed")
320 return nil
321}
322
323// ValidateStatisticsResponse validates a statistics response
324func (v *Validator) ValidateStatisticsResponse(resp *service.StatisticsResponse) error {
325 if resp == nil {
326 return fmt.Errorf("statistics response is nil")
327 }
328
329 var errors []ValidationError
330
331 if resp.Documents < 0 {
332 errors = append(errors, ValidationError{
333 Field: "documents",
334 Value: fmt.Sprintf("%d", resp.Documents),
335 Message: "document count cannot be negative",
336 })
337 }
338
339 if resp.Tokens < 0 {
340 errors = append(errors, ValidationError{
341 Field: "tokens",
342 Value: fmt.Sprintf("%d", resp.Tokens),
343 Message: "token count cannot be negative",
344 })
345 }
346
347 if resp.Sentences < 0 {
348 errors = append(errors, ValidationError{
349 Field: "sentences",
350 Value: fmt.Sprintf("%d", resp.Sentences),
351 Message: "sentence count cannot be negative",
352 })
353 }
354
355 if resp.Paragraphs < 0 {
356 errors = append(errors, ValidationError{
357 Field: "paragraphs",
358 Value: fmt.Sprintf("%d", resp.Paragraphs),
359 Message: "paragraph count cannot be negative",
360 })
361 }
362
363 if len(errors) > 0 {
364 v.logger.Warn().Interface("errors", errors).Msg("Statistics response validation failed")
365 return ValidationErrors{Errors: errors}
366 }
367
368 v.logger.Debug().Msg("Statistics response validation passed")
369 return nil
370}
371
372// validateQuerySafety performs basic security validation on queries
373func (v *Validator) validateQuerySafety(query string) error {
374 // Check for extremely long queries that could cause DoS
375 if len(query) > 10000 {
376 return fmt.Errorf("query is too long (max 10000 characters)")
377 }
378
379 // Check for potentially dangerous URL patterns
380 if strings.Contains(query, "://") {
381 if _, err := url.Parse(query); err == nil {
382 return fmt.Errorf("query appears to contain a URL which is not allowed")
383 }
384 }
385
386 // Check for excessive nesting that could cause parser issues
387 openParens := strings.Count(query, "(")
388 closeParens := strings.Count(query, ")")
389 if openParens != closeParens {
390 return fmt.Errorf("unmatched parentheses in query")
391 }
392 if openParens > 100 {
393 return fmt.Errorf("query has too many nested levels (max 100)")
394 }
395
396 return nil
397}
398
Akron8db31c32025-06-17 12:22:41 +0200399// validateCorpusID validates a corpus identifier or collection query
400// This supports both simple corpus sigles (e.g., "DeReKo-2023-I") and complex
401// collection queries with metadata fields (e.g., "textClass = \"politics\" & pubDate in 2020")
Akron81f709c2025-06-12 17:30:55 +0200402func (v *Validator) validateCorpusID(corpusID string) error {
403 if len(corpusID) == 0 {
404 return fmt.Errorf("corpus ID cannot be empty")
405 }
406
407 if len(corpusID) > 100 {
408 return fmt.Errorf("corpus ID is too long (max 100 characters)")
409 }
410
411 if !corpusIDRegex.MatchString(corpusID) {
Akron8db31c32025-06-17 12:22:41 +0200412 return fmt.Errorf("collection query contains invalid characters (supports alphanumeric, dots, hyphens, underscores, spaces, quotes, operators & | ! = < > in, parentheses, and regex /pattern/)")
Akron81f709c2025-06-12 17:30:55 +0200413 }
414
415 return nil
416}
417
418// SanitizeQuery performs basic sanitization on search queries
419func (v *Validator) SanitizeQuery(query string) string {
420 // Trim whitespace
421 sanitized := strings.TrimSpace(query)
422
423 // Remove any null bytes
424 sanitized = strings.ReplaceAll(sanitized, "\x00", "")
425
426 // Normalize whitespace
427 sanitized = regexp.MustCompile(`\s+`).ReplaceAllString(sanitized, " ")
428
429 v.logger.Debug().
430 Str("original", query).
431 Str("sanitized", sanitized).
432 Msg("Query sanitized")
433
434 return sanitized
435}
436
437// SanitizeCorpusID performs basic sanitization on corpus IDs
438func (v *Validator) SanitizeCorpusID(corpusID string) string {
439 // Trim whitespace
440 sanitized := strings.TrimSpace(corpusID)
441
442 // Remove any null bytes
443 sanitized = strings.ReplaceAll(sanitized, "\x00", "")
444
445 // Convert to lowercase for consistency
446 sanitized = strings.ToLower(sanitized)
447
448 v.logger.Debug().
449 Str("original", corpusID).
450 Str("sanitized", sanitized).
451 Msg("Corpus ID sanitized")
452
453 return sanitized
454}
Akron8db31c32025-06-17 12:22:41 +0200455
456// contains checks if a string slice contains a specific value
457func contains(slice []string, item string) bool {
458 for _, s := range slice {
459 if s == item {
460 return true
461 }
462 }
463 return false
464}