/*
 * Decompiled with CFR 0.152.
 */
package marmot.core;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import marmot.core.FeatureVector;
import marmot.core.Model;
import marmot.core.Result;
import marmot.core.Sequence;
import marmot.core.State;
import marmot.core.Tagger;
import marmot.core.Token;
import marmot.core.Transition;
import marmot.core.WeightVector;
import marmot.core.lattice.Hypothesis;
import marmot.core.lattice.Lattice;
import marmot.core.lattice.SequenceSumLattice;
import marmot.core.lattice.SequenceViterbiLattice;
import marmot.core.lattice.SumLattice;
import marmot.core.lattice.ViterbiLattice;
import marmot.core.lattice.ZeroOrderSumLattice;
import marmot.core.lattice.ZeroOrderViterbiLattice;

public class SimpleTagger
implements Tagger {
    private static final long serialVersionUID = 1L;
    private Model model_;
    private WeightVector weight_vector_;
    private int num_level_;
    private int format_ = 0;
    private double[][] threshs_;
    private double[] candidates_per_state_;
    private double[][] num_states_;
    private double[][] length_;
    private int order_;
    private boolean prune_;
    private int effective_order_;
    private int beam_size_;
    private boolean oracle_;
    private final int AVERAGE_NUMBER_OF_CANDIDATES = 5;
    private boolean cache_feature_vector_ = false;
    private Result result_;

    public SimpleTagger(Model model, int order, WeightVector weight_vector) {
        int levels;
        this.order_ = order;
        this.model_ = model;
        this.prune_ = model.getOptions().getPrune();
        this.beam_size_ = model.getOptions().getBeamSize();
        this.oracle_ = model.getOptions().getOracle();
        this.effective_order_ = Math.min(order, model.getOptions().getEffectiveOrder());
        this.weight_vector_ = weight_vector;
        this.candidates_per_state_ = model.getOptions().getCandidatesPerState();
        this.num_level_ = levels = this.model_.getTagTables().size();
        this.threshs_ = new double[levels][this.getOrder() + 1];
        this.length_ = new double[levels][this.getOrder() + 1];
        this.num_states_ = new double[levels][this.getOrder() + 1];
        for (int level = 0; level < this.threshs_.length; ++level) {
            Arrays.fill(this.threshs_[level], model.getOptions().getProbThreshold());
            Arrays.fill(this.length_[level], 0.0);
            Arrays.fill(this.num_states_[level], 0.0);
        }
    }

    private void addTransitions(List<List<State>> states, int level, int order) {
        List<State> last_states = Collections.singletonList(this.model_.getBoundaryState(level));
        for (List<State> current_states : states) {
            Transition[][] transitions = new Transition[last_states.size()][current_states.size()];
            int from_index = 0;
            for (State last_state : last_states) {
                FeatureVector vector = this.weight_vector_.extractTransitionFeatures(last_state);
                int to_index = 0;
                for (State state : current_states) {
                    if (last_state.canTransitionTo(state)) {
                        Transition transition = new Transition(last_state, state, order);
                        transition.setVector(vector);
                        double score = 0.0;
                        for (State run = state; run != null; run = run.getSubLevelState()) {
                            score += this.weight_vector_.dotProduct(run, vector);
                        }
                        transition.setScore(score);
                        transitions[from_index][to_index] = transition;
                    }
                    ++to_index;
                }
                ++from_index;
            }
            int to_index = 0;
            for (State state : current_states) {
                boolean found_transition = false;
                Transition[] transition_row = new Transition[last_states.size()];
                for (from_index = 0; from_index < last_states.size(); ++from_index) {
                    transition_row[from_index] = transitions[from_index][to_index];
                    if (transition_row[from_index] == null) continue;
                    found_transition = true;
                }
                assert (found_transition);
                state.setTransitions(transition_row);
                ++to_index;
            }
            last_states = current_states;
        }
    }

    protected List<List<State>> increaseOrder(List<List<State>> states, int level) {
        ArrayList<List<State>> new_state_candidates = new ArrayList<List<State>>(states.size() + 1);
        for (int index = 0; index < states.size(); ++index) {
            int num_previous_states = index == 0 ? 1 : states.get(index - 1).size();
            List<State> current_states = states.get(index);
            ArrayList<Transition> new_states = new ArrayList<Transition>(current_states.size() * num_previous_states);
            for (State state : current_states) {
                Transition[] transitions = state.getTransitions();
                state.setTransitions(null);
                assert (num_previous_states <= transitions.length);
                for (int previous_state_index = 0; previous_state_index < num_previous_states; ++previous_state_index) {
                    Transition t2 = transitions[previous_state_index];
                    if (t2 == null) continue;
                    t2.setScore(t2.getScore() + state.getScore());
                    new_states.add(t2);
                    t2.getSubOrderState().setTransitions(null);
                    assert (t2.check());
                }
            }
            assert (!new_states.isEmpty());
            new_state_candidates.add(new_states);
        }
        new_state_candidates.add(Collections.singletonList(this.model_.getBoundaryState(level)));
        return new_state_candidates;
    }

    protected List<List<State>> getStates(Sequence sequence, boolean training) {
        ArrayList<List<State>> candidates = new ArrayList<List<State>>(sequence.size() + 1);
        for (int index = 0; index < sequence.size(); ++index) {
            Token token = (Token)sequence.get(index);
            FeatureVector vector = token.getVector();
            if (vector == null) {
                vector = this.weight_vector_.extractStateFeatures(sequence, index);
                if (this.cache_feature_vector_) {
                    token.setVector(vector);
                }
            }
            int[] tag_indexes = this.model_.getTagCandidates(sequence, index, null);
            ArrayList<State> states = new ArrayList<State>(tag_indexes.length);
            for (int tag_index : tag_indexes) {
                if (tag_index == -1) break;
                State state = new State(tag_index);
                state.setVector(vector);
                state.setScore(this.weight_vector_.dotProduct(state, vector));
                this.model_.setLemmaCandidates(token, state, true, training);
                states.add(state);
            }
            assert (states.size() > 0);
            candidates.add(states);
        }
        candidates.add(Collections.singletonList(this.model_.getBoundaryState(0)));
        return candidates;
    }

    @Override
    public String setThresholds(boolean print) {
        StringBuilder sb = null;
        if (print) {
            sb = new StringBuilder();
        }
        for (int level = 0; level < this.num_states_.length; ++level) {
            for (int order = 0; order < this.num_states_[level].length; ++order) {
                if (!(this.length_[level][order] > 0.0)) continue;
                double num_states = this.num_states_[level][order] / this.length_[level][order];
                int effective_order = Math.min(order, this.candidates_per_state_.length - 1);
                double want = this.candidates_per_state_[effective_order];
                if (Math.abs(num_states - want) > 0.1) {
                    if (num_states > want) {
                        double[] dArray = this.threshs_[level];
                        int n = order;
                        dArray[n] = dArray[n] + 0.1 * this.threshs_[level][order];
                    } else {
                        double[] dArray = this.threshs_[level];
                        int n = order;
                        dArray[n] = dArray[n] - 0.1 * this.threshs_[level][order];
                    }
                }
                if (print) {
                    sb.append(' ');
                    sb.append(num_states);
                }
                this.num_states_[level][order] = 0.0;
                this.length_[level][order] = 0.0;
            }
            if (!print) continue;
            sb.append('\n');
        }
        if (print) {
            return sb.toString();
        }
        return null;
    }

    private List<List<State>> increaseLevel(List<List<State>> candidates, Sequence sentence) {
        ArrayList<List<State>> new_candidates = new ArrayList<List<State>>(candidates.size());
        int average_size = 5;
        int index = 0;
        for (List<State> current_states : candidates) {
            List<Object> new_current_states;
            if (index < candidates.size() - 1) {
                new_current_states = new ArrayList(current_states.size() * 5);
                for (State state : current_states) {
                    int[] tag_indexes;
                    FeatureVector vector = this.weight_vector_.extractStateFeatures(state);
                    assert (state.getTransitions() == null);
                    for (int tag_index : tag_indexes = this.model_.getTagCandidates(sentence, index, state)) {
                        if (tag_index == -1) break;
                        assert (state.getOrder() == 1);
                        State new_state = new State(tag_index, state);
                        new_state.setVector(vector);
                        new_state.setScore(this.weight_vector_.dotProduct(new_state, vector) + state.getRealScore());
                        this.model_.setLemmaCandidates(new_state, true);
                        new_current_states.add(new_state);
                    }
                }
            } else {
                new_current_states = Collections.singletonList(this.model_.getBoundaryState(current_states.get(0).getLevel() + 1));
            }
            new_candidates.add(new_current_states);
            ++index;
        }
        return new_candidates;
    }

    protected void incrementStateCounter(int level, int order, List<List<State>> candidates) {
        int num_states = 0;
        for (List<State> states : candidates) {
            num_states += states.size();
        }
        int length = candidates.size();
        double[] dArray = this.num_states_[level];
        int n = order;
        dArray[n] = dArray[n] + (double)num_states;
        double[] dArray2 = this.length_[level];
        int n2 = order;
        dArray2[n2] = dArray2[n2] + (double)length;
    }

    @Override
    public SumLattice getSumLattice(boolean training, Sequence sentence) {
        int order = this.getOrder();
        List<List<State>> candidates = null;
        Lattice lattice = null;
        for (int level = 0; level < this.getNumLevels(); ++level) {
            if (level == 0) {
                candidates = this.getStates(sentence, training);
            } else {
                candidates = lattice.getZeroOrderCandidates(this.prune_);
                this.incrementStateCounter(level - 1, lattice.getOrder(), candidates);
                if (training && this.testForGoldCandidates(sentence, candidates, (SumLattice)lattice) == null) {
                    return lattice;
                }
                int old_size = candidates.size();
                candidates = this.increaseLevel(candidates, sentence);
                assert (candidates.size() == old_size);
                for (List<State> states : candidates) {
                    assert (!states.isEmpty());
                }
            }
            lattice = new ZeroOrderSumLattice(candidates, this.threshs_[level][0], this.oracle_);
            if (this.oracle_ || training) {
                lattice.setGoldCandidates(this.getGoldIndexes(sentence, lattice.getCandidates()));
            }
            int effective_order = this.effective_order_;
            if (level + 1 == this.getNumLevels()) {
                effective_order = order;
            }
            for (int current_order = 0; current_order < effective_order; ++current_order) {
                if (this.prune_) {
                    candidates = lattice.prune();
                    this.incrementStateCounter(level, current_order, lattice.getZeroOrderCandidates(true));
                    assert (candidates.size() > 0);
                }
                if (current_order == 0) {
                    if (level == 0) {
                        index = 0;
                        for (List<State> states : candidates) {
                            if (index + 1 < candidates.size()) {
                                for (State state : states) {
                                    this.model_.setLemmaCandidates((Token)sentence.get(index), state, false, training);
                                }
                            }
                            ++index;
                        }
                    } else if (level + 1 == this.getNumLevels()) {
                        index = 0;
                        for (List<State> states : candidates) {
                            if (index + 1 < candidates.size()) {
                                for (State state : states) {
                                    this.model_.setLemmaCandidates(state, false);
                                }
                            }
                            ++index;
                        }
                    }
                }
                if (training && this.testForGoldCandidates(sentence, candidates, (SumLattice)lattice) == null) {
                    return lattice;
                }
                if (current_order > 0) {
                    candidates = this.increaseOrder(candidates, level);
                }
                this.addTransitions(candidates, level, current_order + 2);
                lattice = new SequenceSumLattice(candidates, this.model_.getBoundaryState(level), this.threshs_[level][current_order + 1], current_order + 1, false);
                if (!this.oracle_ && !training) continue;
                lattice.setGoldCandidates(this.getGoldIndexes(sentence, lattice.getCandidates()));
            }
        }
        assert (lattice.getCandidates().size() >= sentence.size());
        return lattice;
    }

    private List<Integer> testForGoldCandidates(Sequence sentence, List<List<State>> candidates, SumLattice lattice) {
        List<Integer> gold_candidates = this.getGoldIndexes(sentence, candidates);
        if (gold_candidates != null) {
            return gold_candidates;
        }
        return null;
    }

    public int getOrder() {
        return this.order_;
    }

    @Override
    public int getNumLevels() {
        return this.num_level_;
    }

    @Override
    public List<Integer> getGoldIndexes(Sequence sequence, List<List<State>> candidates) {
        ArrayList<Integer> list = new ArrayList<Integer>(candidates.size());
        int last_candidate_index = 0;
        for (int index = 0; index < candidates.size(); ++index) {
            int max_level;
            List<State> current_candidates = candidates.get(index);
            ArrayList<Integer> current_candidate_indexes = new ArrayList<Integer>(current_candidates.size());
            for (int candidate_index = 0; candidate_index < current_candidates.size(); ++candidate_index) {
                current_candidate_indexes.add(candidate_index);
            }
            for (int level = max_level = current_candidates.get(0).getZeroOrderState().getLevel(); level >= 0; --level) {
                ArrayList<Integer> new_current_candidate_indexes = new ArrayList<Integer>(current_candidate_indexes.size());
                int gold_tag_index = index < sequence.size() ? ((Token)sequence.get(index)).getTagIndexes()[level] : this.model_.getBoundaryIndex();
                Iterator iterator2 = current_candidate_indexes.iterator();
                while (iterator2.hasNext()) {
                    boolean valid;
                    int candidate_index = (Integer)iterator2.next();
                    State state = current_candidates.get(candidate_index);
                    if (level == max_level && !(valid = state.getTransitions() == null || state.getTransition(last_candidate_index) != null) || gold_tag_index != state.getZeroOrderState().getSubLevel(max_level - level).getIndex()) continue;
                    new_current_candidate_indexes.add(candidate_index);
                }
                current_candidate_indexes = new_current_candidate_indexes;
                if (!current_candidate_indexes.isEmpty()) continue;
                return null;
            }
            assert (current_candidate_indexes.size() == 1);
            int gold_candidate_index = (Integer)current_candidate_indexes.get(0);
            list.add(gold_candidate_index);
            last_candidate_index = gold_candidate_index;
        }
        return list;
    }

    @Override
    public Model getModel() {
        return this.model_;
    }

    @Override
    public WeightVector getWeightVector() {
        return this.weight_vector_;
    }

    @Override
    public List<List<String>> tag(Sequence sentence) {
        List<int[]> indexes = this.tag_(sentence);
        ArrayList<List<String>> strings = new ArrayList<List<String>>(indexes.size());
        for (int[] array : indexes) {
            strings.add(this.indexesToStrings(array));
        }
        return strings;
    }

    protected List<String> indexesToStrings(int[] indexes) {
        ArrayList<String> sarray = new ArrayList<String>(indexes.length);
        int level = 0;
        for (int index : indexes) {
            sarray.add(this.model_.getTagTables().get(level).toSymbol(index));
            ++level;
        }
        return sarray;
    }

    protected int[] stateToIndexes(State state) {
        int num_levels = state.getLevel() + 1;
        int[] indexes = new int[num_levels];
        for (int level = num_levels - 1; level >= 0; --level) {
            assert (state != null);
            assert (state.getIndex() >= 0);
            indexes[level] = state.getIndex();
            state = state.getSubLevelState();
        }
        return indexes;
    }

    protected List<State> tag_states(Sequence sequence) {
        ArrayList<State> list = new ArrayList<State>(sequence.size());
        SumLattice sum_lattice = this.getSumLattice(false, sequence);
        List<List<State>> candidates = sum_lattice.getCandidates();
        ViterbiLattice lattice = sum_lattice instanceof ZeroOrderSumLattice ? new ZeroOrderViterbiLattice(candidates, this.beam_size_, this.model_.getMarganlizeLemmas()) : new SequenceViterbiLattice(candidates, this.model_.getBoundaryState(this.getNumLevels() - 1), this.beam_size_, this.model_.getMarganlizeLemmas());
        Hypothesis h2 = lattice.getViterbiSequence();
        List<Integer> state_indexes = h2.getStates();
        for (int index = 0; index < sequence.size(); ++index) {
            int candidate_index = state_indexes.get(index);
            List<State> token_candidates = candidates.get(index);
            State state = token_candidates.get(candidate_index);
            state = state.getZeroOrderState();
            list.add(state);
        }
        return list;
    }

    public void setFormat(String format) {
        if (format == "conllu") {
            this.format_ = 1;
        }
    }

    public int getFormat() {
        return this.format_;
    }

    protected List<int[]> tag_(Sequence sequence) {
        ArrayList<int[]> list = new ArrayList<int[]>(sequence.size());
        List<State> states = this.tag_states(sequence);
        for (State state : states) {
            int[] indexes = this.stateToIndexes(state);
            list.add(indexes);
        }
        return list;
    }

    public void setMaxLevel(int level) {
        this.num_level_ = level;
    }

    @Override
    public void setResult(Result result) {
        this.result_ = result;
    }

    @Override
    public Result getResult() {
        return this.result_;
    }
}

