| package service |
| |
| import ( |
| "context" |
| "encoding/json" |
| "fmt" |
| "net/http" |
| "net/http/httptest" |
| "strings" |
| "testing" |
| "time" |
| |
| "github.com/korap/korap-mcp/auth" |
| "github.com/korap/korap-mcp/config" |
| "github.com/stretchr/testify/assert" |
| "golang.org/x/oauth2" |
| ) |
| |
| // TestNewClient tests client creation |
| func TestNewClient(t *testing.T) { |
| tests := []struct { |
| name string |
| opts ClientOptions |
| wantErr bool |
| errMsg string |
| }{ |
| { |
| name: "valid client without oauth", |
| opts: ClientOptions{ |
| BaseURL: "https://example.com", |
| Timeout: 10 * time.Second, |
| }, |
| wantErr: false, |
| }, |
| { |
| name: "valid client with oauth", |
| opts: ClientOptions{ |
| BaseURL: "https://example.com", |
| OAuthConfig: &config.OAuthConfig{ |
| ClientID: "test-client", |
| ClientSecret: "test-secret", |
| AuthURL: "https://example.com/auth", |
| TokenURL: "https://example.com/token", |
| Enabled: true, |
| }, |
| }, |
| wantErr: false, |
| }, |
| { |
| name: "invalid client - no base URL", |
| opts: ClientOptions{ |
| Timeout: 10 * time.Second, |
| }, |
| wantErr: true, |
| errMsg: "base URL is required", |
| }, |
| { |
| name: "invalid oauth config", |
| opts: ClientOptions{ |
| BaseURL: "https://example.com", |
| OAuthConfig: &config.OAuthConfig{ |
| Enabled: true, // Missing required fields |
| }, |
| }, |
| wantErr: true, |
| errMsg: "failed to create OAuth client", |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| client, err := NewClient(tt.opts) |
| |
| if tt.wantErr { |
| assert.Error(t, err) |
| assert.Contains(t, err.Error(), tt.errMsg) |
| return |
| } |
| |
| assert.NoError(t, err) |
| assert.NotNil(t, client) |
| |
| // Check base URL normalization |
| assert.True(t, strings.HasSuffix(client.baseURL, "/"), "Base URL should end with /") |
| |
| // Check default timeout |
| if tt.opts.Timeout == 0 { |
| assert.Equal(t, 30*time.Second, client.httpClient.Timeout, "Default timeout should be 30 seconds") |
| } |
| }) |
| } |
| } |
| |
| // TestClientMethods tests basic client HTTP methods |
| func TestClientMethods(t *testing.T) { |
| // Create a mock server |
| server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| switch r.URL.Path { |
| case "/test": |
| if r.Method == http.MethodGet { |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(map[string]string{"message": "success"}) |
| } else if r.Method == http.MethodPost { |
| w.Header().Set("Content-Type", "application/json") |
| w.WriteHeader(http.StatusCreated) |
| json.NewEncoder(w).Encode(map[string]string{"created": "true"}) |
| } |
| case "/error": |
| w.WriteHeader(http.StatusBadRequest) |
| json.NewEncoder(w).Encode(ErrorResponse{ |
| Error: "bad_request", |
| ErrorDescription: "Invalid request parameter", |
| }) |
| case "/": |
| // Root endpoint for ping test |
| w.WriteHeader(http.StatusOK) |
| w.Write([]byte("OK")) |
| default: |
| w.WriteHeader(http.StatusNotFound) |
| } |
| })) |
| defer server.Close() |
| |
| client, err := NewClient(ClientOptions{ |
| BaseURL: server.URL, |
| Timeout: 5 * time.Second, |
| }) |
| assert.NoError(t, err) |
| |
| ctx := context.Background() |
| |
| t.Run("GET request", func(t *testing.T) { |
| resp, err := client.Get(ctx, "/test") |
| assert.NoError(t, err) |
| defer resp.Body.Close() |
| |
| assert.Equal(t, http.StatusOK, resp.StatusCode) |
| }) |
| |
| t.Run("POST request", func(t *testing.T) { |
| body := map[string]string{"key": "value"} |
| resp, err := client.Post(ctx, "/test", body) |
| assert.NoError(t, err) |
| defer resp.Body.Close() |
| |
| assert.Equal(t, http.StatusCreated, resp.StatusCode) |
| }) |
| |
| t.Run("GetJSON", func(t *testing.T) { |
| var result map[string]string |
| err := client.GetJSON(ctx, "/test", &result) |
| assert.NoError(t, err) |
| |
| assert.Equal(t, "success", result["message"]) |
| }) |
| |
| t.Run("PostJSON", func(t *testing.T) { |
| body := map[string]string{"key": "value"} |
| var result map[string]string |
| err := client.PostJSON(ctx, "/test", body, &result) |
| assert.NoError(t, err) |
| |
| assert.Equal(t, "true", result["created"]) |
| }) |
| |
| t.Run("Error handling", func(t *testing.T) { |
| var result map[string]any |
| err := client.GetJSON(ctx, "/error", &result) |
| |
| assert.Error(t, err) |
| |
| apiErr, ok := err.(*APIError) |
| assert.True(t, ok, "Expected APIError") |
| |
| assert.Equal(t, http.StatusBadRequest, apiErr.StatusCode) |
| assert.Equal(t, "bad_request", apiErr.Message) |
| }) |
| |
| t.Run("Ping", func(t *testing.T) { |
| err := client.Ping(ctx) |
| assert.NoError(t, err) |
| }) |
| } |
| |
| // TestClientAuthentication tests OAuth2 integration |
| func TestClientAuthentication(t *testing.T) { |
| // Create a mock OAuth2 server |
| authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if r.URL.Path == "/token" { |
| // Mock token endpoint |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(map[string]any{ |
| "access_token": "test-access-token", |
| "token_type": "Bearer", |
| "expires_in": 3600, |
| }) |
| } |
| })) |
| defer authServer.Close() |
| |
| // Create a mock API server that checks authentication |
| apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| auth := r.Header.Get("Authorization") |
| if auth != "Bearer test-access-token" { |
| w.WriteHeader(http.StatusUnauthorized) |
| json.NewEncoder(w).Encode(ErrorResponse{ |
| Error: "unauthorized", |
| ErrorDescription: "Invalid or missing access token", |
| }) |
| return |
| } |
| |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(map[string]string{"authenticated": "true"}) |
| })) |
| defer apiServer.Close() |
| |
| // Create client with OAuth2 configuration |
| oauthConfig := &config.OAuthConfig{ |
| ClientID: "test-client", |
| ClientSecret: "test-secret", |
| TokenURL: authServer.URL + "/token", |
| Enabled: true, |
| } |
| |
| client, err := NewClient(ClientOptions{ |
| BaseURL: apiServer.URL, |
| OAuthConfig: oauthConfig, |
| }) |
| assert.NoError(t, err) |
| |
| ctx := context.Background() |
| |
| t.Run("Unauthenticated request", func(t *testing.T) { |
| var result map[string]any |
| err := client.GetJSON(ctx, "/", &result) |
| |
| assert.Error(t, err) |
| |
| apiErr, ok := err.(*APIError) |
| assert.True(t, ok, "Expected APIError") |
| |
| assert.Equal(t, http.StatusUnauthorized, apiErr.StatusCode) |
| }) |
| |
| t.Run("Client credentials authentication", func(t *testing.T) { |
| err := client.AuthenticateWithClientCredentials(ctx) |
| assert.NoError(t, err) |
| |
| assert.True(t, client.IsAuthenticated(), "Client should be authenticated") |
| }) |
| |
| t.Run("Authenticated request", func(t *testing.T) { |
| var result map[string]any |
| err := client.GetJSON(ctx, "/", &result) |
| assert.NoError(t, err) |
| |
| assert.Equal(t, "true", result["authenticated"]) |
| }) |
| } |
| |
| // TestURLBuilding tests URL construction logic |
| func TestURLBuilding(t *testing.T) { |
| client, err := NewClient(ClientOptions{ |
| BaseURL: "https://example.com/api/v1", |
| }) |
| assert.NoError(t, err) |
| |
| tests := []struct { |
| endpoint string |
| expected string |
| }{ |
| {"/search", "https://example.com/api/v1/search"}, |
| {"search", "https://example.com/api/v1/search"}, |
| {"/corpus/info", "https://example.com/api/v1/corpus/info"}, |
| {"", "https://example.com/api/v1/"}, |
| } |
| |
| for _, tt := range tests { |
| t.Run(fmt.Sprintf("endpoint_%s", tt.endpoint), func(t *testing.T) { |
| url, err := client.buildURL(tt.endpoint) |
| assert.NoError(t, err) |
| |
| assert.Equal(t, tt.expected, url) |
| }) |
| } |
| } |
| |
| // TestOAuthClientOperations tests OAuth client operations |
| func TestOAuthClientOperations(t *testing.T) { |
| client, err := NewClient(ClientOptions{ |
| BaseURL: "https://example.com", |
| }) |
| assert.NoError(t, err) |
| |
| // Test setting OAuth client |
| oauthConfig := &config.OAuthConfig{ |
| ClientID: "test-client", |
| ClientSecret: "test-secret", |
| TokenURL: "https://example.com/token", |
| Enabled: true, |
| } |
| |
| oauthClient, err := auth.NewOAuthClient(oauthConfig) |
| assert.NoError(t, err) |
| |
| // Set a mock token |
| token := &oauth2.Token{ |
| AccessToken: "test-token", |
| TokenType: "Bearer", |
| Expiry: time.Now().Add(time.Hour), |
| } |
| oauthClient.SetToken(token) |
| |
| client.SetOAuthClient(oauthClient) |
| |
| assert.True(t, client.IsAuthenticated(), "Client should be authenticated after setting valid token") |
| |
| assert.Equal(t, "https://example.com/", client.GetBaseURL()) |
| } |