Add exclude punctuation option

Change-Id: Ie90a59f77a92b8007af92411bcbaf00a8c910722
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2c443e2..70a7c18 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,7 @@
 # Changelog
 
+- added option `--exclude-punctuation`
+
 ## [2.1.0] - 2022-12-01
 - added script `GeneratePseudonymKey.groovy` to compute pseudonyms
 - added script `Pseudonymize.groovy` to pseudonymize tokens (and lemmas)
diff --git a/Readme.md b/Readme.md
index 1afddce..dc74c12 100644
--- a/Readme.md
+++ b/Readme.md
@@ -16,6 +16,10 @@
       <inputFiles>...    input files
   -d, --downcase         Convert all token characters into lower case (default:
                            false)
+      --exclude-punctuation
+                         Ignore all tokens tagged as punctuation (according to
+                           STTS tags set, i.e. starting with '$') (default:
+                           false)
   -f, --fold=<fold>      current fold (default: 1)
   -F, --folds=<FOLDS>    number of random folds (default: 1)
       --force            Force overwrite (default: false)
diff --git a/src/main/java/org/ids_mannheim/TotalNGrams.java b/src/main/java/org/ids_mannheim/TotalNGrams.java
index bbbee83..4efea34 100644
--- a/src/main/java/org/ids_mannheim/TotalNGrams.java
+++ b/src/main/java/org/ids_mannheim/TotalNGrams.java
@@ -98,6 +98,11 @@
                     + PaddedSlidingWindowQueue.TEXT_END_SYMBOL + " symbols at text edges  (default: ${DEFAULT-VALUE})")
     boolean addPadding = false;
 
+    @SuppressWarnings("CanBeFinal")
+    @CommandLine.Option(names = {
+            "--exclude-punctuation" }, description = "Ignore all tokens tagged as punctuation (according to STTS tags set, i.e. starting with '$') (default: ${DEFAULT-VALUE})")
+    boolean excludePunctuation = false;
+
     private Progressbar etaPrinter;
 
     public TotalNGrams() {
@@ -174,7 +179,7 @@
         logger.info("Processing fold " + fold + "/" + FOLDS);
         logger.info("Using " + threads + " threads");
         IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, fold, FOLDS,
-                map, with_lemma_and_pos, downcase_tokens, workerNodePool, etaPrinter, logger, addPadding)));
+                map, with_lemma_and_pos, downcase_tokens, workerNodePool, etaPrinter, logger, addPadding, excludePunctuation)));
         queue.addAll(IntStream.range(0, inputFiles.size()).boxed().collect(Collectors.toList()));
         IntStream.range(0, threads).forEach(unused -> {
             try {
diff --git a/src/main/java/org/ids_mannheim/Utils.java b/src/main/java/org/ids_mannheim/Utils.java
index e8413b5..8f4e801 100644
--- a/src/main/java/org/ids_mannheim/Utils.java
+++ b/src/main/java/org/ids_mannheim/Utils.java
@@ -55,5 +55,8 @@
                 .replaceAll("^(\\d+)\t&gt;\t--\t[^\t]+", "$1\t>\t&gt;\t\\$(");
     }
 
+    public static boolean isPunctuation(String token, String lemma, String pos) {
+        return pos.startsWith("$") || token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">");
+    }
 }
 
diff --git a/src/main/java/org/ids_mannheim/Worker.java b/src/main/java/org/ids_mannheim/Worker.java
index d571dc6..e5b1f9b 100644
--- a/src/main/java/org/ids_mannheim/Worker.java
+++ b/src/main/java/org/ids_mannheim/Worker.java
@@ -32,11 +32,12 @@
     private final boolean addPadding;
 
     private final DeterministicRandomProvider deterministicRandomProvider;
+    private final boolean excludePunctuation;
 
     public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int target_fold, int folds,
                   ConcurrentHashMap<String, AtomicInteger> map,
                   boolean with_lemma_and_pos, boolean downcase_tokens, WorkerNodePool pool,
