/*
 * Decompiled with CFR 0.152.
 */
package chipmunk.test.segmenter;

import chipmunk.segmenter.SegmentationInstance;
import chipmunk.segmenter.SegmentationReading;
import chipmunk.segmenter.SegmentationResult;
import chipmunk.segmenter.SegmentationSumLattice;
import chipmunk.segmenter.SegmenterModel;
import chipmunk.segmenter.SegmenterOptions;
import chipmunk.segmenter.Word;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import marmot.util.DynamicWeights;
import marmot.util.Numerics;
import org.junit.Assert;
import org.junit.Test;

public class SumLatticeTest {
    double explicit_update(SegmentationInstance instance, SegmenterModel model) {
        double score;
        int max_segment_length = model.getMaxSegmentLength();
        LinkedList<SegmentationResult> results = new LinkedList<SegmentationResult>();
        this.addAllResults(instance, model, max_segment_length, results, 0);
        double score_sum = Double.NEGATIVE_INFINITY;
        for (SegmentationResult result : results) {
            score = model.getScore(instance, result);
            score_sum = Numerics.sumLogProb(score, score_sum);
        }
        for (SegmentationResult result : results) {
            score = model.getScore(instance, result);
            double log_prob = score - score_sum;
            double prob = Math.exp(log_prob);
            model.update(instance, result, -prob);
        }
        assert (instance.getResults().size() == 1);
        SegmentationResult result = instance.getResults().iterator().next();
        double score2 = model.getScore(instance, result);
        double log_prob = score2 - score_sum;
        model.update(instance, result, 1.0);
        return log_prob;
    }

    private void addAllResults(SegmentationInstance instance, SegmenterModel model, int max_segment_length, List<SegmentationResult> results, int start) {
        String word = instance.getWord().getWord();
        for (int end = start + 1; end <= Math.min(start + max_segment_length, word.length()); ++end) {
            LinkedList<SegmentationResult> intermediates = new LinkedList<SegmentationResult>();
            if (end == word.length()) {
                for (int tag = 0; tag < model.getNumTags(); ++tag) {
                    LinkedList<Integer> tags = new LinkedList<Integer>();
                    tags.add(tag);
                    LinkedList<Integer> indexes = new LinkedList<Integer>();
                    indexes.add(end);
                    results.add(new SegmentationResult(tags, indexes));
                }
                continue;
            }
            this.addAllResults(instance, model, max_segment_length, intermediates, end);
            for (SegmentationResult intermediate : intermediates) {
                for (int tag = 0; tag < model.getNumTags(); ++tag) {
                    LinkedList<Integer> tags = new LinkedList<Integer>();
                    tags.add(tag);
                    tags.addAll(intermediate.getTags());
                    LinkedList<Integer> indexes = new LinkedList<Integer>();
                    indexes.add(end);
                    indexes.addAll(intermediate.getInputIndexes());
                    results.add(new SegmentationResult(tags, indexes));
                }
            }
        }
    }

    @Test
    public void test() {
        LinkedList<Word> words = new LinkedList<Word>();
        words.add(this.toWord(Arrays.asList("b"), Arrays.asList("B")));
        words.add(this.toWord(Arrays.asList("aa"), Arrays.asList("A")));
        words.add(this.toWord(Arrays.asList("a", "bb"), Arrays.asList("A", "B")));
        words.add(this.toWord(Arrays.asList("aa", "bb"), Arrays.asList("A", "B")));
        words.add(this.toWord(Arrays.asList("a", "b"), Arrays.asList("A", "B")));
        words.add(this.toWord(Arrays.asList("aa", "b"), Arrays.asList("A", "B")));
        words.add(this.toWord(Arrays.asList("aa", "c"), Arrays.asList("A", "C")));
        SegmenterModel model = new SegmenterModel();
        SegmenterOptions options = new SegmenterOptions();
        options.setOption("use-character-feature", false);
        options.setOption("use-segment-context", false);
        model.init(options, words);
        SegmentationSumLattice lattice = new SegmentationSumLattice(model);
        Random random = new Random(42L);
        for (int trial = 0; trial < 10; ++trial) {
            double[] weights = new double[50];
            for (int i = 0; i < weights.length; ++i) {
                weights[i] = random.nextGaussian();
            }
            double[] gradient = new double[weights.length];
            model.setScorerWeights(new DynamicWeights(weights, false, false));
            model.setUpdaterWeights(new DynamicWeights(gradient, false, false));
            for (Word word : words) {
                boolean equal_value;
                SegmentationInstance instance = model.getInstance(word);
                double act_value = lattice.update(instance, true);
                double[] act_gradient = (double[])gradient.clone();
                Arrays.fill(gradient, 0.0);
                double real_value = this.explicit_update(instance, model);
                double[] real_gradient = (double[])gradient.clone();
                Arrays.fill(gradient, 0.0);
                boolean equal_gradient = Numerics.approximatelyEqual(act_gradient, real_gradient, 1.0E-5);
                if (!equal_gradient) {
                    System.err.println(Arrays.toString(act_gradient) + "\n" + Arrays.toString(real_gradient));
                }
                if (!(equal_value = Numerics.approximatelyEqual(act_value, real_value))) {
                    System.err.println(word + " " + act_value + "\n" + real_value);
                }
                Assert.assertTrue((equal_gradient && equal_value ? 1 : 0) != 0);
            }
        }
    }

    private Word toWord(List<String> segments, List<String> tags) {
        Object form = "";
        for (String segment : segments) {
            form = (String)form + segment;
        }
        Word w = new Word((String)form);
        w.add(new SegmentationReading(segments, tags));
        return w;
    }
}

