/*
 * Decompiled with CFR 0.152.
 */
package experimental.analyzer.simple;

import cc.mallet.optimize.Optimizable;
import experimental.analyzer.AnalyzerTag;
import experimental.analyzer.simple.SimpleAnalyzer;
import experimental.analyzer.simple.SimpleAnalyzerInstance;
import experimental.analyzer.simple.SimpleAnalyzerModel;
import experimental.analyzer.simple.SimpleAnalyzerTrainer;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import marmot.util.Mutable;
import marmot.util.Numerics;

public class SimpleAnalyzerObjective
implements Optimizable.ByGradientValue {
    private SimpleAnalyzerModel model_;
    private Collection<SimpleAnalyzerInstance> instances_;
    private double value_;
    private double[] gradient_;
    private double[] weights_;
    private double penalty_;
    private SimpleAnalyzer.Mode mode_;
    private double[] scores;
    private double[] updates;
    private Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts_;
    private SimpleAnalyzerTrainer.PairConstraint pair_constraint_;

    public SimpleAnalyzerObjective(double penalty, SimpleAnalyzerModel model, Collection<SimpleAnalyzerInstance> instances, SimpleAnalyzer.Mode mode, Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts, SimpleAnalyzerTrainer.PairConstraint pair_constraint) {
        this.model_ = model;
        this.instances_ = instances;
        this.weights_ = model.getWeights();
        this.gradient_ = new double[this.weights_.length];
        this.penalty_ = penalty;
        this.mode_ = mode;
        this.relative_counts_ = relative_counts;
        int num_tags = this.model_.getNumTags();
        this.scores = new double[num_tags];
        this.updates = new double[num_tags];
        this.pair_constraint_ = pair_constraint;
    }

    public void update() {
        this.value_ = 0.0;
        Arrays.fill(this.gradient_, 0.0);
        for (SimpleAnalyzerInstance instance : this.instances_) {
            this.update(instance, 1.0, false);
        }
        int i = 0;
        while (i < this.weights_.length) {
            double w = this.weights_[i];
            this.value_ -= this.penalty_ * w * w;
            int n = i++;
            this.gradient_[n] = this.gradient_[n] - 2.0 * this.penalty_ * w;
        }
    }

    public void update(SimpleAnalyzerInstance instance, double step_width, boolean sgd) {
        Arrays.fill(this.scores, 0.0);
        Arrays.fill(this.updates, 0.0);
        int num_tags = this.model_.getNumTags();
        this.model_.setWeights(this.weights_);
        this.model_.score(instance, this.scores);
        switch (this.mode_) {
            case binary: {
                this.value_ += this.binaryUpdate(this.scores, this.updates, num_tags, instance);
                break;
            }
            case classifier: {
                this.value_ += this.classifierUpdate(this.scores, this.updates, num_tags, instance);
                break;
            }
            default: {
                throw new RuntimeException("Unsupported mode: " + this.mode_);
            }
        }
        if (!sgd) {
            this.model_.setWeights(this.gradient_);
        }
        if (!Numerics.approximatelyEqual(step_width, 1.0)) {
            int i = 0;
            while (i < num_tags) {
                int n = i++;
                this.updates[n] = this.updates[n] * step_width;
            }
        }
        this.model_.update(instance, this.updates);
        this.model_.setWeights(this.weights_);
    }

    private double classifierUpdate(double[] scores, double[] updates, int num_tags, SimpleAnalyzerInstance instance) {
        int tag_index;
        double value = 0.0;
        double sum = Double.NEGATIVE_INFINITY;
        int num_tag_indexes = instance.getTagIndexes().size();
        for (tag_index = 0; tag_index < num_tags; ++tag_index) {
            sum = Numerics.sumLogProb(scores[tag_index], sum);
        }
        value -= (double)num_tag_indexes * sum;
        for (tag_index = 0; tag_index < num_tags; ++tag_index) {
            updates[tag_index] = (double)(-num_tag_indexes) * Math.exp(scores[tag_index] - sum);
        }
        if (this.pair_constraint_ != SimpleAnalyzerTrainer.PairConstraint.none) {
            for (int tag_index2 : instance.getTagIndexes()) {
                AnalyzerTag tag = this.model_.getTagTable().toSymbol(tag_index2);
                Map<AnalyzerTag, Mutable<Double>> map2 = this.relative_counts_.get(tag);
                if (map2 != null) {
                    for (Map.Entry<AnalyzerTag, Mutable<Double>> entry : map2.entrySet()) {
                        int new_tag_index = this.model_.getTagTable().toIndex(entry.getKey());
                        double count2 = entry.getValue().get();
                        if (this.pair_constraint_ == SimpleAnalyzerTrainer.PairConstraint.weighted) {
                            value += count2 * scores[new_tag_index];
                            int n = new_tag_index;
                            updates[n] = updates[n] + count2;
                            continue;
                        }
                        if (new_tag_index == tag_index2) continue;
                        updates[new_tag_index] = 0.0;
                    }
                }
                if (map2 != null && this.pair_constraint_ != SimpleAnalyzerTrainer.PairConstraint.simple) continue;
                value += scores[tag_index2];
                int n = tag_index2;
                updates[n] = updates[n] + 1.0;
            }
        } else {
            for (int tag_index3 : instance.getTagIndexes()) {
                value += scores[tag_index3];
                int n = tag_index3;
                updates[n] = updates[n] + 1.0;
            }
        }
        return value;
    }

    private double binaryUpdate(double[] scores, double[] updates, int num_tags, SimpleAnalyzerInstance instance) {
        double value = 0.0;
        for (int tag_index = 0; tag_index < num_tags; ++tag_index) {
            double sum = Numerics.sumLogProb(scores[tag_index], 0.0);
            value -= sum;
            updates[tag_index] = -Math.exp(scores[tag_index] - sum);
        }
        for (int tag_index : instance.getTagIndexes()) {
            value += scores[tag_index];
            int n = tag_index;
            updates[n] = updates[n] + 1.0;
        }
        return value;
    }

    public int getNumParameters() {
        return this.weights_.length;
    }

    public double getParameter(int arg0) {
        throw new UnsupportedOperationException();
    }

    public void getParameters(double[] params) {
        System.arraycopy(this.weights_, 0, params, 0, this.weights_.length);
    }

    public void setParameter(int arg0, double arg1) {
        throw new UnsupportedOperationException();
    }

    public void setParameters(double[] params) {
        System.arraycopy(params, 0, this.weights_, 0, this.weights_.length);
        this.update();
    }

    public double getValue() {
        return this.value_;
    }

    public void getValueGradient(double[] gradient) {
        System.arraycopy(this.gradient_, 0, gradient, 0, this.gradient_.length);
    }
}

