Add configurable CORS
Change-Id: I6b902e028c1f192987b4d7d6415aad456137c520
diff --git a/cmd/koralmapper/main.go b/cmd/koralmapper/main.go
index eac2984..50bdace 100644
--- a/cmd/koralmapper/main.go
+++ b/cmd/koralmapper/main.go
@@ -20,6 +20,7 @@
"github.com/KorAP/Koral-Mapper/mapper"
"github.com/alecthomas/kong"
"github.com/gofiber/fiber/v2"
+ "github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/limiter"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@@ -335,6 +336,18 @@
return c.Next()
})
+ // CORS middleware to allow cross-origin requests from trusted
+ // origins. Required because the service is designed to be
+ // called as a KorAP/Kalamar plugin from cross-origin iframes.
+ // Configurable via the "allowOrigins" YAML key or the
+ // KORAL_MAPPER_ALLOW_ORIGINS environment variable
+ // (default: "https://korap.ids-mannheim.de").
+ app.Use(cors.New(cors.Config{
+ AllowOrigins: yamlConfig.AllowOrigins,
+ AllowMethods: "GET,POST",
+ AllowHeaders: "Content-Type",
+ }))
+
// Rate limiting middleware to prevent resource exhaustion from
// request floods. The maximum number of requests per minute
// per IP is configurable via the "rateLimit" YAML key or the
diff --git a/cmd/koralmapper/main_test.go b/cmd/koralmapper/main_test.go
index 92318a3..fbabbb7 100644
--- a/cmd/koralmapper/main_test.go
+++ b/cmd/koralmapper/main_test.go
@@ -2742,3 +2742,251 @@
assert.Greater(t, idxA, idxZ, "mapper-a should appear after mapper-z")
assert.Greater(t, idxM, idxA, "mapper-m should appear after mapper-a")
}
+
+// TestCORSHeadersDefault verifies that CORS headers are present with
+// the default AllowOrigins value when no custom origin is configured.
+func TestCORSHeadersDefault(t *testing.T) {
+ mappingList := tmconfig.MappingList{
+ ID: "test-mapper",
+ Mappings: []tmconfig.MappingRule{"[A] <> [B]"},
+ }
+
+ m, err := mapper.NewMapper([]tmconfig.MappingList{mappingList})
+ require.NoError(t, err)
+
+ mockConfig := &tmconfig.MappingConfig{Lists: []tmconfig.MappingList{mappingList}}
+ tmconfig.ApplyDefaults(mockConfig)
+
+ app := fiber.New()
+ setupRoutes(app, m, mockConfig)
+
+ // Preflight OPTIONS request with default allowed origin
+ req := httptest.NewRequest(http.MethodOptions, "/health", nil)
+ req.Header.Set("Origin", "https://korap.ids-mannheim.de")
+ req.Header.Set("Access-Control-Request-Method", "GET")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, "https://korap.ids-mannheim.de",
+ resp.Header.Get("Access-Control-Allow-Origin"),
+ "default AllowOrigins should include korap.ids-mannheim.de")
+ assert.NotEmpty(t, resp.Header.Get("Access-Control-Allow-Methods"))
+
+ // Actual GET request should include CORS headers too
+ req = httptest.NewRequest(http.MethodGet, "/health", nil)
+ req.Header.Set("Origin", "https://korap.ids-mannheim.de")
+ resp, err = app.Test(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "https://korap.ids-mannheim.de",
+ resp.Header.Get("Access-Control-Allow-Origin"))
+}
+
+// TestCORSHeadersCustomOrigin verifies that a custom AllowOrigins value
+// is propagated to CORS response headers.
+func TestCORSHeadersCustomOrigin(t *testing.T) {
+ mappingList := tmconfig.MappingList{
+ ID: "test-mapper",
+ Mappings: []tmconfig.MappingRule{"[A] <> [B]"},
+ }
+
+ m, err := mapper.NewMapper([]tmconfig.MappingList{mappingList})
+ require.NoError(t, err)
+
+ mockConfig := &tmconfig.MappingConfig{
+ AllowOrigins: "https://custom.example.com",
+ Lists: []tmconfig.MappingList{mappingList},
+ }
+ tmconfig.ApplyDefaults(mockConfig)
+
+ app := fiber.New()
+ setupRoutes(app, m, mockConfig)
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ req.Header.Set("Origin", "https://custom.example.com")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "https://custom.example.com",
+ resp.Header.Get("Access-Control-Allow-Origin"),
+ "custom AllowOrigins should be reflected in the response")
+}
+
+// TestCORSRejectsDisallowedOrigin verifies that requests from origins
+// not listed in AllowOrigins do not receive an Access-Control-Allow-Origin
+// header.
+func TestCORSRejectsDisallowedOrigin(t *testing.T) {
+ mappingList := tmconfig.MappingList{
+ ID: "test-mapper",
+ Mappings: []tmconfig.MappingRule{"[A] <> [B]"},
+ }
+
+ m, err := mapper.NewMapper([]tmconfig.MappingList{mappingList})
+ require.NoError(t, err)
+
+ mockConfig := &tmconfig.MappingConfig{
+ AllowOrigins: "https://allowed.example.com",
+ Lists: []tmconfig.MappingList{mappingList},
+ }
+ tmconfig.ApplyDefaults(mockConfig)
+
+ app := fiber.New()
+ setupRoutes(app, m, mockConfig)
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ req.Header.Set("Origin", "https://evil.example.com")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Empty(t, resp.Header.Get("Access-Control-Allow-Origin"),
+ "disallowed origin must not receive Access-Control-Allow-Origin")
+}
+
+// TestCORSAllowsMultipleOrigins verifies that multiple comma-separated
+// origins are all accepted by the CORS middleware.
+func TestCORSAllowsMultipleOrigins(t *testing.T) {
+ mappingList := tmconfig.MappingList{
+ ID: "test-mapper",
+ Mappings: []tmconfig.MappingRule{"[A] <> [B]"},
+ }
+
+ m, err := mapper.NewMapper([]tmconfig.MappingList{mappingList})
+ require.NoError(t, err)
+
+ mockConfig := &tmconfig.MappingConfig{
+ AllowOrigins: "https://first.example.com,https://second.example.com",
+ Lists: []tmconfig.MappingList{mappingList},
+ }
+ tmconfig.ApplyDefaults(mockConfig)
+
+ app := fiber.New()
+ setupRoutes(app, m, mockConfig)
+
+ for _, origin := range []string{
+ "https://first.example.com",
+ "https://second.example.com",
+ } {
+ t.Run(origin, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ req.Header.Set("Origin", origin)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, origin,
+ resp.Header.Get("Access-Control-Allow-Origin"),
+ "both configured origins should be accepted")
+ })
+ }
+}
+
+// TestCORSPreflightAllowedMethods verifies that the CORS preflight
+// response advertises only GET and POST methods by default.
+func TestCORSPreflightAllowedMethods(t *testing.T) {
+ mappingList := tmconfig.MappingList{
+ ID: "test-mapper",
+ Mappings: []tmconfig.MappingRule{"[A] <> [B]"},
+ }
+
+ m, err := mapper.NewMapper([]tmconfig.MappingList{mappingList})
+ require.NoError(t, err)
+
+ mockConfig := &tmconfig.MappingConfig{Lists: []tmconfig.MappingList{mappingList}}
+ tmconfig.ApplyDefaults(mockConfig)
+
+ app := fiber.New()
+ setupRoutes(app, m, mockConfig)
+
+ req := httptest.NewRequest(http.MethodOptions, "/test-mapper/query", nil)
+ req.Header.Set("Origin", "https://korap.ids-mannheim.de")
+ req.Header.Set("Access-Control-Request-Method", "POST")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ allowedMethods := resp.Header.Get("Access-Control-Allow-Methods")
+ assert.Contains(t, allowedMethods, "GET")
+ assert.Contains(t, allowedMethods, "POST")
+}
+
+// TestCORSPreflightAllowedHeaders verifies that the Content-Type
+// header is allowed in CORS preflight responses.
+func TestCORSPreflightAllowedHeaders(t *testing.T) {
+ mappingList := tmconfig.MappingList{
+ ID: "test-mapper",
+ Mappings: []tmconfig.MappingRule{"[A] <> [B]"},
+ }
+
+ m, err := mapper.NewMapper([]tmconfig.MappingList{mappingList})
+ require.NoError(t, err)
+
+ mockConfig := &tmconfig.MappingConfig{Lists: []tmconfig.MappingList{mappingList}}
+ tmconfig.ApplyDefaults(mockConfig)
+
+ app := fiber.New()
+ setupRoutes(app, m, mockConfig)
+
+ req := httptest.NewRequest(http.MethodOptions, "/test-mapper/query", nil)
+ req.Header.Set("Origin", "https://korap.ids-mannheim.de")
+ req.Header.Set("Access-Control-Request-Method", "POST")
+ req.Header.Set("Access-Control-Request-Headers", "Content-Type")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ allowedHeaders := resp.Header.Get("Access-Control-Allow-Headers")
+ assert.Contains(t, allowedHeaders, "Content-Type")
+}
+
+// TestCORSOnTransformEndpoint verifies CORS headers are present on
+// actual POST requests to the transformation endpoints (not just
+// health check).
+func TestCORSOnTransformEndpoint(t *testing.T) {
+ cfg := loadConfigFromYAML(t, `
+lists:
+ - id: test-mapper
+ foundryA: opennlp
+ layerA: p
+ foundryB: upos
+ layerB: p
+ mappings:
+ - "[PIDAT] <> [DET]"
+`)
+
+ m, err := mapper.NewMapper(cfg.Lists)
+ require.NoError(t, err)
+
+ app := fiber.New()
+ setupRoutes(app, m, cfg)
+
+ input := `{
+ "@type": "koral:token",
+ "wrap": {
+ "@type": "koral:term",
+ "foundry": "opennlp",
+ "key": "PIDAT",
+ "layer": "p",
+ "match": "match:eq"
+ }
+ }`
+
+ req := httptest.NewRequest(http.MethodPost, "/test-mapper/query?dir=atob",
+ bytes.NewBufferString(input))
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Origin", "https://korap.ids-mannheim.de")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "https://korap.ids-mannheim.de",
+ resp.Header.Get("Access-Control-Allow-Origin"),
+ "transform endpoints should include CORS headers")
+}