Balance annotation worker backpressure

Change-Id: I93db9fcd7840c633c76d808cc241eccf12e63324
diff --git a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
index 8736ba7..75c3ed4 100644
--- a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
+++ b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
@@ -6,21 +6,77 @@
 import java.lang.Thread.sleep
 import java.util.concurrent.BlockingQueue
 import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.Semaphore
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.atomic.AtomicInteger
 import java.util.logging.Logger
 
 private const val BUFFER_SIZE = 10000000
 private const val HIGH_WATERMARK = 1000000
+private const val DEFAULT_BUFFER_HEAP_DIVISOR = 16L
+private const val MIN_BUFFER_BYTES = 256L * 1024 * 1024
+private const val MAX_BUFFER_BYTES = 8L * 1024 * 1024 * 1024
+private const val APPROX_BYTES_PER_BUFFER_UNIT = 2L * HIGH_WATERMARK
+
+internal fun parseKorapXmlToolXmxToBytes(spec: String?): Long? {
+    if (spec.isNullOrBlank()) return null
+    val trimmed = spec.trim()
+    val match = Regex("""^(\d+)([kKmMgGtT]?)$""").matchEntire(trimmed) ?: return null
+    val amount = match.groupValues[1].toLongOrNull() ?: return null
+    return when (match.groupValues[2].lowercase()) {
+        "" -> amount * 1024 * 1024
+        "k" -> amount * 1024
+        "m" -> amount * 1024 * 1024
+        "g" -> amount * 1024 * 1024 * 1024
+        "t" -> amount * 1024 * 1024 * 1024 * 1024
+        else -> null
+    }
+}
+
+internal fun annotationWorkerHeapBudgetBytes(
+    env: Map<String, String> = System.getenv(),
+    runtimeMaxBytes: Long = Runtime.getRuntime().maxMemory()
+): Long {
+    val envXmxBytes = parseKorapXmlToolXmxToBytes(env["KORAPXMLTOOL_XMX"])
+    val positiveRuntimeMaxBytes = runtimeMaxBytes.takeIf { it > 0 } ?: (4L * 1024 * 1024 * 1024)
+    return envXmxBytes?.coerceAtMost(positiveRuntimeMaxBytes) ?: positiveRuntimeMaxBytes
+}
+
+internal fun defaultBufferedTaskUnits(
+    numWorkers: Int,
+    env: Map<String, String> = System.getenv(),
+    runtimeMaxBytes: Long = Runtime.getRuntime().maxMemory()
+): Int {
+    val heapBudgetBytes = annotationWorkerHeapBudgetBytes(env, runtimeMaxBytes)
+    val targetBufferBytes = (heapBudgetBytes / DEFAULT_BUFFER_HEAP_DIVISOR)
+        .coerceIn(MIN_BUFFER_BYTES, MAX_BUFFER_BYTES)
+    val heapBasedUnits = (targetBufferBytes / APPROX_BYTES_PER_BUFFER_UNIT).toInt().coerceAtLeast(1)
+    return maxOf(heapBasedUnits, maxOf(numWorkers * 32, 128))
+}
+
+internal fun defaultQueuedTasks(numWorkers: Int, maxBufferedTaskUnits: Int): Int {
+    return maxOf(numWorkers * 64, minOf(maxBufferedTaskUnits * 2, 4096), 256)
+}
 
 class AnnotationWorkerPool(
     private val command: String,
     private val numWorkers: Int,
     private val LOGGER: Logger,
     private val outputHandler: ((String, AnnotationTask?) -> Unit)? = null,
-    private val stderrLogPath: String? = null
+    private val stderrLogPath: String? = null,
+    // Bound buffered task text globally by approximate character count, not by document count.
+    // Size the budget from KORAPXMLTOOL_XMX or the JVM max heap so large heaps can keep more
+    // annotation workers busy without returning to unbounded buffering.
+    private val maxBufferedTaskUnits: Int = defaultBufferedTaskUnits(numWorkers),
+    private val maxQueuedTasks: Int = defaultQueuedTasks(numWorkers, maxBufferedTaskUnits)
 ) {
-    private val queue: BlockingQueue<AnnotationTask> = LinkedBlockingQueue()
+    init {
+        require(maxQueuedTasks > 0) { "maxQueuedTasks must be at least 1" }
+        require(maxBufferedTaskUnits > 0) { "maxBufferedTaskUnits must be at least 1" }
+    }
+
+    private val queue: BlockingQueue<AnnotationTask> = LinkedBlockingQueue(maxQueuedTasks)
+    private val bufferedTaskPermits = Semaphore(maxBufferedTaskUnits, true)
     private val threads = mutableListOf<Thread>()
     private val threadCount = AtomicInteger(0)
     private val threadsLock = Any()
@@ -34,11 +90,39 @@
         null
     }
 
