/*
 * Decompiled with CFR 0.152.
 */
package marmot.core;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import marmot.core.Evaluator;
import marmot.core.Model;
import marmot.core.Options;
import marmot.core.Sequence;
import marmot.core.State;
import marmot.core.Tagger;
import marmot.core.Trainer;
import marmot.core.Transition;
import marmot.core.WeightVector;
import marmot.core.lattice.SequenceViterbiLattice;
import marmot.core.lattice.SumLattice;
import marmot.core.lattice.ViterbiLattice;
import marmot.core.lattice.ZeroOrderSumLattice;
import marmot.core.lattice.ZeroOrderViterbiLattice;

public class PerceptronTrainer
implements Trainer {
    private int steps_;
    private boolean shuffle_;
    private boolean verbose_;
    private boolean averaging_;
    private long seed_;

    @Override
    public void train(Tagger tagger, Collection<Sequence> in_sequences, Evaluator evaluator) {
        Random rng = null;
        if (this.shuffle_) {
            rng = this.seed_ == 0L ? new Random() : new Random(this.seed_);
        }
        ArrayList<Sequence> sequences = new ArrayList<Sequence>(in_sequences);
        int fraction = Math.max(sequences.size() / 4, 1);
        WeightVector weights = tagger.getWeightVector();
        assert (weights != null);
        double[] sum_weights = null;
        if (this.averaging_) {
            sum_weights = new double[weights.getWeights().length];
        }
        Model model = tagger.getModel();
        for (int step = 0; step < this.steps_; ++step) {
            if (this.verbose_) {
                System.err.println("step: " + step);
            }
            if (this.shuffle_) {
                Collections.shuffle(sequences, rng);
            }
            int current_sentence = 0;
            long train_time = System.currentTimeMillis();
            for (Sequence sequence : sequences) {
                SumLattice sum_lattice = tagger.getSumLattice(true, sequence);
                List<List<State>> candidates = sum_lattice.getCandidates();
                ViterbiLattice lattice = sum_lattice instanceof ZeroOrderSumLattice ? new ZeroOrderViterbiLattice(candidates, 1, false) : new SequenceViterbiLattice(candidates, model.getBoundaryState(tagger.getNumLevels() - 1), 1, false);
                List<Integer> best_sequence = lattice.getViterbiSequence().getStates();
                List<Integer> gold_sequence = sum_lattice.getGoldCandidates();
                if (!gold_sequence.equals(best_sequence)) {
                    this.update(weights, candidates, gold_sequence, 1.0);
                    this.update(weights, candidates, best_sequence, -1.0);
                    if (this.averaging_) {
                        double[] current_weights = weights.getWeights();
                        int amount = sequences.size() - current_sentence;
                        assert (amount > 0);
                        weights.setWeights(sum_weights);
                        this.update(weights, candidates, gold_sequence, amount);
                        this.update(weights, candidates, best_sequence, -amount);
                        weights.setWeights(current_weights);
                    }
                }
                if (++current_sentence % fraction != 0 || !this.verbose_) continue;
                System.err.format("Processed %d sentences at %g sentence/s \n", current_sentence, (double)current_sentence / ((double)(System.currentTimeMillis() - train_time) / 1000.0));
            }
            if (this.averaging_) {
                double[] current_weights = weights.getWeights();
                int i = 0;
                while (i < current_weights.length) {
                    double scaling = (step + 1) * sequences.size();
                    assert (scaling > 0.0);
                    current_weights[i] = sum_weights[i] / scaling;
                    scaling = (double)(step + 2) / (double)(step + 1);
                    assert (scaling > 0.0);
                    assert (scaling < 2.00001);
                    int n = i++;
                    sum_weights[n] = sum_weights[n] * scaling;
                }
            }
            if (evaluator == null || !this.verbose_) continue;
            weights.setExtendFeatureSet(false);
            evaluator.eval(tagger);
            weights.setExtendFeatureSet(true);
        }
        weights.setExtendFeatureSet(false);
    }

    private void update(WeightVector weights, List<List<State>> candidates, List<Integer> sequence, double amount) {
        int last_candidate_index = 0;
        for (int index = 0; index < sequence.size(); ++index) {
            int candidate_index = sequence.get(index);
            State state = candidates.get(index).get(candidate_index);
            weights.updateWeights(state, amount, false);
            Transition transition = state.getTransition(last_candidate_index);
            weights.updateWeights(transition, amount, true);
            last_candidate_index = candidate_index;
        }
    }

    @Override
    public void setOptions(Options options) {
        this.steps_ = options.getNumIterations();
        this.shuffle_ = options.getShuffle();
        this.verbose_ = options.getVerbose();
        this.averaging_ = options.getAveraging();
        this.seed_ = options.getSeed();
    }
}

