/*
 * Decompiled with CFR 0.152.
 */
package chipmunk.segmenter;

import chipmunk.segmenter.SegmentationInstance;
import chipmunk.segmenter.SegmentationResult;
import chipmunk.segmenter.SegmenterModel;
import java.util.Arrays;
import marmot.util.Numerics;

public class SegmentationSumLattice {
    private SegmenterModel model_;
    private int num_tags_;
    private int max_segment_length;
    private double[] forward_score_array_;
    private double[] backward_score_array_;
    private int input_length_;

    public SegmentationSumLattice(SegmenterModel model) {
        this.model_ = model;
        this.num_tags_ = this.model_.getNumTags();
        this.max_segment_length = this.model_.getMaxSegmentLength();
    }

    public double update(SegmentationInstance instance, boolean do_update) {
        double backward_sum;
        int l_end;
        double score;
        double transiton_score;
        double prev_cost;
        double pair_score;
        double score_sum;
        int tag;
        this.input_length_ = instance.getLength();
        this.checkArraySize(this.num_tags_ * this.input_length_);
        Arrays.fill(this.forward_score_array_, Double.NEGATIVE_INFINITY);
        for (int l_end2 = 1; l_end2 < this.input_length_ + 1; ++l_end2) {
            for (tag = 0; tag < this.num_tags_; ++tag) {
                score_sum = Double.NEGATIVE_INFINITY;
                for (int l_start = Math.max(0, l_end2 - this.max_segment_length); l_start < l_end2; ++l_start) {
                    pair_score = this.model_.getPairScore(instance, l_start, l_end2, tag);
                    if (l_start == 0) {
                        double score2 = pair_score;
                        score_sum = Numerics.sumLogProb(score2, score_sum);
                        continue;
                    }
                    for (int last_tag = 0; last_tag < this.num_tags_; ++last_tag) {
                        prev_cost = this.forward_score_array_[this.getIndex(last_tag, l_start - 1)];
                        transiton_score = this.model_.getTransitionScore(instance, last_tag, tag, l_start, l_end2);
                        score = pair_score + transiton_score + prev_cost;
                        score_sum = Numerics.sumLogProb(score, score_sum);
                    }
                }
                this.forward_score_array_[this.getIndex((int)tag, (int)(l_end2 - 1))] = score_sum;
            }
        }
        Arrays.fill(this.backward_score_array_, Double.NEGATIVE_INFINITY);
        for (int l_start = this.input_length_ - 1; l_start >= 0; --l_start) {
            for (tag = 0; tag < this.num_tags_; ++tag) {
                score_sum = Double.NEGATIVE_INFINITY;
                for (l_end = Math.min(this.input_length_, l_start + this.max_segment_length); l_end > l_start; --l_end) {
                    pair_score = this.model_.getPairScore(instance, l_start, l_end, tag);
                    if (l_end == this.input_length_) {
                        double score3 = pair_score;
                        score_sum = Numerics.sumLogProb(score3, score_sum);
                        continue;
                    }
                    for (int next_tag = 0; next_tag < this.num_tags_; ++next_tag) {
                        prev_cost = this.backward_score_array_[this.getIndex(next_tag, l_end)];
                        transiton_score = this.model_.getTransitionScore(instance, tag, next_tag, l_start, l_end);
                        score = pair_score + transiton_score + prev_cost;
                        score_sum = Numerics.sumLogProb(score, score_sum);
                    }
                }
                this.backward_score_array_[this.getIndex((int)tag, (int)l_start)] = score_sum;
            }
        }
        double sum = backward_sum = this.sumTag(this.backward_score_array_, 0);
        for (l_end = 1; l_end < this.input_length_ + 1; ++l_end) {
            for (int tag2 = 0; tag2 < this.num_tags_; ++tag2) {
                for (int l_start = Math.max(0, l_end - this.max_segment_length); l_start < l_end; ++l_start) {
                    double pair_score2 = this.model_.getPairScore(instance, l_start, l_end, tag2);
                    double backward_score = 0.0;
                    if (l_end < this.input_length_) {
                        backward_score = Double.NEGATIVE_INFINITY;
                        for (int next_tag = 0; next_tag < this.num_tags_; ++next_tag) {
                            double trans_score = this.model_.getTransitionScore(instance, tag2, next_tag, l_start, l_end);
                            double next_tag_score = this.backward_score_array_[this.getIndex(next_tag, l_end)] + trans_score;
                            backward_score = Numerics.sumLogProb(next_tag_score, backward_score);
                        }
                    }
                    if (l_start == 0) {
                        double score4 = backward_score + pair_score2;
                        double log_prob = score4 - sum;
                        double prob = Math.exp(log_prob);
                        double update = -prob;
                        if (!do_update) continue;
                        this.model_.update(instance, l_start, l_end, tag2, update);
                        continue;
                    }
                    double update = 0.0;
                    for (int last_tag = 0; last_tag < this.num_tags_; ++last_tag) {
                        double forward_score = this.forward_score_array_[this.getIndex(last_tag, l_start - 1)];
                        double transiton_score2 = this.model_.getTransitionScore(instance, last_tag, tag2, l_start, l_end);
                        double score5 = forward_score + pair_score2 + transiton_score2 + backward_score;
                        double log_prob = score5 - sum;
                        double prob = Math.exp(log_prob);
                        double tag_update = -prob;
                        if (do_update) {
                            this.model_.update(instance, l_start, l_end, last_tag, tag2, tag_update);
                        }
                        update += tag_update;
                    }
                    if (!do_update) continue;
                    this.model_.update(instance, l_start, l_end, tag2, update);
                }
            }
        }
        double real_value = 0.0;
        for (SegmentationResult result2 : instance.getResults()) {
            this.model_.update(instance, result2, 1.0 / (double)instance.getResults().size());
            real_value += this.model_.getScore(instance, result2) - sum;
        }
        return real_value;
    }

    private double sumTag(double[] score_array, int l) {
        double score_sum = Double.NEGATIVE_INFINITY;
        for (int tag = 0; tag < this.num_tags_; ++tag) {
            double score = score_array[this.getIndex(tag, l)];
            score_sum = Numerics.sumLogProb(score, score_sum);
        }
        return score_sum;
    }

    private int getIndex(int tag, int index) {
        return tag * this.input_length_ + index;
    }

    private void checkArraySize(int required_length) {
        if (this.forward_score_array_ == null || this.forward_score_array_.length < required_length) {
            this.forward_score_array_ = new double[required_length];
            this.backward_score_array_ = new double[required_length];
        }
    }
}

