totalngrams: Calculate only one fold at a time
diff --git a/pom.xml b/pom.xml
index f7ecbdd..06d16eb 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
<groupId>groupId</groupId>
<artifactId>nGrammFoldCount</artifactId>
- <version>1.1</version>
+ <version>1.3</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
diff --git a/src/main/java/org/ids_mannheim/FoldedEntry.java b/src/main/java/org/ids_mannheim/FoldedEntry.java
index 028561a..68759f3 100644
--- a/src/main/java/org/ids_mannheim/FoldedEntry.java
+++ b/src/main/java/org/ids_mannheim/FoldedEntry.java
@@ -1,7 +1,10 @@
package org.ids_mannheim;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerArray;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.LongAdder;
import java.util.stream.IntStream;
public class FoldedEntry implements Comparable<FoldedEntry> {
@@ -25,6 +28,16 @@
}
}
+ public static void incr(ConcurrentHashMap<String, AtomicInteger> map, String ngram) {
+ map.compute(ngram, (key, value) -> {
+ if (value == null) {
+ value = new AtomicInteger();
+ }
+ value.incrementAndGet();
+ return value;
+ });
+ }
+
public static void incr(ConcurrentHashMap<String, FoldedEntry> map, String ngram, int fold) {
map.compute(ngram, (key, value) -> {
if (value == null) {
diff --git a/src/main/java/org/ids_mannheim/FreqListEntryComparator.java b/src/main/java/org/ids_mannheim/FreqListEntryComparator.java
index fd5600d..412ee2d 100644
--- a/src/main/java/org/ids_mannheim/FreqListEntryComparator.java
+++ b/src/main/java/org/ids_mannheim/FreqListEntryComparator.java
@@ -3,11 +3,11 @@
import java.util.Comparator;
import java.util.Map;
-public class FreqListEntryComparator<K extends Comparable<? super K>,
- V extends Comparable<? super V>>
- implements Comparator<Map.Entry<K, V>> {
+public class FreqListEntryComparator<String extends Comparable<? super String>,
+ AtomicInteger extends Comparable<? super AtomicInteger>>
+ implements Comparator<Map.Entry<String, AtomicInteger>> {
- public int compare(Map.Entry<K, V> a, Map.Entry<K, V> b) {
+ public int compare(Map.Entry<String, AtomicInteger> a, Map.Entry<String, AtomicInteger> b) {
int cmp1 = b.getValue().compareTo(a.getValue());
if (cmp1 != 0) {
return cmp1;
@@ -16,4 +16,5 @@
}
}
+
}
\ No newline at end of file
diff --git a/src/main/java/org/ids_mannheim/SlidingWindowQueue.java b/src/main/java/org/ids_mannheim/SlidingWindowQueue.java
index b5225e3..a0a2cf8 100644
--- a/src/main/java/org/ids_mannheim/SlidingWindowQueue.java
+++ b/src/main/java/org/ids_mannheim/SlidingWindowQueue.java
@@ -8,7 +8,7 @@
public int fold;
interface Increaser {
- void accept(String s, int fold);
+ void accept(String s);
}
public SlidingWindowQueue(int size, Increaser flush) {
diff --git a/src/main/java/org/ids_mannheim/TotalNGram.java b/src/main/java/org/ids_mannheim/TotalNGram.java
index fb0ea3e..0e4a703 100644
--- a/src/main/java/org/ids_mannheim/TotalNGram.java
+++ b/src/main/java/org/ids_mannheim/TotalNGram.java
@@ -7,27 +7,24 @@
import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
-import java.nio.file.AccessDeniedException;
-import java.nio.file.FileAlreadyExistsException;
-import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Locale;
import java.util.concurrent.*;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.FileHandler;
import java.util.logging.Logger;
import java.util.logging.SimpleFormatter;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
-import java.util.zip.GZIPOutputStream;
@CommandLine.Command(mixinStandardHelpOptions = true,
name = "totalngram", description = "add ngram counts from KorAP-XML or CoNLL-U files")
public class TotalNGram implements Callable<Integer> {
+ static public final Logger logger;
private static final int MAX_THREADS = Runtime.getRuntime().availableProcessors() * 2 / 3;
private static final String DEFAULT_LOGFILE = "totalngram.log";
- static public final Logger logger;
static {
String path = TotalNGram.class.getClassLoader()
@@ -68,9 +65,15 @@
int ngram_size = 1;
@SuppressWarnings("CanBeFinal")
- @CommandLine.Option(names = {"-f",
+ @CommandLine.Option(names = {"-F",
"--folds"}, description = "number of folds (default: ${DEFAULT-VALUE})")
int FOLDS = 10;
+
+ @SuppressWarnings("CanBeFinal")
+ @CommandLine.Option(names = {"-f",
+ "--fold"}, description = "current folds (default: ${DEFAULT-VALUE})")
+ int fold = 1;
+
private Progressbar etaPrinter;
public TotalNGram() {
@@ -114,8 +117,7 @@
}
- FoldedEntry.setFolds(FOLDS);
- ConcurrentHashMap<String, FoldedEntry> map = new ConcurrentHashMap<>();
+ ConcurrentHashMap<String, AtomicInteger> map = new ConcurrentHashMap<>();
long totalFilesSizes = inputFiles.parallelStream().mapToLong(fname -> new File(fname).length()).sum();
etaPrinter = new Progressbar(totalFilesSizes);
BlockingQueue<Integer> queue = new LinkedBlockingQueue<>(inputFiles.size());
@@ -125,8 +127,9 @@
max_threads = workerNodePool.size;
}
int threads = Math.min(max_threads, inputFiles.size());
+ logger.info("Processing fold " + fold + "/" + FOLDS);
logger.info("Using " + threads + " threads");
- IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, FOLDS, map, workerNodePool, etaPrinter, logger)));
+ IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, fold, FOLDS, map, workerNodePool, etaPrinter, logger)));
queue.addAll(IntStream.range(0, inputFiles.size()).boxed().collect(Collectors.toList()));
IntStream.range(0, threads).forEach(unused -> {
try {
@@ -141,15 +144,16 @@
logger.info("Sorting and writing frequency table.");
System.err.println("Sorting and writing frequency table.");
map.entrySet().parallelStream()
- .sorted(new FreqListEntryComparator<>())
- .forEachOrdered(entry -> output_stream.println(entry.getKey() + entry.getValue().toString()));
- logger.info("Calculating column sums.");
- System.err.println("Calculating column sums.");
- IntStream.rangeClosed(1, FOLDS)
- .parallel()
- .forEachOrdered(i -> output_stream.print("\t" + Long.toUnsignedString(map.values()
- .parallelStream().mapToLong(e -> Integer.toUnsignedLong(e.count.get(i))).sum())));
- output_stream.println("\t" + Long.toUnsignedString(map.values().parallelStream().mapToLong(e -> Integer.toUnsignedLong(e.count.get(0))).sum()));
+ .sorted((a, b) -> {
+ int cmp1 = Integer.compareUnsigned(b.getValue().get(), a.getValue().get());
+ if (cmp1 != 0) {
+ return cmp1;
+ } else {
+ return a.getKey().compareTo(b.getKey());
+ }
+ })
+ .forEachOrdered(entry -> output_stream.println(entry.getKey() + "\t" + entry.getValue().toString()));
+ logger.info("Finished.");
output_stream.close();
return null;
}
diff --git a/src/main/java/org/ids_mannheim/Utils.java b/src/main/java/org/ids_mannheim/Utils.java
index 3ac2966..6af95f0 100644
--- a/src/main/java/org/ids_mannheim/Utils.java
+++ b/src/main/java/org/ids_mannheim/Utils.java
@@ -38,5 +38,9 @@
}
return f;
}
+
+ public static int getFoldFromTextID(String id, int max_fold) {
+ return Math.abs(id.hashCode() % max_fold);
+ }
}
diff --git a/src/main/java/org/ids_mannheim/Worker.java b/src/main/java/org/ids_mannheim/Worker.java
index e1d052b..18f993c 100644
--- a/src/main/java/org/ids_mannheim/Worker.java
+++ b/src/main/java/org/ids_mannheim/Worker.java
@@ -7,6 +7,7 @@
import java.util.ArrayList;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -18,15 +19,16 @@
private static final int MAX_RETRIES = 10;
private final ArrayList<String> fnames;
private final BlockingQueue<Integer> queue;
- private final ConcurrentHashMap<String, FoldedEntry> map;
+ private final ConcurrentHashMap<String, AtomicInteger> map;
private final int folds;
private final Progressbar etaPrinter;
private final int ngram_size;
+ private final int target_fold;
private final Logger logger;
private final WorkerNodePool pool;
- public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int folds,
- ConcurrentHashMap<String, FoldedEntry> map,
+ public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int target_fold, int folds,
+ ConcurrentHashMap<String, AtomicInteger> map,
WorkerNodePool pool,
Progressbar etaPrinter, Logger logger) {
this.queue = queue;
@@ -34,6 +36,7 @@
this.map = map;
this.ngram_size = ngram_size;
this.folds = folds;
+ this.target_fold = target_fold;
this.pool = pool;
this.etaPrinter = etaPrinter;
this.logger = logger;
@@ -44,7 +47,7 @@
try {
int index = queue.take();
int retries = MAX_RETRIES;
- SlidingWindowQueue slidingWindowQueue = new SlidingWindowQueue(ngram_size, (s, f) -> FoldedEntry.incr(map, s, f));
+ SlidingWindowQueue slidingWindowQueue = new SlidingWindowQueue(ngram_size, s -> FoldedEntry.incr(map, s));
while (index >= 0) {
String fname = fnames.get(index);
File current_file = new File(fname);
@@ -65,18 +68,20 @@
if (line.startsWith("#")) {
Matcher matcher = new_text_pattern.matcher(line);
if (matcher.find()) {
- fold = Math.abs(matcher.group(1).hashCode()) % folds + 1;
+ fold = Utils.getFoldFromTextID(matcher.group(1), folds + 1);
+ texts++;
+ if(fold == target_fold) {
+ slidingWindowQueue.reset(fold);
+ }
}
- slidingWindowQueue.reset(fold);
- texts++;
- continue;
+ } else if (fold == target_fold) {
+ String[] strings = line.split("\\s+");
+ if (strings.length < 4) {
+ continue;
+ }
+ //noinspection ConstantConditions
+ slidingWindowQueue.add(strings[1]);
}
- String[] strings = line.split("\\s+");
- if (strings.length < 4) {
- continue;
- }
- //noinspection ConstantConditions
- slidingWindowQueue.add(strings[1]);
}
pool.markFree(poolIndex);
if (texts > 0) {