import my_utils.file_utils as fu
from lib.CoNLL_Annotation import read_conll, CoNLL09_Token, TigerNew_Token 
from collections import Counter

ORIGINAL_TIGER = "/home/daza/datasets/TIGER_conll/TIGER_original_data/tiger_release_aug07.corrected.16012013.conll09"
NEW_ORTH = "/vol/work/kupietz/Tiger_2_2/data/german/tiger/train/german_tiger_new_orthography.csv"


def get_confident_mapping(common_counts):
	mapping = {}
	for (old, new), cnt in common_counts:
		if old not in mapping:
			mapping[old] = (new, cnt)
		else:
			prev_tok, prev_cnt = mapping[old]
			if cnt > prev_cnt:
				mapping[old] = (new, cnt)
	return {k: v[0] for k,v in mapping.items()}
		

def check_orthography(s_old,s_new):
	global total_tokens
	identical_sents = True
	words_old = s_old.get_words()
	words_new = s_new.get_words()
	assert len(words_old) == len(words_new)
	total_tokens += len(words_old)
	for w1,w2 in zip(words_old, words_new):
		if w1 != w2 and w1[0].lower() == w2[0].lower():
			token_changes.append((w1, w2))
			identical_sents = False
	return identical_sents


if __name__ == "__main__":
	line_generator = fu.file_generator(ORIGINAL_TIGER)
	original_sents, _ = read_conll(line_generator, chunk_size=60000, token_class=CoNLL09_Token, comment_str="#")
	
	line_generator = fu.file_generator(NEW_ORTH)
	new_orth_sents, _ = read_conll(line_generator, chunk_size=60000, token_class=TigerNew_Token, comment_str="#")
	
	new_ix = 0
	train_tiger, test_tiger = [], []
	problematic_sents, token_changes = [], []
	total_tokens = 0
	for i, s1 in enumerate(original_sents):
		s2 = new_orth_sents[new_ix]
		print(f"--- {new_ix} ---\n{s1.get_sentence()}\n{s2.get_sentence()}\n\n")
		if len(s1.get_words()) == len(s2.get_words()):
			train_tiger.append((s1,s2))
			identical_sents = check_orthography(s1,s2)
			if not identical_sents: 
				problematic_sents.append(new_ix)
			new_ix += 1
		else:
			test_tiger.append(s1)
	
	# Print Stats
	print(len(train_tiger))
	print(len(test_tiger))
	print(len(new_orth_sents))
	print(f"{len(problematic_sents)}/{len(train_tiger)} ({len(problematic_sents)*100/len(train_tiger)}%) of sentences have change of orthography.")
	print(f"{len(token_changes)}/{total_tokens} ({len(token_changes)*100/total_tokens}%) of tokens have change of orthography.")
	# Save Files
	save_path = "/home/daza/datasets/TIGER_conll"
	new_cases = Counter(token_changes).most_common()
	case_mapping = get_confident_mapping(new_cases)
	# Stats
	fu.counter_to_file(new_cases, f"{save_path}/TigerTokensChangeOrth.train.tsv")
	fu.dict_to_file(case_mapping, f"{save_path}/TigerOrthMapping.train.json")
	fu.list_to_file(problematic_sents, f"{save_path}/NewOrthProblems_Indices.train.txt")
	# Train/Test Splits
	old_train, new_train = zip(*train_tiger)
	fu.write_conll_file(old_train, out_path=f"{save_path}/Tiger.OldOrth.train.conll")
	fu.write_conll_file(new_train, out_path=f"{save_path}/Tiger.NewOrth.train.conll")
	fu.write_conll_file(test_tiger, out_path=f"{save_path}/Tiger.OldOrth.test.conll")
	
		
