Add exclude punctuation option

Change-Id: Ie90a59f77a92b8007af92411bcbaf00a8c910722
diff --git a/src/main/java/org/ids_mannheim/TotalNGrams.java b/src/main/java/org/ids_mannheim/TotalNGrams.java
index bbbee83..4efea34 100644
--- a/src/main/java/org/ids_mannheim/TotalNGrams.java
+++ b/src/main/java/org/ids_mannheim/TotalNGrams.java
@@ -98,6 +98,11 @@
                     + PaddedSlidingWindowQueue.TEXT_END_SYMBOL + " symbols at text edges  (default: ${DEFAULT-VALUE})")
     boolean addPadding = false;
 
+    @SuppressWarnings("CanBeFinal")
+    @CommandLine.Option(names = {
+            "--exclude-punctuation" }, description = "Ignore all tokens tagged as punctuation (according to STTS tags set, i.e. starting with '$') (default: ${DEFAULT-VALUE})")
+    boolean excludePunctuation = false;
+
     private Progressbar etaPrinter;
 
     public TotalNGrams() {
@@ -174,7 +179,7 @@
         logger.info("Processing fold " + fold + "/" + FOLDS);
         logger.info("Using " + threads + " threads");
         IntStream.range(0, threads).forEach(unused -> es.execute(new Worker(queue, inputFiles, ngram_size, fold, FOLDS,
-                map, with_lemma_and_pos, downcase_tokens, workerNodePool, etaPrinter, logger, addPadding)));
+                map, with_lemma_and_pos, downcase_tokens, workerNodePool, etaPrinter, logger, addPadding, excludePunctuation)));
         queue.addAll(IntStream.range(0, inputFiles.size()).boxed().collect(Collectors.toList()));
         IntStream.range(0, threads).forEach(unused -> {
             try {
diff --git a/src/main/java/org/ids_mannheim/Utils.java b/src/main/java/org/ids_mannheim/Utils.java
index e8413b5..8f4e801 100644
--- a/src/main/java/org/ids_mannheim/Utils.java
+++ b/src/main/java/org/ids_mannheim/Utils.java
@@ -55,5 +55,8 @@
                 .replaceAll("^(\\d+)\t>\t--\t[^\t]+", "$1\t>\t>\t\\$(");
     }
 
+    public static boolean isPunctuation(String token, String lemma, String pos) {
+        return pos.startsWith("$") || token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">");
+    }
 }
 
diff --git a/src/main/java/org/ids_mannheim/Worker.java b/src/main/java/org/ids_mannheim/Worker.java
index d571dc6..e5b1f9b 100644
--- a/src/main/java/org/ids_mannheim/Worker.java
+++ b/src/main/java/org/ids_mannheim/Worker.java
@@ -32,11 +32,12 @@
     private final boolean addPadding;
 
     private final DeterministicRandomProvider deterministicRandomProvider;
+    private final boolean excludePunctuation;
 
     public Worker(BlockingQueue<Integer> queue, ArrayList<String> fnames, int ngram_size, int target_fold, int folds,
                   ConcurrentHashMap<String, AtomicInteger> map,
                   boolean with_lemma_and_pos, boolean downcase_tokens, WorkerNodePool pool,
-                  Progressbar etaPrinter, Logger logger, boolean addPadding) {
+                  Progressbar etaPrinter, Logger logger, boolean addPadding, boolean excludePunctuation) {
         this.queue = queue;
         this.fnames = fnames;
         this.map = map;
@@ -49,6 +50,7 @@
         this.deterministicRandomProvider = new DeterministicRandomProvider(folds);
         this.downcase_tokens = downcase_tokens;
         this.addPadding = addPadding;
+        this.excludePunctuation = excludePunctuation;
     }
 
     @Override
@@ -134,31 +136,33 @@
                         String token = ( downcase_tokens?
                                 Utils.unEscapeEntities(strings[1]).toLowerCase(Locale.ROOT) :
                                 Utils.unEscapeEntities(strings[1]));
-                        if (with_lemma_and_pos) {
-                            String lemma, pos;
+                        if(!excludePunctuation || !Utils.isPunctuation(token, strings[2], strings[3])) {
+                            if (with_lemma_and_pos) {
+                                String lemma, pos;
 
-                            if (token.equals("\"") || token.equals("'")) {
-                                lemma = "\"";
-                            } else if (token.equals("&")) {
-                                lemma = "&amp;";
-                            } else if (token.equals("<")) {
-                                lemma = "&lt;";
-                            } else if (token.equals(">")) {
-                                lemma = "&gt;";
+                                if (token.equals("\"") || token.equals("'")) {
+                                    lemma = "\"";
+                                } else if (token.equals("&")) {
+                                    lemma = "&amp;";
+                                } else if (token.equals("<")) {
+                                    lemma = "&lt;";
+                                } else if (token.equals(">")) {
+                                    lemma = "&gt;";
+                                } else {
+                                    lemma = strings[2];
+                                }
+                                if (token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">")) {
+                                    pos = "$(";
+                                } else if (token.equals("&")) {
+                                    pos = "KON";
+                                } else {
+                                    pos = strings[3];
+                                }
+                                //noinspection ConstantCondition
+                                slidingWindowQueue.add(join("\t", token, lemma, pos));
                             } else {
-                                lemma = strings[2];
+                                slidingWindowQueue.add(token);
                             }
-                            if (token.equals("\"") || token.equals("'") || token.equals("<") || token.equals(">")) {
-                                pos = "$(";
-                            } else if (token.equals("&")) {
-                                pos = "KON";
-                            } else {
-                                pos = strings[3];
-                            }
-                            //noinspection ConstantCondition
-                            slidingWindowQueue.add(join("\t", token, lemma, pos));
-                        } else {
-                            slidingWindowQueue.add(token);
                         }
                     }
                 }