/*
 * Decompiled with CFR 0.152.
 */
package lemming.test.lemma.toutanova;

import java.util.List;
import java.util.logging.Logger;
import junit.framework.AssertionFailedError;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.Result;
import lemming.lemma.toutanova.ToutanovaInstance;
import lemming.lemma.toutanova.ToutanovaLemmatizer;
import lemming.lemma.toutanova.ToutanovaModel;
import lemming.lemma.toutanova.ToutanovaTrainer;
import lemming.lemma.toutanova.ZeroOrderDecoder;
import lemming.lemma.toutanova.ZeroOrderNbestDecoder;
import marmot.morph.io.SentenceReader;
import marmot.util.Numerics;
import org.junit.Assert;
import org.junit.Test;

public class NbestDecoderTest {
    private static final double DELTA = 0.01;

    public void trainDecodeTest(String trainfile, String devfile, int num_iters, int rank_max) {
        List<LemmaInstance> train_instances = LemmaInstance.getInstances(new SentenceReader(trainfile));
        ToutanovaTrainer trainer = new ToutanovaTrainer();
        ToutanovaLemmatizer lemmatizer = (ToutanovaLemmatizer)trainer.train((List)train_instances, (List)null);
        this.testDecoder(lemmatizer, devfile, rank_max);
    }

    private void testDecoder(ToutanovaLemmatizer lemmatizer, String devfile, int rank_max) {
        ToutanovaModel model = lemmatizer.getModel();
        ZeroOrderDecoder decoder = new ZeroOrderDecoder();
        decoder.init(model);
        ZeroOrderNbestDecoder nbest_decoder = new ZeroOrderNbestDecoder(rank_max);
        nbest_decoder.init(model);
        List<LemmaInstance> test_instances = LemmaInstance.getInstances(new SentenceReader(devfile));
        int correct = 0;
        int nbest_correct = 0;
        int total = 0;
        for (LemmaInstance instance : test_instances) {
            ToutanovaInstance tinstance = new ToutanovaInstance(instance, null);
            model.addIndexes(tinstance, false);
            Result result = decoder.decode(tinstance);
            double expected_score = model.getScore(tinstance, result);
            double first_best_score = result.getScore();
            Assert.assertEquals((double)expected_score, (double)first_best_score, (double)0.01);
            List<Result> nbest_results = nbest_decoder.decode(tinstance);
            Assert.assertTrue((!nbest_results.isEmpty() ? 1 : 0) != 0);
            Result first_nbest_result = nbest_results.get(0);
            Assert.assertEquals((Object)result.getOutput(), (Object)first_nbest_result.getOutput());
            Assert.assertEquals((double)first_best_score, (double)first_nbest_result.getScore(), (double)0.01);
            Result last_result = null;
            boolean found_lemma = false;
            for (Result nbest_result : nbest_results) {
                Assert.assertEquals((double)model.getScore(tinstance, nbest_result), (double)nbest_result.getScore(), (double)0.01);
                if (last_result != null && !Numerics.approximatelyLesserEqual(nbest_result.getScore(), last_result.getScore())) {
                    throw new AssertionFailedError(String.format("%g <= %g", nbest_result.getScore(), last_result.getScore()));
                }
                last_result = nbest_result;
                if (!nbest_result.getOutput().equals(instance.getLemma())) continue;
                found_lemma = true;
            }
            if (found_lemma) {
                nbest_correct = (int)((double)nbest_correct + instance.getCount());
            }
            if (result.getOutput().equals(instance.getLemma())) {
                correct = (int)((double)correct + instance.getCount());
            }
            total = (int)((double)total + instance.getCount());
        }
        Logger logger = Logger.getLogger(this.getClass().getName());
        logger.info(String.format("One-best : %5d %5d = %g", correct, total, (double)correct * 100.0 / (double)total));
        logger.info(String.format("N-best : %5d %5d = %g", nbest_correct, total, (double)nbest_correct * 100.0 / (double)total));
    }

    @Test
    public void test() {
        String indexes = "form-index=4,lemma-index=5,tag-index=2,";
        String train_sml = indexes + this.getResourceFile("trn_mod.tsv");
        String dev = indexes + this.getResourceFile("dev.tsv");
        this.trainDecodeTest(train_sml, train_sml, 1, 5);
        this.trainDecodeTest(train_sml, dev, 10, 10);
    }

    protected String getResourceFile(String name) {
        return String.format("res:///%s/%s", "marmot/test/lemma", name);
    }
}

