/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.perceptron;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import opennlp.tools.ml.AbstractEventModelSequenceTrainer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Sequence;
import opennlp.tools.ml.model.SequenceStream;
import opennlp.tools.ml.model.SequenceStreamEventStream;
import opennlp.tools.ml.perceptron.PerceptronModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SimplePerceptronSequenceTrainer
extends AbstractEventModelSequenceTrainer {
    private static final Logger logger = LoggerFactory.getLogger(SimplePerceptronSequenceTrainer.class);
    public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
    private int iterations;
    private SequenceStream<Event> sequenceStream;
    private int numEvents;
    private int numPreds;
    private int numOutcomes;
    private int[] outcomeList;
    private String[] outcomeLabels;
    private MutableContext[] averageParams;
    private Map<String, Integer> pmap;
    private Map<String, Integer> omap;
    private MutableContext[] params;
    private boolean useAverage;
    private int[][][] updates;
    private static final int VALUE = 0;
    private static final int ITER = 1;
    private static final int EVENT = 2;
    private String[] predLabels;
    private int numSequences;

    @Override
    public void validate() {
        super.validate();
        String algorithmName = this.getAlgorithm();
        if (algorithmName != null && !PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName)) {
            throw new IllegalArgumentException("algorithmName must be PERCEPTRON_SEQUENCE");
        }
    }

    @Override
    public AbstractModel doTrain(SequenceStream<Event> events) throws IOException {
        int iterations = this.getIterations();
        int cutoff = this.getCutoff();
        boolean useAverage = this.trainingParameters.getBooleanParameter("UseAverage", true);
        return this.trainModel(iterations, events, cutoff, useAverage);
    }

    public AbstractModel trainModel(int iterations, SequenceStream<Event> sequenceStream, int cutoff, boolean useAverage) throws IOException {
        this.iterations = iterations;
        this.sequenceStream = sequenceStream;
        this.trainingParameters.put("Cutoff", cutoff);
        this.trainingParameters.put("sort", false);
        OnePassDataIndexer di = new OnePassDataIndexer();
        di.init(this.trainingParameters, this.reportMap);
        di.index(new SequenceStreamEventStream(sequenceStream));
        this.numSequences = 0;
        sequenceStream.reset();
        while (sequenceStream.read() != null) {
            ++this.numSequences;
        }
        this.outcomeList = di.getOutcomeList();
        this.predLabels = di.getPredLabels();
        this.pmap = new HashMap<String, Integer>();
        for (int i = 0; i < this.predLabels.length; ++i) {
            this.pmap.put(this.predLabels[i], i);
        }
        logger.info("Incorporating indexed data for training... ");
        this.useAverage = useAverage;
        this.numEvents = di.getNumEvents();
        this.iterations = iterations;
        this.outcomeLabels = di.getOutcomeLabels();
        this.omap = new HashMap<String, Integer>();
        for (int oli = 0; oli < this.outcomeLabels.length; ++oli) {
            this.omap.put(this.outcomeLabels[oli], oli);
        }
        this.outcomeList = di.getOutcomeList();
        this.numPreds = this.predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        if (useAverage) {
            this.updates = new int[this.numPreds][this.numOutcomes][3];
        }
        logger.info("done.");
        logger.info("\tNumber of Event Tokens: {} \n\t Number of Outcomes: {} \n\t Number of Predicates: {}", this.numEvents, this.numOutcomes, this.numPreds);
        this.params = new MutableContext[this.numPreds];
        if (useAverage) {
            this.averageParams = new MutableContext[this.numPreds];
        }
        int[] allOutcomesPattern = new int[this.numOutcomes];
        for (int oi = 0; oi < this.numOutcomes; ++oi) {
            allOutcomesPattern[oi] = oi;
        }
        for (int pi = 0; pi < this.numPreds; ++pi) {
            this.params[pi] = new MutableContext(allOutcomesPattern, new double[this.numOutcomes]);
            if (useAverage) {
                this.averageParams[pi] = new MutableContext(allOutcomesPattern, new double[this.numOutcomes]);
            }
            for (int aoi = 0; aoi < this.numOutcomes; ++aoi) {
                this.params[pi].setParameter(aoi, 0.0);
                if (!useAverage) continue;
                this.averageParams[pi].setParameter(aoi, 0.0);
            }
        }
        logger.info("Computing model parameters...");
        this.findParameters(iterations);
        logger.info("...done.");
        String[] updatedPredLabels = this.predLabels;
        if (useAverage) {
            return new PerceptronModel(this.averageParams, updatedPredLabels, this.outcomeLabels);
        }
        return new PerceptronModel(this.params, updatedPredLabels, this.outcomeLabels);
    }

    private void findParameters(int iterations) throws IOException {
        logger.info("Performing {} iterations.\n", (Object)iterations);
        for (int i = 1; i <= iterations; ++i) {
            this.nextIteration(i);
        }
        if (this.useAverage) {
            this.trainingStats(this.averageParams);
        } else {
            this.trainingStats(this.params);
        }
    }

    public void nextIteration(int iteration) throws IOException {
        Sequence sequence;
        --iteration;
        int numCorrect = 0;
        int oei = 0;
        int si = 0;
        ArrayList featureCounts = new ArrayList(this.numOutcomes);
        for (int oi = 0; oi < this.numOutcomes; ++oi) {
            featureCounts.add(new HashMap());
        }
        PerceptronModel model = new PerceptronModel(this.params, this.predLabels, this.outcomeLabels);
        this.sequenceStream.reset();
        while ((sequence = (Sequence)this.sequenceStream.read()) != null) {
            Event[] taggerEvents = this.sequenceStream.updateContext(sequence, model);
            Event[] events = sequence.getEvents();
            boolean update = false;
            int ei = 0;
            while (ei < events.length) {
                if (!taggerEvents[ei].getOutcome().equals(events[ei].getOutcome())) {
                    update = true;
                } else {
                    ++numCorrect;
                }
                ++ei;
                ++oei;
            }
            if (update) {
                int oi;
                Object[] contextStrings;
                int oi2;
                for (oi2 = 0; oi2 < this.numOutcomes; ++oi2) {
                    ((Map)featureCounts.get(oi2)).clear();
                }
                if (logger.isTraceEnabled()) {
                    StringBuilder sb = new StringBuilder();
                    for (Event event : events) {
                        sb.append(" ").append(event.getOutcome());
                    }
                    logger.trace("train: {}", (Object)sb);
                }
                ei = 0;
                while (ei < events.length) {
                    contextStrings = events[ei].getContext();
                    float[] values2 = events[ei].getValues();
                    oi = this.omap.get(events[ei].getOutcome());
                    for (int ci = 0; ci < contextStrings.length; ++ci) {
                        Float c;
                        float value = 1.0f;
                        if (values2 != null) {
                            value = values2[ci];
                        }
                        c = (c = (Float)((Map)featureCounts.get(oi)).get(contextStrings[ci])) == null ? Float.valueOf(value) : Float.valueOf(c.floatValue() + value);
                        ((Map)featureCounts.get(oi)).put(contextStrings[ci], c);
                    }
                    ++ei;
                    ++oei;
                }
                if (logger.isTraceEnabled()) {
                    StringBuilder sb = new StringBuilder();
                    contextStrings = taggerEvents;
                    int values2 = contextStrings.length;
                    for (oi = 0; oi < values2; ++oi) {
                        Object taggerEvent = contextStrings[oi];
                        sb.append(" ").append(((Event)taggerEvent).getOutcome());
                    }
                    logger.trace("test: {}", (Object)sb);
                }
                for (Event taggerEvent : taggerEvents) {
                    String[] contextStrings2 = taggerEvent.getContext();
                    float[] values3 = taggerEvent.getValues();
                    int oi3 = this.omap.get(taggerEvent.getOutcome());
                    for (int ci = 0; ci < contextStrings2.length; ++ci) {
                        Float c;
                        float value = 1.0f;
                        if (values3 != null) {
                            value = values3[ci];
                        }
                        if ((c = (c = (Float)((Map)featureCounts.get(oi3)).get(contextStrings2[ci])) == null ? Float.valueOf(-1.0f * value) : Float.valueOf(c.floatValue() - value)).floatValue() == 0.0f) {
                            ((Map)featureCounts.get(oi3)).remove(contextStrings2[ci]);
                            continue;
                        }
                        ((Map)featureCounts.get(oi3)).put(contextStrings2[ci], c);
                    }
                }
                for (oi2 = 0; oi2 < this.numOutcomes; ++oi2) {
                    for (String feature : ((Map)featureCounts.get(oi2)).keySet()) {
                        int pi = this.pmap.getOrDefault(feature, -1);
                        if (pi == -1) continue;
                        if (logger.isTraceEnabled()) {
                            logger.trace("{} {} {} {}", si, this.outcomeLabels[oi2], feature, ((Map)featureCounts.get(oi2)).get(feature));
                        }
                        this.params[pi].updateParameter(oi2, ((Float)((Map)featureCounts.get(oi2)).get(feature)).floatValue());
                        if (!this.useAverage) continue;
                        if (this.updates[pi][oi2][0] != 0) {
                            this.averageParams[pi].updateParameter(oi2, this.updates[pi][oi2][0] * (this.numSequences * (iteration - this.updates[pi][oi2][1]) + (si - this.updates[pi][oi2][2])));
                            if (logger.isTraceEnabled()) {
                                logger.trace("p avp[{}].{}={}", pi, oi2, this.averageParams[pi].getParameters()[oi2]);
                            }
                        }
                        if (logger.isTraceEnabled()) {
                            logger.trace("p updates[{}]{{}]=({},{},{})({},{},{}) -> {}", pi, oi2, this.updates[pi][oi2][1], this.updates[pi][oi2][2], this.updates[pi][oi2][0], iteration, oei, this.params[pi].getParameters()[oi2], this.averageParams[pi].getParameters()[oi2]);
                        }
                        this.updates[pi][oi2][0] = (int)this.params[pi].getParameters()[oi2];
                        this.updates[pi][oi2][1] = iteration;
                        this.updates[pi][oi2][2] = si;
                    }
                }
                model = new PerceptronModel(this.params, this.predLabels, this.outcomeLabels);
            }
            ++si;
        }
        double totIterations = (double)this.iterations * (double)si;
        if (this.useAverage && iteration == this.iterations - 1) {
            for (int pi = 0; pi < this.numPreds; ++pi) {
                double[] predParams = this.averageParams[pi].getParameters();
                for (int oi = 0; oi < this.numOutcomes; ++oi) {
                    if (this.updates[pi][oi][0] != 0) {
                        int n = oi;
                        predParams[n] = predParams[n] + (double)(this.updates[pi][oi][0] * (this.numSequences * (this.iterations - this.updates[pi][oi][1]) - this.updates[pi][oi][2]));
                    }
                    if (predParams[oi] == 0.0) continue;
                    int n = oi;
                    predParams[n] = predParams[n] / totIterations;
                    this.averageParams[pi].setParameter(oi, predParams[oi]);
                    if (!logger.isTraceEnabled()) continue;
                    logger.trace("updates[{}][{}]=({},{},{})({},{},{}) -> {}", pi, oi, this.updates[pi][oi][1], this.updates[pi][oi][2], this.updates[pi][oi][0], this.iterations, 0, this.params[pi].getParameters()[oi], this.averageParams[pi].getParameters()[oi]);
                }
            }
        }
        logger.info("{}. ({}/{}) {}", iteration, numCorrect, this.numEvents, (double)numCorrect / (double)this.numEvents);
    }

    private void trainingStats(MutableContext[] params) throws IOException {
        Sequence sequence;
        int numCorrect = 0;
        int oei = 0;
        this.sequenceStream.reset();
        while ((sequence = (Sequence)this.sequenceStream.read()) != null) {
            Event[] taggerEvents = this.sequenceStream.updateContext(sequence, new PerceptronModel(params, this.predLabels, this.outcomeLabels));
            int ei = 0;
            while (ei < taggerEvents.length) {
                int max = this.omap.get(taggerEvents[ei].getOutcome());
                if (max == this.outcomeList[oei]) {
                    ++numCorrect;
                }
                ++ei;
                ++oei;
            }
        }
        logger.info(". ({}/{}) {}", numCorrect, this.numEvents, (double)numCorrect / (double)this.numEvents);
    }
}

