/*
 * Decompiled with CFR 0.152.
 */
package experimental.analyzer.simple;

import experimental.analyzer.Analyzer;
import experimental.analyzer.AnalyzerInstance;
import experimental.analyzer.AnalyzerResult;
import experimental.analyzer.simple.SimpleAnalyzer;
import experimental.analyzer.simple.SimpleAnalyzerInstance;
import experimental.analyzer.simple.SimpleAnalyzerModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import marmot.util.Numerics;

public class SimpleThresholdOptimizer {
    boolean simple_;

    public SimpleThresholdOptimizer(boolean simple) {
        this.simple_ = simple;
    }

    public double findTreshold(SimpleAnalyzerModel model, Collection<AnalyzerInstance> instances, SimpleAnalyzer.Mode mode) {
        if (this.simple_) {
            return this.simpleFindTreshold(model, instances, mode);
        }
        int num_tags = model.getNumTags();
        int num_actives = 0;
        LinkedList entries2 = new LinkedList();
        for (AnalyzerInstance instance : instances) {
            SimpleAnalyzerInstance simple_instance = model.getInstance(instance);
            double[] scores = new double[num_tags];
            model.score(simple_instance, scores);
            double sum = Double.NEGATIVE_INFINITY;
            if (mode == SimpleAnalyzer.Mode.classifier) {
                for (double score : scores) {
                    sum = Numerics.sumLogProb(score, sum);
                }
            }
            ArrayList<Entry> current_entries = new ArrayList<Entry>(num_tags);
            for (int tag_index = 0; tag_index < num_tags; ++tag_index) {
                Entry entry = new Entry();
                entry.active = false;
                entry.prob = mode == SimpleAnalyzer.Mode.classifier ? Math.exp(scores[tag_index] - sum) : Math.exp(scores[tag_index] - Numerics.sumLogProb(scores[tag_index], 0.0));
                assert (entry.prob >= 0.0 && entry.prob < 1.0);
                current_entries.add(entry);
            }
            for (int tag_index : simple_instance.getTagIndexes()) {
                ((Entry)current_entries.get((int)tag_index)).active = true;
                ++num_actives;
            }
            entries2.addAll(current_entries);
        }
        Collections.sort(entries2);
        int best_correct = 0;
        double best_threshold = 0.0;
        int correct = num_actives;
        for (Entry entry : entries2) {
            correct = entry.active ? --correct : ++correct;
            if (correct <= best_correct) continue;
            best_correct = correct;
            best_threshold = entry.prob + 1.0E-10;
        }
        System.err.println("Correct: " + best_correct);
        return best_threshold;
    }

    private double simpleFindTreshold(SimpleAnalyzerModel model, Collection<AnalyzerInstance> instances, SimpleAnalyzer.Mode mode) {
        double[] thresholds = new double[]{0.5, 0.35, 0.3, 0.25, 0.2, 0.15, 0.1, 0.05};
        System.err.println("Thresholds: " + Arrays.toString(thresholds));
        double best_threshold = 0.0;
        double best_fscore = -1.0;
        for (double threshold : thresholds) {
            double fscore = this.getFscore(model, instances, threshold, mode);
            if (fscore > best_fscore) {
                best_fscore = fscore;
                best_threshold = threshold;
            }
            System.err.format("Threshold: %g F1-Score on train: %g\n", threshold, fscore);
        }
        return best_threshold;
    }

    private double getFscore(SimpleAnalyzerModel model, Collection<AnalyzerInstance> instances, double threshold, SimpleAnalyzer.Mode tag_mode) {
        SimpleAnalyzer analyzer = new SimpleAnalyzer(model, threshold, tag_mode, null);
        AnalyzerResult result2 = AnalyzerResult.test((Analyzer)analyzer, instances);
        double fscore = result2.getFscore();
        return fscore;
    }

    private static class Entry
    implements Comparable<Entry> {
        double prob;
        boolean active;

        private Entry() {
        }

        @Override
        public int compareTo(Entry o) {
            return Double.compare(this.prob, o.prob);
        }
    }
}

