blob: 66b77c8a74c5ba2ca1477b7120b6455aa4f3086f [file] [log] [blame]
Akron81f709c2025-06-12 17:30:55 +02001package validation
2
3import (
4 "fmt"
5 "net/url"
6 "regexp"
7 "strings"
8
9 "github.com/korap/korap-mcp/service"
10 "github.com/rs/zerolog"
11)
12
13// Validator provides input validation and response schema validation
14type Validator struct {
15 logger zerolog.Logger
16}
17
18// New creates a new validator instance
19func New(logger zerolog.Logger) *Validator {
20 return &Validator{
21 logger: logger.With().Str("component", "validator").Logger(),
22 }
23}
24
25// SearchRequest holds the parameters for a search request validation
26type SearchRequest struct {
27 Query string `json:"query"`
28 QueryLanguage string `json:"query_language,omitempty"`
29 Corpus string `json:"corpus,omitempty"`
30 Count int `json:"count,omitempty"`
31}
32
33// MetadataRequest holds the parameters for a metadata request validation
34type MetadataRequest struct {
35 Action string `json:"action"`
36 Corpus string `json:"corpus,omitempty"`
37}
38
39// ValidationError represents a validation error with details
40type ValidationError struct {
41 Field string `json:"field"`
42 Value string `json:"value"`
43 Message string `json:"message"`
44}
45
46func (e ValidationError) Error() string {
47 return fmt.Sprintf("validation error for field '%s' (value: '%s'): %s", e.Field, e.Value, e.Message)
48}
49
50// ValidationErrors represents multiple validation errors
51type ValidationErrors struct {
52 Errors []ValidationError `json:"errors"`
53}
54
55func (e ValidationErrors) Error() string {
56 if len(e.Errors) == 0 {
57 return "validation errors occurred"
58 }
59 var messages []string
60 for _, err := range e.Errors {
61 messages = append(messages, err.Error())
62 }
63 return strings.Join(messages, "; ")
64}
65
66// Regular expressions for validation
67var (
68 // Query language validation - KorAP supports poliqarp, cosmas2, annis
69 validQueryLanguages = map[string]bool{
70 "poliqarp": true,
71 "cosmas2": true,
72 "annis": true,
73 }
74
75 // Corpus ID validation - alphanumeric with dots, hyphens, underscores
76 corpusIDRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
77
78 // Action validation for metadata requests
79 validMetadataActions = map[string]bool{
80 "list": true,
81 "statistics": true,
82 }
83)
84
85// ValidateSearchRequest validates a search request
86func (v *Validator) ValidateSearchRequest(req SearchRequest) error {
87 var errors []ValidationError
88
89 // Validate query - required and non-empty
90 if strings.TrimSpace(req.Query) == "" {
91 errors = append(errors, ValidationError{
92 Field: "query",
93 Value: req.Query,
94 Message: "query is required and cannot be empty",
95 })
96 } else {
97 // Basic query validation - check for potentially dangerous patterns
98 if err := v.validateQuerySafety(req.Query); err != nil {
99 errors = append(errors, ValidationError{
100 Field: "query",
101 Value: req.Query,
102 Message: err.Error(),
103 })
104 }
105 }
106
107 // Validate query language if provided
108 if req.QueryLanguage != "" && !validQueryLanguages[req.QueryLanguage] {
109 var validLangs []string
110 for lang := range validQueryLanguages {
111 validLangs = append(validLangs, lang)
112 }
113 errors = append(errors, ValidationError{
114 Field: "query_language",
115 Value: req.QueryLanguage,
116 Message: fmt.Sprintf("invalid query language, must be one of: %s", strings.Join(validLangs, ", ")),
117 })
118 }
119
120 // Validate corpus if provided
121 if req.Corpus != "" {
122 if err := v.validateCorpusID(req.Corpus); err != nil {
123 errors = append(errors, ValidationError{
124 Field: "corpus",
125 Value: req.Corpus,
126 Message: err.Error(),
127 })
128 }
129 }
130
131 // Validate count if provided (0 means use default, so only validate non-zero values)
132 if req.Count < 0 || req.Count > 10000 {
133 errors = append(errors, ValidationError{
134 Field: "count",
135 Value: fmt.Sprintf("%d", req.Count),
136 Message: "count must be between 0 and 10000 (0 means use default)",
137 })
138 }
139
140 if len(errors) > 0 {
141 v.logger.Warn().Interface("errors", errors).Msg("Search request validation failed")
142 return ValidationErrors{Errors: errors}
143 }
144
145 v.logger.Debug().Interface("request", req).Msg("Search request validation passed")
146 return nil
147}
148
149// ValidateMetadataRequest validates a metadata request
150func (v *Validator) ValidateMetadataRequest(req MetadataRequest) error {
151 var errors []ValidationError
152
153 // Validate action - required
154 if strings.TrimSpace(req.Action) == "" {
155 errors = append(errors, ValidationError{
156 Field: "action",
157 Value: req.Action,
158 Message: "action is required and cannot be empty",
159 })
160 } else if !validMetadataActions[req.Action] {
161 var validActions []string
162 for action := range validMetadataActions {
163 validActions = append(validActions, action)
164 }
165 errors = append(errors, ValidationError{
166 Field: "action",
167 Value: req.Action,
168 Message: fmt.Sprintf("invalid action, must be one of: %s", strings.Join(validActions, ", ")),
169 })
170 }
171
172 // Validate corpus if provided
173 if req.Corpus != "" {
174 if err := v.validateCorpusID(req.Corpus); err != nil {
175 errors = append(errors, ValidationError{
176 Field: "corpus",
177 Value: req.Corpus,
178 Message: err.Error(),
179 })
180 }
181 }
182
183 if len(errors) > 0 {
184 v.logger.Warn().Interface("errors", errors).Msg("Metadata request validation failed")
185 return ValidationErrors{Errors: errors}
186 }
187
188 v.logger.Debug().Interface("request", req).Msg("Metadata request validation passed")
189 return nil
190}
191
192// ValidateSearchResponse validates a search response
193func (v *Validator) ValidateSearchResponse(resp *service.SearchResponse) error {
194 if resp == nil {
195 return fmt.Errorf("search response is nil")
196 }
197
198 var errors []ValidationError
199
200 // Validate meta structure
201 if resp.Meta.TotalResults < 0 {
202 errors = append(errors, ValidationError{
203 Field: "meta.totalResults",
204 Value: fmt.Sprintf("%d", resp.Meta.TotalResults),
205 Message: "totalResults cannot be negative",
206 })
207 }
208
209 if resp.Meta.Count < 0 {
210 errors = append(errors, ValidationError{
211 Field: "meta.count",
212 Value: fmt.Sprintf("%d", resp.Meta.Count),
213 Message: "count cannot be negative",
214 })
215 }
216
217 if resp.Meta.StartIndex < 0 {
218 errors = append(errors, ValidationError{
219 Field: "meta.startIndex",
220 Value: fmt.Sprintf("%d", resp.Meta.StartIndex),
221 Message: "startIndex cannot be negative",
222 })
223 }
224
225 if resp.Meta.ItemsPerPage < 0 {
226 errors = append(errors, ValidationError{
227 Field: "meta.itemsPerPage",
228 Value: fmt.Sprintf("%d", resp.Meta.ItemsPerPage),
229 Message: "itemsPerPage cannot be negative",
230 })
231 }
232
233 // Validate matches if present
234 if resp.Matches != nil {
235 for i, match := range resp.Matches {
236 if match.MatchID == "" {
237 errors = append(errors, ValidationError{
238 Field: fmt.Sprintf("matches[%d].matchID", i),
239 Value: "",
240 Message: "match ID is required",
241 })
242 }
243
244 if match.TextSigle == "" {
245 errors = append(errors, ValidationError{
246 Field: fmt.Sprintf("matches[%d].textSigle", i),
247 Value: "",
248 Message: "textSigle is required",
249 })
250 }
251
252 if match.Position < 0 {
253 errors = append(errors, ValidationError{
254 Field: fmt.Sprintf("matches[%d].position", i),
255 Value: fmt.Sprintf("%d", match.Position),
256 Message: "position cannot be negative",
257 })
258 }
259 }
260 }
261
262 if len(errors) > 0 {
263 v.logger.Warn().Interface("errors", errors).Msg("Search response validation failed")
264 return ValidationErrors{Errors: errors}
265 }
266
267 v.logger.Debug().Msg("Search response validation passed")
268 return nil
269}
270
271// ValidateCorpusListResponse validates a corpus list response
272func (v *Validator) ValidateCorpusListResponse(resp *service.CorpusListResponse) error {
273 if resp == nil {
274 return fmt.Errorf("corpus list response is nil")
275 }
276
277 var errors []ValidationError
278
279 // Validate corpus entries
280 if resp.Corpora != nil {
281 for i, corpus := range resp.Corpora {
282 if corpus.ID == "" {
283 errors = append(errors, ValidationError{
284 Field: fmt.Sprintf("corpora[%d].id", i),
285 Value: "",
286 Message: "corpus ID is required",
287 })
288 } else if err := v.validateCorpusID(corpus.ID); err != nil {
289 errors = append(errors, ValidationError{
290 Field: fmt.Sprintf("corpora[%d].id", i),
291 Value: corpus.ID,
292 Message: err.Error(),
293 })
294 }
295
296 if corpus.Name == "" {
297 errors = append(errors, ValidationError{
298 Field: fmt.Sprintf("corpora[%d].name", i),
299 Value: "",
300 Message: "corpus name is required",
301 })
302 }
303
304 if corpus.Documents < 0 {
305 errors = append(errors, ValidationError{
306 Field: fmt.Sprintf("corpora[%d].documents", i),
307 Value: fmt.Sprintf("%d", corpus.Documents),
308 Message: "document count cannot be negative",
309 })
310 }
311
312 if corpus.Tokens < 0 {
313 errors = append(errors, ValidationError{
314 Field: fmt.Sprintf("corpora[%d].tokens", i),
315 Value: fmt.Sprintf("%d", corpus.Tokens),
316 Message: "token count cannot be negative",
317 })
318 }
319 }
320 }
321
322 if len(errors) > 0 {
323 v.logger.Warn().Interface("errors", errors).Msg("Corpus list response validation failed")
324 return ValidationErrors{Errors: errors}
325 }
326
327 v.logger.Debug().Msg("Corpus list response validation passed")
328 return nil
329}
330
331// ValidateStatisticsResponse validates a statistics response
332func (v *Validator) ValidateStatisticsResponse(resp *service.StatisticsResponse) error {
333 if resp == nil {
334 return fmt.Errorf("statistics response is nil")
335 }
336
337 var errors []ValidationError
338
339 if resp.Documents < 0 {
340 errors = append(errors, ValidationError{
341 Field: "documents",
342 Value: fmt.Sprintf("%d", resp.Documents),
343 Message: "document count cannot be negative",
344 })
345 }
346
347 if resp.Tokens < 0 {
348 errors = append(errors, ValidationError{
349 Field: "tokens",
350 Value: fmt.Sprintf("%d", resp.Tokens),
351 Message: "token count cannot be negative",
352 })
353 }
354
355 if resp.Sentences < 0 {
356 errors = append(errors, ValidationError{
357 Field: "sentences",
358 Value: fmt.Sprintf("%d", resp.Sentences),
359 Message: "sentence count cannot be negative",
360 })
361 }
362
363 if resp.Paragraphs < 0 {
364 errors = append(errors, ValidationError{
365 Field: "paragraphs",
366 Value: fmt.Sprintf("%d", resp.Paragraphs),
367 Message: "paragraph count cannot be negative",
368 })
369 }
370
371 if len(errors) > 0 {
372 v.logger.Warn().Interface("errors", errors).Msg("Statistics response validation failed")
373 return ValidationErrors{Errors: errors}
374 }
375
376 v.logger.Debug().Msg("Statistics response validation passed")
377 return nil
378}
379
380// validateQuerySafety performs basic security validation on queries
381func (v *Validator) validateQuerySafety(query string) error {
382 // Check for extremely long queries that could cause DoS
383 if len(query) > 10000 {
384 return fmt.Errorf("query is too long (max 10000 characters)")
385 }
386
387 // Check for potentially dangerous URL patterns
388 if strings.Contains(query, "://") {
389 if _, err := url.Parse(query); err == nil {
390 return fmt.Errorf("query appears to contain a URL which is not allowed")
391 }
392 }
393
394 // Check for excessive nesting that could cause parser issues
395 openParens := strings.Count(query, "(")
396 closeParens := strings.Count(query, ")")
397 if openParens != closeParens {
398 return fmt.Errorf("unmatched parentheses in query")
399 }
400 if openParens > 100 {
401 return fmt.Errorf("query has too many nested levels (max 100)")
402 }
403
404 return nil
405}
406
407// validateCorpusID validates a corpus identifier
408func (v *Validator) validateCorpusID(corpusID string) error {
409 if len(corpusID) == 0 {
410 return fmt.Errorf("corpus ID cannot be empty")
411 }
412
413 if len(corpusID) > 100 {
414 return fmt.Errorf("corpus ID is too long (max 100 characters)")
415 }
416
417 if !corpusIDRegex.MatchString(corpusID) {
418 return fmt.Errorf("corpus ID contains invalid characters (only alphanumeric, dots, hyphens, underscores allowed)")
419 }
420
421 return nil
422}
423
424// SanitizeQuery performs basic sanitization on search queries
425func (v *Validator) SanitizeQuery(query string) string {
426 // Trim whitespace
427 sanitized := strings.TrimSpace(query)
428
429 // Remove any null bytes
430 sanitized = strings.ReplaceAll(sanitized, "\x00", "")
431
432 // Normalize whitespace
433 sanitized = regexp.MustCompile(`\s+`).ReplaceAllString(sanitized, " ")
434
435 v.logger.Debug().
436 Str("original", query).
437 Str("sanitized", sanitized).
438 Msg("Query sanitized")
439
440 return sanitized
441}
442
443// SanitizeCorpusID performs basic sanitization on corpus IDs
444func (v *Validator) SanitizeCorpusID(corpusID string) string {
445 // Trim whitespace
446 sanitized := strings.TrimSpace(corpusID)
447
448 // Remove any null bytes
449 sanitized = strings.ReplaceAll(sanitized, "\x00", "")
450
451 // Convert to lowercase for consistency
452 sanitized = strings.ToLower(sanitized)
453
454 v.logger.Debug().
455 Str("original", corpusID).
456 Str("sanitized", sanitized).
457 Msg("Corpus ID sanitized")
458
459 return sanitized
460}