/*
 * Decompiled with CFR 0.152.
 */
package marmot.core.lattice;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
import lemming.lemma.ranker.RankerCandidate;
import marmot.core.State;
import marmot.core.Transition;
import marmot.core.WeightVector;
import marmot.core.lattice.BackwardSequenceLattice;
import marmot.core.lattice.ForwardSequenceLattice;
import marmot.core.lattice.SumLattice;
import marmot.util.Check;
import marmot.util.Numerics;

public class SequenceSumLattice
implements SumLattice {
    private ForwardSequenceLattice forward_;
    private BackwardSequenceLattice backward_;
    private List<List<State>> candidates_;
    private double log_threshold_;
    private boolean initilized_;
    private State boundary_;
    private int order_;
    private List<Integer> gold_candidate_indexes_;
    private boolean oracle_;

    public SequenceSumLattice(List<List<State>> candidates, State boundary, double threshold, int order, boolean oracle) {
        this.forward_ = new ForwardSequenceLattice(candidates, boundary);
        this.backward_ = new BackwardSequenceLattice(candidates, boundary);
        this.candidates_ = candidates;
        this.log_threshold_ = Math.log(threshold);
        this.initilized_ = false;
        this.boundary_ = boundary;
        this.order_ = order;
        this.oracle_ = oracle;
    }

    @Override
    public List<List<State>> getCandidates() {
        return this.candidates_;
    }

    public void init() {
        if (this.initilized_) {
            return;
        }
        this.initilized_ = true;
        this.forward_.init();
        this.backward_.init();
    }

    @Override
    public List<List<State>> prune() {
        return this.pruneStates();
    }

    public List<List<State>> pruneStates() {
        this.init();
        double score_sum_forward = this.forward_.partitionFunction();
        assert (Check.isNormal(score_sum_forward));
        assert (Check.isNormal(this.backward_.partitionFunction()));
        assert (Numerics.approximatelyEqual(score_sum_forward, this.backward_.partitionFunction()));
        ArrayList<List<State>> candidates = new ArrayList<List<State>>(this.candidates_.size());
        int[] index_map = null;
        int num_previous_states = 1;
        for (int index = 0; index < this.candidates_.size(); ++index) {
            int num_states = this.candidates_.get(index).size();
            int[] new_index_map = new int[num_states];
            Arrays.fill(new_index_map, -1);
            double score_sum = Double.NEGATIVE_INFINITY;
            ArrayList<State> states = new ArrayList<State>(num_states);
            int max_state_index = -1;
            double max_score = Double.NEGATIVE_INFINITY;
            for (int state_index = 0; state_index < num_states; ++state_index) {
                State state = this.candidates_.get(index).get(state_index);
                double score = this.forward_.get(index, state_index) + this.backward_.get(index + 1, state_index);
                score_sum = Numerics.sumLogProb(score_sum, score);
                if (index_map != null) {
                    boolean found_transition = false;
                    for (int transition_index = 0; transition_index < num_previous_states; ++transition_index) {
                        Transition transition = state.getTransition(transition_index);
                        if (transition == null || index_map[transition_index] < 0) continue;
                        found_transition = true;
                        break;
                    }
                    if (!found_transition) continue;
                }
                boolean is_oracle_state = false;
                if (this.oracle_ && this.gold_candidate_indexes_ != null) {
                    boolean bl = is_oracle_state = this.gold_candidate_indexes_.get(index) == state_index;
                }
                if (score - score_sum_forward > this.log_threshold_ || is_oracle_state) {
                    if (states.size() > 50) continue;
                    states.add(this.fixTransitions(state, index_map, num_previous_states));
                    new_index_map[state_index] = states.size() - 1;
                }
                if (!(score > max_score)) continue;
                max_score = score;
                max_state_index = state_index;
            }
            assert (score_sum != Double.NEGATIVE_INFINITY);
            if (Math.abs(score_sum - score_sum_forward) > 1.0E-5) {
                Logger logger = Logger.getLogger(this.getClass().getName());
                logger.warning(String.format("Difference in FB: %g %g", score_sum, score_sum_forward));
            }
            assert (Math.abs(score_sum - score_sum_forward) < 1.0E-5);
            if (states.isEmpty()) {
                states.add(this.fixTransitions(this.candidates_.get(index).get(max_state_index), index_map, num_previous_states));
                new_index_map[max_state_index] = 0;
            }
            assert (!states.isEmpty());
            candidates.add(states);
            num_previous_states = num_states;
            index_map = new_index_map;
        }
        assert (candidates.size() == this.candidates_.size());
        return candidates;
    }

    private State fixTransitions(State state, int[] index_map, int num_states) {
        if (index_map == null) {
            return state;
        }
        state = state.copy();
        Transition[] old_transitions = state.getTransitions();
        Transition[] new_transitions = new Transition[num_states];
        for (int index = 0; index < old_transitions.length; ++index) {
            int new_index = index_map[index];
            if (new_index < 0) continue;
            new_transitions[new_index] = old_transitions[index];
        }
        state.setTransitions(new_transitions);
        return state;
    }

    @Override
    public double update(WeightVector weights, double step_width) {
        this.init();
        double ll = 0.0;
        double score_sum = this.forward_.partitionFunction();
        int last_gold_candidate_index = 0;
        for (int index = 0; index < this.candidates_.size(); ++index) {
            int gold_candidate_index = this.gold_candidate_indexes_.get(index);
            double state_sum = Double.NEGATIVE_INFINITY;
            double trans_sum = Double.NEGATIVE_INFINITY;
            int state_index = 0;
            for (State state : this.candidates_.get(index)) {
                boolean is_gold_sequence_state = state_index == gold_candidate_index;
                int trans_index = 0;
                for (Transition transition : state.getTransitions()) {
                    if (transition != null) {
                        double trans_score = this.forward_.get(index - 1, trans_index) + state.getScore() + transition.getScore() + this.backward_.get(index + 1, state_index);
                        trans_sum = Numerics.sumLogProb(trans_sum, trans_score);
                        double p = Math.exp(trans_score - score_sum);
                        if (trans_index == last_gold_candidate_index && is_gold_sequence_state) {
                            ll += transition.getScore();
                            weights.updateWeights(transition, (1.0 - p) * step_width, true);
                        } else {
                            weights.updateWeights(transition, -p * step_width, true);
                        }
                    }
                    ++trans_index;
                }
                double state_score = this.forward_.get(index, state_index) + this.backward_.get(index + 1, state_index);
                state_sum = Numerics.sumLogProb(state_sum, state_score);
                double p = Math.exp(state_score - score_sum);
                double value = -p;
                if (is_gold_sequence_state) {
                    ll += state.getScore();
                    value += 1.0;
                }
                state.incrementEstimatedCounts(value * step_width);
                State zero_order_state = state.getZeroOrderState();
                if (zero_order_state.getLemmaCandidates() != null) {
                    double new_state_score = state_score - zero_order_state.getScore() + zero_order_state.getRealScore();
                    for (RankerCandidate candidate : zero_order_state.getLemmaCandidates()) {
                        double score = new_state_score + candidate.getScore();
                        p = Math.exp(score - score_sum);
                        value = -p;
                        if (is_gold_sequence_state && candidate.isCorrect()) {
                            value += 1.0;
                        }
                        candidate.incrementEstimatedCounts(value * step_width);
                    }
                }
                ++state_index;
            }
            for (State state : this.candidates_.get(index)) {
                state.updateWeights(weights);
            }
            last_gold_candidate_index = gold_candidate_index;
        }
        return ll -= score_sum;
    }

    @Override
    public int getOrder() {
        return this.order_;
    }

    public static List<List<State>> getZeroOrderCandidates(List<List<State>> candidates, int boundary_index) {
        ArrayList<List<State>> new_candidates = new ArrayList<List<State>>(candidates.size());
        boolean found_boundary = false;
        for (List<State> states : candidates) {
            ArrayList<State> new_states = new ArrayList<State>();
            for (State state : states) {
                State zero_order_state = state.getZeroOrderState();
                assert (!(zero_order_state instanceof Transition));
                if (zero_order_state.getIndex() == boundary_index) {
                    found_boundary = true;
                }
                boolean contains = false;
                for (State new_state : new_states) {
                    if (!new_state.equalIndexes(zero_order_state)) continue;
                    contains = true;
                    break;
                }
                if (contains) continue;
                State new_state = zero_order_state.copy();
                new_state.setTransitions(null);
                new_states.add(new_state);
                assert (new_state.getIndex() >= 0);
                assert (new_state.getTransitions() == null);
            }
            new_candidates.add(new_states);
            if (found_boundary) {
                assert (new_states.size() == 1);
                break;
            }
            assert (!new_states.isEmpty());
        }
        assert (!new_candidates.isEmpty());
        for (List<State> states : new_candidates) {
            for (State state : states) {
                assert (state.getTransitions() == null);
            }
        }
        assert (found_boundary);
        return new_candidates;
    }

    @Override
    public List<List<State>> getZeroOrderCandidates(boolean filter2) {
        List<List<State>> candidates = filter2 ? this.prune() : this.candidates_;
        ArrayList<List<State>> new_candidates = new ArrayList<List<State>>(candidates.size());
        boolean found_boundary = false;
        for (List<State> states : candidates) {
            ArrayList<State> new_states = new ArrayList<State>();
            for (State state : states) {
                State zero_order_state = state.getZeroOrderState();
                assert (!(zero_order_state instanceof Transition));
                if (zero_order_state.getIndex() == this.boundary_.getIndex()) {
                    found_boundary = true;
                }
                boolean contains = false;
                for (State new_state : new_states) {
                    if (!new_state.equalIndexes(zero_order_state)) continue;
                    contains = true;
                    break;
                }
                if (contains) continue;
                State new_state = zero_order_state.copy();
                new_state.setTransitions(null);
                new_states.add(new_state);
                assert (new_state.getIndex() >= 0);
                assert (new_state.getTransitions() == null);
            }
            new_candidates.add(new_states);
            if (found_boundary) {
                assert (new_states.size() == 1);
                break;
            }
            assert (!new_states.isEmpty());
        }
        assert (!new_candidates.isEmpty());
        for (List<State> states : new_candidates) {
            for (State state : states) {
                assert (state.getTransitions() == null);
            }
        }
        assert (found_boundary);
        return new_candidates;
    }

    @Override
    public void setGoldCandidates(List<Integer> candidates) {
        this.gold_candidate_indexes_ = candidates;
    }

    @Override
    public int getLevel() {
        return this.candidates_.get(0).get(0).getLevel();
    }

    @Override
    public List<Integer> getGoldCandidates() {
        return this.gold_candidate_indexes_;
    }
}