-                  Progressbar etaPrinter, Logger logger, boolean addPadding) {
+                  Progressbar etaPrinter, Logger logger, boolean addPadding, boolean excludePunctuation) {
         this.queue = queue;
         this.fnames = fnames;
         this.map = map;
@@ -49,6 +50,7 @@
         this.deterministicRandomProvider = new DeterministicRandomProvider(folds);
         this.downcase_tokens = downcase_tokens;
         this.addPadding = addPadding;
+        this.excludePunctuation = excludePunctuation;
     }
 
     @Override
@@ -134,31 +136,33 @@
                         String token = ( downcase_tokens?
                                 Utils.unEscapeEntities(strings[1]).toLowerCase(Locale.ROOT) :
                                 Utils.unEscapeEntities(strings[1]));
-                        if (with_lemma_and_pos) {
-                            String lemma, pos;
+                        if(!excludePunctuation || !Utils.isPunctuation(token, strings[2], strings[3])) {
+                            if (with_lemma_and_pos) {
+                                String lemma, pos;
 
-                            if (token.equals("\"") || token.equals("'")) {
-                                lemma = "\"";
-                            } else if (token.equals("&")) {
-                                lemma = "&amp;";
-                            } else if (token.equals("<")) {
-                                lemma = "&lt;";
-                            } else if (token.equals(">")) {
-                                lemma = "&gt;";
+                                if (token.equals("\"") || token.equals("'")) {
+                                    lemma = "\"";
+                                } else if (token.equals("&")) {
+                                    lemma = "&amp;";
+                                } else if (token.equals("<")) {
+                                    lemma = "&lt;";
+                                } else if (token.equals(">")) {
+                                    lemma = "&gt;";
+                                } else {
+                                    lemma = strings[2];
+                                }
+                                if (token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">")) {
+                                    pos = "$(";
+                                } else if (token.equals("&")) {
+                                    pos = "KON";
+                                } else {
+                                    pos = strings[3];
+                                }
+                                //noinspection ConstantCondition
+                                slidingWindowQueue.add(join("\t", token, lemma, pos));
                             } else {
-                                lemma = strings[2];
+                                slidingWindowQueue.add(token);
                             }
-                            if (token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">")) {
-                                pos = "$(";
-                            } else if (token.equals("&")) {
-                                pos = "KON";
-                            } else {
-                                pos = strings[3];
-                            }
-                            //noinspection ConstantCondition
-                            slidingWindowQueue.add(join("\t", token, lemma, pos));
-                        } else {
-                            slidingWindowQueue.add(token);
                         }
                     }
                 }
diff --git a/src/test/java/org/ids_mannheim/UtilsTest.java b/src/test/java/org/ids_mannheim/UtilsTest.java
index 65dbca7..3d19ef9 100644
--- a/src/test/java/org/ids_mannheim/UtilsTest.java
+++ b/src/test/java/org/ids_mannheim/UtilsTest.java
@@ -35,4 +35,9 @@
         assertEquals("55555\t&\t&amp;\tKON\txxx", Utils.fixEscapedConlluEntities("55555\t&amp;\t--\tNN\txxx"));
     }
 
+    @Test
+    void isPunctuation() {
+        assert(Utils.isPunctuation(",", ",", "$,"));
+        assert(Utils.isPunctuation(",", ",", "$,"));
+    }
 }
\ No newline at end of file
diff --git a/src/test/java/org/ids_mannheim/WorkerTest.java b/src/test/java/org/ids_mannheim/WorkerTest.java
index b72919f..f1ec290 100644
--- a/src/test/java/org/ids_mannheim/WorkerTest.java
+++ b/src/test/java/org/ids_mannheim/WorkerTest.java
@@ -73,7 +73,7 @@
                 false,
                 new WorkerNodePool(""),
                 new Progressbar(tempFile.length()),
