Use coroutines instead of threads for annotation pipes

Change-Id: I7e17b731f26524ecdaf0347ea89368efb6c3eb93
diff --git a/app/build.gradle b/app/build.gradle
index dbce1b5..7337882 100644
--- a/app/build.gradle
+++ b/app/build.gradle
@@ -29,6 +29,7 @@
 
     // Use the Kotlin JDK 8 standard library.
     implementation 'org.jetbrains.kotlin:kotlin-stdlib'
+    implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.0'
 
     // This dependency is used by the application.
     implementation 'com.google.guava:guava:33.0.0-jre'
diff --git a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
index f4f2bba..7de7cf1 100644
--- a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
+++ b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/AnnotationWorkerPool.kt
@@ -1,91 +1,137 @@
 package de.ids_mannheim.korapxmltools
 
+import kotlinx.coroutines.*
 import java.io.*
+import java.lang.Thread.currentThread
 import java.lang.Thread.sleep
 import java.util.concurrent.BlockingQueue
 import java.util.concurrent.LinkedBlockingQueue
 import java.util.concurrent.TimeUnit
 import java.util.logging.Logger
 
-class WorkerPool(private val command: String, private val numWorkers: Int, private val LOGGER: Logger) {
+private const val BUFFER_SIZE = 10000000
+private const val HIGH_WATERMARK = 1000000
+
+class AnnotationWorkerPool(
+    private val command: String,
+    private val numWorkers: Int,
+    private val LOGGER: Logger
+) {
     private val queue: BlockingQueue<String> = LinkedBlockingQueue()
     private val threads = mutableListOf<Thread>()
+    private var threadCount = 0
+
     init {
         openWorkerPool()
-        LOGGER.info("Worker pool opened")
+        LOGGER.info("Annotation worker pool with ${numWorkers} threads opened")
     }
 
     private fun openWorkerPool() {
-           repeat(numWorkers) {
+        repeat(numWorkers) {
             Thread {
-                    try {
-                        threads.add(Thread.currentThread())
-                        val process = ProcessBuilder("/bin/sh", "-c", command)
-                            //.directory(File("/tmp"))
-                            .redirectOutput(ProcessBuilder.Redirect.PIPE)
-                            .redirectInput(ProcessBuilder.Redirect.PIPE)
-                            .redirectError(ProcessBuilder.Redirect.INHERIT)
-                            //.redirectErrorStream(true) // Merges stderr into stdout
-                            .start()
-                        process.outputStream.buffered(1000000)
-                        process.inputStream.buffered(1000000)
-                        val outputStreamWriter = process.outputStream.bufferedWriter(Charsets.UTF_8)
-                        val output = StringBuilder()
+                try {
+                    threads.add(currentThread())
+                    threadCount++
+                    val process = ProcessBuilder("/bin/sh", "-c", command)
+                        //.directory(File("/tmp"))
+                        .redirectOutput(ProcessBuilder.Redirect.PIPE).redirectInput(ProcessBuilder.Redirect.PIPE)
+                        .redirectError(ProcessBuilder.Redirect.INHERIT)
+                        //.redirectErrorStream(true) // Merges stderr into stdout
+                        .start()
+                    if (process.outputStream == null) {
+                        LOGGER.severe("Worker $it failed to open pipe '$command'")
+                        return@Thread
+                    }
+                    process.outputStream.buffered(BUFFER_SIZE)
+                    process.inputStream.buffered(BUFFER_SIZE)
 
+                    val coroutineScope = CoroutineScope(Dispatchers.IO)
+                    var inputGotEof = false
+                    var readBytes = 0
+                    var writtenBytes = 0
+
+                    coroutineScope.launch {
+                        val outputStreamWriter = OutputStreamWriter(process.outputStream)
                         while (true) {
                             val text = queue.poll(5, TimeUnit.SECONDS)
                             if (text == "#eof" || text == null) {
                                 outputStreamWriter.write("\n# eof\n")
-                                outputStreamWriter.flush()
+                                outputStreamWriter.close()
                                 LOGGER.info("Worker $it received eof")
                                 break
                             }
-
-                            text.split(Regex("\n\n")).forEach {
-                                outputStreamWriter.write(it + "\n\n")
+                            try {
+                                outputStreamWriter.write(text + "\n# eot\n")
+                                /*text.split("\n\n").forEach {
+                                  outputStreamWriter.write(it + "\n\n")
+                                }*/
                                 outputStreamWriter.flush()
-                                readAndPrintAvailable(process, output)
+                                writtenBytes += text.length
+                            } catch (e: IOException) {
+                                LOGGER.severe("Worker $it failed to write to process: ${e.message}")
+                                threads.remove(currentThread())
+                                threadCount--
+                                return@launch //break
+                            }
+
+                        }
+
+                    }
+
+                    coroutineScope.launch {
+                        val output = StringBuilder()
+                        while (!inputGotEof && process.isAlive) {
+                            process.inputStream.bufferedReader().useLines { lines ->
+                                lines.forEach { line ->
+                                    when (line) {
+                                        "# eof" -> {
+                                            LOGGER.info("Worker $it got EOF in output")
+                                            inputGotEof = true;
+                                            return@forEach }
+                                        "# eot" -> {
+                                            printOutput(output.toString())
+                                            output.clear() }
+                                        else -> { output.append(line, "\n")
+                                            readBytes += line.length +1 }
+                                    }
+                                }
+                                printOutput(output.toString())
+                                output.clear()
+                                if (!inputGotEof) {
+                                    LOGGER.info("Worker $it waiting for more output")
+                                    sleep(10)
+                                }
                             }
                         }
 
-                        process.outputStream.close()
-                        while(process.isAlive && output.indexOf("# eof\n") == -1) {
-                            readAndPrintAvailable(process, output)
-                        }
-                        LOGGER.info("Worker $it got eof in output")
-                        output.append(process.inputStream.bufferedReader(Charsets.UTF_8).readText())
-                        synchronized(System.out) {
-                            print(output.replace(Regex("\\s*\n# eof\n\\s*"), ""))
-                        }
-
-                        process.inputStream.close()
-
-                    } catch (e: IOException) {
-                        e.printStackTrace()
-                        LOGGER.warning("Worker $it failed: ${e.message}")
-                        threads.remove(Thread.currentThread())
                     }
+                    //while (!inputGotEof && process.isAlive) {
+                    //    LOGGER.info("Worker $it waiting for EOF output to finish")
+                    //    sleep(1000)
+                   // }
+                    //outputStreamWriter.close()
+                    process.waitFor()
+                    LOGGER.info("Worker $it finished")
 
+
+                } catch (e: IOException) {
+                    e.printStackTrace()
+                    LOGGER.warning("Worker $it failed: ${e.message}")
+                    threads.remove(currentThread())
+                }
             }.start()
-
-
         }
     }
 
-    private fun readAndPrintAvailable(process: Process, output: StringBuilder) {
-        if (process.inputStream.available() > 0) {
-            val readBytes = ByteArray(process.inputStream.available())
-            process.inputStream.read(readBytes)
-            output.append(String(readBytes))
-            val eotOffset = output.lastIndexOf("# eot\n")
-            if (eotOffset > -1) {
-                synchronized(System.out) {
-                    print(output.substring(0, eotOffset).replace(Regex("\n# eot\n\\s*"), ""))
-                }
-                output.delete(0, eotOffset + 6)
+
+    suspend fun printOutput(output: String) {
+        synchronized(System.out) {
+            try {
+                System.out.write(output.toByteArray())
+            } catch (e: IOException) {
+                LOGGER.severe("Failed to write to stdout: ${e.message}")
             }
-        } else {
-            sleep(1)
+            //  println(output)
         }
     }
 
@@ -104,8 +150,9 @@
     }
 
     fun close() {
-        var n = threads.size
-        while(n > 0) {
+        var n = threadCount
+        LOGGER.info("Closing worker pool with $n threads")
+        while (n > 0) {
             if (queue.offer("#eof")) {
                 n--
             } else {
@@ -113,7 +160,8 @@
                 sleep(100)
             }
         }
-        waitForWorkersToFinish()
+        if (threadCount > 0)
+            waitForWorkersToFinish()
     }
 
     private fun waitForWorkersToFinish() {
@@ -129,11 +177,11 @@
 fun main() {
     val command = "cat"
     val numWorkers = 3
-    val workerPool = WorkerPool(command, numWorkers, Logger.getLogger("de.ids_mannheim.korapxmltools.WorkerPool") )
+    val annotationWorkerPool = AnnotationWorkerPool(command, numWorkers, Logger.getLogger("de.ids_mannheim.korapxmltools.WorkerPool"))
 
     val texts = listOf("The", "World", "This", "Is", "A", "Test")
 
-    workerPool.pushToQueue(texts)
+    annotationWorkerPool.pushToQueue(texts)
 
-    workerPool.close()
+    annotationWorkerPool.close()
 }
diff --git a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXml2Conllu.kt b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXml2Conllu.kt
index de32dcb..b5be6ba 100644
--- a/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXml2Conllu.kt
+++ b/app/src/main/kotlin/de/ids_mannheim/korapxmltools/KorapXml2Conllu.kt
@@ -112,7 +112,7 @@
         paramLabel = "THREADS",
         description = ["Maximum number of threads to use. Default: ${"$"}{DEFAULT-VALUE}"]
     )
-    var threads: Int = Runtime.getRuntime().availableProcessors()
+    var threads: Int = Runtime.getRuntime().availableProcessors() / 2
 
     override fun call(): Int {
         LOGGER.level = try {
@@ -130,7 +130,7 @@
 
     private val LOGGER: Logger = Logger.getLogger(KorapXml2Conllu::class.java.name)
 
-    private var workerPool : WorkerPool? = null
+    private var annotationWorkerPool : AnnotationWorkerPool? = null
 
     val texts: ConcurrentHashMap<String, String> = ConcurrentHashMap()
     val sentences: ConcurrentHashMap<String, Array<Span>> = ConcurrentHashMap()
@@ -140,11 +140,12 @@
     val metadata: ConcurrentHashMap<String, Array<String>> = ConcurrentHashMap()
     val extraFeatures: ConcurrentHashMap<String, MutableMap<String, String>> = ConcurrentHashMap()
     var waitForMorpho: Boolean = false
+
     fun korapxml2conllu(args: Array<String>) {
         val executor: ExecutorService = Executors.newFixedThreadPool(threads)
 
         if (annotateWith != "") {
-            workerPool = WorkerPool(annotateWith, threads, LOGGER)
+            annotationWorkerPool = AnnotationWorkerPool(annotateWith, threads, LOGGER)
         }
 
         var zips: Array<String> = args
@@ -181,7 +182,7 @@
         }
         if (annotateWith.isNotEmpty()) {
             LOGGER.info("closing worker pool")
-            workerPool?.close()
+            annotationWorkerPool?.close()
         }
     }
 
@@ -402,7 +403,7 @@
         }
 
         if (annotateWith != "") {
-            workerPool?.pushToQueue(output.append("\n# eot\n").toString())
+            annotationWorkerPool?.pushToQueue(output.append("\n# eot\n").toString())
         } else {
             synchronized(System.out) {
                 println(output.toString())
diff --git a/app/src/test/kotlin/de/ids_mannheim/korapxmltools/KorapXml2ConlluTest.kt b/app/src/test/kotlin/de/ids_mannheim/korapxmltools/KorapXml2ConlluTest.kt
index b44e337..a34166c 100644
--- a/app/src/test/kotlin/de/ids_mannheim/korapxmltools/KorapXml2ConlluTest.kt
+++ b/app/src/test/kotlin/de/ids_mannheim/korapxmltools/KorapXml2ConlluTest.kt
@@ -162,7 +162,7 @@
             outContent.toString(),
             "axtomatiqxe"
         )
-        assertTrue("Annotated CoNLL-U should have at least as many lines as the original",
+        assertTrue("Annotated CoNLL-U should have at least as many lines as the original, but only has ${outContent.toString().count { it == '\n'}} lines",
             { outContent.toString().count { it == '\n'} >= 61511 })
     }