/*
 * Decompiled with CFR 0.152.
 */
package marmot.core;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Random;
import marmot.core.Evaluator;
import marmot.core.Options;
import marmot.core.Result;
import marmot.core.Sequence;
import marmot.core.Tagger;
import marmot.core.Trainer;
import marmot.core.WeightVector;
import marmot.core.lattice.SumLattice;

public class CrfTrainer
implements Trainer {
    private double penalty_;
    private double step_width_ = 0.1;
    private double steps_;
    private boolean shuffle_;
    private boolean verbose_;
    private boolean very_verbose_;
    private double quadratic_penalty_;
    private long seed_;
    private boolean optimize_num_iterations_;

    @Override
    public void train(Tagger tagger, Collection<Sequence> in_sequences, Evaluator evaluator) {
        if (this.optimize_num_iterations_) assert (evaluator != null) : "Set optimize_num_iterations but did not provide test data.";
        Random rng = null;
        if (this.shuffle_) {
            rng = this.seed_ == 0L ? new Random() : new Random(this.seed_);
        }
        ArrayList<Sequence> sequences = new ArrayList<Sequence>(in_sequences);
        int fraction = Math.max(sequences.size() / 4, 1);
        int smaller_fraction = Math.max(sequences.size() / 4000, 1);
        int small_factor = 1;
        WeightVector weights = tagger.getWeightVector();
        assert (weights != null);
        double[] best_float_params = null;
        double[] best_params = null;
        double best_score = 0.0;
        double accumalted_penalty = 0.0;
        int number = 0;
        int step = 0;
        while ((double)step < this.steps_) {
            if (this.verbose_) {
                System.err.println("step: " + step);
            }
            if (this.shuffle_) {
                Collections.shuffle(sequences, rng);
            }
            int current_sentence = 0;
            long train_time = System.currentTimeMillis();
            for (Sequence sequence : sequences) {
                double step_width = this.step_width_ / (1.0 + (double)number / (double)sequences.size());
                double scale_factor = 1.0 - 2.0 * step_width * this.quadratic_penalty_ / (double)sequences.size();
                assert (!Double.isNaN(scale_factor));
                assert (!Double.isInfinite(scale_factor));
                assert (scale_factor > 1.0E-10);
                assert (scale_factor < 1.0000000001);
                step_width /= scale_factor;
                if (Math.abs(this.penalty_) > 1.0E-10) {
                    weights.setPenalty(true, accumalted_penalty += step_width * this.penalty_ / (double)sequences.size());
                }
                SumLattice lattice = tagger.getSumLattice(true, sequence);
                if (this.very_verbose_) {
                    System.err.format("vv %d %d %d %d\n", number, lattice.getOrder() + lattice.getLevel() * (tagger.getModel().getOrder() + 1), lattice.getLevel(), lattice.getOrder());
                }
                assert (lattice != null);
                lattice.update(weights, step_width);
                weights.scaleBy(scale_factor);
                if (++current_sentence % fraction == 0) {
                    if (this.verbose_) {
                        System.err.format("Processed %d sentences at %g sentence/s \n", current_sentence, (double)current_sentence / ((double)(System.currentTimeMillis() - train_time) / 1000.0));
                    }
                    if (small_factor < 100) {
                        smaller_fraction = Math.max((small_factor *= 10) * sequences.size() / 400, 1);
                    }
                }
                if (current_sentence % smaller_fraction == 0) {
                    tagger.setThresholds(false);
                }
                ++number;
            }
            if (evaluator != null && (this.verbose_ || this.optimize_num_iterations_)) {
                double score;
                weights.setExtendFeatureSet(false);
                Result result2 = evaluator.eval(tagger);
                weights.setExtendFeatureSet(true);
                tagger.setResult(result2);
                if (this.verbose_) {
                    System.err.println(result2);
                }
                if (this.optimize_num_iterations_ && (score = result2.getScore()) > best_score) {
                    best_score = score;
                    best_params = (double[])weights.getWeights().clone();
                    best_float_params = (double[])weights.getFloatWeights().clone();
                }
            }
            ++step;
        }
        weights.setPenalty(false, 0.0);
        weights.setExtendFeatureSet(false);
        if (this.optimize_num_iterations_) {
            if (best_params != null) {
                assert (weights.getWeights().length == best_params.length);
                weights.setWeights(best_params);
            }
            if (best_float_params != null) {
                weights.setFloatWeights(best_float_params);
            }
            if (evaluator != null) {
                Result result3 = evaluator.eval(tagger);
                tagger.setResult(result3);
            }
        }
    }

    @Override
    public void setOptions(Options options) {
        this.setOptions(options.getPenalty(), options.getQuadraticPenalty(), options.getNumIterations(), options.getShuffle(), options.getVerbose(), options.getVeryVerbose(), options.getSeed(), options.getOptimizeNumIterations());
    }

    private void setOptions(double penalty, double quadratic_penalty, int steps, boolean shuffle, boolean verbose, boolean very_verbose, long seed, boolean optimize_num_iterations) {
        this.penalty_ = penalty;
        this.steps_ = steps;
        this.shuffle_ = shuffle;
        this.verbose_ = verbose;
        this.very_verbose_ = very_verbose;
        this.quadratic_penalty_ = quadratic_penalty;
        this.seed_ = seed;
        this.optimize_num_iterations_ = optimize_num_iterations;
    }
}

