/*
 * Decompiled with CFR 0.152.
 */
package marmot.core.lattice;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import lemming.lemma.ranker.RankerCandidate;
import marmot.core.State;
import marmot.core.Transition;
import marmot.core.lattice.Hypothesis;
import marmot.core.lattice.LatticeEntry;
import marmot.core.lattice.ViterbiLattice;
import marmot.util.HashableIntArray;

public class SequenceViterbiLattice
implements ViterbiLattice {
    private LatticeEntry[][][] lattice_;
    private List<List<State>> candidates_;
    private State boundary_;
    private int beam_size_;
    private boolean initilized_;
    private boolean marginalize_lemmas_;

    public SequenceViterbiLattice(List<List<State>> candidates, State boundary, int beam_size, boolean marginalize_lemmas) {
        this.candidates_ = candidates;
        this.boundary_ = boundary;
        this.beam_size_ = beam_size;
        this.initilized_ = false;
        this.marginalize_lemmas_ = marginalize_lemmas;
    }

    public void init() {
        if (this.initilized_) {
            return;
        }
        this.initilized_ = true;
        this.lattice_ = new LatticeEntry[this.candidates_.size()][][];
        PriorityQueue<LatticeEntry> queue = new PriorityQueue<LatticeEntry>();
        List<State> previous_states = Collections.singletonList(this.boundary_);
        for (int index = 0; index < this.candidates_.size(); ++index) {
            List<State> states = this.candidates_.get(index);
            this.lattice_[index] = new LatticeEntry[states.size()][];
            int state_index = 0;
            for (State state : states) {
                LatticeEntry entry;
                queue.clear();
                double state_score = state.getScore();
                State zero_order_state = state.getZeroOrderState();
                if (zero_order_state.getLemmaCandidates() != null && !this.marginalize_lemmas_) {
                    double score = state_score - zero_order_state.getScore() + zero_order_state.getRealScore();
                    state_score = Double.NEGATIVE_INFINITY;
                    for (RankerCandidate candidate : zero_order_state.getLemmaCandidates()) {
                        double candidate_score = score + candidate.getScore();
                        state_score = Math.max(state_score, candidate_score);
                    }
                }
                for (int previous_state_index = 0; previous_state_index < previous_states.size(); ++previous_state_index) {
                    Transition transition = state.getTransition(previous_state_index);
                    if (transition == null) continue;
                    double score = state_score + transition.getScore();
                    if (index > 0) {
                        score += this.lattice_[index - 1][previous_state_index][0].getScore();
                    }
                    queue.add(new LatticeEntry(score, previous_state_index));
                }
                int length = Math.min(this.beam_size_, queue.size());
                assert (length > 0);
                this.lattice_[index][state_index] = new LatticeEntry[length];
                for (int rank = 0; rank < length && (entry = (LatticeEntry)queue.poll()) != null; ++rank) {
                    this.lattice_[index][state_index][rank] = entry;
                }
                ++state_index;
            }
            previous_states = states;
        }
    }

    @Override
    public Hypothesis getViterbiSequence() {
        this.init();
        int[] signature_array = new int[this.candidates_.size() - 1];
        HashableIntArray signature = new HashableIntArray(signature_array);
        return this.getSequenceBySignature(signature);
    }

    public Hypothesis getSequenceBySignature(HashableIntArray signature) {
        this.init();
        LinkedList<Integer> list = new LinkedList<Integer>();
        int state_index = 0;
        list.add(0);
        Double score = null;
        int[] signature_array = signature.getArray();
        for (int index = this.candidates_.size() - 1; index >= 1; --index) {
            int rank = signature_array[index - 1];
            if (rank >= this.lattice_[index][state_index].length) {
                return null;
            }
            LatticeEntry entry = this.lattice_[index][state_index][rank];
            if (entry == null) {
                return null;
            }
            if (score == null) {
                score = entry.getScore();
            }
            if (rank != 0) {
                score = score + (entry.getScore() - this.lattice_[index][state_index][0].getScore());
            }
            state_index = entry.getPreviousStateIndex();
            list.add(state_index);
        }
        if (score == null) {
            return null;
        }
        Collections.reverse(list);
        return new Hypothesis(list, score, signature);
    }

    @Override
    public List<Hypothesis> getNbestSequences() {
        Hypothesis h;
        this.init();
        LinkedList<Hypothesis> list = new LinkedList<Hypothesis>();
        HashableIntArray signature = new HashableIntArray(new int[this.candidates_.size() - 1]);
        PriorityQueue<Hypothesis> queue = new PriorityQueue<Hypothesis>();
        HashSet<HashableIntArray> used_signatures = new HashSet<HashableIntArray>();
        queue.add(this.getSequenceBySignature(signature));
        used_signatures.add(signature);
        while (list.size() < this.beam_size_ && (h = (Hypothesis)queue.poll()) != null) {
            list.add(h);
            signature = h.getSignature();
            int[] signature_array = signature.getArray();
            for (int index = 0; index < signature_array.length; ++index) {
                int[] new_signature_array = new int[signature_array.length];
                System.arraycopy(signature_array, 0, new_signature_array, 0, signature_array.length);
                int n = index;
                new_signature_array[n] = new_signature_array[n] + 1;
                HashableIntArray new_signature = new HashableIntArray(new_signature_array);
                if (used_signatures.contains(new_signature)) continue;
                used_signatures.add(new_signature);
                h = this.getSequenceBySignature(new_signature);
                if (h == null) continue;
                queue.add(h);
            }
        }
        return list;
    }

    public void findGoldSequence(List<Integer> path) {
        this.init();
        assert (path.size() == this.candidates_.size());
        assert (path.size() == this.lattice_.length);
        for (int index = path.size() - 1; index > 0; --index) {
            int state_index = path.get(index);
            int real_previous_state_index = path.get(index - 1);
            boolean found_index = false;
            for (LatticeEntry entry : this.lattice_[index][state_index]) {
                if (entry == null) break;
                int previous_state_index = entry.getPreviousStateIndex();
                if (previous_state_index != real_previous_state_index) continue;
                found_index = true;
                break;
            }
            if (found_index) continue;
            System.err.format("%s index = %d p_index = %d lattice entries = %s\n", this.candidates_.get(index).get(state_index), index, real_previous_state_index, Arrays.toString(this.lattice_[index][state_index]));
        }
    }

    @Override
    public List<List<State>> prune() {
        int index;
        this.init();
        List<List<State>> candidates = this.getCandidates();
        ArrayList candidate_sets = new ArrayList(candidates.size());
        for (int index2 = 0; index2 < candidates.size(); ++index2) {
            candidate_sets.add(new HashSet());
        }
        for (Hypothesis h : this.getNbestSequences()) {
            index = 0;
            int previous_state_index = 0;
            for (int state_index : h.getStates()) {
                int previous_num_candidates = index - 1 >= 0 ? candidates.get(index - 1).size() : 1;
                ((Set)candidate_sets.get(index)).add(state_index * previous_num_candidates + previous_state_index);
                previous_state_index = state_index;
                ++index;
            }
        }
        ArrayList<List<State>> new_candidates = new ArrayList<List<State>>(candidates.size());
        int[] index_map = null;
        for (index = 0; index < candidates.size(); ++index) {
            Set candidate_set = (Set)candidate_sets.get(index);
            int[] new_index_map = new int[candidates.get(index).size()];
            Arrays.fill(new_index_map, -1);
            ArrayList<State> states = new ArrayList<State>(candidate_set.size());
            Iterator iterator2 = candidate_set.iterator();
            while (iterator2.hasNext()) {
                int encoded_indexes = (Integer)iterator2.next();
                int previous_num_candidates = index - 1 >= 0 ? candidates.get(index - 1).size() : 1;
                int state_index = encoded_indexes / previous_num_candidates;
                int previous_state_index = encoded_indexes % previous_num_candidates;
                int new_state_index = new_index_map[state_index];
                if (new_state_index < 0) {
                    new_index_map[state_index] = new_state_index = states.size();
                    State state = candidates.get(index).get(state_index);
                    if (index > 0) {
                        state = state.copy();
                        Transition[] new_transitions = new Transition[((List)new_candidates.get(index - 1)).size()];
                        state.setTransitions(new_transitions);
                    }
                    states.add(state);
                }
                if (index <= 0) continue;
                State old_state = candidates.get(index).get(state_index);
                Transition[] transitions = old_state.getTransitions();
                State state = (State)states.get(new_state_index);
                Transition[] new_transitions = state.getTransitions();
                new_transitions[index_map[previous_state_index]] = transitions[previous_state_index];
            }
            new_candidates.add(states);
            index_map = new_index_map;
        }
        return new_candidates;
    }

    @Override
    public List<List<State>> getCandidates() {
        return this.candidates_;
    }
}

