/*
 * Decompiled with CFR 0.152.
 */
package lemming.lemma.ranker;

import cc.mallet.optimize.Optimizable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lemming.lemma.LemmaCandidate;
import lemming.lemma.LemmaCandidateSet;
import lemming.lemma.ranker.RankerInstance;
import lemming.lemma.ranker.RankerModel;
import lemming.lemma.ranker.RankerTrainer;
import marmot.util.Numerics;

public class RankerObjective
implements Optimizable.ByGradientValue {
    private RankerModel model_;
    private List<RankerInstance> instances_;
    private double value_;
    private double[] gradient_;
    private double[] weights_;
    private double penalty_;
    private RankerTrainer.RankerTrainerOptions options_;
    private int max_num_duplicates_;

    public RankerObjective(RankerTrainer.RankerTrainerOptions options, RankerModel model, List<RankerInstance> instances, int max_num_duplicates) {
        this.options_ = options;
        this.model_ = model;
        this.instances_ = instances;
        this.weights_ = model.getWeights();
        this.gradient_ = new double[this.weights_.length];
        this.penalty_ = options.getQuadraticPenalty();
        this.max_num_duplicates_ = max_num_duplicates;
    }

    public RankerObjective(RankerTrainer.RankerTrainerOptions options, RankerModel model, List<RankerInstance> instances) {
        this(options, model, instances, 1);
    }

    public void update(RankerInstance instance, boolean sgd, double step_width) {
        if (!this.options_.getUseOfflineFeatureExtraction()) {
            this.model_.addIndexes(instance, instance.getCandidateSet(), true);
        }
        int pos_index_ = instance.getPosIndex(this.model_.getPosTable(), false);
        int[] morph_indexes_ = instance.getMorphIndexes(this.model_.getMorphTable(), false);
        this.model_.setWeights(this.weights_);
        double logSum = Double.NEGATIVE_INFINITY;
        LemmaCandidateSet set = instance.getCandidateSet();
        for (Map.Entry<String, LemmaCandidate> entry : set) {
            LemmaCandidate candidate = entry.getValue();
            double score = this.model_.score(candidate, pos_index_, morph_indexes_);
            candidate.setScore(score);
            logSum = Numerics.sumLogProb(logSum, score);
        }
        double target_prob = Double.NEGATIVE_INFINITY;
        if (!sgd) {
            this.model_.setWeights(this.gradient_);
        }
        for (Map.Entry<String, LemmaCandidate> entry : set) {
            LemmaCandidate candidate = entry.getValue();
            double score = candidate.getScore();
            double prob = Math.exp(score - logSum);
            double update = -prob;
            String plemma = entry.getKey();
            if (plemma.equals(instance.getInstance().getLemma())) {
                update += 1.0;
                target_prob = prob;
                this.value_ += (score - logSum) * instance.getInstance().getCount();
            }
            if (sgd) {
                double effective_count = instance.getInstance().getCount();
                if (Numerics.approximatelyGreaterEqual(effective_count, this.max_num_duplicates_)) {
                    effective_count /= (double)this.max_num_duplicates_;
                }
                update *= Math.log(effective_count * Math.E);
            } else {
                update *= instance.getInstance().getCount();
            }
            this.model_.update(instance, plemma, update * step_width);
        }
        assert (target_prob != Double.NEGATIVE_INFINITY);
        if (!this.options_.getUseOfflineFeatureExtraction()) {
            this.model_.removeIndexes(instance.getCandidateSet());
        }
        this.model_.setWeights(this.weights_);
    }

    public void update() {
        this.value_ = 0.0;
        Arrays.fill(this.gradient_, 0.0);
        for (RankerInstance instance : this.instances_) {
            this.update(instance, false, 1.0);
        }
        int i = 0;
        while (i < this.weights_.length) {
            double w = this.weights_[i];
            this.value_ -= this.penalty_ * w * w;
            int n = i++;
            this.gradient_[n] = this.gradient_[n] - 2.0 * this.penalty_ * w;
        }
    }

    public int getNumParameters() {
        return this.weights_.length;
    }

    public double getParameter(int arg0) {
        throw new UnsupportedOperationException();
    }

    public void getParameters(double[] params) {
        System.arraycopy(this.weights_, 0, params, 0, this.weights_.length);
    }

    public void setParameter(int arg0, double arg1) {
        throw new UnsupportedOperationException();
    }

    public void setParameters(double[] params) {
        System.arraycopy(params, 0, this.weights_, 0, this.weights_.length);
        this.update();
    }

    public double getValue() {
        return this.value_;
    }

    public void getValueGradient(double[] gradient) {
        System.arraycopy(this.gradient_, 0, gradient, 0, this.gradient_.length);
    }
}

