/*
 * Decompiled with CFR 0.152.
 */
package chipmunk.segmenter;

import cc.mallet.optimize.Optimizable;
import chipmunk.segmenter.SegmentationInstance;
import chipmunk.segmenter.SegmentationSumLattice;
import chipmunk.segmenter.SegmenterModel;
import chipmunk.segmenter.Word;
import java.util.Arrays;
import java.util.Collection;
import marmot.util.DynamicWeights;

public class SemiCrfObjective
implements Optimizable.ByGradientValue {
    private SegmenterModel model_;
    private Collection<Word> words_;
    private double value_;
    private double[] gradient_;
    private double[] weights_;
    private double penalty_;

    public SemiCrfObjective(SegmenterModel model, Collection<Word> words, double penalty) {
        this.model_ = model;
        this.words_ = words;
        this.penalty_ = penalty;
    }

    public void init() {
        DynamicWeights weights = new DynamicWeights(null);
        this.model_.setScorerWeights(weights);
        DynamicWeights gradient = new DynamicWeights(null);
        this.model_.setUpdaterWeights(gradient);
        this.model_.getUpdater().setInsert(false);
        this.calcLikelihood();
        DynamicWeights scorer = this.model_.getScorer().getWeights();
        DynamicWeights updater = this.model_.getUpdater().getWeights();
        if (scorer.getLength() != updater.getLength()) {
            int length = Math.max(scorer.getLength(), updater.getLength());
            scorer.setLength(length);
            updater.setLength(length);
        }
        this.weights_ = scorer.getWeights();
        scorer.setExapnd(false);
        this.gradient_ = updater.getWeights();
        updater.setExapnd(false);
        assert (this.weights_.length == this.gradient_.length) : this.weights_.length + " " + this.gradient_.length;
        this.calcPenalty();
    }

    public void update() {
        this.value_ = 0.0;
        Arrays.fill(this.gradient_, 0.0);
        this.calcLikelihood();
        this.calcPenalty();
    }

    private void calcPenalty() {
        if (this.penalty_ > 0.0) {
            int i = 0;
            while (i < this.weights_.length) {
                double w = this.weights_[i];
                this.value_ -= this.penalty_ * w * w;
                int n = i++;
                this.gradient_[n] = this.gradient_[n] - 2.0 * this.penalty_ * w;
            }
        }
    }

    private void calcLikelihood() {
        SegmentationSumLattice lattice = new SegmentationSumLattice(this.model_);
        for (Word word : this.words_) {
            SegmentationInstance instance = this.model_.getInstance(word);
            this.value_ += lattice.update(instance, true);
        }
    }

    public int getNumParameters() {
        return this.weights_.length;
    }

    public double getParameter(int arg0) {
        throw new UnsupportedOperationException();
    }

    public void getParameters(double[] params) {
        System.arraycopy(this.weights_, 0, params, 0, this.weights_.length);
    }

    public void setParameter(int arg0, double arg1) {
        throw new UnsupportedOperationException();
    }

    public void setParameters(double[] params) {
        System.arraycopy(params, 0, this.weights_, 0, this.weights_.length);
        this.update();
    }

    public double getValue() {
        return this.value_;
    }

    public void getValueGradient(double[] gradient) {
        System.arraycopy(this.gradient_, 0, gradient, 0, this.gradient_.length);
    }
}

