/*
 * Decompiled with CFR 0.152.
 */
package marmot.test.util.edit;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.logging.Logger;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.Aligner;
import lemming.lemma.toutanova.EditTreeAligner;
import marmot.morph.io.SentenceReader;
import marmot.util.Numerics;
import marmot.util.edit.EditTree;
import marmot.util.edit.EditTreeBuilder;
import marmot.util.edit.EditTreeBuilderTrainer;
import org.junit.Assert;
import org.junit.Test;

public class EditTreeBuilderTrainerTest {
    @Test
    public void test() {
        String indexes = "form-index=4,lemma-index=5,tag-index=2,";
        String trainfile = indexes + this.getResourceFile("trn_mod.tsv");
        List<LemmaInstance> instances = LemmaInstance.getInstances(new SentenceReader(trainfile));
        EditTreeBuilderTrainer trainer = new EditTreeBuilderTrainer(new Random(42L), 1, -1);
        EditTreeBuilder builder = trainer.train(instances);
        EditTreeAligner aligner = new EditTreeAligner(builder, true);
        this.testAligner(aligner, "umgezogen", "umziehen", Arrays.asList("u", "m", "ge", "z", "og", "e", "n"), Arrays.asList("u", "m", "", "z", "ieh", "e", "n"));
        this.testAligner(aligner, "gebissen", "bei\u00dfen", Arrays.asList("ge", "b", "i", "ss", "e", "n"), Arrays.asList("", "be", "i", "\u00df", "e", "n"));
        this.testAligner(aligner, "gebogen", "biegen", Arrays.asList("ge", "b", "o", "g", "e", "n"), Arrays.asList("", "b", "ie", "g", "e", "n"));
    }

    @Test
    public void testApply() {
        String indexes = "form-index=4,lemma-index=5,tag-index=2,";
        String trainfile = indexes + this.getResourceFile("trn_mod.tsv");
        List<LemmaInstance> instances = LemmaInstance.getInstances(new SentenceReader(trainfile));
        EditTreeBuilderTrainer trainer = new EditTreeBuilderTrainer(new Random(42L), 1, -1);
        EditTreeBuilder builder = trainer.train(instances);
        this.testHashAndEquals(builder, "loves", "love", "hates", "hate", true);
        this.testHashAndEquals(builder, "lachen", "gelacht", "machen", "gemacht", true);
        this.testHashAndEquals(builder, "lachen", "gelacht", "aaaaaaaaen", "geaaaaaaaat", true);
        HashMap<EditTree, List<LemmaInstance>> map2 = new HashMap<EditTree, List<LemmaInstance>>();
        for (LemmaInstance instance : instances) {
            String input = instance.getForm();
            String output = instance.getLemma();
            EditTree tree = builder.build(input, output);
            String p_output = tree.apply(input, 0, input.length());
            Assert.assertEquals((Object)output, (Object)p_output);
            List list = map2.computeIfAbsent(tree, k -> new LinkedList());
            list.add(instance);
        }
        this.applyTest(map2, instances, false, 0.0);
        this.applyTest(map2, LemmaInstance.getInstances(indexes + this.getResourceFile("dev.tsv")), false, 0.02526);
    }

    private void applyTest(Map<EditTree, List<LemmaInstance>> map2, List<LemmaInstance> instances, boolean log_missed_outputs, double expected_miss_rate) {
        Logger logger = Logger.getLogger(this.getClass().getName());
        int missed_outputs = 0;
        for (LemmaInstance instance : instances) {
            String input = instance.getForm();
            String output = instance.getLemma();
            HashSet<String> outputs = new HashSet<String>();
            for (EditTree tree : map2.keySet()) {
                String poutput = tree.apply(input, 0, input.length());
                if (poutput == null) continue;
                outputs.add(poutput);
            }
            if (!outputs.contains(output)) {
                ++missed_outputs;
                if (log_missed_outputs) {
                    logger.info(String.format("Missed: %s", instance));
                }
            }
            Assert.assertTrue((boolean)outputs.contains(input));
        }
        double missed_rate = (double)missed_outputs * 1.0 / (double)instances.size();
        logger.info(Double.toString(missed_rate));
        Assert.assertTrue((boolean)Numerics.approximatelyLesserEqual(missed_rate, expected_miss_rate));
    }

    private void testHashAndEquals(EditTreeBuilder builder, String input_a, String output_a, String input_b, String output_b, boolean result) {
        EditTree tree_a = builder.build(input_a, output_a);
        EditTree tree_b = builder.build(input_b, output_b);
        Assert.assertEquals((Object)result, (Object)tree_a.equals(tree_b));
        Assert.assertEquals((Object)result, (Object)(tree_a.hashCode() == tree_b.hashCode() ? 1 : 0));
    }

    public void testAligner(EditTreeAligner aligner, String input, String output, List<String> input_segments, List<String> output_segments) {
        List<Integer> indexes = aligner.align(input, output);
        List<Aligner.Pair> pairs = Aligner.Pair.toPairs(input, output, indexes);
        LinkedList<String> real_input_segments = new LinkedList<String>();
        LinkedList<String> real_output_segments = new LinkedList<String>();
        for (Aligner.Pair pair : pairs) {
            real_input_segments.add(pair.getInputSegment());
            real_output_segments.add(pair.getOutputSegment());
        }
        Assert.assertEquals(input_segments, real_input_segments);
        Assert.assertEquals(output_segments, real_output_segments);
    }

    protected String getResourceFile(String name) {
        return String.format("res:///%s/%s", "marmot/test/lemma", name);
    }
}