-    data class AnnotationTask(val text: String, val docId: String?, val entryPath: String?)
+    data class AnnotationTask(
+        val text: String,
+        val docId: String?,
+        val entryPath: String?,
+        val bufferedUnits: Int = 0
+    )
 
     init {
         openWorkerPool()
-        LOGGER.info("Annotation worker pool with ${numWorkers} threads opened")
+        LOGGER.info(
+            "Annotation worker pool with ${numWorkers} threads opened " +
+                "(queueCapacity=$maxQueuedTasks, bufferedTaskUnits=$maxBufferedTaskUnits, " +
+                "heapBudgetMB=${annotationWorkerHeapBudgetBytes() / (1024 * 1024)})"
+        )
+    }
+
+    private fun unitsForText(text: String): Int {
+        if (text == "#eof") return 0
+        return maxOf(1, (text.length + HIGH_WATERMARK - 1) / HIGH_WATERMARK)
+    }
+
+    private fun newTask(text: String, docId: String?, entryPath: String?): AnnotationTask {
+        val bufferedUnits = unitsForText(text)
+        if (bufferedUnits > 0) {
+            bufferedTaskPermits.acquire(bufferedUnits)
+        }
+        return AnnotationTask(text, docId, entryPath, bufferedUnits)
+    }
+
+    private fun releaseTaskBuffer(task: AnnotationTask?) {
+        if (task != null && task.bufferedUnits > 0) {
+            bufferedTaskPermits.release(task.bufferedUnits)
+        }
     }
 
     private fun openWorkerPool() {
@@ -76,7 +160,7 @@
                         return@Thread // Exits thread, finally block will run
                     }
 
-                    // Declare pendingTasks here so it's accessible after process exits
+                    // pendingTasks tracks tasks already sent to the external process and awaiting output
                     val pendingTasks: BlockingQueue<AnnotationTask> = LinkedBlockingQueue()
 
                     // Using try-with-resources for streams to ensure they are closed
@@ -113,17 +197,18 @@
                                             LOGGER.info("Worker $workerIndex (thread ${self.threadId()}) sent EOF to process and writer is stopping.")
                                             break // Exit while loop
                                         }
-                                        pendingTasks.put(task)
                                         try {
-                                            val trimmed = task.text.trimEnd()
-                                            val dataToSend = if (trimmed.isEmpty()) {
-                                                "# eot\n"
-                                            } else {
-                                                trimmed + "\n\n# eot\n"
+                                            pendingTasks.put(task)
+                                            LOGGER.fine("Worker $workerIndex: Sending ${task.text.length} chars to external process")
+                                            LOGGER.finer("Worker $workerIndex: First 500 chars of data to send:\n${task.text.take(500)}")
+                                            if (task.text.isNotEmpty()) {
+                                                outputStreamWriter.write(task.text)
+                                                if (!task.text.endsWith('\n')) {
+                                                    outputStreamWriter.write('\n'.code)
+                                                }
+                                                outputStreamWriter.write('\n'.code)
                                             }
-                                            LOGGER.fine("Worker $workerIndex: Sending ${dataToSend.length} chars to external process")
-                                            LOGGER.finer("Worker $workerIndex: First 500 chars of data to send:\n${dataToSend.take(500)}")
-                                            outputStreamWriter.write(dataToSend)
+                                            outputStreamWriter.write("# eot\n")
                                             outputStreamWriter.flush()
                                             LOGGER.fine("Worker $workerIndex: Data sent and flushed")
                                         } catch (e: IOException) {
@@ -176,9 +261,11 @@
                                                                 outputHandler.invoke(output.toString(), task)
                                                             } finally {
                                                                 pendingOutputHandlers.decrementAndGet()
+                                                                releaseTaskBuffer(task)
                                                             }
                                                         } else {
                                                             printOutput(output.toString())
+                                                            releaseTaskBuffer(task)
                                                         }
                                                         output.clear()
                                                     }
