/*
 * Decompiled with CFR 0.152.
 */
package lemming.lemma.ranker;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import lemming.lemma.LemmaCandidateGenerator;
import lemming.lemma.LemmaCandidateGeneratorTrainer;
import lemming.lemma.LemmaInstance;
import lemming.lemma.LemmaOptions;
import lemming.lemma.LemmatizerGenerator;
import lemming.lemma.LemmatizerGeneratorTrainer;
import lemming.lemma.SimpleLemmatizerTrainer;
import lemming.lemma.edit.EditTreeGeneratorTrainer;
import lemming.lemma.ranker.Ranker;
import lemming.lemma.ranker.RankerInstance;
import lemming.lemma.ranker.RankerModel;
import lemming.lemma.ranker.RankerObjective;
import lemming.lemma.toutanova.EditTreeAligner;
import lemming.lemma.toutanova.EditTreeAlignerTrainer;
import marmot.util.Sys;

public class RankerTrainer
implements LemmatizerGeneratorTrainer {
    private RankerTrainerOptions options_ = new RankerTrainerOptions();
    private static final int MAX_NUM_DUPLICATES_ = 3;

    @Override
    public LemmatizerGenerator train(List<LemmaInstance> train_instances, List<LemmaInstance> test_instances) {
        List<LemmaCandidateGenerator> generators = this.options_.getGenerators(train_instances);
        return this.trainReranker(generators, train_instances);
    }

    private LemmatizerGenerator trainReranker(List<LemmaCandidateGenerator> generators, List<LemmaInstance> simple_instances) {
        List<RankerInstance> instances = RankerInstance.getInstances(simple_instances, generators);
        RankerModel model = new RankerModel();
        EditTreeAligner aligner = (EditTreeAligner)new EditTreeAlignerTrainer(this.options_.getRandom(), false, 1, -1).train(simple_instances);
        Logger logger = Logger.getLogger(this.getClass().getName());
        logger.info("Extracting features");
        model.init(this.options_, instances, aligner);
        if (this.options_.getUsePerceptron()) {
            this.runPerceptron(model, instances);
        } else {
            this.runMaxEnt(model, instances);
        }
        return new Ranker(model, generators);
    }

    private void runMaxEnt(RankerModel model, List<RankerInstance> instances) {
        if (this.options_.getUseMallet()) {
            this.runMallet(model, instances);
        } else {
            this.runSgd(model, instances);
        }
    }

    private void runSgd(RankerModel model, List<RankerInstance> instances) {
        LinkedList<RankerInstance> duplicates = new LinkedList<RankerInstance>();
        for (RankerInstance instance : instances) {
            double count2 = instance.getInstance().getCount();
            int number = Math.min(3, (int)count2);
            for (int i = 0; i < number; ++i) {
                duplicates.add(instance);
            }
        }
        Logger logger = Logger.getLogger(this.getClass().getName());
        logger.info(String.format("Created duplicates. Increased num instances from %d to %d.\n", instances.size(), duplicates.size()));
        instances = duplicates;
        double initial_step_width = 0.1;
        RankerObjective objective = new RankerObjective(this.options_, model, instances, 3);
        Random random = this.options_.getRandom();
        int number = 0;
        for (int step = 0; step < this.options_.getNumIterations(); ++step) {
            logger.info("SGD step: " + step);
            Collections.shuffle(instances, random);
            for (RankerInstance instance : instances) {
                double step_width = initial_step_width / (1.0 + (double)number / (double)instances.size());
                objective.update(instance, true, step_width);
                ++number;
            }
        }
    }

    private void runMallet(RankerModel model, List<RankerInstance> instances) {
        Logger logger = Logger.getLogger(this.getClass().getName());
        double memory_used_before_optimization = Sys.getUsedMemoryInMegaBytes();
        double memory_usage_of_one_weights_array = (double)model.getWeights().length * 64.0 / 8388608.0;
        logger.info(String.format("Memory usage of weights array: %g (%g) MB", Sys.getUsedMemoryInMegaBytes((Serializable)model.getWeights(), false), memory_usage_of_one_weights_array));
        logger.info(String.format("Memory usage: %g / %g MB", memory_used_before_optimization, Sys.getMaxHeapSizeInMegaBytes()));
        logger.info("Start optimization");
        RankerObjective objective = new RankerObjective(this.options_, model, instances);
        LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS((Optimizable.ByGradientValue)objective);
        Logger.getLogger(optimizer.getClass().getName()).setLevel(Level.OFF);
        objective.setParameters(model.getWeights());
        try {
            optimizer.optimize(1);
            double memory_usage_during_optimization = Sys.getUsedMemoryInMegaBytes();
            logger.info(String.format("Memory usage after first iteration: %g / %g MB", memory_usage_during_optimization, Sys.getMaxHeapSizeInMegaBytes()));
            for (int i = 0; i < 200 && !optimizer.isConverged(); ++i) {
                optimizer.optimize(1);
                logger.info(String.format("Iteration: %3d / %3d", i + 1, 200));
            }
        }
        catch (OptimizationException | IllegalArgumentException throwable) {
            // empty catch block
        }
        logger.info("Finished optimization");
    }

    private void runPerceptron(RankerModel model, List<RankerInstance> instances) {
        Logger logger = Logger.getLogger(this.getClass().getName());
        double[] weights = model.getWeights();
        double[] sum_weights = null;
        if (this.options_.getAveraging()) {
            sum_weights = new double[weights.length];
        }
        for (int iter = 0; iter < this.options_.getNumIterations(); ++iter) {
            double error = 0.0;
            double total = 0.0;
            int number = 0;
            Collections.shuffle(instances, this.options_.getRandom());
            for (RankerInstance instance : instances) {
                String lemma = model.select(instance);
                if (!lemma.equals(instance.getInstance().getLemma())) {
                    model.update(instance, lemma, -1.0);
                    model.update(instance, instance.getInstance().getLemma(), 1.0);
                    if (sum_weights != null) {
                        double amount = instances.size() - number;
                        assert (amount > 0.0);
                        model.setWeights(sum_weights);
                        model.update(instance, lemma, -amount);
                        model.update(instance, instance.getInstance().getLemma(), amount);
                        model.setWeights(weights);
                    }
                    error += instance.getInstance().getCount();
                }
                total += instance.getInstance().getCount();
                ++number;
            }
            if (sum_weights != null) {
                double weights_scaling = 1.0 / (((double)iter + 1.0) * (double)instances.size());
                double sum_weights_scaling = ((double)iter + 2.0) / ((double)iter + 1.0);
                for (int i = 0; i < weights.length; ++i) {
                    weights[i] = sum_weights[i] * weights_scaling;
                    sum_weights[i] = sum_weights[i] * sum_weights_scaling;
                }
            }
            logger.info(String.format("Train Accuracy: %g / %g = %g", total - error, total, (total - error) * 100.0 / total));
        }
    }

    @Override
    public LemmaOptions getOptions() {
        return this.options_;
    }

    public void setOptions(RankerTrainerOptions roptions) {
        this.options_ = roptions;
    }

    public static class RankerTrainerOptions
    extends LemmaOptions {
        private static final long serialVersionUID = 1L;
        public static final String GENERATOR_TRAINERS = "generator-trainers";
        public static final String USE_PERCEPTRON = "use-perceptron";
        public static final String QUADRATIC_PENALTY = "quadratic-penalty";
        public static final String UNIGRAM_FILE = "unigram-file";
        public static final String USE_SHAPE_LEXICON = "use-shape-lexicon";
        public static final String ASPELL_LANG = "aspell-lang";
        public static final String ASPELL_PATH = "aspell-path";
        public static final String USE_CORE_FEATURES = "use-core-features";
        public static final String USE_ALIGNMENT_FEATURES = "use-alignment-features";
        public static final String IGNORE_FEATURES = "ignore-features";
        public static final String NUM_EDIT_TREE_STEPS = "num-edit-tree-steps";
        public static final String COPY_CONJUNCTONS = "copy-conjunctions";
        public static final String TAG_DEPENDENT = "tag-dependent";
        public static final String EDIT_TREE_MIN_COUNT = "edit-tree-min-count";
        public static final String EDIT_TREE_MAX_DEPTH = "edit-tree-max-depth";
        public static final String USE_HASH_FEATURE_TABLE = "use-hash-feature-table";
        public static final String USE_MALLET = "use-mallet";
        public static final String OFFLINE_FEATURE_EXTRACTION = "offline-feature-extraction";
        public static final String CLUSTER_FILE = "cluster-file";

        public RankerTrainerOptions() {
            this.map_.put(GENERATOR_TRAINERS, Arrays.asList(SimpleLemmatizerTrainer.class, EditTreeGeneratorTrainer.class));
            this.map_.put(USE_PERCEPTRON, true);
            this.map_.put(QUADRATIC_PENALTY, 0.0);
            this.map_.put(UNIGRAM_FILE, Arrays.asList(""));
            this.map_.put(USE_SHAPE_LEXICON, false);
            this.map_.put(ASPELL_LANG, "");
            this.map_.put(ASPELL_PATH, "");
            this.map_.put(USE_CORE_FEATURES, true);
            this.map_.put(USE_ALIGNMENT_FEATURES, true);
            this.map_.put(IGNORE_FEATURES, "");
            this.map_.put(NUM_EDIT_TREE_STEPS, 1);
            this.map_.put(COPY_CONJUNCTONS, false);
            this.map_.put(USE_HASH_FEATURE_TABLE, false);
            this.map_.put(TAG_DEPENDENT, false);
            this.map_.put(EDIT_TREE_MIN_COUNT, 0);
            this.map_.put(EDIT_TREE_MAX_DEPTH, -1);
            this.map_.put(USE_MALLET, true);
            this.map_.put(OFFLINE_FEATURE_EXTRACTION, true);
            this.map_.put(CLUSTER_FILE, "");
        }

        public RankerTrainerOptions(RankerTrainerOptions roptions) {
            this.map_ = new HashMap(roptions.map_);
        }

        public List<Object> getUnigramFile() {
            return (List)this.getOption(UNIGRAM_FILE);
        }

        public List<Object> getGeneratorTrainers() {
            return (List)this.getOption(GENERATOR_TRAINERS);
        }

        public boolean getUsePerceptron() {
            return (Boolean)this.getOption(USE_PERCEPTRON);
        }

        public double getQuadraticPenalty() {
            return (Double)this.getOption(QUADRATIC_PENALTY);
        }

        public List<LemmaCandidateGenerator> getGenerators(List<LemmaInstance> instances) {
            LinkedList<LemmaCandidateGenerator> generators = new LinkedList<LemmaCandidateGenerator>();
            for (Object trainer_class : this.getGeneratorTrainers()) {
                LemmaCandidateGeneratorTrainer trainer = (LemmaCandidateGeneratorTrainer)this.toInstance((Class)trainer_class);
                if (trainer instanceof EditTreeGeneratorTrainer) {
                    trainer.getOptions().setOption("num-steps", this.getNumEditTreeSteps());
                    trainer.getOptions().setOption(TAG_DEPENDENT, this.getTagDependent());
                    trainer.getOptions().setOption("min-count", this.getEditTreeMinCount());
                    trainer.getOptions().setOption("max-depth", this.getEditTreeMaxDepth());
                }
                generators.add(trainer.train(instances, null));
            }
            return generators;
        }

        private Integer getEditTreeMaxDepth() {
            return (Integer)this.getOption(EDIT_TREE_MAX_DEPTH);
        }

        private Integer getEditTreeMinCount() {
            return (Integer)this.getOption(EDIT_TREE_MIN_COUNT);
        }

        public boolean getTagDependent() {
            return (Boolean)this.getOption(TAG_DEPENDENT);
        }

        public boolean getUseShapeLexicon() {
            return (Boolean)this.getOption(USE_SHAPE_LEXICON);
        }

        public String getAspellPath() {
            return (String)this.getOption(ASPELL_PATH);
        }

        public String getAspellLang() {
            return (String)this.getOption(ASPELL_LANG);
        }

        public boolean getUseCoreFeatures() {
            return (Boolean)this.getOption(USE_CORE_FEATURES);
        }

        public boolean getUseAlignmentFeatures() {
            return (Boolean)this.getOption(USE_ALIGNMENT_FEATURES);
        }

        public String getIgnoreFeatures() {
            return (String)this.getOption(IGNORE_FEATURES);
        }

        public int getNumEditTreeSteps() {
            return (Integer)this.getOption(NUM_EDIT_TREE_STEPS);
        }

        public boolean getCopyConjunctions() {
            return (Boolean)this.getOption(COPY_CONJUNCTONS);
        }

        public boolean getUseHashFeatureTable() {
            return (Boolean)this.getOption(USE_HASH_FEATURE_TABLE);
        }

        public boolean getUseMallet() {
            return (Boolean)this.getOption(USE_MALLET);
        }

        public boolean getUseOfflineFeatureExtraction() {
            return (Boolean)this.getOption(OFFLINE_FEATURE_EXTRACTION);
        }

        public String getClusterFile() {
            return (String)this.getOption(CLUSTER_FILE);
        }
    }
}

