blob: 3827b092d87e10ead12b184588400097e1de76f6 [file] [log] [blame]
Akron90f65212025-06-12 14:32:55 +02001package service
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "testing"
11 "time"
12
13 "github.com/korap/korap-mcp/auth"
14 "github.com/korap/korap-mcp/config"
15 "github.com/stretchr/testify/assert"
16 "golang.org/x/oauth2"
17)
18
19// TestNewClient tests client creation
20func TestNewClient(t *testing.T) {
21 tests := []struct {
22 name string
23 opts ClientOptions
24 wantErr bool
25 errMsg string
26 }{
27 {
28 name: "valid client without oauth",
29 opts: ClientOptions{
30 BaseURL: "https://example.com",
31 Timeout: 10 * time.Second,
32 },
33 wantErr: false,
34 },
35 {
36 name: "valid client with oauth",
37 opts: ClientOptions{
38 BaseURL: "https://example.com",
39 OAuthConfig: &config.OAuthConfig{
40 ClientID: "test-client",
41 ClientSecret: "test-secret",
42 AuthURL: "https://example.com/auth",
43 TokenURL: "https://example.com/token",
44 Enabled: true,
45 },
46 },
47 wantErr: false,
48 },
49 {
50 name: "invalid client - no base URL",
51 opts: ClientOptions{
52 Timeout: 10 * time.Second,
53 },
54 wantErr: true,
55 errMsg: "base URL is required",
56 },
57 {
58 name: "invalid oauth config",
59 opts: ClientOptions{
60 BaseURL: "https://example.com",
61 OAuthConfig: &config.OAuthConfig{
62 Enabled: true, // Missing required fields
63 },
64 },
65 wantErr: true,
66 errMsg: "failed to create OAuth client",
67 },
68 }
69
70 for _, tt := range tests {
71 t.Run(tt.name, func(t *testing.T) {
72 client, err := NewClient(tt.opts)
73
74 if tt.wantErr {
75 assert.Error(t, err)
76 assert.Contains(t, err.Error(), tt.errMsg)
77 return
78 }
79
80 assert.NoError(t, err)
81 assert.NotNil(t, client)
82
83 // Check base URL normalization
84 assert.True(t, strings.HasSuffix(client.baseURL, "/"), "Base URL should end with /")
85
86 // Check default timeout
87 if tt.opts.Timeout == 0 {
88 assert.Equal(t, 30*time.Second, client.httpClient.Timeout, "Default timeout should be 30 seconds")
89 }
90 })
91 }
92}
93
94// TestClientMethods tests basic client HTTP methods
95func TestClientMethods(t *testing.T) {
96 // Create a mock server
97 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98 switch r.URL.Path {
99 case "/test":
100 if r.Method == http.MethodGet {
101 w.Header().Set("Content-Type", "application/json")
102 json.NewEncoder(w).Encode(map[string]string{"message": "success"})
103 } else if r.Method == http.MethodPost {
104 w.Header().Set("Content-Type", "application/json")
105 w.WriteHeader(http.StatusCreated)
106 json.NewEncoder(w).Encode(map[string]string{"created": "true"})
107 }
108 case "/error":
109 w.WriteHeader(http.StatusBadRequest)
110 json.NewEncoder(w).Encode(ErrorResponse{
111 Error: "bad_request",
112 ErrorDescription: "Invalid request parameter",
113 })
114 case "/":
115 // Root endpoint for ping test
116 w.WriteHeader(http.StatusOK)
117 w.Write([]byte("OK"))
118 default:
119 w.WriteHeader(http.StatusNotFound)
120 }
121 }))
122 defer server.Close()
123
124 client, err := NewClient(ClientOptions{
125 BaseURL: server.URL,
126 Timeout: 5 * time.Second,
127 })
128 assert.NoError(t, err)
129
130 ctx := context.Background()
131
132 t.Run("GET request", func(t *testing.T) {
133 resp, err := client.Get(ctx, "/test")
134 assert.NoError(t, err)
135 defer resp.Body.Close()
136
137 assert.Equal(t, http.StatusOK, resp.StatusCode)
138 })
139
140 t.Run("POST request", func(t *testing.T) {
141 body := map[string]string{"key": "value"}
142 resp, err := client.Post(ctx, "/test", body)
143 assert.NoError(t, err)
144 defer resp.Body.Close()
145
146 assert.Equal(t, http.StatusCreated, resp.StatusCode)
147 })
148
149 t.Run("GetJSON", func(t *testing.T) {
150 var result map[string]string
151 err := client.GetJSON(ctx, "/test", &result)
152 assert.NoError(t, err)
153
154 assert.Equal(t, "success", result["message"])
155 })
156
157 t.Run("PostJSON", func(t *testing.T) {
158 body := map[string]string{"key": "value"}
159 var result map[string]string
160 err := client.PostJSON(ctx, "/test", body, &result)
161 assert.NoError(t, err)
162
163 assert.Equal(t, "true", result["created"])
164 })
165
166 t.Run("Error handling", func(t *testing.T) {
167 var result map[string]interface{}
168 err := client.GetJSON(ctx, "/error", &result)
169
170 assert.Error(t, err)
171
172 apiErr, ok := err.(*APIError)
173 assert.True(t, ok, "Expected APIError")
174
175 assert.Equal(t, http.StatusBadRequest, apiErr.StatusCode)
176 assert.Equal(t, "bad_request", apiErr.Message)
177 })
178
179 t.Run("Ping", func(t *testing.T) {
180 err := client.Ping(ctx)
181 assert.NoError(t, err)
182 })
183}
184
185// TestClientAuthentication tests OAuth2 integration
186func TestClientAuthentication(t *testing.T) {
187 // Create a mock OAuth2 server
188 authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
189 if r.URL.Path == "/token" {
190 // Mock token endpoint
191 w.Header().Set("Content-Type", "application/json")
192 json.NewEncoder(w).Encode(map[string]interface{}{
193 "access_token": "test-access-token",
194 "token_type": "Bearer",
195 "expires_in": 3600,
196 })
197 }
198 }))
199 defer authServer.Close()
200
201 // Create a mock API server that checks authentication
202 apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
203 auth := r.Header.Get("Authorization")
204 if auth != "Bearer test-access-token" {
205 w.WriteHeader(http.StatusUnauthorized)
206 json.NewEncoder(w).Encode(ErrorResponse{
207 Error: "unauthorized",
208 ErrorDescription: "Invalid or missing access token",
209 })
210 return
211 }
212
213 w.Header().Set("Content-Type", "application/json")
214 json.NewEncoder(w).Encode(map[string]string{"authenticated": "true"})
215 }))
216 defer apiServer.Close()
217
218 // Create client with OAuth2 configuration
219 oauthConfig := &config.OAuthConfig{
220 ClientID: "test-client",
221 ClientSecret: "test-secret",
222 TokenURL: authServer.URL + "/token",
223 Enabled: true,
224 }
225
226 client, err := NewClient(ClientOptions{
227 BaseURL: apiServer.URL,
228 OAuthConfig: oauthConfig,
229 })
230 assert.NoError(t, err)
231
232 ctx := context.Background()
233
234 t.Run("Unauthenticated request", func(t *testing.T) {
235 var result map[string]interface{}
236 err := client.GetJSON(ctx, "/", &result)
237
238 assert.Error(t, err)
239
240 apiErr, ok := err.(*APIError)
241 assert.True(t, ok, "Expected APIError")
242
243 assert.Equal(t, http.StatusUnauthorized, apiErr.StatusCode)
244 })
245
246 t.Run("Client credentials authentication", func(t *testing.T) {
247 err := client.AuthenticateWithClientCredentials(ctx)
248 assert.NoError(t, err)
249
250 assert.True(t, client.IsAuthenticated(), "Client should be authenticated")
251 })
252
253 t.Run("Authenticated request", func(t *testing.T) {
254 var result map[string]interface{}
255 err := client.GetJSON(ctx, "/", &result)
256 assert.NoError(t, err)
257
258 assert.Equal(t, "true", result["authenticated"])
259 })
260}
261
262// TestURLBuilding tests URL construction logic
263func TestURLBuilding(t *testing.T) {
264 client, err := NewClient(ClientOptions{
265 BaseURL: "https://example.com/api/v1",
266 })
267 assert.NoError(t, err)
268
269 tests := []struct {
270 endpoint string
271 expected string
272 }{
273 {"/search", "https://example.com/api/v1/search"},
274 {"search", "https://example.com/api/v1/search"},
275 {"/corpus/info", "https://example.com/api/v1/corpus/info"},
276 {"", "https://example.com/api/v1/"},
277 }
278
279 for _, tt := range tests {
280 t.Run(fmt.Sprintf("endpoint_%s", tt.endpoint), func(t *testing.T) {
281 url, err := client.buildURL(tt.endpoint)
282 assert.NoError(t, err)
283
284 assert.Equal(t, tt.expected, url)
285 })
286 }
287}
288
289// TestOAuthClientOperations tests OAuth client operations
290func TestOAuthClientOperations(t *testing.T) {
291 client, err := NewClient(ClientOptions{
292 BaseURL: "https://example.com",
293 })
294 assert.NoError(t, err)
295
296 // Test setting OAuth client
297 oauthConfig := &config.OAuthConfig{
298 ClientID: "test-client",
299 ClientSecret: "test-secret",
300 TokenURL: "https://example.com/token",
301 Enabled: true,
302 }
303
304 oauthClient, err := auth.NewOAuthClient(oauthConfig)
305 assert.NoError(t, err)
306
307 // Set a mock token
308 token := &oauth2.Token{
309 AccessToken: "test-token",
310 TokenType: "Bearer",
311 Expiry: time.Now().Add(time.Hour),
312 }
313 oauthClient.SetToken(token)
314
315 client.SetOAuthClient(oauthClient)
316
317 assert.True(t, client.IsAuthenticated(), "Client should be authenticated after setting valid token")
318
319 assert.Equal(t, "https://example.com/", client.GetBaseURL())
320}