/*
 * Decompiled with CFR 0.152.
 */
package marmot.core.lattice;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import lemming.lemma.ranker.RankerCandidate;
import marmot.core.State;
import marmot.core.WeightVector;
import marmot.core.lattice.SumLattice;
import marmot.util.Check;
import marmot.util.Numerics;

public class ZeroOrderSumLattice
implements SumLattice {
    private List<List<State>> candidates_;
    private double log_threshold_;
    private List<Integer> gold_candidate_indexes_;
    private double[] score_sums_;
    private boolean initialized_;
    private boolean oracle_;

    public ZeroOrderSumLattice(List<List<State>> candidates, double threshold, boolean oracle) {
        this.candidates_ = candidates;
        this.log_threshold_ = Math.log(threshold);
        this.initialized_ = false;
        this.oracle_ = oracle;
    }

    private void init() {
        if (this.initialized_) {
            return;
        }
        this.initialized_ = true;
        this.score_sums_ = new double[this.candidates_.size()];
        for (int index = 0; index < this.candidates_.size(); ++index) {
            List<State> states = this.candidates_.get(index);
            assert (!states.isEmpty());
            this.score_sums_[index] = this.getScoreSum(states);
        }
    }

    @Override
    public List<List<State>> getCandidates() {
        return this.candidates_;
    }

    @Override
    public List<List<State>> prune() {
        return this.prune(this.log_threshold_);
    }

    public List<List<State>> prune(double log_threshold) {
        this.init();
        ArrayList<List<State>> candidates = new ArrayList<List<State>>(this.candidates_.size());
        for (int index = 0; index < this.candidates_.size(); ++index) {
            int num_states = this.candidates_.get(index).size();
            assert (num_states >= 0);
            double score_sum = this.score_sums_[index];
            ArrayList<State> states = new ArrayList<State>(num_states);
            State max_state = null;
            double max_score = Double.NEGATIVE_INFINITY;
            int state_index = 0;
            for (State state : this.candidates_.get(index)) {
                double score = state.getScore() - score_sum;
                assert (Check.isNormal(score));
                boolean is_oracle_state = false;
                if (this.oracle_ && this.gold_candidate_indexes_ != null) {
                    boolean bl = is_oracle_state = this.gold_candidate_indexes_.get(index) == state_index;
                }
                if (score > log_threshold || is_oracle_state) {
                    states.add(state);
                }
                if (score > max_score) {
                    max_score = score;
                    max_state = state;
                }
                ++state_index;
            }
            assert (max_state != null);
            if (states.isEmpty()) {
                states.add(max_state);
            }
            candidates.add(states);
        }
        return candidates;
    }

    private double getScoreSum(Collection<State> states) {
        double score_sum = Double.NEGATIVE_INFINITY;
        for (State state : states) {
            assert (Check.isNormal(state.getScore()));
            score_sum = Numerics.sumLogProb(score_sum, state.getScore());
        }
        assert (score_sum != Double.NEGATIVE_INFINITY);
        assert (Check.isNormal(score_sum));
        return score_sum;
    }

    @Override
    public double update(WeightVector weights, double step_width) {
        this.init();
        double ll = 0.0;
        if (this.gold_candidate_indexes_ == null) {
            System.err.println("Warning: Gold sequence not in zero order lattice!");
            return ll;
        }
        for (int index = 0; index < this.candidates_.size() - 1; ++index) {
            int gold_candidate_index = this.gold_candidate_indexes_.get(index);
            List<State> states = this.candidates_.get(index);
            double score_sum = this.score_sums_[index];
            ll += this.update(states, gold_candidate_index, score_sum, weights, step_width);
        }
        return ll;
    }

    private double update(List<State> states, int gold_candidate_index, double score_sum, WeightVector weights, double step_width) {
        int candidate_index = 0;
        double ll = 0.0;
        for (State state : states) {
            assert (state.getZeroOrderState() == state);
            double p = Math.exp(state.getScore() - score_sum);
            double value = -p;
            if (candidate_index == gold_candidate_index) {
                value += 1.0;
                ll = states.get(gold_candidate_index).getScore() - score_sum;
            }
            weights.updateWeights(state, value * step_width, false);
            if (state.getLemmaCandidates() != null) {
                double new_score = state.getRealScore();
                for (RankerCandidate candidate : state.getLemmaCandidates()) {
                    double score = candidate.getScore() + new_score;
                    p = Math.exp(score - score_sum);
                    value = -p;
                    if (candidate.isCorrect() && candidate_index == gold_candidate_index) {
                        value += 1.0;
                        ll = score - score_sum;
                    }
                    candidate.update(state, weights, value * step_width);
                }
            }
            ++candidate_index;
        }
        return ll;
    }

    @Override
    public int getOrder() {
        return 0;
    }

    @Override
    public List<List<State>> getZeroOrderCandidates(boolean filter2) {
        if (filter2) {
            return this.prune();
        }
        return this.candidates_;
    }

    @Override
    public void setGoldCandidates(List<Integer> candidates) {
        this.gold_candidate_indexes_ = candidates;
    }

    @Override
    public int getLevel() {
        return this.candidates_.get(0).get(0).getLevel();
    }

    @Override
    public List<Integer> getGoldCandidates() {
        return this.gold_candidate_indexes_;
    }
}

