/*
 * Decompiled with CFR 0.152.
 */
package lemming.lemma.toutanova;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import lemming.lemma.toutanova.NbestDecoder;
import lemming.lemma.toutanova.Result;
import lemming.lemma.toutanova.ToutanovaInstance;
import lemming.lemma.toutanova.ToutanovaModel;
import marmot.util.HashableIntArray;
import marmot.util.Numerics;

public class ZeroOrderNbestDecoder
implements NbestDecoder {
    private ToutanovaModel model_;
    private int num_output_symbols_;
    private int input_length_;
    private ToutanovaInstance instance_;
    private int rank_length_;
    private State[][] state_array_;
    private PriorityQueue<State> queue_;
    private Queue<Result> result_queue_;
    private Set<HashableIntArray> used_signatures_;

    public ZeroOrderNbestDecoder(int queue_size) {
        this.rank_length_ = queue_size;
        this.queue_ = new PriorityQueue();
        this.result_queue_ = new PriorityQueue<Result>();
        this.used_signatures_ = new HashSet<HashableIntArray>();
    }

    @Override
    public void init(ToutanovaModel model) {
        this.model_ = model;
        this.num_output_symbols_ = this.model_.getOutputTable().size();
    }

    @Override
    public List<Result> decode(ToutanovaInstance instance) {
        assert (this.model_ != null);
        assert (this.num_output_symbols_ > 0);
        int max_input_segment_length = this.model_.getMaxInputSegmentLength();
        this.input_length_ = instance.getFormCharIndexes().length;
        this.instance_ = instance;
        this.checkArraySize(this.input_length_);
        for (int l = 1; l < this.input_length_ + 1; ++l) {
            this.queue_.clear();
            for (int o = 0; o < this.num_output_symbols_; ++o) {
                int l_start = Math.max(0, l - max_input_segment_length);
                while (l_start < l) {
                    double score = this.model_.getPairScore(instance, l_start, l, o);
                    if (l_start > 0) {
                        score += this.state_array_[l_start - 1][0].score;
                    }
                    State state = new State();
                    state.score = score;
                    state.output = o;
                    state.index = l_start++;
                    this.queue_.add(state);
                }
            }
            for (int rank = 0; rank < this.rank_length_; ++rank) {
                State state = this.queue_.poll();
                assert (state == null || state.index < l);
                this.state_array_[l - 1][rank] = state;
            }
        }
        return this.backtrace();
    }

    private Result bySignature(HashableIntArray signature) {
        return this.bySignature(signature, false);
    }

    private Result bySignature(HashableIntArray signature, boolean debug) {
        LinkedList<Integer> outputs = new LinkedList<Integer>();
        LinkedList<Integer> inputs = new LinkedList<Integer>();
        int end_index = this.input_length_;
        double score = this.state_array_[this.input_length_ - 1][0].score;
        int signature_index = 0;
        int[] signature_array = signature.getArray();
        if (debug) {
            System.err.println(score);
        }
        while (true) {
            int rank;
            State state;
            if (signature_index >= signature_array.length) {
                System.err.println(signature);
                System.err.println(this.instance_.getInstance().getForm());
            }
            if ((state = this.state_array_[end_index - 1][rank = signature_array[signature_index++]]) == null) {
                return null;
            }
            int start_index = state.index;
            inputs.add(end_index);
            int output = state.output;
            outputs.add(output);
            double diff_to_best = this.state_array_[end_index - 1][0].score - state.score;
            assert (diff_to_best >= 0.0);
            score -= diff_to_best;
            if (debug) {
                System.err.println(score + "  " + diff_to_best);
            }
            if (start_index == 0) break;
            end_index = start_index;
        }
        for (int i = signature_index; i < signature_array.length; ++i) {
            if (signature_array[i] <= 0) continue;
            return null;
        }
        Collections.reverse(outputs);
        Collections.reverse(inputs);
        return new Result(this.model_, outputs, inputs, this.instance_.getInstance().getForm(), score).setSignature(signature);
    }

    public List<Result> backtrace() {
        Result result;
        LinkedList<Result> list = new LinkedList<Result>();
        HashableIntArray signature = new HashableIntArray(new int[this.input_length_]);
        this.result_queue_.clear();
        this.used_signatures_.clear();
        this.result_queue_.add(this.bySignature(signature));
        this.used_signatures_.add(signature);
        while (list.size() < this.rank_length_ && (result = this.result_queue_.poll()) != null) {
            signature = result.getSignature();
            int[] signature_array = signature.getArray();
            result.setSignature(null);
            list.add(result);
            for (int index = 0; index < result.getOutputs().size(); ++index) {
                int new_rank = signature_array[index] + 1;
                if (new_rank >= this.rank_length_) continue;
                int[] new_signature_array = Arrays.copyOf(signature_array, signature_array.length);
                new_signature_array[index] = new_rank;
                HashableIntArray new_signature = new HashableIntArray(new_signature_array);
                if (this.used_signatures_.contains(new_signature)) continue;
                this.used_signatures_.add(new_signature);
                Result new_result = this.bySignature(new_signature);
                if (new_result == null) continue;
                if (!Numerics.approximatelyLesserEqual(new_result.getScore(), result.getScore())) {
                    System.err.println(signature + " " + new_signature);
                    this.bySignature(signature, true);
                    this.bySignature(new_signature, true);
                }
                assert (Numerics.approximatelyLesserEqual(new_result.getScore(), result.getScore()));
                this.result_queue_.add(new_result);
            }
        }
        return list;
    }

    private void checkArraySize(int required_length) {
        if (this.state_array_ == null || this.state_array_.length < required_length) {
            this.state_array_ = new State[required_length][this.rank_length_];
        }
    }

    private static class State
    implements Comparable<State> {
        private double score;
        private int output;
        private int index;

        private State() {
        }

        @Override
        public int compareTo(State state) {
            return -Double.compare(this.score, state.score);
        }
    }
}

