Added request/response validation
diff --git a/validation/validator_test.go b/validation/validator_test.go
new file mode 100644
index 0000000..43254a7
--- /dev/null
+++ b/validation/validator_test.go
@@ -0,0 +1,827 @@
+package validation
+
+import (
+ "testing"
+
+ "github.com/korap/korap-mcp/service"
+ "github.com/rs/zerolog"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNew(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ assert.NotNil(t, validator)
+ assert.Equal(t, logger.With().Str("component", "validator").Logger(), validator.logger)
+}
+
+func TestValidationError_Error(t *testing.T) {
+ err := ValidationError{
+ Field: "test_field",
+ Value: "test_value",
+ Message: "test message",
+ }
+
+ expected := "validation error for field 'test_field' (value: 'test_value'): test message"
+ assert.Equal(t, expected, err.Error())
+}
+
+func TestValidationErrors_Error(t *testing.T) {
+ // Test empty errors
+ emptyErrors := ValidationErrors{}
+ assert.Equal(t, "validation errors occurred", emptyErrors.Error())
+
+ // Test single error
+ singleError := ValidationErrors{
+ Errors: []ValidationError{
+ {Field: "field1", Value: "value1", Message: "message1"},
+ },
+ }
+ expected := "validation error for field 'field1' (value: 'value1'): message1"
+ assert.Equal(t, expected, singleError.Error())
+
+ // Test multiple errors
+ multipleErrors := ValidationErrors{
+ Errors: []ValidationError{
+ {Field: "field1", Value: "value1", Message: "message1"},
+ {Field: "field2", Value: "value2", Message: "message2"},
+ },
+ }
+ expected = "validation error for field 'field1' (value: 'value1'): message1; validation error for field 'field2' (value: 'value2'): message2"
+ assert.Equal(t, expected, multipleErrors.Error())
+}
+
+func TestValidateSearchRequest(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ request SearchRequest
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "valid_request_minimal",
+ request: SearchRequest{
+ Query: "test query",
+ },
+ expectErr: false,
+ },
+ {
+ name: "valid_request_complete",
+ request: SearchRequest{
+ Query: "test query",
+ QueryLanguage: "poliqarp",
+ Corpus: "test-corpus",
+ Count: 100,
+ },
+ expectErr: false,
+ },
+ {
+ name: "empty_query",
+ request: SearchRequest{
+ Query: "",
+ },
+ expectErr: true,
+ errorMsg: "query is required and cannot be empty",
+ },
+ {
+ name: "whitespace_only_query",
+ request: SearchRequest{
+ Query: " ",
+ },
+ expectErr: true,
+ errorMsg: "query is required and cannot be empty",
+ },
+ {
+ name: "invalid_query_language",
+ request: SearchRequest{
+ Query: "test query",
+ QueryLanguage: "invalid",
+ },
+ expectErr: true,
+ errorMsg: "invalid query language",
+ },
+ {
+ name: "invalid_corpus_id",
+ request: SearchRequest{
+ Query: "test query",
+ Corpus: "invalid corpus!",
+ },
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ {
+ name: "count_negative",
+ request: SearchRequest{
+ Query: "test query",
+ Count: -1,
+ },
+ expectErr: true,
+ errorMsg: "count must be between 0 and 10000",
+ },
+ {
+ name: "count_zero_valid",
+ request: SearchRequest{
+ Query: "test query",
+ Count: 0,
+ },
+ expectErr: false,
+ },
+ {
+ name: "count_too_high",
+ request: SearchRequest{
+ Query: "test query",
+ Count: 10001,
+ },
+ expectErr: true,
+ errorMsg: "count must be between 0 and 10000",
+ },
+ {
+ name: "unsafe_query_too_long",
+ request: SearchRequest{
+ Query: string(make([]byte, 10001)),
+ },
+ expectErr: true,
+ errorMsg: "query is too long",
+ },
+ {
+ name: "unsafe_query_url",
+ request: SearchRequest{
+ Query: "http://example.com",
+ },
+ expectErr: true,
+ errorMsg: "query appears to contain a URL",
+ },
+ {
+ name: "unsafe_query_unmatched_parens",
+ request: SearchRequest{
+ Query: "test (query",
+ },
+ expectErr: true,
+ errorMsg: "unmatched parentheses",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateSearchRequest(tt.request)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateMetadataRequest(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ request MetadataRequest
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "valid_list_action",
+ request: MetadataRequest{
+ Action: "list",
+ },
+ expectErr: false,
+ },
+ {
+ name: "valid_statistics_action",
+ request: MetadataRequest{
+ Action: "statistics",
+ Corpus: "test-corpus",
+ },
+ expectErr: false,
+ },
+ {
+ name: "empty_action",
+ request: MetadataRequest{
+ Action: "",
+ },
+ expectErr: true,
+ errorMsg: "action is required and cannot be empty",
+ },
+ {
+ name: "whitespace_only_action",
+ request: MetadataRequest{
+ Action: " ",
+ },
+ expectErr: true,
+ errorMsg: "action is required and cannot be empty",
+ },
+ {
+ name: "invalid_action",
+ request: MetadataRequest{
+ Action: "invalid",
+ },
+ expectErr: true,
+ errorMsg: "invalid action",
+ },
+ {
+ name: "invalid_corpus_id",
+ request: MetadataRequest{
+ Action: "statistics",
+ Corpus: "invalid corpus!",
+ },
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateMetadataRequest(tt.request)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateSearchResponse(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ response *service.SearchResponse
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "nil_response",
+ response: nil,
+ expectErr: true,
+ errorMsg: "search response is nil",
+ },
+ {
+ name: "valid_response",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ Query: service.SearchQuery{
+ Query: "test",
+ QueryLang: "poliqarp",
+ },
+ Matches: []service.SearchMatch{
+ {MatchID: "match1", TextSigle: "text1", Position: 0},
+ {MatchID: "match2", TextSigle: "text2", Position: 1},
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "negative_total_results",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: -1,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ },
+ expectErr: true,
+ errorMsg: "totalResults cannot be negative",
+ },
+ {
+ name: "negative_count",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: -1,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ },
+ expectErr: true,
+ errorMsg: "count cannot be negative",
+ },
+ {
+ name: "negative_start_index",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: -1,
+ ItemsPerPage: 10,
+ },
+ },
+ expectErr: true,
+ errorMsg: "startIndex cannot be negative",
+ },
+ {
+ name: "negative_items_per_page",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: -1,
+ },
+ },
+ expectErr: true,
+ errorMsg: "itemsPerPage cannot be negative",
+ },
+ {
+ name: "match_missing_id",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ Matches: []service.SearchMatch{
+ {MatchID: "", TextSigle: "text1", Position: 0},
+ },
+ },
+ expectErr: true,
+ errorMsg: "match ID is required",
+ },
+ {
+ name: "match_missing_text_sigle",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ Matches: []service.SearchMatch{
+ {MatchID: "match1", TextSigle: "", Position: 0},
+ },
+ },
+ expectErr: true,
+ errorMsg: "textSigle is required",
+ },
+ {
+ name: "match_negative_position",
+ response: &service.SearchResponse{
+ Meta: service.SearchMeta{
+ TotalResults: 100,
+ Count: 10,
+ StartIndex: 0,
+ ItemsPerPage: 10,
+ },
+ Matches: []service.SearchMatch{
+ {MatchID: "match1", TextSigle: "text1", Position: -1},
+ },
+ },
+ expectErr: true,
+ errorMsg: "position cannot be negative",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateSearchResponse(tt.response)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateCorpusListResponse(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ response *service.CorpusListResponse
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "nil_response",
+ response: nil,
+ expectErr: true,
+ errorMsg: "corpus list response is nil",
+ },
+ {
+ name: "valid_response",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "corpus1",
+ Name: "Test Corpus 1",
+ Documents: 100,
+ Tokens: 50000,
+ },
+ {
+ ID: "corpus2",
+ Name: "Test Corpus 2",
+ Documents: 200,
+ Tokens: 75000,
+ },
+ },
+ },
+ expectErr: false,
+ },
+ {
+ name: "empty_corpus_list",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{},
+ },
+ expectErr: false,
+ },
+ {
+ name: "corpus_missing_id",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "",
+ Name: "Test Corpus",
+ Documents: 100,
+ Tokens: 50000,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "corpus ID is required",
+ },
+ {
+ name: "corpus_invalid_id",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "invalid id!",
+ Name: "Test Corpus",
+ Documents: 100,
+ Tokens: 50000,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ {
+ name: "corpus_missing_name",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "corpus1",
+ Name: "",
+ Documents: 100,
+ Tokens: 50000,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "corpus name is required",
+ },
+ {
+ name: "corpus_negative_documents",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "corpus1",
+ Name: "Test Corpus",
+ Documents: -1,
+ Tokens: 50000,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "document count cannot be negative",
+ },
+ {
+ name: "corpus_negative_tokens",
+ response: &service.CorpusListResponse{
+ Corpora: []service.CorpusInfo{
+ {
+ ID: "corpus1",
+ Name: "Test Corpus",
+ Documents: 100,
+ Tokens: -1,
+ },
+ },
+ },
+ expectErr: true,
+ errorMsg: "token count cannot be negative",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateCorpusListResponse(tt.response)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateStatisticsResponse(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ response *service.StatisticsResponse
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "nil_response",
+ response: nil,
+ expectErr: true,
+ errorMsg: "statistics response is nil",
+ },
+ {
+ name: "valid_response",
+ response: &service.StatisticsResponse{
+ Documents: 100,
+ Tokens: 50000,
+ Sentences: 2500,
+ Paragraphs: 500,
+ },
+ expectErr: false,
+ },
+ {
+ name: "negative_documents",
+ response: &service.StatisticsResponse{
+ Documents: -1,
+ Tokens: 50000,
+ },
+ expectErr: true,
+ errorMsg: "document count cannot be negative",
+ },
+ {
+ name: "negative_tokens",
+ response: &service.StatisticsResponse{
+ Documents: 100,
+ Tokens: -1,
+ },
+ expectErr: true,
+ errorMsg: "token count cannot be negative",
+ },
+ {
+ name: "negative_sentences",
+ response: &service.StatisticsResponse{
+ Documents: 100,
+ Tokens: 50000,
+ Sentences: -1,
+ },
+ expectErr: true,
+ errorMsg: "sentence count cannot be negative",
+ },
+ {
+ name: "negative_paragraphs",
+ response: &service.StatisticsResponse{
+ Documents: 100,
+ Tokens: 50000,
+ Paragraphs: -1,
+ },
+ expectErr: true,
+ errorMsg: "paragraph count cannot be negative",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateStatisticsResponse(tt.response)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateQuerySafety(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ query string
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "valid_query",
+ query: "test query",
+ expectErr: false,
+ },
+ {
+ name: "query_too_long",
+ query: string(make([]byte, 10001)),
+ expectErr: true,
+ errorMsg: "query is too long",
+ },
+ {
+ name: "query_with_url",
+ query: "http://example.com",
+ expectErr: true,
+ errorMsg: "query appears to contain a URL",
+ },
+ {
+ name: "query_with_https_url",
+ query: "https://example.com",
+ expectErr: true,
+ errorMsg: "query appears to contain a URL",
+ },
+ {
+ name: "query_unmatched_open_paren",
+ query: "test (query",
+ expectErr: true,
+ errorMsg: "unmatched parentheses",
+ },
+ {
+ name: "query_unmatched_close_paren",
+ query: "test query)",
+ expectErr: true,
+ errorMsg: "unmatched parentheses",
+ },
+ {
+ name: "query_too_many_nested_parens",
+ query: "(" + string(make([]byte, 100)) + ")" + "(" + string(make([]byte, 100)) + ")",
+ expectErr: false, // This should be under the limit
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.validateQuerySafety(tt.query)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateCorpusID(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ corpusID string
+ expectErr bool
+ errorMsg string
+ }{
+ {
+ name: "valid_corpus_id",
+ corpusID: "test-corpus_1.0",
+ expectErr: false,
+ },
+ {
+ name: "empty_corpus_id",
+ corpusID: "",
+ expectErr: true,
+ errorMsg: "corpus ID cannot be empty",
+ },
+ {
+ name: "corpus_id_too_long",
+ corpusID: string(make([]byte, 101)),
+ expectErr: true,
+ errorMsg: "corpus ID is too long",
+ },
+ {
+ name: "corpus_id_invalid_chars",
+ corpusID: "invalid corpus!",
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ {
+ name: "corpus_id_with_space",
+ corpusID: "corpus with space",
+ expectErr: true,
+ errorMsg: "corpus ID contains invalid characters",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.validateCorpusID(tt.corpusID)
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestSanitizeQuery(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {
+ name: "trim_whitespace",
+ input: " test query ",
+ expected: "test query",
+ },
+ {
+ name: "remove_null_bytes",
+ input: "test\x00query",
+ expected: "testquery",
+ },
+ {
+ name: "normalize_whitespace",
+ input: "test query\t\nwith spaces",
+ expected: "test query with spaces",
+ },
+ {
+ name: "empty_string",
+ input: "",
+ expected: "",
+ },
+ {
+ name: "already_clean",
+ input: "test query",
+ expected: "test query",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := validator.SanitizeQuery(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestSanitizeCorpusID(t *testing.T) {
+ logger := zerolog.Nop()
+ validator := New(logger)
+
+ tests := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {
+ name: "trim_whitespace",
+ input: " Test-Corpus ",
+ expected: "test-corpus",
+ },
+ {
+ name: "remove_null_bytes",
+ input: "test\x00corpus",
+ expected: "testcorpus",
+ },
+ {
+ name: "lowercase",
+ input: "Test-Corpus_1.0",
+ expected: "test-corpus_1.0",
+ },
+ {
+ name: "empty_string",
+ input: "",
+ expected: "",
+ },
+ {
+ name: "already_clean",
+ input: "test-corpus",
+ expected: "test-corpus",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := validator.SanitizeCorpusID(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// Helper function to create bool pointers
+func boolPtr(b bool) *bool {
+ return &b
+}