Add exclude punctuation option
Change-Id: Ie90a59f77a92b8007af92411bcbaf00a8c910722
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2c443e2..70a7c18 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,7 @@
# Changelog
+- added option `--exclude-punctuation`
+
## [2.1.0] - 2022-12-01
- added script `GeneratePseudonymKey.groovy` to compute pseudonyms
- added script `Pseudonymize.groovy` to pseudonymize tokens (and lemmas)
diff --git a/Readme.md b/Readme.md
index 1afddce..dc74c12 100644
--- a/Readme.md
+++ b/Readme.md
@@ -16,6 +16,10 @@
<inputFiles>... input files
-d, --downcase Convert all token characters into lower case (default:
false)
+ --exclude-punctuation
+ Ignore all tokens tagged as punctuation (according to
+ STTS tags set, i.e. starting with '$') (default:
+ false)
-f, --fold=<fold> current fold (default: 1)
-F, --folds=<FOLDS> number of random folds (default: 1)
--force Force overwrite (default: false)
diff --git a/src/main/java/org/ids_mannheim/TotalNGrams.java b/src/main/java/org/ids_mannheim/TotalNGrams.java
index bbbee83..4efea34 100644
--- a/src/main/java/org/ids_mannheim/TotalNGrams.java
+++ b/src/main/java/org/ids_mannheim/TotalNGrams.java
@@ -98,6 +98,11 @@
+ PaddedSlidingWindowQueue.TEXT_END_SYMBOL + " symbols at text edges (default: ${DEFAULT-VALUE})")
boolean addPadding = false;
+ @SuppressWarnings("CanBeFinal")
+ @CommandLine.Option(names = {
+ "--exclude-punctuation" }, description = "Ignore all tokens tagged as punctuation (according to STTS tags set, i.e. starting with '$') (default: ${DEFAULT-VALUE})")
+ boolean excludePunctuation = false;
+
private Progressbar etaPrinter;
public TotalNGrams() {
@@ -174,7 +179,7 @@
logger.info("Processing fold " + fold + "/" + FOLDS);
logger.info("Using " + threads + " threads");
IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, fold, FOLDS,
- map, with_lemma_and_pos, downcase_tokens, workerNodePool, etaPrinter, logger, addPadding)));
+ map, with_lemma_and_pos, downcase_tokens, workerNodePool, etaPrinter, logger, addPadding, excludePunctuation)));
queue.addAll(IntStream.range(0, inputFiles.size()).boxed().collect(Collectors.toList()));
IntStream.range(0, threads).forEach(unused -> {
try {
diff --git a/src/main/java/org/ids_mannheim/Utils.java b/src/main/java/org/ids_mannheim/Utils.java
index e8413b5..8f4e801 100644
--- a/src/main/java/org/ids_mannheim/Utils.java
+++ b/src/main/java/org/ids_mannheim/Utils.java
@@ -55,5 +55,8 @@
.replaceAll("^(\\d+)\t>\t--\t[^\t]+", "$1\t>\t>\t\\$(");
}
+ public static boolean isPunctuation(String token, String lemma, String pos) {
+ return pos.startsWith("$") || token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">");
+ }
}
diff --git a/src/main/java/org/ids_mannheim/Worker.java b/src/main/java/org/ids_mannheim/Worker.java
index d571dc6..e5b1f9b 100644
--- a/src/main/java/org/ids_mannheim/Worker.java
+++ b/src/main/java/org/ids_mannheim/Worker.java
@@ -32,11 +32,12 @@
private final boolean addPadding;
private final DeterministicRandomProvider deterministicRandomProvider;
+ private final boolean excludePunctuation;
public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int target_fold, int folds,
ConcurrentHashMap<String, AtomicInteger> map,
boolean with_lemma_and_pos, boolean downcase_tokens, WorkerNodePool pool,
- Progressbar etaPrinter, Logger logger, boolean addPadding) {
+ Progressbar etaPrinter, Logger logger, boolean addPadding, boolean excludePunctuation) {
this.queue = queue;
this.fnames = fnames;
this.map = map;
@@ -49,6 +50,7 @@
this.deterministicRandomProvider = new DeterministicRandomProvider(folds);
this.downcase_tokens = downcase_tokens;
this.addPadding = addPadding;
+ this.excludePunctuation = excludePunctuation;
}
@Override
@@ -134,31 +136,33 @@
String token = ( downcase_tokens?
Utils.unEscapeEntities(strings[1]).toLowerCase(Locale.ROOT) :
Utils.unEscapeEntities(strings[1]));
- if (with_lemma_and_pos) {
- String lemma, pos;
+ if(!excludePunctuation || !Utils.isPunctuation(token, strings[2], strings[3])) {
+ if (with_lemma_and_pos) {
+ String lemma, pos;
- if (token.equals("\"") || token.equals("'")) {
- lemma = "\"";
- } else if (token.equals("&")) {
- lemma = "&";
- } else if (token.equals("<")) {
- lemma = "<";
- } else if (token.equals(">")) {
- lemma = ">";
+ if (token.equals("\"") || token.equals("'")) {
+ lemma = "\"";
+ } else if (token.equals("&")) {
+ lemma = "&";
+ } else if (token.equals("<")) {
+ lemma = "<";
+ } else if (token.equals(">")) {
+ lemma = ">";
+ } else {
+ lemma = strings[2];
+ }
+ if (token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">")) {
+ pos = "$(";
+ } else if (token.equals("&")) {
+ pos = "KON";
+ } else {
+ pos = strings[3];
+ }
+ //noinspection ConstantCondition
+ slidingWindowQueue.add(join("\t", token, lemma, pos));
} else {
- lemma = strings[2];
+ slidingWindowQueue.add(token);
}
- if (token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">")) {
- pos = "$(";
- } else if (token.equals("&")) {
- pos = "KON";
- } else {
- pos = strings[3];
- }
- //noinspection ConstantCondition
- slidingWindowQueue.add(join("\t", token, lemma, pos));
- } else {
- slidingWindowQueue.add(token);
}
}
}
diff --git a/src/test/java/org/ids_mannheim/UtilsTest.java b/src/test/java/org/ids_mannheim/UtilsTest.java
index 65dbca7..3d19ef9 100644
--- a/src/test/java/org/ids_mannheim/UtilsTest.java
+++ b/src/test/java/org/ids_mannheim/UtilsTest.java
@@ -35,4 +35,9 @@
assertEquals("55555\t&\t&\tKON\txxx", Utils.fixEscapedConlluEntities("55555\t&\t--\tNN\txxx"));
}
+ @Test
+ void isPunctuation() {
+ assert(Utils.isPunctuation(",", ",", "$,"));
+ assert(Utils.isPunctuation(",", ",", "$,"));
+ }
}
\ No newline at end of file
diff --git a/src/test/java/org/ids_mannheim/WorkerTest.java b/src/test/java/org/ids_mannheim/WorkerTest.java
index b72919f..f1ec290 100644
--- a/src/test/java/org/ids_mannheim/WorkerTest.java
+++ b/src/test/java/org/ids_mannheim/WorkerTest.java
@@ -73,7 +73,7 @@
false,
new WorkerNodePool(""),
new Progressbar(tempFile.length()),
- Logger.getLogger(TotalNGrams.class.getSimpleName()), false);
+ Logger.getLogger(TotalNGrams.class.getSimpleName()), false, false);
queue.add(0);
queue.add(-1);
@@ -87,6 +87,62 @@
}
@Test
+ void resultAndOutputAreCorrectWithExcludedPunctuation() throws IOException {
+ Map<String, Integer> gold = Map.of(
+ "und und KON Fluchen Fluchen NN", 1,
+ "Bürger Bürger NN sich sich PRF", 1,
+ "dieses dies PDAT würdigen würdig ADJA", 1,
+ "im in APPRART Kriegshandwerk Kriegshandwerk NN", 1,
+ "man man PIS nur nur ADV", 1,
+ ",\t,\t$,\taber\taber\tKON", 0,
+ ",\t,\t$,\tan\tan\tAPP", 0,
+ ",\t,\t$,\tdas\tdie\tPRE", 0,
+ ",\t,\t$,\tgab\tgeben\tVVF", 0,
+ ".\t.\t$.\tauf\tauf\tAPP", 0
+ );
+
+ File tempFile = File.createTempFile("goe_sample", ".conllu.gz");
+ tempFile.deleteOnExit();
+ try (FileOutputStream out = new FileOutputStream(tempFile)) {
+ IOUtils.copy(Objects.requireNonNull(Thread.currentThread().getContextClassLoader()
+ .getResourceAsStream("goe_sample.conllu.gz")), out);
+ }
+ ArrayList<String> fnames = new ArrayList<>();
+ fnames.add(tempFile.getAbsolutePath());
+ map = new ConcurrentHashMap<>();
+ LinkedBlockingQueue<Integer> queue = new LinkedBlockingQueue<>(2);
+ worker = new Worker(
+ queue,
+ fnames,
+ 2,
+ 8,
+ 10,
+ map,
+ true,
+ false,
+ new WorkerNodePool(""),
+ new Progressbar(tempFile.length()),
+ Logger.getLogger(TotalNGrams.class.getSimpleName()), false, true);
+
+ queue.add(0);
+ queue.add(-1);
+ worker.run();
+ gold.forEach((key, value) -> {
+ if (value == 0) {
+ assertNull(map.get(key));
+ } else {
+ AtomicInteger observed = map.get(key);
+ assertEquals(value, observed.intValue());
+ }
+ });
+
+ map.forEach((k,v) -> {
+ assert(!k.matches(".*\t$.\t.*"));
+ });
+ assertTrue(errContent.toString().contains("100%"));
+ }
+
+ @Test
void downcasedResultAndOutputAreCorrect() throws IOException {
Map<String, Integer> gold = Map.of(
"und und KON fluchen Fluchen NN", 1,
@@ -123,7 +179,7 @@
new WorkerNodePool(""),
new Progressbar(tempFile.length()),
Logger.getLogger(TotalNGrams.class.getSimpleName()),
- false);
+ false, false);
queue.add(0);
queue.add(-1);
@@ -169,7 +225,7 @@
new WorkerNodePool(""),
new Progressbar(tempFile.length()),
Logger.getLogger(TotalNGrams.class.getSimpleName()),
- false);
+ false, false);
queue.add(0);
queue.add(-1);
@@ -215,7 +271,7 @@
new WorkerNodePool(""),
new Progressbar(tempFile.length()),
Logger.getLogger(TotalNGrams.class.getSimpleName()),
- false);
+ false, false);
queue.add(0);
queue.add(-1);
@@ -266,7 +322,7 @@
new WorkerNodePool(""),
new Progressbar(tempFile.length()),
Logger.getLogger(TotalNGrams.class.getSimpleName()),
- with_padding);
+ with_padding, false);
queue.add(0);
queue.add(-1);