Balance annotation worker backpressure
Change-Id: I93db9fcd7840c633c76d808cc241eccf12e63324
diff --git a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
index 8736ba7..75c3ed4 100644
--- a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
+++ b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
@@ -6,21 +6,77 @@
import java.lang.Thread.sleep
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.Semaphore
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Logger
private const val BUFFER_SIZE = 10000000
private const val HIGH_WATERMARK = 1000000
+private const val DEFAULT_BUFFER_HEAP_DIVISOR = 16L
+private const val MIN_BUFFER_BYTES = 256L * 1024 * 1024
+private const val MAX_BUFFER_BYTES = 8L * 1024 * 1024 * 1024
+private const val APPROX_BYTES_PER_BUFFER_UNIT = 2L * HIGH_WATERMARK
+
+internal fun parseKorapXmlToolXmxToBytes(spec: String?): Long? {
+ if (spec.isNullOrBlank()) return null
+ val trimmed = spec.trim()
+ val match = Regex("""^(\d+)([kKmMgGtT]?)$""").matchEntire(trimmed) ?: return null
+ val amount = match.groupValues[1].toLongOrNull() ?: return null
+ return when (match.groupValues[2].lowercase()) {
+ "" -> amount * 1024 * 1024
+ "k" -> amount * 1024
+ "m" -> amount * 1024 * 1024
+ "g" -> amount * 1024 * 1024 * 1024
+ "t" -> amount * 1024 * 1024 * 1024 * 1024
+ else -> null
+ }
+}
+
+internal fun annotationWorkerHeapBudgetBytes(
+ env: Map<String, String> = System.getenv(),
+ runtimeMaxBytes: Long = Runtime.getRuntime().maxMemory()
+): Long {
+ val envXmxBytes = parseKorapXmlToolXmxToBytes(env["KORAPXMLTOOL_XMX"])
+ val positiveRuntimeMaxBytes = runtimeMaxBytes.takeIf { it > 0 } ?: (4L * 1024 * 1024 * 1024)
+ return envXmxBytes?.coerceAtMost(positiveRuntimeMaxBytes) ?: positiveRuntimeMaxBytes
+}
+
+internal fun defaultBufferedTaskUnits(
+ numWorkers: Int,
+ env: Map<String, String> = System.getenv(),
+ runtimeMaxBytes: Long = Runtime.getRuntime().maxMemory()
+): Int {
+ val heapBudgetBytes = annotationWorkerHeapBudgetBytes(env, runtimeMaxBytes)
+ val targetBufferBytes = (heapBudgetBytes / DEFAULT_BUFFER_HEAP_DIVISOR)
+ .coerceIn(MIN_BUFFER_BYTES, MAX_BUFFER_BYTES)
+ val heapBasedUnits = (targetBufferBytes / APPROX_BYTES_PER_BUFFER_UNIT).toInt().coerceAtLeast(1)
+ return maxOf(heapBasedUnits, maxOf(numWorkers * 32, 128))
+}
+
+internal fun defaultQueuedTasks(numWorkers: Int, maxBufferedTaskUnits: Int): Int {
+ return maxOf(numWorkers * 64, minOf(maxBufferedTaskUnits * 2, 4096), 256)
+}
class AnnotationWorkerPool(
private val command: String,
private val numWorkers: Int,
private val LOGGER: Logger,
private val outputHandler: ((String, AnnotationTask?) -> Unit)? = null,
- private val stderrLogPath: String? = null
+ private val stderrLogPath: String? = null,
+ // Bound buffered task text globally by approximate character count, not by document count.
+ // Size the budget from KORAPXMLTOOL_XMX or the JVM max heap so large heaps can keep more
+ // annotation workers busy without returning to unbounded buffering.
+ private val maxBufferedTaskUnits: Int = defaultBufferedTaskUnits(numWorkers),
+ private val maxQueuedTasks: Int = defaultQueuedTasks(numWorkers, maxBufferedTaskUnits)
) {
- private val queue: BlockingQueue<AnnotationTask> = LinkedBlockingQueue()
+ init {
+ require(maxQueuedTasks > 0) { "maxQueuedTasks must be at least 1" }
+ require(maxBufferedTaskUnits > 0) { "maxBufferedTaskUnits must be at least 1" }
+ }
+
+ private val queue: BlockingQueue<AnnotationTask> = LinkedBlockingQueue(maxQueuedTasks)
+ private val bufferedTaskPermits = Semaphore(maxBufferedTaskUnits, true)
private val threads = mutableListOf<Thread>()
private val threadCount = AtomicInteger(0)
private val threadsLock = Any()
@@ -34,11 +90,39 @@
null
}
- data class AnnotationTask(val text: String, val docId: String?, val entryPath: String?)
+ data class AnnotationTask(
+ val text: String,
+ val docId: String?,
+ val entryPath: String?,
+ val bufferedUnits: Int = 0
+ )
init {
openWorkerPool()
- LOGGER.info("Annotation worker pool with ${numWorkers} threads opened")
+ LOGGER.info(
+ "Annotation worker pool with ${numWorkers} threads opened " +
+ "(queueCapacity=$maxQueuedTasks, bufferedTaskUnits=$maxBufferedTaskUnits, " +
+ "heapBudgetMB=${annotationWorkerHeapBudgetBytes() / (1024 * 1024)})"
+ )
+ }
+
+ private fun unitsForText(text: String): Int {
+ if (text == "#eof") return 0
+ return maxOf(1, (text.length + HIGH_WATERMARK - 1) / HIGH_WATERMARK)
+ }
+
+ private fun newTask(text: String, docId: String?, entryPath: String?): AnnotationTask {
+ val bufferedUnits = unitsForText(text)
+ if (bufferedUnits > 0) {
+ bufferedTaskPermits.acquire(bufferedUnits)
+ }
+ return AnnotationTask(text, docId, entryPath, bufferedUnits)
+ }
+
+ private fun releaseTaskBuffer(task: AnnotationTask?) {
+ if (task != null && task.bufferedUnits > 0) {
+ bufferedTaskPermits.release(task.bufferedUnits)
+ }
}
private fun openWorkerPool() {
@@ -76,7 +160,7 @@
return@Thread // Exits thread, finally block will run
}
- // Declare pendingTasks here so it's accessible after process exits
+ // pendingTasks tracks tasks already sent to the external process and awaiting output
val pendingTasks: BlockingQueue<AnnotationTask> = LinkedBlockingQueue()
// Using try-with-resources for streams to ensure they are closed
@@ -113,17 +197,18 @@
LOGGER.info("Worker $workerIndex (thread ${self.threadId()}) sent EOF to process and writer is stopping.")
break // Exit while loop
}
- pendingTasks.put(task)
try {
- val trimmed = task.text.trimEnd()
- val dataToSend = if (trimmed.isEmpty()) {
- "# eot\n"
- } else {
- trimmed + "\n\n# eot\n"
+ pendingTasks.put(task)
+ LOGGER.fine("Worker $workerIndex: Sending ${task.text.length} chars to external process")
+ LOGGER.finer("Worker $workerIndex: First 500 chars of data to send:\n${task.text.take(500)}")
+ if (task.text.isNotEmpty()) {
+ outputStreamWriter.write(task.text)
+ if (!task.text.endsWith('\n')) {
+ outputStreamWriter.write('\n'.code)
+ }
+ outputStreamWriter.write('\n'.code)
}
- LOGGER.fine("Worker $workerIndex: Sending ${dataToSend.length} chars to external process")
- LOGGER.finer("Worker $workerIndex: First 500 chars of data to send:\n${dataToSend.take(500)}")
- outputStreamWriter.write(dataToSend)
+ outputStreamWriter.write("# eot\n")
outputStreamWriter.flush()
LOGGER.fine("Worker $workerIndex: Data sent and flushed")
} catch (e: IOException) {
@@ -176,9 +261,11 @@
outputHandler.invoke(output.toString(), task)
} finally {
pendingOutputHandlers.decrementAndGet()
+ releaseTaskBuffer(task)
}
} else {
printOutput(output.toString())
+ releaseTaskBuffer(task)
}
output.clear()
}
@@ -196,10 +283,12 @@
outputHandler.invoke(output.toString(), task)
} finally {
pendingOutputHandlers.decrementAndGet()
+ releaseTaskBuffer(task)
}
} else {
LOGGER.fine("Worker $workerIndex: Printing output (${output.length} chars)")
printOutput(output.toString())
+ releaseTaskBuffer(task)
}
output.clear()
lastLineWasEmpty = false
@@ -220,6 +309,7 @@
outputHandler.invoke(output.toString(), task)
} finally {
pendingOutputHandlers.decrementAndGet()
+ releaseTaskBuffer(task)
}
output.clear()
lastLineWasEmpty = false
@@ -252,9 +342,11 @@
outputHandler.invoke(output.toString(), task)
} finally {
pendingOutputHandlers.decrementAndGet()
+ releaseTaskBuffer(task)
}
} else {
printOutput(output.toString())
+ releaseTaskBuffer(task)
}
}
} catch (e: Exception) {
@@ -375,21 +467,30 @@
}
fun pushToQueue(text: String, docId: String? = null, entryPath: String? = null) {
+ var task: AnnotationTask? = null
try {
- LOGGER.fine("pushToQueue called: text length=${text.length}, docId=$docId, entryPath=$entryPath")
- queue.put(AnnotationTask(text, docId, entryPath))
+ task = newTask(text, docId, entryPath)
+ LOGGER.fine(
+ "pushToQueue called: text length=${text.length}, docId=$docId, entryPath=$entryPath, " +
+ "queueSize=${queue.size}/$maxQueuedTasks, buffered=${maxBufferedTaskUnits - bufferedTaskPermits.availablePermits()}/$maxBufferedTaskUnits"
+ )
+ queue.put(task)
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
+ releaseTaskBuffer(task)
LOGGER.warning("Interrupted while trying to push text to queue.")
}
}
fun pushToQueue(texts: List<String>) {
texts.forEach { text ->
+ var task: AnnotationTask? = null
try {
- queue.put(AnnotationTask(text, null, null))
+ task = newTask(text, null, null)
+ queue.put(task)
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
+ releaseTaskBuffer(task)
LOGGER.warning("Interrupted while trying to push texts to queue. Some texts may not have been added.")
return // Exit early if interrupted
}
@@ -405,7 +506,7 @@
// to ensure we send enough EOF markers even if some threads haven't started yet
for (i in 0 until numWorkers) {
try {
- queue.put(AnnotationTask("#eof", null, null))
+ queue.put(AnnotationTask("#eof", null, null, 0))
LOGGER.info("Sent EOF marker ${i+1}/$numWorkers to queue")
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
diff --git a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXmlTool.kt b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXmlTool.kt
index e7583d7..67e26de 100644
--- a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXmlTool.kt
+++ b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXmlTool.kt
@@ -4015,7 +4015,7 @@
LOGGER.fine("parseAndWriteAnnotatedConllu called with ${annotatedConllu.length} chars, task=$task")
val docId = task?.docId
- val entryPathAndFoundry = task?.entryPath?.split("|") ?: listOf(null, null)
+ val entryPathAndFoundry = task?.entryPath?.split('|', limit = 2) ?: listOf(null, null)
val entryPath = entryPathAndFoundry.getOrNull(0)
val foundry = entryPathAndFoundry.getOrNull(1) ?: "base"
@@ -4025,7 +4025,6 @@
}
val morphoSpans = mutableMapOf<String, MorphoSpan>()
- val lines = annotatedConllu.lines()
var currentStartOffsets: List<Int>? = null
var currentEndOffsets: List<Int>? = null
var tokenIndexInSentence = 0
@@ -4034,7 +4033,7 @@
var sentenceEndOffset: Int? = null
var extractedFoundry: String? = null
- for (line in lines) {
+ for (line in annotatedConllu.lineSequence()) {
when {
line.startsWith("# foundry =") -> {
val foundryStr = line.substring("# foundry =".length).trim()
@@ -5846,4 +5845,3 @@
try { Locale.setDefault(Locale.ROOT) } catch (_: Exception) {}
return CommandLine(KorapXmlTool()).execute(*args)
}
-
diff --git a/app/src/test/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPoolTest.kt b/app/src/test/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPoolTest.kt
new file mode 100644
index 0000000..156a0fb
--- /dev/null
+++ b/app/src/test/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPoolTest.kt
@@ -0,0 +1,62 @@
+package de.ids_mannheim.korapxmltools
+
+import org.junit.Test
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.logging.Logger
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
+import kotlin.test.assertEquals
+
+class AnnotationWorkerPoolTest {
+
+ @Test
+ fun defaultBufferedTaskUnitsScaleWithKorapXmlToolXmx() {
+ val smallHeapUnits = defaultBufferedTaskUnits(
+ numWorkers = 16,
+ env = mapOf("KORAPXMLTOOL_XMX" to "4g"),
+ runtimeMaxBytes = 4L * 1024 * 1024 * 1024
+ )
+ val largeHeapUnits = defaultBufferedTaskUnits(
+ numWorkers = 16,
+ env = mapOf("KORAPXMLTOOL_XMX" to "100g"),
+ runtimeMaxBytes = 100L * 1024 * 1024 * 1024
+ )
+
+ assertEquals(512, smallHeapUnits, "4g heap should keep enough buffered work for 16 workers")
+ assertTrue(largeHeapUnits >= 3200, "100g heap should allow a much larger buffered-text budget")
+ assertTrue(largeHeapUnits > smallHeapUnits, "Larger KORAPXMLTOOL_XMX should increase buffered task units")
+ assertEquals(4096, defaultQueuedTasks(16, largeHeapUnits), "Queue size should scale up but stay capped")
+ }
+
+ @Test
+ fun pushToQueueBlocksWhenQueueCapacityIsReached() {
+ val pool = AnnotationWorkerPool(
+ command = "cat",
+ numWorkers = 0,
+ LOGGER = Logger.getLogger("AnnotationWorkerPoolTest"),
+ maxQueuedTasks = 1
+ )
+
+ val firstQueued = CountDownLatch(1)
+ val secondQueued = AtomicBoolean(false)
+
+ val producer = Thread {
+ pool.pushToQueue("first")
+ firstQueued.countDown()
+ pool.pushToQueue("second")
+ secondQueued.set(true)
+ }
+
+ producer.start()
+
+ assertTrue(firstQueued.await(2, TimeUnit.SECONDS), "First task should be queued quickly")
+ Thread.sleep(200)
+ assertFalse(secondQueued.get(), "Second task should block once the bounded queue is full")
+
+ producer.interrupt()
+ producer.join(2000)
+ assertFalse(producer.isAlive, "Producer thread should stop after interruption")
+ }
+}