Implemented rate limit for authenticated users
[AI assisted]
Change-Id: Iebd249c8f9d4ef6ede4f2acb2b651d63d3a43fe0
diff --git a/Changes b/Changes
index a3b2545..e3672a4 100644
--- a/Changes
+++ b/Changes
@@ -6,10 +6,11 @@
- Change Userdata to use String username instead of integer userId
- Allow admin to create groups with name length less than 3 characters
to support existing groups from C2
+- Implemented rate limit for authenticated users (with AI assistance)
# version 1.1
-- Change VC and query names to lowercase
+- Change VC and query names to lowercase (with AI assistance)
- Implemented pipe response rewriting for match info API (#814)
- Add pipe response for the metadata web-service.
diff --git a/src/main/java/de/ids_mannheim/korap/core/web/controller/SearchController.java b/src/main/java/de/ids_mannheim/korap/core/web/controller/SearchController.java
index 45010a7..4692cd8 100644
--- a/src/main/java/de/ids_mannheim/korap/core/web/controller/SearchController.java
+++ b/src/main/java/de/ids_mannheim/korap/core/web/controller/SearchController.java
@@ -28,6 +28,7 @@
import de.ids_mannheim.korap.web.filter.APIDeprecationFilter;
import de.ids_mannheim.korap.web.filter.APIVersionFilter;
import de.ids_mannheim.korap.web.filter.AdminFilter;
+import de.ids_mannheim.korap.web.filter.RateLimitFilter;
import de.ids_mannheim.korap.web.filter.AuthenticationFilter;
import de.ids_mannheim.korap.web.filter.DemoUserFilter;
import de.ids_mannheim.korap.web.utils.ResourceFilters;
@@ -58,7 +59,7 @@
@Controller
@Path("/")
@ResourceFilters({ APIVersionFilter.class, AuthenticationFilter.class,
- DemoUserFilter.class})
+ DemoUserFilter.class, RateLimitFilter.class })
public class SearchController {
private static final boolean DEBUG = false;
diff --git a/src/main/java/de/ids_mannheim/korap/web/filter/RateLimitFilter.java b/src/main/java/de/ids_mannheim/korap/web/filter/RateLimitFilter.java
new file mode 100644
index 0000000..184c233
--- /dev/null
+++ b/src/main/java/de/ids_mannheim/korap/web/filter/RateLimitFilter.java
@@ -0,0 +1,211 @@
+package de.ids_mannheim.korap.web.filter;
+
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.time.Duration;
+import java.util.Base64;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.springframework.stereotype.Component;
+
+import jakarta.annotation.Priority;
+import jakarta.ws.rs.Priorities;
+import jakarta.ws.rs.WebApplicationException;
+import jakarta.ws.rs.container.ContainerRequestContext;
+import jakarta.ws.rs.container.ContainerRequestFilter;
+import jakarta.ws.rs.core.HttpHeaders;
+import jakarta.ws.rs.core.Response;
+
+/** Implemented with AI assistance
+ *
+ * Simple in-memory rate limitation for authenticated users.
+ * <p>
+ * Keyed by bearer token (preferred) or username (fallback).
+ * <p>
+ * Note: In-memory means per-JVM only. For clustered deployments, use Redis/etc.
+ */
+@Component
+@Priority(Priorities.AUTHORIZATION)
+public class RateLimitFilter implements ContainerRequestFilter {
+
+ private static final Logger jlog = LogManager
+ .getLogger(RateLimitFilter.class);
+
+ // Defaults: 60 requests per minute per key
+ // Keep these conservative and easy to change later via config injection.
+ private static final long REFILL_TOKENS = 60;
+ private static final Duration REFILL_PERIOD = Duration.ofMinutes(1);
+ public static final long BURST_CAPACITY = 60;
+
+ /**
+ * Prevent unbounded growth: keep at most this many distinct keys in-memory.
+ */
+ private static final int MAX_BUCKETS = 10_000;
+
+ /**
+ * Evict buckets that haven't been seen for this long.
+ */
+ private static final Duration BUCKET_TTL = Duration.ofHours(6);
+
+ private final ConcurrentHashMap<String, BucketEntry> buckets = new ConcurrentHashMap<>();
+
+ @Override
+ public void filter (ContainerRequestContext request) {
+ // Only apply to authenticated requests
+ if (request.getSecurityContext() == null
+ || request.getSecurityContext().getUserPrincipal() == null) {
+ return;
+ }
+
+ String key = resolveKey(request);
+ if (key == null) {
+ return;
+ }
+
+ long now = System.currentTimeMillis();
+
+ // Opportunistic cleanup to avoid memory growth.
+ // Do it only on inserts or if we grow too large.
+ if (buckets.size() > MAX_BUCKETS) {
+ cleanupOldEntries(now);
+ }
+
+ BucketEntry entry = buckets.compute(key, (k, existing) -> {
+ if (existing == null) {
+ // If we're still too large, try another cleanup pass before adding.
+ if (buckets.size() > MAX_BUCKETS) {
+ cleanupOldEntries(now);
+ }
+ return new BucketEntry(new TokenBucket(BURST_CAPACITY,
+ REFILL_TOKENS, REFILL_PERIOD.toMillis()), now);
+ }
+ existing.lastSeenAtMillis = now;
+ return existing;
+ });
+
+ if (!entry.bucket.tryConsume(1)) {
+ long retryAfterSeconds = Math
+ .max(1, entry.bucket.millisUntilNextToken() / 1000);
+
+ throw new WebApplicationException(Response.status(429)
+ .header("Retry-After", String.valueOf(retryAfterSeconds))
+ .entity("Rate limit exceeded")
+ .build());
+ }
+ }
+
+ private void cleanupOldEntries (long nowMillis) {
+ final long cutoff = nowMillis - BUCKET_TTL.toMillis();
+ buckets.entrySet().removeIf(e -> e.getValue().lastSeenAtMillis < cutoff);
+
+ // Still too big? Remove arbitrary entries (best-effort bound).
+ if (buckets.size() > MAX_BUCKETS) {
+ int toRemove = buckets.size() - MAX_BUCKETS;
+ for (String k : buckets.keySet()) {
+ buckets.remove(k);
+ if (--toRemove <= 0)
+ break;
+ }
+ }
+ }
+
+ private String resolveKey (ContainerRequestContext request) {
+ // Prefer bearer token if present
+ String authorization = request.getHeaderString(HttpHeaders.AUTHORIZATION);
+ if (authorization != null
+ && authorization.regionMatches(true, 0, "Bearer ", 0, 7)) {
+ String token = authorization.substring(7).trim();
+ if (!token.isEmpty()) {
+ return "bearer:" + shortHash(token);
+ }
+ }
+
+ // Fallback to username/principal name
+// String name = request.getSecurityContext().getUserPrincipal().getName();
+// if (name != null && !name.isBlank()) {
+// return "user:" + name;
+// }
+
+ return null;
+ }
+
+ private String shortHash (String token) {
+ try {
+ MessageDigest md = MessageDigest.getInstance("SHA-256");
+ byte[] digest = md.digest(token.getBytes(StandardCharsets.UTF_8));
+ // short stable identifier; never store raw token
+ String b64 = Base64.getUrlEncoder().withoutPadding()
+ .encodeToString(digest);
+ return b64.substring(0, Math.min(16, b64.length()));
+ }
+ catch (Exception e) {
+ // extremely unlikely; fallback to deterministic-ish hash
+ String fallback = Integer.toHexString(Objects.hashCode(token));
+ jlog.warn("Could not hash token securely, using fallback hash");
+ return fallback;
+ }
+ }
+
+ /**
+ * Minimal token bucket with lazy refill.
+ */
+ static final class TokenBucket {
+ private final long capacity;
+ private final long refillTokens;
+ private final long refillPeriodMillis;
+
+ private long tokens;
+ private long lastRefillAtMillis;
+
+ TokenBucket (long capacity, long refillTokens, long refillPeriodMillis) {
+ this.capacity = capacity;
+ this.refillTokens = refillTokens;
+ this.refillPeriodMillis = refillPeriodMillis;
+ this.tokens = capacity;
+ this.lastRefillAtMillis = System.currentTimeMillis();
+ }
+
+ synchronized boolean tryConsume (long n) {
+ refill();
+ if (tokens >= n) {
+ tokens -= n;
+ return true;
+ }
+ return false;
+ }
+
+ synchronized long millisUntilNextToken () {
+ refill();
+ if (tokens > 0)
+ return 0;
+ long now = System.currentTimeMillis();
+ long elapsed = now - lastRefillAtMillis;
+ return Math.max(0, refillPeriodMillis - elapsed);
+ }
+
+ private void refill () {
+ long now = System.currentTimeMillis();
+ long elapsed = now - lastRefillAtMillis;
+ if (elapsed < refillPeriodMillis)
+ return;
+
+ long periods = elapsed / refillPeriodMillis;
+ long add = periods * refillTokens;
+ tokens = Math.min(capacity, tokens + add);
+ lastRefillAtMillis += periods * refillPeriodMillis;
+ }
+ }
+
+ private static final class BucketEntry {
+ final TokenBucket bucket;
+ volatile long lastSeenAtMillis;
+
+ BucketEntry (TokenBucket bucket, long lastSeenAtMillis) {
+ this.bucket = bucket;
+ this.lastSeenAtMillis = lastSeenAtMillis;
+ }
+ }
+}
diff --git a/src/test/java/de/ids_mannheim/korap/web/controller/RateLimitTest.java b/src/test/java/de/ids_mannheim/korap/web/controller/RateLimitTest.java
new file mode 100644
index 0000000..e60d278
--- /dev/null
+++ b/src/test/java/de/ids_mannheim/korap/web/controller/RateLimitTest.java
@@ -0,0 +1,41 @@
+package de.ids_mannheim.korap.web.controller;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import org.junit.jupiter.api.Test;
+
+import com.fasterxml.jackson.databind.JsonNode;
+
+import de.ids_mannheim.korap.exceptions.KustvaktException;
+import de.ids_mannheim.korap.utils.JsonUtils;
+import de.ids_mannheim.korap.web.controller.oauth2.OAuth2TestBase;
+import de.ids_mannheim.korap.web.filter.RateLimitFilter;
+import jakarta.ws.rs.core.Response;
+import jakarta.ws.rs.core.Response.Status;
+
+/**
+ * Verifies authenticated rate limiting (HTTP 429) is applied after
+ * auth.
+ */
+public class RateLimitTest extends OAuth2TestBase {
+
+ @Test
+ public void testAuthenticatedRateLimitBearerToken ()
+ throws KustvaktException {
+ Response response = requestTokenWithDoryPassword(superClientId,
+ clientSecret);
+ JsonNode node = JsonUtils.readTree(response.readEntity(String.class));
+ String accessToken = node.at("/access_token").asText();
+
+ for (long i = 0; i < RateLimitFilter.BURST_CAPACITY; i++) {
+ Response r = searchWithAccessToken(accessToken);
+ assertEquals(Status.OK.getStatusCode(), r.getStatus(),
+ "request " + i);
+ r.close();
+ }
+
+ Response limited = searchWithAccessToken(accessToken);
+ assertEquals(429, limited.getStatus());
+ limited.close();
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/de/ids_mannheim/korap/web/lite/RateLimitAnonymousTest.java b/src/test/java/de/ids_mannheim/korap/web/lite/RateLimitAnonymousTest.java
new file mode 100644
index 0000000..f78c269
--- /dev/null
+++ b/src/test/java/de/ids_mannheim/korap/web/lite/RateLimitAnonymousTest.java
@@ -0,0 +1,29 @@
+package de.ids_mannheim.korap.web.lite;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import org.junit.jupiter.api.Test;
+
+import de.ids_mannheim.korap.config.LiteJerseyTest;
+import jakarta.ws.rs.core.Response;
+import jakarta.ws.rs.core.Response.Status;
+
+/**
+ * Verifies unauthenticated requests are not rate-limited.
+ */
+public class RateLimitAnonymousTest extends LiteJerseyTest {
+
+ @Test
+ public void testUnauthenticatedNotRateLimited () {
+ // No Authorization header: should remain unauthenticated and not be limited.
+ for (int i = 0; i < 80; i++) {
+ Response r = target().path(API_VERSION).path("search")
+ .queryParam("q", "[orth=das]")
+ .queryParam("ql", "poliqarp")
+ .request().get();
+ assertEquals(Status.OK.getStatusCode(), r.getStatus(),
+ "request " + i);
+ r.close();
+ }
+ }
+}
\ No newline at end of file