totalngrams: allow using a worker node pool
diff --git a/src/main/java/org/ids_mannheim/TotalNGram.java b/src/main/java/org/ids_mannheim/TotalNGram.java
index 992b03c..752307c 100644
--- a/src/main/java/org/ids_mannheim/TotalNGram.java
+++ b/src/main/java/org/ids_mannheim/TotalNGram.java
@@ -40,6 +40,11 @@
int max_threads = MAX_THREADS;
@SuppressWarnings("CanBeFinal")
+ @CommandLine.Option(names = {"-p",
+ "--worker-pool"}, description = "Run preprocessing on extern hosts, e.g. '10*local,5*host1,3*smith@host2' (default: ${DEFAULT-VALUE})")
+ String worker_pool_specification = "";
+
+ @SuppressWarnings("CanBeFinal")
@CommandLine.Option(names = {"-n",
"--ngram-size"}, description = "n-gram size (default: ${DEFAULT-VALUE})")
int ngram_size = 1;
@@ -82,8 +87,12 @@
etaPrinter = new Progressbar(totalFilesSizes, "MB");
BlockingQueue<Integer> queue = new LinkedBlockingQueue<>(inputFiles.size());
ExecutorService es = Executors.newCachedThreadPool();
+ WorkerNodePool workerNodePool = new WorkerNodePool(worker_pool_specification);
+ if (!worker_pool_specification.equals("")) {
+ max_threads = workerNodePool.size;
+ }
int threads = Math.min(max_threads, inputFiles.size());
- IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, FOLDS, map, etaPrinter)));
+ IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, FOLDS, map, workerNodePool, etaPrinter)));
queue.addAll(IntStream.range(0, inputFiles.size()).boxed().collect(Collectors.toList()));
IntStream.range(0, threads).forEach(unused -> {
try {
diff --git a/src/main/java/org/ids_mannheim/Worker.java b/src/main/java/org/ids_mannheim/Worker.java
index 8ce5288..a82b0c7 100644
--- a/src/main/java/org/ids_mannheim/Worker.java
+++ b/src/main/java/org/ids_mannheim/Worker.java
@@ -7,6 +7,7 @@
import java.util.ArrayList;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ThreadLocalRandom;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -18,13 +19,17 @@
private final int folds;
private final Progressbar etaPrinter;
private final int ngram_size;
-
- public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int folds, ConcurrentHashMap<String, FoldedEntry> map, ETAPrinter etaPrinter) {
+ private WorkerNodePool pool;
+ public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int folds,
+ ConcurrentHashMap<String, FoldedEntry> map,
+ WorkerNodePool pool,
+ Progressbar etaPrinter) {
this.queue = queue;
this.fnames = fnames;
this.map = map;
this.ngram_size = ngram_size;
this.folds = folds;
+ this.pool = pool;
this.etaPrinter = etaPrinter;
}
@@ -35,11 +40,12 @@
while (index >= 0) {
String fname = fnames.get(index);
long file_size = new File(fname).length();
- System.err.println(Thread.currentThread().getName() + " - processing: " + fname);
+ int poolIndex = index % pool.size;
+ System.err.println(String.format("%4d/%4d %-10s %-10s", index, fnames.size(), pool.getHost(poolIndex), fname));
String[] cmd = {
"/bin/sh",
"-c",
- "/usr/local/kl/bin/korapxml2conllu " + fname
+ pool.getExec(poolIndex) + "/usr/local/kl/bin/korapxml2conllu " + fname
};
Process p = Runtime.getRuntime().exec(cmd);
BufferedReader in = new BufferedReader(new InputStreamReader(p.getInputStream()));
diff --git a/src/main/java/org/ids_mannheim/WorkerNodePool.java b/src/main/java/org/ids_mannheim/WorkerNodePool.java
new file mode 100644
index 0000000..255809d
--- /dev/null
+++ b/src/main/java/org/ids_mannheim/WorkerNodePool.java
@@ -0,0 +1,52 @@
+package org.ids_mannheim;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.stream.IntStream;
+
+public class WorkerNodePool {
+ ArrayList<String> execPool = new ArrayList<String>();
+ ArrayList<String> hostPool = new ArrayList<String>();
+ public final int size;
+
+ public String getExec(int i) {
+ return execPool.get(i);
+ }
+
+ public String getHost(int i) {
+ return hostPool.get(i);
+ }
+
+ private String hostToExec(String host) {
+ if(host.equals("") || host.matches("local")) {
+ return "nice -5 ";
+ } else {
+ return " ssh " + host + " nice -5 ";
+ }
+ }
+
+ public WorkerNodePool(String description) {
+ Arrays.stream(description.split(", *")).forEachOrdered(s -> {
+ String[] single = s.split("\\s*[*]\\s*");
+ final String host;
+ int procs = 0;
+ if (single.length == 1) {
+ procs = 1;
+ host = single[0];
+ } else if (single.length == 2) {
+ procs = Integer.valueOf(single[0]);
+ host = single[1];
+ } else {
+ host = "";
+ System.err.println("ERROR: Wrong worker node format: " + s);
+ System.exit(-1);
+ }
+ IntStream.range(0, procs).forEach(u -> {
+ hostPool.add(host);
+ execPool.add(hostToExec(host));
+ });
+
+ });
+ size = hostPool.size();
+ }
+}