| daza | e3bc92e | 2020-11-04 11:06:26 +0100 | [diff] [blame] | 1 | from lib.CoNLL_Annotation import * | 
|  | 2 | from collections import Counter, defaultdict | 
|  | 3 | import pandas as pd | 
|  | 4 | import numpy as np | 
|  | 5 | from sklearn.metrics import precision_recall_fscore_support as eval_f1 | 
|  | 6 | from tabulate import tabulate | 
|  | 7 | import logging, argparse, sys | 
|  | 8 | from datetime import datetime | 
|  | 9 |  | 
|  | 10 |  | 
|  | 11 | tree_tagger_fixes = { | 
|  | 12 | "die": "der", | 
|  | 13 | "eine": "ein", | 
|  | 14 | "dass": "daß", | 
|  | 15 | "keine": "kein", | 
|  | 16 | "dies": "dieser", | 
|  | 17 | "erst": "erster", | 
|  | 18 | "andere": "anderer", | 
|  | 19 | "alle": "aller", | 
|  | 20 | "Sie": "sie", | 
|  | 21 | "wir": "uns", | 
|  | 22 | "alle": "aller", | 
|  | 23 | "wenige": "wenig" | 
|  | 24 | } | 
|  | 25 |  | 
|  | 26 |  | 
|  | 27 | def save_evaluated(all_sys, all_gld, out_path, print_gold=True): | 
|  | 28 | with open(out_path, "w") as out: | 
|  | 29 | if print_gold: | 
|  | 30 | out.write(f"ORIGINAL_CORPUS_TAGS\n\nTAG\tGLD_COUNT\tSYS_COUNT\n") | 
|  | 31 | for g_tag,g_count in sorted(all_gld.items()): | 
|  | 32 | s_count = all_sys.get(g_tag, 0) | 
|  | 33 | out.write(f"{g_tag}\t{g_count}\t{s_count}\n") | 
|  | 34 |  | 
|  | 35 | out.write("\n\nSYSTEM_ONLY_TAGS\n\nTAG\tG_COUNT\tSYS_COUNT\n") | 
|  | 36 | for s_tag,s_count in sorted(all_sys.items()): | 
|  | 37 | g_count = all_gld.get(s_tag, 0) | 
|  | 38 | if g_count == 0: | 
|  | 39 | out.write(f"{s_tag}\t{g_count}\t{s_count}\n") | 
|  | 40 |  | 
|  | 41 |  | 
|  | 42 |  | 
|  | 43 | def eval_lemma(sys, gld): | 
|  | 44 | match, err, symbol = 0, 0, [] | 
|  | 45 | y_gld, y_pred, mistakes = [], [], [] | 
|  | 46 | for i, gld_tok in enumerate(gld.tokens): | 
|  | 47 | # sys_lemma = tree_tagger_fixes.get(sys.tokens[i].lemma, sys.tokens[i].lemma)  # Omit TreeTagger "errors" because of article lemma disagreement | 
|  | 48 | sys_lemma = sys.tokens[i].lemma | 
|  | 49 | y_gld.append(gld_tok.pos_tag) | 
|  | 50 | y_pred.append(sys_lemma) | 
|  | 51 | if gld_tok.lemma == sys_lemma: | 
|  | 52 | match += 1 | 
|  | 53 | elif not sys.tokens[i].lemma.isalnum(): # Turku does not lemmatize symbols (it only copies them) => ERR ((',', '--', ','), 43642) | 
|  | 54 | symbol.append(sys.tokens[i].lemma) | 
|  | 55 | if sys.tokens[i].word == sys.tokens[i].lemma: | 
|  | 56 | match += 1 | 
|  | 57 | else: | 
|  | 58 | err += 1 | 
|  | 59 | else: | 
|  | 60 | err += 1 | 
|  | 61 | mistakes.append((gld_tok.word, gld_tok.lemma, sys.tokens[i].lemma)) | 
|  | 62 | return y_gld, y_pred, match, err, symbol, mistakes | 
|  | 63 |  | 
|  | 64 |  | 
|  | 65 | def eval_pos(sys, gld): | 
|  | 66 | match, mistakes = 0, [] | 
|  | 67 | y_gld, y_pred = [], [] | 
|  | 68 | for i, gld_tok in enumerate(gld.tokens): | 
|  | 69 | y_gld.append(gld_tok.pos_tag) | 
|  | 70 | y_pred.append(sys.tokens[i].pos_tag) | 
|  | 71 | # pos_all_pred[gld_tok.pos_tag] += 1 | 
|  | 72 | # pos_all_gold[sys.tokens[i].pos_tag] += 1 | 
|  | 73 | if gld_tok.pos_tag == sys.tokens[i].pos_tag: | 
|  | 74 | match += 1 | 
|  | 75 | elif gld_tok.pos_tag == "$." and sys.tokens[i].pos_tag == "$": | 
|  | 76 | match += 1 | 
|  | 77 | y_pred = y_pred[:-1] + ["$."] | 
|  | 78 | else: | 
|  | 79 | mistakes.append((gld_tok.word, gld_tok.pos_tag, sys.tokens[i].pos_tag)) | 
|  | 80 | return y_gld, y_pred, match, mistakes | 
|  | 81 |  | 
|  | 82 |  | 
|  | 83 |  | 
|  | 84 | if __name__ == "__main__": | 
|  | 85 | """ | 
|  | 86 | EVALUATIONS: | 
|  | 87 |  | 
|  | 88 | ********** TIGER CORPUS ALL ************ | 
|  | 89 |  | 
|  | 90 | python systems/evaluate.py -t Turku --corpus_name Tiger\ | 
|  | 91 | --sys_file /home/daza/datasets/TIGER_conll/tiger_turku_parsed.conllu \ | 
|  | 92 | --gld_file /home/daza/datasets/TIGER_conll/tiger_release_aug07.corrected.16012013.conll09 | 
|  | 93 |  | 
|  | 94 | python systems/evaluate.py -t SpaCy --corpus_name Tiger\ | 
|  | 95 | --sys_file /home/daza/datasets/TIGER_conll/tiger_spacy_parsed.conllu \ | 
|  | 96 | --gld_file /home/daza/datasets/TIGER_conll/tiger_release_aug07.corrected.16012013.conll09 | 
|  | 97 |  | 
|  | 98 | python systems/evaluate.py -t RNNTagger --corpus_name Tiger\ | 
|  | 99 | --sys_file /home/daza/datasets/TIGER_conll/tiger_all.parsed.RNNTagger.conll \ | 
|  | 100 | --gld_file /home/daza/datasets/TIGER_conll/tiger_release_aug07.corrected.16012013.conll09 | 
|  | 101 |  | 
|  | 102 | python systems/evaluate.py -t TreeTagger --corpus_name Tiger\ | 
|  | 103 | --sys_file /home/daza/datasets/TIGER_conll/tiger_all.parsed.TreeTagger.conll \ | 
|  | 104 | --gld_file /home/daza/datasets/TIGER_conll/tiger_release_aug07.corrected.16012013.conll09 | 
|  | 105 |  | 
|  | 106 | ********** UNIVERSAL DEPENDENCIES TEST-SET ************ | 
|  | 107 |  | 
|  | 108 | python systems/evaluate.py -t Turku --gld_token_type CoNLLUP_Token --corpus_name DE_GSD\ | 
|  | 109 | --sys_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.conllu.parsed.0.conllu \ | 
|  | 110 | --gld_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.conllu | 
|  | 111 |  | 
|  | 112 | python systems/evaluate.py -t SpaCyGL --gld_token_type CoNLLUP_Token --corpus_name DE_GSD\ | 
|  | 113 | --sys_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.parsed.germalemma.conllu \ | 
|  | 114 | --gld_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.conllu | 
|  | 115 |  | 
|  | 116 | python systems/evaluate.py -t SpaCy --gld_token_type CoNLLUP_Token --corpus_name DE_GSD\ | 
|  | 117 | --sys_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.parsed.conllu \ | 
|  | 118 | --gld_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.conllu | 
|  | 119 |  | 
|  | 120 | python systems/evaluate.py -t RNNTagger --gld_token_type CoNLLUP_Token --corpus_name DE_GSD\ | 
|  | 121 | --sys_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.RNNtagger.parsed.conll \ | 
|  | 122 | --gld_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.conllu | 
|  | 123 |  | 
|  | 124 | python systems/evaluate.py -t TreeTagger --gld_token_type CoNLLUP_Token --corpus_name DE_GSD\ | 
|  | 125 | --sys_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.treetagger.parsed.conll \ | 
|  | 126 | --gld_file /home/daza/datasets/ud-treebanks-v2.2/UD_German-GSD/de_gsd-ud-test.conllu | 
|  | 127 |  | 
|  | 128 | """ | 
|  | 129 |  | 
|  | 130 | # ===================================================================================== | 
|  | 131 | #                    INPUT PARAMS | 
|  | 132 | # ===================================================================================== | 
|  | 133 | parser = argparse.ArgumentParser() | 
|  | 134 | parser.add_argument("-s", "--sys_file", help="System output in CoNLL-U Format", required=True) | 
|  | 135 | parser.add_argument("-g", "--gld_file", help="Gold Labels to evaluate in CoNLL-U Format", required=True) | 
|  | 136 | parser.add_argument("-t", "--type_sys", help="Which system produced the outputs", default="system") | 
|  | 137 | parser.add_argument("-c", "--corpus_name", help="Corpus Name for Gold Labels", required=True) | 
|  | 138 | parser.add_argument("-gtt", "--gld_token_type", help="CoNLL Format of the Gold Data", default="CoNLL09_Token") | 
|  | 139 | parser.add_argument("-cs", "--comment_str", help="CoNLL Format of comentaries inside the file", default="#") | 
|  | 140 | args = parser.parse_args() | 
|  | 141 |  | 
|  | 142 | # ===================================================================================== | 
|  | 143 | #                    LOGGING INFO ... | 
|  | 144 | # ===================================================================================== | 
|  | 145 | logger = logging.getLogger(__name__) | 
|  | 146 | console_hdlr = logging.StreamHandler(sys.stdout) | 
|  | 147 | file_hdlr = logging.FileHandler(filename=f"logs/Eval_{args.corpus_name}.{args.type_sys}.log") | 
|  | 148 | logging.basicConfig(level=logging.INFO, handlers=[console_hdlr, file_hdlr]) | 
|  | 149 | now_is = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | 
|  | 150 | logger.info(f"\n\nEvaluating {args.corpus_name} Corpus {now_is}") | 
|  | 151 |  | 
|  | 152 | # Read the Original GOLD Annotations [CoNLL09, CoNLLUP] | 
|  | 153 | gld_generator = read_conll_generator(args.gld_file, token_class=get_token_type(args.gld_token_type), comment_str=args.comment_str) | 
|  | 154 | # Read the Annotations Generated by the Automatic Parser [Turku, SpaCy, RNNTagger] | 
|  | 155 | if args.type_sys == "RNNTagger": | 
|  | 156 | sys_generator = read_conll_generator(args.sys_file, token_class=RNNTagger_Token, comment_str="#") | 
|  | 157 | elif args.type_sys == "TreeTagger": | 
|  | 158 | sys_generator = read_conll_generator(args.sys_file, token_class=RNNTagger_Token, sent_sep="</S>", comment_str="#") | 
|  | 159 | else: | 
|  | 160 | sys_generator = read_conll_generator(args.sys_file, token_class=CoNLLUP_Token, comment_str="#") | 
|  | 161 |  | 
|  | 162 | lemma_all_match, lemma_all_err, lemma_all_mistakes = 0, 0, [] | 
|  | 163 | lemma_all_symbols, sys_only_lemmas = [], [] | 
|  | 164 | pos_all_match, pos_all_err, pos_all_mistakes = 0, 0, [] | 
|  | 165 | pos_all_pred, pos_all_gld = [], [] | 
|  | 166 | lemma_all_pred, lemma_all_gld = [], [] | 
|  | 167 | n_sents = 0 | 
|  | 168 |  | 
|  | 169 | for i, (s,g) in enumerate(zip(sys_generator, gld_generator)): | 
|  | 170 | # print([x.word for x in s.tokens]) | 
|  | 171 | # print([x.word for x in g.tokens]) | 
|  | 172 | assert len(s.tokens) == len(g.tokens), f"Token Mismatch! S={len(s.tokens)} G={len(g.tokens)} IX={i+1}" | 
|  | 173 | n_sents += 1 | 
|  | 174 | # Lemmas ... | 
|  | 175 | lemma_gld, lemma_pred, lemma_match, lemma_err, lemma_sym, mistakes = eval_lemma(s,g) | 
|  | 176 | lemma_all_match += lemma_match | 
|  | 177 | lemma_all_err += lemma_err | 
|  | 178 | lemma_all_mistakes += mistakes | 
|  | 179 | lemma_all_symbols += lemma_sym | 
|  | 180 | lemma_all_pred += lemma_pred | 
|  | 181 | lemma_all_gld += lemma_gld | 
|  | 182 | # POS Tags ... | 
|  | 183 | pos_gld, pos_pred, pos_match, pos_mistakes = eval_pos(s, g) | 
|  | 184 | pos_all_pred += pos_pred | 
|  | 185 | pos_all_gld += pos_gld | 
|  | 186 | pos_all_match +=  pos_match | 
|  | 187 | pos_all_err += len(pos_mistakes) | 
|  | 188 | pos_all_mistakes += pos_mistakes | 
|  | 189 |  | 
|  | 190 | logger.info(f"A total of {n_sents} sentences were analyzed") | 
|  | 191 |  | 
|  | 192 | # Lemmas ... | 
|  | 193 | logger.info(f"Lemma Matches = {lemma_all_match} || Errors = {lemma_all_err} || Symbol Chars = {len(lemma_all_symbols)}") | 
|  | 194 | logger.info(f"Lemma Accuracy = {(lemma_all_match*100/(lemma_all_match + lemma_all_err)):.2f}%\n") | 
|  | 195 | lemma_miss_df = pd.DataFrame(lemma_all_mistakes, columns =['Gold_Word', 'Gold_Lemma', 'Sys_Lemma']).value_counts() | 
|  | 196 | lemma_miss_df.to_csv(path_or_buf=f"outputs/LemmaErrors.{args.corpus_name}.{args.type_sys}.tsv", sep="\t") | 
|  | 197 | save_evaluated(Counter(lemma_all_pred), Counter(lemma_all_gld), | 
|  | 198 | f"outputs/Lemma-Catalogue.{args.corpus_name}.{args.type_sys}.txt", print_gold=False) | 
|  | 199 |  | 
|  | 200 | # POS Tags ... | 
|  | 201 | logger.info(f"POS Matches = {pos_all_match} || Errors = {pos_all_err}") | 
|  | 202 | logger.info(f"POS Tagging Accuracy = {(pos_all_match*100/(pos_all_match + pos_all_err)):.2f}%\n") | 
|  | 203 | pos_miss_df = pd.DataFrame(pos_all_mistakes, columns =['Gold_Word', 'Gold_POS', 'Sys_POS']).value_counts() | 
|  | 204 | pos_miss_df.to_csv(path_or_buf=f"outputs/POS-Errors.{args.corpus_name}.{args.type_sys}.tsv", sep="\t") | 
|  | 205 | save_evaluated(Counter(pos_all_pred), Counter(pos_all_gld), f"outputs/POS-Catalogue.{args.corpus_name}.{args.type_sys}.txt") | 
|  | 206 |  | 
|  | 207 | ordered_labels = sorted(set(pos_all_gld)) | 
|  | 208 | p_labels, r_labels, f_labels, support = eval_f1(y_true=pos_all_gld, y_pred=pos_all_pred, labels=ordered_labels , average=None) | 
|  | 209 | scores_per_label = zip(ordered_labels, [x*100 for x in p_labels], [x*100 for x in r_labels], [x*100 for x in f_labels]) | 
|  | 210 | logger.info("\n\n") | 
|  | 211 | logger.info(tabulate(scores_per_label, headers=["POS Tag","Precision", "Recall", "F1"], floatfmt=".2f")) | 
|  | 212 | p_labels, r_labels, f_labels, support = eval_f1(y_true=np.array(pos_all_gld), y_pred=np.array(pos_all_pred), average='macro', zero_division=0) | 
|  | 213 | logger.info(f"Total Prec = {p_labels*100}\tRec = {r_labels*100}\tF1 = {f_labels*100}") | 
|  | 214 |  | 
|  | 215 |  |