/*
 * Decompiled with CFR 0.152.
 */
package lemming.lemma.toutanova;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import lemming.lemma.LemmaInstance;
import lemming.lemma.LemmaOptions;
import lemming.lemma.LemmatizerGenerator;
import lemming.lemma.LemmatizerGeneratorTrainer;
import lemming.lemma.toutanova.Aligner;
import lemming.lemma.toutanova.AlignerTrainer;
import lemming.lemma.toutanova.Decoder;
import lemming.lemma.toutanova.EditTreeAlignerTrainer;
import lemming.lemma.toutanova.Result;
import lemming.lemma.toutanova.ToutanovaInstance;
import lemming.lemma.toutanova.ToutanovaLemmatizer;
import lemming.lemma.toutanova.ToutanovaModel;
import lemming.lemma.toutanova.ZeroOrderDecoder;
import marmot.util.DynamicWeights;

public class ToutanovaTrainer
implements LemmatizerGeneratorTrainer {
    private ToutanovaOptions options_ = new ToutanovaOptions();

    public static List<ToutanovaInstance> createToutanovaInstances(List<LemmaInstance> instances, Aligner aligner) {
        LinkedList<ToutanovaInstance> new_instances = new LinkedList<ToutanovaInstance>();
        for (LemmaInstance instance : instances) {
            List<Integer> alignment = null;
            if (aligner != null) {
                alignment = aligner.align(instance.getForm(), instance.getLemma());
                assert (alignment != null);
            }
            new_instances.add(new ToutanovaInstance(instance, alignment));
        }
        return new_instances;
    }

    @Override
    public LemmatizerGenerator train(List<LemmaInstance> train_instances, List<LemmaInstance> dev_instances) {
        AlignerTrainer aligner_trainer = this.options_.getAligner();
        Aligner aligner = aligner_trainer.train(train_instances);
        List<ToutanovaInstance> new_train_instances = ToutanovaTrainer.createToutanovaInstances(train_instances, aligner);
        List<ToutanovaInstance> new_dev_instances = null;
        if (dev_instances != null) {
            new_dev_instances = ToutanovaTrainer.createToutanovaInstances(dev_instances, null);
        }
        return this.trainAligned(new_train_instances, new_dev_instances);
    }

    public LemmatizerGenerator trainAligned(List<ToutanovaInstance> train_instances, List<ToutanovaInstance> dev_instances) {
        Logger logger = Logger.getLogger(this.getClass().getName());
        ToutanovaModel model = new ToutanovaModel();
        model.init(this.options_, train_instances, dev_instances);
        DynamicWeights weights = model.getWeights();
        DynamicWeights sum_weights = null;
        if (this.options_.getAveraging()) {
            sum_weights = new DynamicWeights(null);
        }
        Decoder decoder = this.options_.getDecoderInstance();
        decoder.init(model);
        LinkedList<ToutanovaInstance> token_instances = new LinkedList<ToutanovaInstance>();
        for (ToutanovaInstance instance : train_instances) {
            if (instance.isRare()) continue;
            int i = 0;
            while ((double)i < Math.min((double)this.options_.getMaxCount(), instance.getInstance().getCount())) {
                token_instances.add(instance);
                ++i;
            }
        }
        for (int iter = 0; iter < this.options_.getNumIterations(); ++iter) {
            logger.info(String.format("Iter: %3d / %3d", iter + 1, this.options_.getNumIterations()));
            double correct = 0.0;
            double total = 0.0;
            int number = 0;
            Collections.shuffle(token_instances, this.options_.getRandom());
            for (ToutanovaInstance instance : token_instances) {
                Result result = decoder.decode(instance);
                String output = result.getOutput();
                if (!output.equals(instance.getInstance().getLemma())) {
                    model.update(instance, result, -1.0);
                    model.update(instance, instance.getResult(), 1.0);
                    if (sum_weights != null) {
                        double amount = token_instances.size() - number;
                        assert (amount > 0.0);
                        model.setWeights(sum_weights);
                        sum_weights = model.getWeights();
                        model.update(instance, result, -amount);
                        model.update(instance, instance.getResult(), amount);
                        model.setWeights(weights);
                        weights = model.getWeights();
                    }
                } else {
                    correct += 1.0;
                }
                total += 1.0;
                if (++number % 1000 != 0 || this.options_.getVerbosity() <= 0) continue;
                logger.info(String.format("Processed: %3d / %3d", number, token_instances.size()));
            }
            if (sum_weights != null) {
                double weights_scaling = 1.0 / (((double)iter + 1.0) * (double)token_instances.size());
                double sum_weights_scaling = ((double)iter + 2.0) / ((double)iter + 1.0);
                for (int i = 0; i < weights.getLength(); ++i) {
                    weights.set(i, sum_weights.get(i) * weights_scaling);
                    sum_weights.set(i, sum_weights.get(i) * sum_weights_scaling);
                }
            }
            logger.info(String.format("Train Accuracy: %g / %g = %g", correct, total, correct * 100.0 / total));
        }
        return new ToutanovaLemmatizer(this.options_, model);
    }

    @Override
    public LemmaOptions getOptions() {
        return this.options_;
    }

    public static class ToutanovaOptions
    extends LemmaOptions {
        private static final long serialVersionUID = 1L;
        public static final String FILTER_ALPHABET = "filter-alphabet";
        public static final String ALIGNER_TRAINER = "aligner-trainer";
        public static final String DECODER = "decoder";
        public static final String MAX_COUNT = "max-count";
        public static final String NBEST_RANK = "nbest-rank";
        public static final String WINDOW_SIZE = "window-size";

        public ToutanovaOptions() {
            this.map_.put(FILTER_ALPHABET, 5);
            this.map_.put(ALIGNER_TRAINER, EditTreeAlignerTrainer.class);
            this.map_.put(DECODER, ZeroOrderDecoder.class);
            this.map_.put(MAX_COUNT, 1);
            this.map_.put(NBEST_RANK, 50);
            this.map_.put(WINDOW_SIZE, 2);
        }

        public static ToutanovaOptions newInstance() {
            return new ToutanovaOptions();
        }

        public int getFilterAlphabet() {
            return (Integer)this.getOption(FILTER_ALPHABET);
        }

        public AlignerTrainer getAligner() {
            return (AlignerTrainer)this.getInstance(ALIGNER_TRAINER);
        }

        public Decoder getDecoderInstance() {
            return (Decoder)this.getInstance(DECODER);
        }

        public int getMaxCount() {
            return (Integer)this.getOption(MAX_COUNT);
        }

        public int getNbestRank() {
            return (Integer)this.getOption(NBEST_RANK);
        }

        public int getMaxWindowSize() {
            return (Integer)this.getOption(WINDOW_SIZE);
        }
    }
}

