blob: 9c6b0cc3d5548cc21d2fb084924bc0b721b99978 [file] [log] [blame]
Akron90f65212025-06-12 14:32:55 +02001package auth
2
3import (
4 "context"
5 "encoding/json"
6 "net/http"
7 "net/http/httptest"
8 "testing"
9 "time"
10
11 "github.com/korap/korap-mcp/config"
12 "github.com/stretchr/testify/assert"
13 "golang.org/x/oauth2"
14)
15
16func TestNewOAuthClient(t *testing.T) {
17 cfg := &config.OAuthConfig{
18 ClientID: "test-client",
19 ClientSecret: "test-secret",
20 AuthURL: "https://example.com/auth",
21 TokenURL: "https://example.com/token",
22 Enabled: true,
23 }
24
25 client, err := NewOAuthClient(cfg)
26 assert.NoError(t, err)
27 assert.NotNil(t, client)
28 assert.Equal(t, cfg, client.config)
29}
30
31func TestNewOAuthClient_InvalidConfig(t *testing.T) {
32 tests := []struct {
33 name string
34 config *config.OAuthConfig
35 }{
36 {
37 name: "nil config",
38 config: nil,
39 },
40 {
41 name: "empty client ID",
42 config: &config.OAuthConfig{
43 ClientSecret: "test-secret",
44 AuthURL: "https://example.com/auth",
45 TokenURL: "https://example.com/token",
46 Enabled: true,
47 },
48 },
49 {
50 name: "empty client secret",
51 config: &config.OAuthConfig{
52 ClientID: "test-client",
53 AuthURL: "https://example.com/auth",
54 TokenURL: "https://example.com/token",
55 Enabled: true,
56 },
57 },
58 }
59
60 for _, tt := range tests {
61 t.Run(tt.name, func(t *testing.T) {
62 client, err := NewOAuthClient(tt.config)
63 assert.Error(t, err)
64 assert.Nil(t, client)
65 })
66 }
67}
68
69func TestOAuthClient_GetAuthURL(t *testing.T) {
70 cfg := &config.OAuthConfig{
71 ClientID: "test-client",
72 ClientSecret: "test-secret",
73 AuthURL: "https://example.com/auth",
74 TokenURL: "https://example.com/token",
75 RedirectURL: "https://example.com/callback",
76 Enabled: true,
77 }
78
79 client, err := NewOAuthClient(cfg)
80 assert.NoError(t, err)
81
82 authURL := client.GetAuthURL("test-state")
83 assert.NotEmpty(t, authURL)
84
85 // Check if the URL contains expected parameters
86 assert.Contains(t, authURL, "client_id=test-client")
87 assert.Contains(t, authURL, "state=test-state")
88}
89
90func TestOAuthClient_IsAuthenticated(t *testing.T) {
91 cfg := &config.OAuthConfig{
92 ClientID: "test-client",
93 ClientSecret: "test-secret",
94 AuthURL: "https://example.com/auth",
95 TokenURL: "https://example.com/token",
96 Enabled: true,
97 }
98
99 client, err := NewOAuthClient(cfg)
100 assert.NoError(t, err)
101
102 // Should not be authenticated initially
103 assert.False(t, client.IsAuthenticated(), "client should not be authenticated initially")
104
105 // Set a valid token
106 validToken := &oauth2.Token{
107 AccessToken: "test-token",
108 Expiry: time.Now().Add(time.Hour),
109 }
110 client.SetToken(validToken)
111
112 assert.True(t, client.IsAuthenticated(), "client should be authenticated with valid token")
113
114 // Set an expired token
115 expiredToken := &oauth2.Token{
116 AccessToken: "test-token",
117 Expiry: time.Now().Add(-time.Hour),
118 }
119 client.SetToken(expiredToken)
120
121 assert.False(t, client.IsAuthenticated(), "client should not be authenticated with expired token")
122}
123
124func TestOAuthClient_ExchangeCode(t *testing.T) {
125 // Create mock OAuth2 server
126 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
127 if r.URL.Path == "/token" {
128 w.Header().Set("Content-Type", "application/json")
129 response := map[string]interface{}{
130 "access_token": "test-access-token",
131 "token_type": "Bearer",
132 "expires_in": 3600,
133 }
134 json.NewEncoder(w).Encode(response)
135 return
136 }
137 http.NotFound(w, r)
138 }))
139 defer server.Close()
140
141 cfg := &config.OAuthConfig{
142 ClientID: "test-client",
143 ClientSecret: "test-secret",
144 AuthURL: server.URL + "/auth",
145 TokenURL: server.URL + "/token",
146 Enabled: true,
147 }
148
149 client, err := NewOAuthClient(cfg)
150 assert.NoError(t, err)
151
152 err = client.ExchangeCode(context.Background(), "test-code")
153 assert.NoError(t, err)
154
155 token := client.GetToken()
156 assert.NotNil(t, token)
157 assert.Equal(t, "test-access-token", token.AccessToken)
158}
159
160func TestOAuthClient_ClientCredentialsFlow(t *testing.T) {
161 // Create mock OAuth2 server
162 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
163 if r.URL.Path == "/token" {
164 w.Header().Set("Content-Type", "application/json")
165 response := map[string]interface{}{
166 "access_token": "client-credentials-token",
167 "token_type": "Bearer",
168 "expires_in": 3600,
169 }
170 json.NewEncoder(w).Encode(response)
171 return
172 }
173 http.NotFound(w, r)
174 }))
175 defer server.Close()
176
177 cfg := &config.OAuthConfig{
178 ClientID: "test-client",
179 ClientSecret: "test-secret",
180 TokenURL: server.URL + "/token",
181 Scopes: []string{"read"},
182 Enabled: true,
183 }
184
185 client, err := NewOAuthClient(cfg)
186 assert.NoError(t, err)
187
188 err = client.ClientCredentialsFlow(context.Background())
189 assert.NoError(t, err)
190
191 token := client.GetToken()
192 assert.NotNil(t, token)
193 assert.Equal(t, "client-credentials-token", token.AccessToken)
194}
195
196func TestOAuthClient_GetHTTPClient(t *testing.T) {
197 cfg := &config.OAuthConfig{
198 ClientID: "test-client",
199 ClientSecret: "test-secret",
200 AuthURL: "https://example.com/auth",
201 TokenURL: "https://example.com/token",
202 Enabled: true,
203 }
204
205 client, err := NewOAuthClient(cfg)
206 assert.NoError(t, err)
207
208 httpClient := client.GetHTTPClient()
209 assert.NotNil(t, httpClient)
210
211 // Should return default client when not authenticated
212 assert.Equal(t, 30*time.Second, httpClient.Timeout)
213}
214
215func TestOAuthClient_AddAuthHeader(t *testing.T) {
216 cfg := &config.OAuthConfig{
217 ClientID: "test-client",
218 ClientSecret: "test-secret",
219 AuthURL: "https://example.com/auth",
220 TokenURL: "https://example.com/token",
221 Enabled: true,
222 }
223
224 client, err := NewOAuthClient(cfg)
225 assert.NoError(t, err)
226
227 req, _ := http.NewRequest("GET", "https://example.com/api", nil)
228
229 // Should fail when no token is set
230 err = client.AddAuthHeader(req)
231 assert.Error(t, err, "expected error when no token is available")
232
233 // Set a token and try again
234 token := &oauth2.Token{
235 AccessToken: "test-token",
236 TokenType: "Bearer",
237 }
238 client.SetToken(token)
239
240 err = client.AddAuthHeader(req)
241 assert.NoError(t, err)
242
243 authHeader := req.Header.Get("Authorization")
244 assert.Equal(t, "Bearer test-token", authHeader)
245}
246
247func TestOAuthClient_RefreshToken(t *testing.T) {
248 // Create mock OAuth2 server
249 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
250 if r.URL.Path == "/token" {
251 w.Header().Set("Content-Type", "application/json")
252 response := map[string]interface{}{
253 "access_token": "refreshed-access-token",
254 "refresh_token": "new-refresh-token",
255 "token_type": "Bearer",
256 "expires_in": 3600,
257 }
258 json.NewEncoder(w).Encode(response)
259 return
260 }
261 http.NotFound(w, r)
262 }))
263 defer server.Close()
264
265 cfg := &config.OAuthConfig{
266 ClientID: "test-client",
267 ClientSecret: "test-secret",
268 AuthURL: server.URL + "/auth",
269 TokenURL: server.URL + "/token",
270 Enabled: true,
271 }
272
273 client, err := NewOAuthClient(cfg)
274 assert.NoError(t, err)
275
276 // Set an initial token with refresh token
277 initialToken := &oauth2.Token{
278 AccessToken: "initial-token",
279 RefreshToken: "refresh-token",
280 Expiry: time.Now().Add(-time.Hour), // Expired
281 }
282 client.SetToken(initialToken)
283
284 err = client.RefreshToken(context.Background())
285 assert.NoError(t, err)
286
287 token := client.GetToken()
288 assert.NotNil(t, token)
289 assert.Equal(t, "refreshed-access-token", token.AccessToken)
290}
291
292func TestOAuthClient_RefreshToken_Errors(t *testing.T) {
293 cfg := &config.OAuthConfig{
294 ClientID: "test-client",
295 ClientSecret: "test-secret",
296 AuthURL: "https://example.com/auth",
297 TokenURL: "https://example.com/token",
298 Enabled: true,
299 }
300
301 client, err := NewOAuthClient(cfg)
302 assert.NoError(t, err)
303
304 // Test refresh without token
305 err = client.RefreshToken(context.Background())
306 assert.Error(t, err, "expected error when refreshing without token")
307
308 // Test refresh with unconfigured OAuth2
309 client.oauth2Config = nil
310 client.SetToken(&oauth2.Token{AccessToken: "test"})
311 err = client.RefreshToken(context.Background())
312 assert.Error(t, err, "expected error when OAuth2 not configured")
313}
314
315func TestOAuthClient_TokenExpiration(t *testing.T) {
316 cfg := &config.OAuthConfig{
317 ClientID: "test-client",
318 ClientSecret: "test-secret",
319 TokenURL: "https://example.com/token",
320 Enabled: true,
321 }
322
323 client, err := NewOAuthClient(cfg)
324 assert.NoError(t, err)
325
326 // Test with token expiring soon (within 5 minute buffer)
327 soonExpiredToken := &oauth2.Token{
328 AccessToken: "test-token",
329 Expiry: time.Now().Add(2 * time.Minute), // Expires in 2 minutes
330 }
331 client.SetToken(soonExpiredToken)
332
333 assert.False(t, client.IsAuthenticated(), "client should not be authenticated with soon-to-expire token")
334
335 // Test with token expiring later (outside 5 minute buffer)
336 validToken := &oauth2.Token{
337 AccessToken: "test-token",
338 Expiry: time.Now().Add(10 * time.Minute), // Expires in 10 minutes
339 }
340 client.SetToken(validToken)
341
342 assert.True(t, client.IsAuthenticated(), "client should be authenticated with valid token")
343}
344
345func TestOAuthClient_ErrorHandling(t *testing.T) {
346 cfg := &config.OAuthConfig{
347 ClientID: "test-client",
348 ClientSecret: "test-secret",
349 AuthURL: "https://invalid-server.example.com/auth",
350 TokenURL: "https://invalid-server.example.com/token",
351 Enabled: true,
352 }
353
354 client, err := NewOAuthClient(cfg)
355 assert.NoError(t, err)
356
357 // Test client credentials flow with invalid server
358 err = client.ClientCredentialsFlow(context.Background())
359 assert.Error(t, err, "expected error when connecting to invalid server")
360
361 // Test code exchange with invalid server
362 err = client.ExchangeCode(context.Background(), "test-code")
363 assert.Error(t, err, "expected error when connecting to invalid server")
364}
365
366func TestOAuthClient_DisabledOAuth(t *testing.T) {
367 cfg := &config.OAuthConfig{
368 ClientID: "test-client",
369 ClientSecret: "test-secret",
370 Enabled: false,
371 }
372
373 client, err := NewOAuthClient(cfg)
374 assert.NoError(t, err)
375
376 httpClient := client.GetHTTPClient()
377 assert.NotNil(t, httpClient, "HTTP client should not be nil even when OAuth is disabled")
378
379 authURL := client.GetAuthURL("test-state")
380 assert.Empty(t, authURL, "auth URL should be empty when OAuth is disabled")
381
382 err = client.ClientCredentialsFlow(context.Background())
383 assert.Error(t, err, "expected error when trying client credentials flow with disabled OAuth")
384}
385
386func TestOAuthClient_ContextCancellation(t *testing.T) {
387 cfg := &config.OAuthConfig{
388 ClientID: "test-client",
389 ClientSecret: "test-secret",
390 TokenURL: "https://example.com/token",
391 Enabled: true,
392 }
393
394 client, err := NewOAuthClient(cfg)
395 assert.NoError(t, err)
396
397 // Create a cancelled context
398 ctx, cancel := context.WithCancel(context.Background())
399 cancel()
400
401 err = client.ClientCredentialsFlow(ctx)
402 assert.Error(t, err, "expected error with cancelled context")
403}