/*
 * Decompiled with CFR 0.152.
 */
package lemming.lemma.toutanova;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import lemming.lemma.LemmaInstance;
import lemming.lemma.toutanova.Aligner;
import lemming.lemma.toutanova.IndexConsumer;
import lemming.lemma.toutanova.IndexScorer;
import lemming.lemma.toutanova.IndexUpdater;
import lemming.lemma.toutanova.Result;
import lemming.lemma.toutanova.ToutanovaInstance;
import lemming.lemma.toutanova.ToutanovaTrainer;
import marmot.core.Feature;
import marmot.util.DynamicWeights;
import marmot.util.Encoder;
import marmot.util.SymbolTable;

public class ToutanovaModel
implements Serializable {
    private static final long serialVersionUID = 1L;
    private String[] alphabet_;
    private SymbolTable<String> output_table_;
    private SymbolTable<String> pos_table_;
    private int max_input_segment_length_;
    private int num_output_bits;
    private SymbolTable<Character> char_table;
    private Set<String> form_vocab_;
    private transient Encoder encoder;
    private transient Encoder.State encoder_state;
    private int num_char_bits;
    private int num_pos_bits;
    private IndexScorer scorer_;
    private IndexUpdater updater_;
    private boolean use_zero_order_;
    private int max_input_segment_length_bits_;
    private DynamicWeights weights_;
    private static final int length_bits_ = 6;
    private static final int FEATURE_BITS = Encoder.bitsNeeded(2);
    private static final int TRANS_FEAT = 0;
    private static final int OUTPUT_FEAT = 1;
    private static final int PAIR_FEAT = 2;
    private static final String COPY_SYMBOL = "<COPY>";
    private int max_window = 2;

    public void init(ToutanovaTrainer.ToutanovaOptions options, List<ToutanovaInstance> train_instances, List<ToutanovaInstance> test_instances) {
        Logger logger = Logger.getLogger(this.getClass().getName());
        this.max_window = options.getMaxWindowSize();
        this.createOutputTable(options, train_instances);
        logger.info("Output alphabet size: " + this.output_table_.size());
        logger.info("Max input segment length: " + this.max_input_segment_length_);
        if (options.getFilterAlphabet() > 0) {
            this.filterRareOutputSymbols(options, train_instances);
            this.createOutputTable(options, train_instances);
            logger.info("Output alphabet size: " + this.output_table_.size());
            logger.info("Max input segment length: " + this.max_input_segment_length_);
        }
        this.char_table = new SymbolTable();
        if (options.getUsePos()) {
            this.pos_table_ = new SymbolTable();
        }
        this.form_vocab_ = new HashSet<String>();
        for (ToutanovaInstance toutanovaInstance : train_instances) {
            this.form_vocab_.add(toutanovaInstance.getInstance().getForm());
        }
        this.addIndexes(train_instances, true);
        if (test_instances != null) {
            this.addIndexes(test_instances, false);
        }
        this.num_output_bits = Encoder.bitsNeeded(this.output_table_.size());
        this.alphabet_ = new String[this.output_table_.size()];
        for (Map.Entry entry : this.output_table_.entrySet()) {
            this.alphabet_[((Integer)entry.getValue()).intValue()] = (String)entry.getKey();
        }
        this.output_table_.setBidirectional(false);
        this.num_char_bits = Encoder.bitsNeeded(this.char_table.size());
        this.num_pos_bits = -1;
        if (this.pos_table_ != null) {
            this.num_pos_bits = Encoder.bitsNeeded(this.pos_table_.size());
        }
        this.weights_ = new DynamicWeights(options.getRandom());
        SymbolTable<Feature> feature_map = new SymbolTable<Feature>();
        this.scorer_ = new IndexScorer(this.weights_, feature_map, this.num_pos_bits);
        this.updater_ = new IndexUpdater(this.weights_, feature_map, this.num_pos_bits);
        this.use_zero_order_ = options.getDecoderInstance().getOrder() < 1;
        this.setupTemp();
    }

    private void setupTemp() {
        this.encoder = new Encoder(10);
        this.encoder_state = new Encoder.State();
    }

    private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
        ois.defaultReadObject();
        this.setupTemp();
    }

    private void createOutputTable(ToutanovaTrainer.ToutanovaOptions options, List<ToutanovaInstance> train_instances) {
        this.output_table_ = new SymbolTable(true);
        this.output_table_.insert(COPY_SYMBOL);
        this.max_input_segment_length_ = 0;
        for (ToutanovaInstance instance : train_instances) {
            if (instance.isRare()) {
                instance.setResult(null);
                continue;
            }
            String form = instance.getInstance().getForm();
            assert (instance.getAlignment() != null);
            List<Aligner.Pair> pairs = Aligner.Pair.toPairs(form, instance.getInstance().getLemma(), instance.getAlignment());
            ArrayList<Integer> form_indexes = new ArrayList<Integer>(pairs.size());
            ArrayList<Integer> lemma_segments = new ArrayList<Integer>(pairs.size());
            int start_index = 0;
            for (Aligner.Pair pair : pairs) {
                int current_input_length = pair.getInputSegment().length();
                this.max_input_segment_length_ = Math.max(this.max_input_segment_length_, current_input_length);
                form_indexes.add(start_index += current_input_length);
                int output_segment_index = 0;
                if (!pair.getInputSegment().equals(pair.getOutputSegment())) {
                    output_segment_index = this.output_table_.toIndex(pair.getOutputSegment(), true);
                }
                lemma_segments.add(output_segment_index);
            }
            Result result2 = new Result(this, lemma_segments, form_indexes, form);
            assert (result2.getOutput().equals(instance.getInstance().getLemma()));
            instance.setResult(result2);
        }
        this.max_input_segment_length_bits_ = Encoder.bitsNeeded(this.max_input_segment_length_);
    }

    private void filterRareOutputSymbols(ToutanovaTrainer.ToutanovaOptions options, List<ToutanovaInstance> train_instances) {
        Logger logger = Logger.getLogger(this.getClass().getName());
        int[] count2 = new int[this.output_table_.size()];
        for (ToutanovaInstance instance : train_instances) {
            Iterator<Integer> iterator2 = instance.getResult().getOutputs().iterator();
            while (iterator2.hasNext()) {
                int output_index;
                int n = output_index = iterator2.next().intValue();
                count2[n] = count2[n] + 1;
            }
        }
        int rare_output_symbols = 0;
        for (int j : count2) {
            if (j != 1) continue;
            ++rare_output_symbols;
        }
        logger.info(String.format("Num rare output symbols (< %d): %d", options.getFilterAlphabet(), rare_output_symbols));
        Object object = train_instances.iterator();
        while (object.hasNext()) {
            ToutanovaInstance instance = (ToutanovaInstance)object.next();
            boolean instance_is_rare = false;
            for (int output_index : instance.getResult().getOutputs()) {
                if (count2[output_index] > options.getFilterAlphabet()) continue;
                instance_is_rare = true;
                break;
            }
            instance.setRare(instance_is_rare);
        }
    }

    public SymbolTable<String> getOutputTable() {
        return this.output_table_;
    }

    public int getMaxInputSegmentLength() {
        return this.max_input_segment_length_;
    }

    public String getOutput(int o) {
        if (this.alphabet_ == null) {
            return this.output_table_.toSymbol(o);
        }
        return this.alphabet_[o];
    }

    public void consumeTransitionFeature(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int last_o, int o) {
        if (last_o < 0) {
            return;
        }
        this.encoder.reset();
        this.encoder.append(0, FEATURE_BITS);
        this.encoder.append(last_o, this.num_output_bits);
        this.encoder.append(o, this.num_output_bits);
        consumer.consume(instance, this.encoder);
        this.addAffixes(instance, consumer, l_start, l_end);
    }

    private void addAffixes(ToutanovaInstance instance, IndexConsumer consumer, int l_start, int l_end) {
        int window;
        for (window = 1; window <= this.max_window; ++window) {
            this.encoder.storeState(this.encoder_state);
            this.addSegment(instance.getFormCharIndexes(), l_start - window, l_start);
            this.addSegment(instance.getFormCharIndexes(), l_end + 1, l_end + window + 1);
            consumer.consume(instance, this.encoder);
            this.encoder.restoreState(this.encoder_state);
        }
        for (window = 1; window <= this.max_window; ++window) {
            this.encoder.storeState(this.encoder_state);
            this.addSegment(instance.getFormCharIndexes(), l_start - window, l_start);
            consumer.consume(instance, this.encoder);
            this.encoder.restoreState(this.encoder_state);
        }
        for (window = 1; window <= this.max_window; ++window) {
            this.encoder.storeState(this.encoder_state);
            this.addSegment(instance.getFormCharIndexes(), l_end + 1, l_end + window + 1);
            consumer.consume(instance, this.encoder);
            this.encoder.restoreState(this.encoder_state);
        }
    }

    private void addSegment(int[] chars, int start, int end) {
        this.encoder.append(end - start, 6);
        for (int i = start; i < end; ++i) {
            int c = i >= 0 && i < chars.length ? chars[i] : this.char_table.size();
            if (c < 0) {
                return;
            }
            this.encoder.append(c, this.num_char_bits);
        }
    }

    public void consumeOutputFeature(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int o) {
        this.encoder.reset();
        this.encoder.append(1, FEATURE_BITS);
        this.encoder.append(o, this.num_output_bits);
        consumer.consume(instance, this.encoder);
        this.addAffixes(instance, consumer, l_start, l_end);
    }

    public void consumePairFeature(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int o) {
        int[] chars = instance.getFormCharIndexes();
        this.encoder.reset();
        this.encoder.append(2, FEATURE_BITS);
        this.encoder.append(o, this.num_output_bits);
        this.encoder.append(l_end - l_start, this.max_input_segment_length_bits_);
        this.encoder.append(l_end - l_start, 4);
        for (int l = l_start; l < l_end; ++l) {
            int c = chars[l];
            if (c < 0) {
                return;
            }
            this.encoder.append(c, this.num_char_bits);
        }
        consumer.consume(instance, this.encoder);
        this.addAffixes(instance, consumer, l_start, l_end);
    }

    private void consumeOutputPair(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int o) {
        this.consumePairFeature(consumer, instance, l_start, l_end, o);
        this.consumeOutputFeature(consumer, instance, l_start, l_end, o);
    }

    private void consumeTransition(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int last_o, int o) {
        if (this.use_zero_order_) {
            return;
        }
        this.consumeTransitionFeature(consumer, instance, l_start, l_end, last_o, o);
    }

    public double getPairScore(ToutanovaInstance instance, int l_start, int l_end, int o) {
        this.scorer_.reset();
        this.consumeOutputPair(this.scorer_, instance, l_start, l_end, o);
        return this.scorer_.getScore();
    }

    public double getTransitionScore(ToutanovaInstance instance, int last_o, int o, int l_start, int l_end) {
        this.scorer_.reset();
        this.consumeTransition(this.scorer_, instance, l_start, l_end, last_o, o);
        return this.scorer_.getScore();
    }

    public double getScore(ToutanovaInstance instance, Result result2) {
        this.scorer_.reset();
        Iterator<Integer> output_iterator = result2.getOutputs().iterator();
        Iterator<Integer> input_iterator = result2.getInputs().iterator();
        int last_o = -1;
        int l_start = 0;
        while (output_iterator.hasNext()) {
            int o = output_iterator.next();
            int l_end = input_iterator.next();
            if (last_o >= 0) {
                this.consumeTransition(this.scorer_, instance, l_start, l_end, last_o, o);
            }
            this.consumeOutputPair(this.scorer_, instance, l_start, l_end, o);
            last_o = o;
            l_start = l_end;
        }
        return this.scorer_.getScore();
    }

    public void update(ToutanovaInstance instance, Result result2, double update) {
        this.updater_.setUpdate(update);
        Iterator<Integer> output_iterator = result2.getOutputs().iterator();
        Iterator<Integer> input_iterator = result2.getInputs().iterator();
        int last_o = -1;
        int l_start = 0;
        while (output_iterator.hasNext()) {
            int o = output_iterator.next();
            int l_end = input_iterator.next();
            if (last_o >= 0) {
                this.consumeTransition(this.updater_, instance, l_start, l_end, last_o, o);
            }
            this.consumeOutputPair(this.updater_, instance, l_start, l_end, o);
            last_o = o;
            l_start = l_end;
        }
    }

    public void addIndexes(List<ToutanovaInstance> instances, boolean insert) {
        for (ToutanovaInstance instance : instances) {
            this.addIndexes(instance, insert);
        }
    }

    public void addIndexes(ToutanovaInstance instance, boolean insert) {
        if (!instance.isRare()) {
            String pos_tag;
            String form = instance.getInstance().getForm();
            int[] char_indexes = new int[form.length()];
            for (int i = 0; i < form.length(); ++i) {
                char_indexes[i] = this.char_table.toIndex(Character.valueOf(form.charAt(i)), -1, insert);
            }
            instance.setFormCharIndexes(char_indexes);
            if (this.pos_table_ != null && (pos_tag = instance.getInstance().getPosTag()) != null) {
                int index = this.pos_table_.toIndex(pos_tag, -1, insert);
                instance.setPosTagIndex(index);
            }
        }
    }

    public DynamicWeights getWeights() {
        return this.weights_;
    }

    public void setWeights(DynamicWeights weights) {
        this.weights_ = weights;
        this.scorer_.setWeights(weights);
        this.updater_.setWeights(weights);
    }

    public boolean isOOV(LemmaInstance instance) {
        return !this.form_vocab_.contains(instance.getForm());
    }
}