@@ -196,10 +283,12 @@
                                                             outputHandler.invoke(output.toString(), task)
                                                         } finally {
                                                             pendingOutputHandlers.decrementAndGet()
+                                                            releaseTaskBuffer(task)
                                                         }
                                                     } else {
                                                         LOGGER.fine("Worker $workerIndex: Printing output (${output.length} chars)")
                                                         printOutput(output.toString())
+                                                        releaseTaskBuffer(task)
                                                     }
                                                     output.clear()
                                                     lastLineWasEmpty = false
@@ -220,6 +309,7 @@
                                                                 outputHandler.invoke(output.toString(), task)
                                                             } finally {
                                                                 pendingOutputHandlers.decrementAndGet()
+                                                                releaseTaskBuffer(task)
                                                             }
                                                             output.clear()
                                                             lastLineWasEmpty = false
@@ -252,9 +342,11 @@
                                                 outputHandler.invoke(output.toString(), task)
                                             } finally {
                                                 pendingOutputHandlers.decrementAndGet()
+                                                releaseTaskBuffer(task)
                                             }
                                         } else {
                                             printOutput(output.toString())
+                                            releaseTaskBuffer(task)
                                         }
                                     }
                                 } catch (e: Exception) {
@@ -375,21 +467,30 @@
     }
 
     fun pushToQueue(text: String, docId: String? = null, entryPath: String? = null) {
+        var task: AnnotationTask? = null
         try {
-            LOGGER.fine("pushToQueue called: text length=${text.length}, docId=$docId, entryPath=$entryPath")
-            queue.put(AnnotationTask(text, docId, entryPath))
+            task = newTask(text, docId, entryPath)
+            LOGGER.fine(
+                "pushToQueue called: text length=${text.length}, docId=$docId, entryPath=$entryPath, " +
+                    "queueSize=${queue.size}/$maxQueuedTasks, buffered=${maxBufferedTaskUnits - bufferedTaskPermits.availablePermits()}/$maxBufferedTaskUnits"
+            )
+            queue.put(task)
         } catch (e: InterruptedException) {
             Thread.currentThread().interrupt()
+            releaseTaskBuffer(task)
             LOGGER.warning("Interrupted while trying to push text to queue.")
         }
     }
 
     fun pushToQueue(texts: List<String>) {
         texts.forEach { text ->
+            var task: AnnotationTask? = null
             try {
-                queue.put(AnnotationTask(text, null, null))
+                task = newTask(text, null, null)
+                queue.put(task)
             } catch (e: InterruptedException) {
                 Thread.currentThread().interrupt()
+                releaseTaskBuffer(task)
                 LOGGER.warning("Interrupted while trying to push texts to queue. Some texts may not have been added.")
                 return // Exit early if interrupted
             }
@@ -405,7 +506,7 @@
         // to ensure we send enough EOF markers even if some threads haven't started yet
         for (i in 0 until numWorkers) {
             try {
-                queue.put(AnnotationTask("#eof", null, null))
+                queue.put(AnnotationTask("#eof", null, null, 0))
                 LOGGER.info("Sent EOF marker ${i+1}/$numWorkers to queue")
             } catch (e: InterruptedException) {
                 Thread.currentThread().interrupt()
diff --git a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXmlTool.kt b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXmlTool.kt
index e7583d7..67e26de 100644
--- a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXmlTool.kt
+++ b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXmlTool.kt
@@ -4015,7 +4015,7 @@
         LOGGER.fine("parseAndWriteAnnotatedConllu called with ${annotatedConllu.length} chars, task=$task")
 
         val docId = task?.docId
-        val entryPathAndFoundry = task?.entryPath?.split("|") ?: listOf(null, null)
+        val entryPathAndFoundry = task?.entryPath?.split('|', limit = 2) ?: listOf(null, null)
         val entryPath = entryPathAndFoundry.getOrNull(0)
         val foundry = entryPathAndFoundry.getOrNull(1) ?: "base"
 
@@ -4025,7 +4025,6 @@
         }
 
         val morphoSpans = mutableMapOf<String, MorphoSpan>()
-        val lines = annotatedConllu.lines()
         var currentStartOffsets: List<Int>? = null
         var currentEndOffsets: List<Int>? = null
         var tokenIndexInSentence = 0
@@ -4034,7 +4033,7 @@
         var sentenceEndOffset: Int? = null
         var extractedFoundry: String? = null
 
-        for (line in lines) {
+        for (line in annotatedConllu.lineSequence()) {
             when {
                 line.startsWith("# foundry =") -> {
                     val foundryStr = line.substring("# foundry =".length).trim()
@@ -5846,4 +5845,3 @@
     try { Locale.setDefault(Locale.ROOT) } catch (_: Exception) {}
     return CommandLine(KorapXmlTool()).execute(*args)
 }
-
diff --git a/app/src/test/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPoolTest.kt b/app/src/test/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPoolTest.kt
new file mode 100644
index 0000000..156a0fb
--- /dev/null
+++ b/app/src/test/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPoolTest.kt
@@ -0,0 +1,62 @@
+package de.ids_mannheim.korapxmltools
+
+import org.junit.Test
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.logging.Logger
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
+import kotlin.test.assertEquals
+
+class AnnotationWorkerPoolTest {
+
+    @Test
+    fun defaultBufferedTaskUnitsScaleWithKorapXmlToolXmx() {
+        val smallHeapUnits = defaultBufferedTaskUnits(
+            numWorkers = 16,
+            env = mapOf("KORAPXMLTOOL_XMX" to "4g"),
+            runtimeMaxBytes = 4L * 1024 * 1024 * 1024
+        )
+        val largeHeapUnits = defaultBufferedTaskUnits(
+            numWorkers = 16,
+            env = mapOf("KORAPXMLTOOL_XMX" to "100g"),
+            runtimeMaxBytes = 100L * 1024 * 1024 * 1024
+        )
+
+        assertEquals(512, smallHeapUnits, "4g heap should keep enough buffered work for 16 workers")
+        assertTrue(largeHeapUnits >= 3200, "100g heap should allow a much larger buffered-text budget")
+        assertTrue(largeHeapUnits > smallHeapUnits, "Larger KORAPXMLTOOL_XMX should increase buffered task units")
+        assertEquals(4096, defaultQueuedTasks(16, largeHeapUnits), "Queue size should scale up but stay capped")
+    }
+
+    @Test
+    fun pushToQueueBlocksWhenQueueCapacityIsReached() {
+        val pool = AnnotationWorkerPool(
+            command = "cat",
+            numWorkers = 0,
+            LOGGER = Logger.getLogger("AnnotationWorkerPoolTest"),
+            maxQueuedTasks = 1
+        )
+
+        val firstQueued = CountDownLatch(1)
+        val secondQueued = AtomicBoolean(false)
+
+        val producer = Thread {
+            pool.pushToQueue("first")
+            firstQueued.countDown()
+            pool.pushToQueue("second")
+            secondQueued.set(true)
+        }
+
+        producer.start()
+
+        assertTrue(firstQueued.await(2, TimeUnit.SECONDS), "First task should be queued quickly")
+        Thread.sleep(200)
+        assertFalse(secondQueued.get(), "Second task should block once the bounded queue is full")
+
+        producer.interrupt()
+        producer.join(2000)
+        assertFalse(producer.isAlive, "Producer thread should stop after interruption")
+    }
+}