totalngram: Support > 1-grams
diff --git a/src/main/java/org/ids_mannheim/SlidingWindowQueue.java b/src/main/java/org/ids_mannheim/SlidingWindowQueue.java
new file mode 100644
index 0000000..4f07362
--- /dev/null
+++ b/src/main/java/org/ids_mannheim/SlidingWindowQueue.java
@@ -0,0 +1,23 @@
+package org.ids_mannheim;
+
+import java.util.LinkedList;
+import java.util.function.Consumer;
+
+public class SlidingWindowQueue extends LinkedList {
+ private final int maxSize;
+ private final Consumer<String> flush;
+
+ public SlidingWindowQueue(int size, Consumer<String> flush) {
+ this.maxSize = size;
+ this.flush = flush;
+ }
+
+ public boolean add(String k) {
+ boolean r = super.add(k);
+ if (size() == maxSize) {
+ this.flush.accept(String.join(" ", this));
+ remove(0);
+ }
+ return r;
+ }
+}
diff --git a/src/main/java/org/ids_mannheim/TotalNGram.java b/src/main/java/org/ids_mannheim/TotalNGram.java
index 6db5308..a56580f 100644
--- a/src/main/java/org/ids_mannheim/TotalNGram.java
+++ b/src/main/java/org/ids_mannheim/TotalNGram.java
@@ -34,6 +34,11 @@
int max_threads = MAX_THREADS;
@SuppressWarnings("CanBeFinal")
+ @CommandLine.Option(names = {"-n",
+ "--ngram-size"}, description = "n-gram size (default: ${DEFAULT-VALUE})")
+ int ngram_size = 1;
+
+ @SuppressWarnings("CanBeFinal")
@CommandLine.Option(names = {"-f",
"--folds"}, description = "number of folds (default: ${DEFAULT-VALUE})")
int FOLDS = 10;
@@ -55,7 +60,7 @@
BlockingQueue<Integer> queue = new LinkedBlockingQueue<>(inputFiles.size());
ExecutorService es = Executors.newCachedThreadPool();
int threads = Math.min(max_threads, inputFiles.size());
- IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, FOLDS, map, etaPrinter)));
+ IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, FOLDS, map, etaPrinter)));
queue.addAll(IntStream.range(0, inputFiles.size()).boxed().collect(Collectors.toList()));
IntStream.range(0, threads).forEach(unused -> {
try {
diff --git a/src/main/java/org/ids_mannheim/Worker.java b/src/main/java/org/ids_mannheim/Worker.java
index 0619233..5bf951b 100644
--- a/src/main/java/org/ids_mannheim/Worker.java
+++ b/src/main/java/org/ids_mannheim/Worker.java
@@ -19,11 +19,13 @@
private final ConcurrentHashMap<String, FoldedEntry> map;
private final int folds;
private final ETAPrinter etaPrinter;
+ private final int ngram_size;
- public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int folds, ConcurrentHashMap<String, FoldedEntry> map, ETAPrinter etaPrinter) {
+ public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int folds, ConcurrentHashMap<String, FoldedEntry> map, ETAPrinter etaPrinter) {
this.queue = queue;
this.fnames = fnames;
this.map = map;
+ this.ngram_size = ngram_size;
this.folds = folds;
this.etaPrinter = etaPrinter;
}
@@ -45,19 +47,24 @@
BufferedReader in = new BufferedReader(new InputStreamReader(p.getInputStream()));
String line;
int fold = -1;
+ SlidingWindowQueue slidingWindowQueue = null;
while ((line = in.readLine()) != null) {
if (line.startsWith("#")) {
Matcher matcher = new_text_pattern.matcher(line);
if (matcher.find()) {
fold = Math.abs(matcher.group(1).hashCode()) % folds +1;
}
+ int finalFold = fold;
+ slidingWindowQueue = new SlidingWindowQueue(ngram_size, s -> {
+ FoldedEntry.incr(map, s, finalFold);
+ });
continue;
}
String[] strings = line.split("\\s+");
if (strings.length < 4) {
continue;
}
- FoldedEntry.incr(map, strings[1], fold);
+ slidingWindowQueue.add(strings[1]);
}
etaPrinter.update(file_size);
index = queue.take();