-                Logger.getLogger(TotalNGrams.class.getSimpleName()), false);
+                Logger.getLogger(TotalNGrams.class.getSimpleName()), false, false);
 
         queue.add(0);
         queue.add(-1);
@@ -87,6 +87,62 @@
     }
 
     @Test
+    void resultAndOutputAreCorrectWithExcludedPunctuation() throws IOException {
+        Map<String, Integer> gold = Map.of(
+                "und	und	KON	Fluchen	Fluchen	NN", 1,
+                "Bürger	Bürger	NN	sich	sich	PRF", 1,
+                "dieses	dies	PDAT	würdigen	würdig	ADJA", 1,
+                "im	in	APPRART	Kriegshandwerk	Kriegshandwerk	NN", 1,
+                "man	man	PIS	nur	nur	ADV", 1,
+                ",\t,\t$,\taber\taber\tKON", 0,
+                ",\t,\t$,\tan\tan\tAPP", 0,
+                ",\t,\t$,\tdas\tdie\tPRE", 0,
+                ",\t,\t$,\tgab\tgeben\tVVF", 0,
+                ".\t.\t$.\tauf\tauf\tAPP", 0
+                );
+
+        File tempFile = File.createTempFile("goe_sample", ".conllu.gz");
+        tempFile.deleteOnExit();
+        try (FileOutputStream out = new FileOutputStream(tempFile)) {
+            IOUtils.copy(Objects.requireNonNull(Thread.currentThread().getContextClassLoader()
+                    .getResourceAsStream("goe_sample.conllu.gz")), out);
+        }
+        ArrayList<String> fnames = new ArrayList<>();
+        fnames.add(tempFile.getAbsolutePath());
+        map = new ConcurrentHashMap<>();
+        LinkedBlockingQueue<Integer> queue = new LinkedBlockingQueue<>(2);
+        worker = new Worker(
+                queue,
+                fnames,
+                2,
+                8,
+                10,
+                map,
+                true,
+                false,
+                new WorkerNodePool(""),
+                new Progressbar(tempFile.length()),
+                Logger.getLogger(TotalNGrams.class.getSimpleName()), false, true);
+
+        queue.add(0);
+        queue.add(-1);
+        worker.run();
+        gold.forEach((key, value) -> {
+            if (value == 0) {
+                assertNull(map.get(key));
+            } else {
+                AtomicInteger observed = map.get(key);
+                assertEquals(value, observed.intValue());
+            }
+        });
+
+        map.forEach((k,v) -> {
+            assert(!k.matches(".*\t$.\t.*"));
+        });
+        assertTrue(errContent.toString().contains("100%"));
+    }
+
+    @Test
     void downcasedResultAndOutputAreCorrect() throws IOException {
         Map<String, Integer> gold = Map.of(
                 "und	und	KON	fluchen	Fluchen	NN", 1,
@@ -123,7 +179,7 @@
                 new WorkerNodePool(""),
                 new Progressbar(tempFile.length()),
                 Logger.getLogger(TotalNGrams.class.getSimpleName()),
-                false);
+                false, false);
 
         queue.add(0);
         queue.add(-1);
@@ -169,7 +225,7 @@
                 new WorkerNodePool(""),
                 new Progressbar(tempFile.length()),
                 Logger.getLogger(TotalNGrams.class.getSimpleName()),
-                false);
+                false, false);
 
         queue.add(0);
         queue.add(-1);
@@ -215,7 +271,7 @@
                 new WorkerNodePool(""),
                 new Progressbar(tempFile.length()),
                 Logger.getLogger(TotalNGrams.class.getSimpleName()),
-                false);
+                false, false);
 
         queue.add(0);
         queue.add(-1);
@@ -266,7 +322,7 @@
                             new WorkerNodePool(""),
                             new Progressbar(tempFile.length()),
                             Logger.getLogger(TotalNGrams.class.getSimpleName()),
-                            with_padding);
+                            with_padding, false);
 
                     queue.add(0);
                     queue.add(-1);