/*
 * Decompiled with CFR 0.152.
 */
package marmot.core.lattice;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import lemming.lemma.ranker.RankerCandidate;
import marmot.core.State;
import marmot.core.lattice.Hypothesis;
import marmot.core.lattice.LatticeEntry;
import marmot.core.lattice.ViterbiLattice;
import marmot.util.HashableIntArray;

public class ZeroOrderViterbiLattice
implements ViterbiLattice {
    private LatticeEntry[][] lattice_;
    private List<List<State>> candidates_;
    private int beam_size_;
    private boolean initilized_;
    private boolean marginalize_lemmas_;

    public ZeroOrderViterbiLattice(List<List<State>> candidates, int beam_size, boolean marginalize_lemmas) {
        this.candidates_ = candidates;
        this.beam_size_ = beam_size;
        this.initilized_ = false;
        this.marginalize_lemmas_ = marginalize_lemmas;
    }

    public void init() {
        if (this.initilized_) {
            return;
        }
        this.initilized_ = true;
        this.lattice_ = new LatticeEntry[this.candidates_.size()][];
        PriorityQueue<LatticeEntry> queue = new PriorityQueue<LatticeEntry>();
        int index = 0;
        for (List<State> states : this.candidates_) {
            LatticeEntry entry;
            queue.clear();
            int state_index = 0;
            for (State state : states) {
                double score = state.getScore();
                if (state.getLemmaCandidates() != null && !this.marginalize_lemmas_) {
                    RankerCandidate candidate = RankerCandidate.bestCandidate(state.getLemmaCandidates());
                    score = candidate.getScore() + state.getRealScore();
                }
                queue.add(new LatticeEntry(score, state_index));
                ++state_index;
            }
            int length = Math.min(this.beam_size_, queue.size());
            this.lattice_[index] = new LatticeEntry[length];
            assert (length > 0);
            for (int rank = 0; rank < length && (entry = (LatticeEntry)queue.poll()) != null; ++rank) {
                this.lattice_[index][rank] = entry;
            }
            ++index;
        }
    }

    @Override
    public Hypothesis getViterbiSequence() {
        this.init();
        int[] signature_array = new int[this.candidates_.size()];
        HashableIntArray signature = new HashableIntArray(signature_array);
        return this.getSequenceBySignature(signature);
    }

    public Hypothesis getSequenceBySignature(HashableIntArray signature) {
        this.init();
        LinkedList<Integer> list = new LinkedList<Integer>();
        double score = 0.0;
        int[] signature_array = signature.getArray();
        for (int index = 0; index < signature_array.length; ++index) {
            int rank = signature_array[index];
            if (rank >= this.lattice_[index].length) {
                return null;
            }
            LatticeEntry entry = this.lattice_[index][rank];
            if (entry == null) {
                return null;
            }
            score += entry.getScore();
            list.add(entry.getPreviousStateIndex());
        }
        return new Hypothesis(list, score, signature);
    }

    @Override
    public List<List<State>> prune() {
        this.init();
        ArrayList<List<State>> candidates = new ArrayList<List<State>>(this.candidates_.size());
        for (int index = 0; index < this.candidates_.size(); ++index) {
            ArrayList<State> states = new ArrayList<State>(this.lattice_[index].length);
            for (int rank = 0; rank < this.lattice_[index].length; ++rank) {
                LatticeEntry entry = this.lattice_[index][rank];
                int candidate_index = entry.getPreviousStateIndex();
                states.add(this.candidates_.get(index).get(candidate_index));
            }
            candidates.add(states);
        }
        assert (candidates.size() > 0);
        return candidates;
    }

    @Override
    public List<Hypothesis> getNbestSequences() {
        Hypothesis h2;
        this.init();
        LinkedList<Hypothesis> list = new LinkedList<Hypothesis>();
        int[] signature_array = new int[this.candidates_.size()];
        HashableIntArray signature = new HashableIntArray(signature_array);
        PriorityQueue<Hypothesis> queue = new PriorityQueue<Hypothesis>();
        HashSet<HashableIntArray> used_signatures = new HashSet<HashableIntArray>();
        queue.add(this.getSequenceBySignature(signature));
        used_signatures.add(signature);
        while (list.size() < this.beam_size_ && (h2 = (Hypothesis)queue.poll()) != null) {
            list.add(h2);
            signature = h2.getSignature();
            for (int index = 0; index < signature_array.length; ++index) {
                int[] new_signature_array = new int[signature_array.length];
                System.arraycopy(signature_array, 0, new_signature_array, 0, signature_array.length);
                int n = index;
                new_signature_array[n] = new_signature_array[n] + 1;
                HashableIntArray new_signature = new HashableIntArray(new_signature_array);
                if (used_signatures.contains(new_signature)) continue;
                used_signatures.add(new_signature);
                h2 = this.getSequenceBySignature(new_signature);
                if (h2 == null) continue;
                queue.add(h2);
            }
        }
        return list;
    }

    @Override
    public List<List<State>> getCandidates() {
        return this.candidates_;
    }
}

