/*
 * Decompiled with CFR 0.152.
 */
package experimental.analyzer.simple;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import experimental.analyzer.Analyzer;
import experimental.analyzer.AnalyzerInstance;
import experimental.analyzer.AnalyzerReading;
import experimental.analyzer.AnalyzerTag;
import experimental.analyzer.AnalyzerTrainer;
import experimental.analyzer.simple.SimpleAnalyzer;
import experimental.analyzer.simple.SimpleAnalyzerInstance;
import experimental.analyzer.simple.SimpleAnalyzerModel;
import experimental.analyzer.simple.SimpleAnalyzerObjective;
import experimental.analyzer.simple.SimpleThresholdOptimizer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import marmot.util.Counter;
import marmot.util.Mutable;
import org.javatuples.Pair;

public class SimpleAnalyzerTrainer
extends AnalyzerTrainer {
    private SimpleAnalyzer.Mode train_mode_;
    private SimpleAnalyzer.Mode tag_mode_;
    private double penalty_;
    public static final String MODE = "mode";
    public static final String PENALTY = "penalty";
    public static final String PAIR_CONSTRAINT = "pair-constraint";
    public static final String PAIR_CONSTRAINT_THRESHOLD = "pair-constraint-threshold";
    private boolean optimize_threshold_ = false;
    private boolean mallet_ = false;
    private PairConstraint pair_constraint_ = PairConstraint.weighted;
    private double pair_constraint_threshold_ = 0.1;
    private Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts_ = null;

    @Override
    public Analyzer train(Collection<AnalyzerInstance> instances) {
        System.err.format("Num instances: %d\n", instances.size());
        boolean use_simple_optimizer = false;
        boolean couple_tags = false;
        this.tag_mode_ = SimpleAnalyzer.Mode.binary;
        this.train_mode_ = SimpleAnalyzer.Mode.binary;
        if (this.options_.containsKey(MODE)) {
            SimpleAnalyzer.Mode mode;
            this.tag_mode_ = mode = SimpleAnalyzer.Mode.valueOf((String)this.options_.get(MODE));
            this.train_mode_ = mode;
        }
        if (this.options_.containsKey(PAIR_CONSTRAINT)) {
            this.pair_constraint_ = PairConstraint.valueOf((String)this.options_.get(PAIR_CONSTRAINT));
        }
        if (this.options_.containsKey(PAIR_CONSTRAINT_THRESHOLD)) {
            this.pair_constraint_threshold_ = Double.parseDouble((String)this.options_.get(PAIR_CONSTRAINT_THRESHOLD));
        }
        System.err.format("Modes: %s / %s\n", new Object[]{this.tag_mode_, this.train_mode_});
        this.penalty_ = 1.0;
        if (this.options_.containsKey(PENALTY)) {
            this.penalty_ = Double.parseDouble((String)this.options_.get(PENALTY));
        }
        System.err.format("Penalty: %g\n", this.penalty_);
        Collection<Pair<AnalyzerTag, AnalyzerTag>> coupled = null;
        if (couple_tags) {
            coupled = this.getCoupledTags(instances);
        }
        if (this.pair_constraint_ != PairConstraint.none) {
            this.preparePairConstraints(instances);
        }
        LinkedList<SimpleAnalyzerInstance> simple_instances = new LinkedList<SimpleAnalyzerInstance>();
        for (AnalyzerInstance instance : instances) {
            Collection<AnalyzerTag> tags = AnalyzerReading.toTags(instance.getReadings());
            simple_instances.add(new SimpleAnalyzerInstance(instance, tags));
        }
        SimpleAnalyzerModel model = new SimpleAnalyzerModel();
        String float_dict_file = null;
        if (this.options_.containsKey("float-dict")) {
            float_dict_file = (String)this.options_.get("float-dict");
        }
        model.init(simple_instances, float_dict_file);
        if (this.mallet_) {
            this.run_mallet(model, simple_instances);
        } else {
            this.run_sgd(model, simple_instances, 10, true, 0.1);
        }
        double best_threshold = 0.01;
        if (this.optimize_threshold_) {
            SimpleThresholdOptimizer opt = new SimpleThresholdOptimizer(use_simple_optimizer);
            best_threshold = opt.findTreshold(model, instances, this.tag_mode_);
            System.err.println("Best threshold on train: " + best_threshold);
        }
        SimpleAnalyzer analyzer = new SimpleAnalyzer(model, best_threshold, this.tag_mode_, coupled);
        return analyzer;
    }

    private void preparePairConstraints(Collection<AnalyzerInstance> instances) {
        TagStats stats = this.getTagStates(instances);
        this.relative_counts_ = new HashMap<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>>();
        for (Map.Entry<Pair<AnalyzerTag, AnalyzerTag>, Double> entry : stats.tag_tag_counts.entrySet()) {
            Pair<AnalyzerTag, AnalyzerTag> pair = entry.getKey();
            Double count2 = entry.getValue();
            this.addRelativeProb((AnalyzerTag)pair.getValue0(), (AnalyzerTag)pair.getValue1(), count2, stats.tag_counts.count((AnalyzerTag)pair.getValue0()), this.relative_counts_);
            this.addRelativeProb((AnalyzerTag)pair.getValue1(), (AnalyzerTag)pair.getValue0(), count2, stats.tag_counts.count((AnalyzerTag)pair.getValue1()), this.relative_counts_);
        }
        for (Map.Entry<Object, Object> entry : this.relative_counts_.entrySet()) {
            AnalyzerTag tag = (AnalyzerTag)entry.getKey();
            Map map2 = (Map)entry.getValue();
            map2.put(tag, new Mutable<Double>(1.0));
            double sum = 0.0;
            for (Mutable count3 : map2.values()) {
                sum += ((Double)count3.get()).doubleValue();
            }
            for (Mutable count2 : map2.values()) {
                count2.set((Double)count2.get() / sum);
            }
            System.err.println(tag + " " + map2);
        }
    }

    private void addRelativeProb(AnalyzerTag tag, AnalyzerTag other_tag, Double tag_tag_count, Double tag_count, Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts) {
        double prob = tag_tag_count / tag_count;
        if (prob > this.pair_constraint_threshold_) {
            Map map2 = relative_counts.computeIfAbsent(tag, k -> new HashMap());
            assert (!map2.containsKey(other_tag));
            map2.put(other_tag, new Mutable<Double>(prob));
        }
    }

    private void run_sgd(SimpleAnalyzerModel model, Collection<SimpleAnalyzerInstance> simple_instances, int steps_, boolean verbose_, double step_width_) {
        LinkedList<SimpleAnalyzerInstance> instances = new LinkedList<SimpleAnalyzerInstance>(simple_instances);
        SimpleAnalyzerObjective objective = new SimpleAnalyzerObjective(this.penalty_, model, simple_instances, this.train_mode_, this.relative_counts_, this.pair_constraint_);
        int number = 0;
        Random random = new Random(42L);
        for (int step = 0; step < steps_; ++step) {
            if (verbose_) {
                System.err.println("step: " + step);
            }
            Collections.shuffle(instances, random);
            for (SimpleAnalyzerInstance instance : instances) {
                double step_width = step_width_ / (1.0 + (double)number / (double)instances.size());
                objective.update(instance, step_width, true);
                ++number;
            }
        }
    }

    private void run_mallet(SimpleAnalyzerModel model, Collection<SimpleAnalyzerInstance> simple_instances) {
        Logger logger = Logger.getLogger(this.getClass().getName());
        logger.info("Start optimization");
        SimpleAnalyzerObjective objective = new SimpleAnalyzerObjective(this.penalty_, model, simple_instances, this.train_mode_, this.relative_counts_, this.pair_constraint_);
        LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS((Optimizable.ByGradientValue)objective);
        Logger.getLogger(optimizer.getClass().getName()).setLevel(Level.OFF);
        objective.setParameters(model.getWeights());
        try {
            optimizer.optimize(1);
            for (int i = 0; i < 200 && !optimizer.isConverged(); ++i) {
                optimizer.optimize(1);
                logger.info(String.format("Iteration: %3d / %3d: %g", i + 1, 200, objective.getValue()));
            }
        }
        catch (OptimizationException | IllegalArgumentException throwable) {
            // empty catch block
        }
    }

    private TagStats getTagStates(Collection<AnalyzerInstance> instances) {
        TagStats stats = new TagStats();
        for (AnalyzerInstance instance : instances) {
            Collection<AnalyzerTag> tags = AnalyzerReading.toTags(instance.getReadings());
            for (AnalyzerTag tag : tags) {
                stats.tag_counts.increment(tag, 1.0);
            }
            ArrayList<AnalyzerTag> tag_list = new ArrayList<AnalyzerTag>(tags);
            for (int i = 0; i < tag_list.size(); ++i) {
                AnalyzerTag tag = (AnalyzerTag)tag_list.get(i);
                for (int j = i + 1; j < tag_list.size(); ++j) {
                    AnalyzerTag other_tag = (AnalyzerTag)tag_list.get(j);
                    if (tag.hashCode() < other_tag.hashCode()) {
                        stats.tag_tag_counts.increment((Pair<AnalyzerTag, AnalyzerTag>)new Pair((Object)other_tag, (Object)tag), 1.0);
                        continue;
                    }
                    stats.tag_tag_counts.increment((Pair<AnalyzerTag, AnalyzerTag>)new Pair((Object)tag, (Object)other_tag), 1.0);
                }
            }
        }
        return stats;
    }

    private Collection<Pair<AnalyzerTag, AnalyzerTag>> getCoupledTags(Collection<AnalyzerInstance> instances) {
        TagStats stats = this.getTagStates(instances);
        LinkedList<Pair<AnalyzerTag, AnalyzerTag>> coupled = new LinkedList<Pair<AnalyzerTag, AnalyzerTag>>();
        for (Map.Entry<Pair<AnalyzerTag, AnalyzerTag>, Double> entry : stats.tag_tag_counts.entrySet()) {
            double pseudo_pmi;
            Pair<AnalyzerTag, AnalyzerTag> pair = entry.getKey();
            double tag_count = stats.tag_counts.count((AnalyzerTag)pair.getValue0());
            assert (tag_count < (double)instances.size());
            double other_tag_count = stats.tag_counts.count((AnalyzerTag)pair.getValue1());
            assert (other_tag_count < (double)instances.size());
            double joint_count = entry.getValue();
            assert (joint_count < (double)instances.size());
            if (!(entry.getValue() >= 10.0) || !((pseudo_pmi = joint_count / Math.sqrt(tag_count * other_tag_count)) > 0.99)) continue;
            coupled.add(pair);
        }
        System.err.println("|Coupled|: " + coupled.size());
        System.err.println("Coupled: " + coupled);
        return coupled;
    }

    static enum PairConstraint {
        simple,
        weighted,
        none;

    }

    private static class TagStats {
        Counter<AnalyzerTag> tag_counts = new Counter();
        Counter<Pair<AnalyzerTag, AnalyzerTag>> tag_tag_counts = new Counter();

        public TagStats() {
            this.tag_counts = new Counter();
            this.tag_tag_counts = new Counter();
        }
    }
}

