totalngram: compute marginals in parallel
diff --git a/src/main/java/org/ids_mannheim/ColumnAdder.java b/src/main/java/org/ids_mannheim/ColumnAdder.java
new file mode 100644
index 0000000..ace7329
--- /dev/null
+++ b/src/main/java/org/ids_mannheim/ColumnAdder.java
@@ -0,0 +1,19 @@
+package org.ids_mannheim;
+
+import java.util.concurrent.ConcurrentHashMap;
+
+public class ColumnAdder implements Runnable {
+ private final ConcurrentHashMap<String, FoldedEntry> map;
+ private final int fold;
+ private long[] columnSums;
+
+ public ColumnAdder(ConcurrentHashMap<String, FoldedEntry> map, int fold, long[] columnSums) {
+ this.map = map;
+ this.fold = fold;
+ this.columnSums = columnSums;
+ }
+ @Override
+ public void run() {
+ columnSums[fold] = map.values().stream().mapToLong(e -> e.count.get(fold)).sum();
+ }
+}
diff --git a/src/main/java/org/ids_mannheim/TotalNGram.java b/src/main/java/org/ids_mannheim/TotalNGram.java
index a56580f..c111c2d 100644
--- a/src/main/java/org/ids_mannheim/TotalNGram.java
+++ b/src/main/java/org/ids_mannheim/TotalNGram.java
@@ -4,6 +4,7 @@
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.concurrent.*;
import java.util.stream.Collectors;
@@ -60,28 +61,34 @@
BlockingQueue<Integer> queue = new LinkedBlockingQueue<>(inputFiles.size());
ExecutorService es = Executors.newCachedThreadPool();
int threads = Math.min(max_threads, inputFiles.size());
+ long[] columnsSums = new long[FOLDS+1];
IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, FOLDS, map, etaPrinter)));
queue.addAll(IntStream.range(0, inputFiles.size()).boxed().collect(Collectors.toList()));
- IntStream.range(0, threads).forEach(unused -> {
+ IntStream.range(0, threads).forEach(i -> {
try {
- queue.put(-1);
+ queue.put(-i-1);
} catch (InterruptedException e) {
e.printStackTrace();
}
});
es.shutdown();
boolean finished = es.awaitTermination(120, TimeUnit.HOURS);
+
+ ExecutorService columnAddExecutor = Executors.newCachedThreadPool();
+ IntStream.range(1, FOLDS+1).forEach(fold -> columnAddExecutor.execute(new ColumnAdder(map, fold, columnsSums)));
+ columnAddExecutor.shutdown();
map.entrySet().stream()
.sorted(Collections.reverseOrder(new ValueThenKeyComparator<>()))
.forEach(entry -> output_stream.println(entry.getKey() + entry.getValue().toString()));
+ finished = columnAddExecutor.awaitTermination(120, TimeUnit.HOURS);
+
IntStream.rangeClosed(1, FOLDS)
- .forEach(i -> output_stream.print("\t" + map.values()
- .parallelStream().mapToLong(e -> e.count.get(i)).sum()));
- output_stream.println("\t" + map.values().parallelStream().mapToLong(e -> e.count.get(0)).sum());
+ .forEach(i -> output_stream.print("\t" + columnsSums[i]));
+ output_stream.println("\t" + Arrays.stream(columnsSums).sum());
return null;
}
- public static void main(String[] args) throws FileNotFoundException {
+ public static void main(String[] args) {
System.exit(new CommandLine(new TotalNGram()).execute(args));
}
}
diff --git a/src/main/java/org/ids_mannheim/Worker.java b/src/main/java/org/ids_mannheim/Worker.java
index 5bf951b..1df9eba 100644
--- a/src/main/java/org/ids_mannheim/Worker.java
+++ b/src/main/java/org/ids_mannheim/Worker.java
@@ -21,7 +21,8 @@
private final ETAPrinter etaPrinter;
private final int ngram_size;
- public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int folds, ConcurrentHashMap<String, FoldedEntry> map, ETAPrinter etaPrinter) {
+ public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int folds, ConcurrentHashMap<String, FoldedEntry> map,
+ ETAPrinter etaPrinter) {
this.queue = queue;
this.fnames = fnames;
this.map = map;