/*
 * Decompiled with CFR 0.152.
 */
package marmot.core.lattice;

import java.util.Collections;
import java.util.List;
import marmot.core.State;
import marmot.core.Transition;
import marmot.util.Numerics;

public class BackwardSequenceLattice {
    private double[][] lattice_;
    private List<List<State>> candidates_;
    private State boundary_;

    public BackwardSequenceLattice(List<List<State>> candidates, State boundary) {
        this.candidates_ = candidates;
        this.boundary_ = boundary;
        assert (this.candidates_.get(this.candidates_.size() - 1).size() == 1);
        State state = this.candidates_.get(this.candidates_.size() - 1).get(0);
        assert (state == state.getZeroOrderState());
        assert (state.getIndex() == this.boundary_.getIndex());
    }

    private double test(int index, List<List<State>> candidates, boolean print) {
        double score_sum = Double.NEGATIVE_INFINITY;
        if (candidates.isEmpty()) {
            return 0.0;
        }
        int state_index = 0;
        for (State state : candidates.get(0)) {
            Transition transition = state.getTransition(index);
            if (transition != null) {
                double rec = this.test(state_index, candidates.subList(1, candidates.size()), false);
                if (print) {
                    System.err.format("%d %g %g %g\n", state_index, state.getScore(), transition.getScore(), rec);
                }
                double score = transition.getScore() + state.getScore() + rec;
                score_sum = Numerics.sumLogProb(score_sum, score);
            } else {
                System.err.format("%d null\n", state_index);
            }
            ++state_index;
        }
        this.lattice_[this.candidates_.size() - candidates.size()][index] = score_sum;
        return score_sum;
    }

    public void init() {
        this.lattice_ = new double[this.candidates_.size()][];
        for (int index = this.candidates_.size() - 1; index >= 0; --index) {
            List<State> previous_states = this.candidates_.get(index);
            List<State> states = index == 0 ? Collections.singletonList(this.boundary_) : this.candidates_.get(index - 1);
            this.lattice_[index] = new double[states.size()];
            for (int state_index = 0; state_index < states.size(); ++state_index) {
                double score_sum = Double.NEGATIVE_INFINITY;
                int previous_state_index = 0;
                for (State previous_state : previous_states) {
                    Transition transition = previous_state.getTransition(state_index);
                    if (transition == null) {
                        ++previous_state_index;
                        continue;
                    }
                    double score = previous_state.getScore() + transition.getScore();
                    if (index + 1 < this.candidates_.size()) {
                        score += this.lattice_[index + 1][previous_state_index];
                    }
                    score_sum = Numerics.sumLogProb(score_sum, score);
                    ++previous_state_index;
                }
                this.lattice_[index][state_index] = score_sum;
            }
        }
        assert (this.lattice_[0].length == 1);
    }

    double partitionFunction() {
        return this.lattice_[0][0];
    }

    public double get(int index, int state_index) {
        if (index == this.candidates_.size()) {
            return 0.0;
        }
        return this.lattice_[index][state_index];
    }

    public void reinit() {
        int index;
        this.lattice_ = new double[this.candidates_.size()][];
        for (index = this.candidates_.size() - 1; index >= 0; --index) {
            List<State> states = index == 0 ? Collections.singletonList(this.boundary_) : this.candidates_.get(index - 1);
            this.lattice_[index] = new double[states.size()];
        }
        System.err.println(this.test(0, this.candidates_, false));
        System.err.println(this.partitionFunction());
        for (index = this.candidates_.size() - 1; index >= 0; --index) {
            List<State> previous_states = this.candidates_.get(index);
            List<State> states = index == 0 ? Collections.singletonList(this.boundary_) : this.candidates_.get(index - 1);
            for (int state_index = 0; state_index < states.size(); ++state_index) {
                double score_sum = Double.NEGATIVE_INFINITY;
                int previous_state_index = 0;
                System.err.format("STATE\n", new Object[0]);
                for (State previous_state : previous_states) {
                    Transition transition = previous_state.getTransition(state_index);
                    if (transition == null) {
                        System.err.format("%d null\n", state_index);
                        ++previous_state_index;
                        continue;
                    }
                    double score = previous_state.getScore() + transition.getScore();
                    System.err.format("%d %g %g", previous_state_index, previous_state.getScore(), transition.getScore());
                    if (index + 1 < this.candidates_.size()) {
                        score += this.lattice_[index + 1][previous_state_index];
                        System.err.format(" %g\n", this.lattice_[index + 1][previous_state_index]);
                    }
                    System.err.format("\n", new Object[0]);
                    score_sum = Numerics.sumLogProb(score_sum, score);
                    ++previous_state_index;
                }
                try {
                    this.diffTest(this.lattice_[index][state_index], score_sum);
                }
                catch (RuntimeException e) {
                    System.err.println();
                    System.err.println(this.candidates_.size());
                    System.err.println(index);
                    System.err.println(state_index);
                    System.err.println();
                    System.err.println();
                    System.err.println(this.lattice_[index][state_index]);
                    double f = this.test(state_index, this.candidates_.subList(index, this.candidates_.size()), true);
                    System.err.println(this.lattice_[index][state_index]);
                    System.err.println(f);
                    System.err.println();
                    System.err.println();
                    throw e;
                }
                this.lattice_[index][state_index] = score_sum;
            }
        }
        assert (this.lattice_[0].length == 1);
    }

    protected double diffTest(double a, double b) {
        double diff = Math.abs(a - b);
        if (diff > 1.0E-10) {
            throw new RuntimeException(String.format("test failed: %g %g : %g", a, b, diff));
        }
        return diff;
    }
}

