blob: 0009d713522f5ca46293013a39d59fe7a70f59fd [file] [log] [blame]
package service
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/korap/korap-mcp/auth"
"github.com/korap/korap-mcp/config"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)
// TestNewClient tests client creation
func TestNewClient(t *testing.T) {
tests := []struct {
name string
opts ClientOptions
wantErr bool
errMsg string
}{
{
name: "valid client without oauth",
opts: ClientOptions{
BaseURL: "https://example.com",
Timeout: 10 * time.Second,
},
wantErr: false,
},
{
name: "valid client with oauth",
opts: ClientOptions{
BaseURL: "https://example.com",
OAuthConfig: &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
Enabled: true,
},
},
wantErr: false,
},
{
name: "invalid client - no base URL",
opts: ClientOptions{
Timeout: 10 * time.Second,
},
wantErr: true,
errMsg: "base URL is required",
},
{
name: "invalid oauth config",
opts: ClientOptions{
BaseURL: "https://example.com",
OAuthConfig: &config.OAuthConfig{
Enabled: true, // Missing required fields
},
},
wantErr: true,
errMsg: "failed to create OAuth client",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewClient(tt.opts)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
return
}
assert.NoError(t, err)
assert.NotNil(t, client)
// Check base URL normalization
assert.True(t, strings.HasSuffix(client.baseURL, "/"), "Base URL should end with /")
// Check default timeout
if tt.opts.Timeout == 0 {
assert.Equal(t, 30*time.Second, client.httpClient.Timeout, "Default timeout should be 30 seconds")
}
})
}
}
// TestClientMethods tests basic client HTTP methods
func TestClientMethods(t *testing.T) {
// Create a mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/test":
if r.Method == http.MethodGet {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "success"})
} else if r.Method == http.MethodPost {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"created": "true"})
}
case "/error":
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(ErrorResponse{
Error: "bad_request",
ErrorDescription: "Invalid request parameter",
})
case "/":
// Root endpoint for ping test
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client, err := NewClient(ClientOptions{
BaseURL: server.URL,
Timeout: 5 * time.Second,
})
assert.NoError(t, err)
ctx := context.Background()
t.Run("GET request", func(t *testing.T) {
resp, err := client.Get(ctx, "/test")
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("POST request", func(t *testing.T) {
body := map[string]string{"key": "value"}
resp, err := client.Post(ctx, "/test", body)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusCreated, resp.StatusCode)
})
t.Run("GetJSON", func(t *testing.T) {
var result map[string]string
err := client.GetJSON(ctx, "/test", &result)
assert.NoError(t, err)
assert.Equal(t, "success", result["message"])
})
t.Run("PostJSON", func(t *testing.T) {
body := map[string]string{"key": "value"}
var result map[string]string
err := client.PostJSON(ctx, "/test", body, &result)
assert.NoError(t, err)
assert.Equal(t, "true", result["created"])
})
t.Run("Error handling", func(t *testing.T) {
var result map[string]any
err := client.GetJSON(ctx, "/error", &result)
assert.Error(t, err)
apiErr, ok := err.(*APIError)
assert.True(t, ok, "Expected APIError")
assert.Equal(t, http.StatusBadRequest, apiErr.StatusCode)
assert.Equal(t, "bad_request", apiErr.Message)
})
t.Run("Ping", func(t *testing.T) {
err := client.Ping(ctx)
assert.NoError(t, err)
})
}
// TestClientAuthentication tests OAuth2 integration
func TestClientAuthentication(t *testing.T) {
// Create a mock OAuth2 server
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" {
// Mock token endpoint
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "test-access-token",
"token_type": "Bearer",
"expires_in": 3600,
})
}
}))
defer authServer.Close()
// Create a mock API server that checks authentication
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != "Bearer test-access-token" {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(ErrorResponse{
Error: "unauthorized",
ErrorDescription: "Invalid or missing access token",
})
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"authenticated": "true"})
}))
defer apiServer.Close()
// Create client with OAuth2 configuration
oauthConfig := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenURL: authServer.URL + "/token",
Enabled: true,
}
client, err := NewClient(ClientOptions{
BaseURL: apiServer.URL,
OAuthConfig: oauthConfig,
})
assert.NoError(t, err)
ctx := context.Background()
t.Run("Unauthenticated request", func(t *testing.T) {
var result map[string]any
err := client.GetJSON(ctx, "/", &result)
assert.Error(t, err)
apiErr, ok := err.(*APIError)
assert.True(t, ok, "Expected APIError")
assert.Equal(t, http.StatusUnauthorized, apiErr.StatusCode)
})
t.Run("Client credentials authentication", func(t *testing.T) {
err := client.AuthenticateWithClientCredentials(ctx)
assert.NoError(t, err)
assert.True(t, client.IsAuthenticated(), "Client should be authenticated")
})
t.Run("Authenticated request", func(t *testing.T) {
var result map[string]any
err := client.GetJSON(ctx, "/", &result)
assert.NoError(t, err)
assert.Equal(t, "true", result["authenticated"])
})
}
// TestURLBuilding tests URL construction logic
func TestURLBuilding(t *testing.T) {
client, err := NewClient(ClientOptions{
BaseURL: "https://example.com/api/v1",
})
assert.NoError(t, err)
tests := []struct {
endpoint string
expected string
}{
{"/search", "https://example.com/api/v1/search"},
{"search", "https://example.com/api/v1/search"},
{"/corpus/info", "https://example.com/api/v1/corpus/info"},
{"", "https://example.com/api/v1/"},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("endpoint_%s", tt.endpoint), func(t *testing.T) {
url, err := client.buildURL(tt.endpoint)
assert.NoError(t, err)
assert.Equal(t, tt.expected, url)
})
}
}
// TestOAuthClientOperations tests OAuth client operations
func TestOAuthClientOperations(t *testing.T) {
client, err := NewClient(ClientOptions{
BaseURL: "https://example.com",
})
assert.NoError(t, err)
// Test setting OAuth client
oauthConfig := &config.OAuthConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
TokenURL: "https://example.com/token",
Enabled: true,
}
oauthClient, err := auth.NewOAuthClient(oauthConfig)
assert.NoError(t, err)
// Set a mock token
token := &oauth2.Token{
AccessToken: "test-token",
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
}
oauthClient.SetToken(token)
client.SetOAuthClient(oauthClient)
assert.True(t, client.IsAuthenticated(), "Client should be authenticated after setting valid token")
assert.Equal(t, "https://example.com/", client.GetBaseURL())
}