/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.postag;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.ml.BeamSearch;
import opennlp.tools.ml.EventModelSequenceTrainer;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.SequenceTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.ngram.NGramModel;
import opennlp.tools.postag.MutableTagDictionary;
import opennlp.tools.postag.POSContextGenerator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
import opennlp.tools.postag.POSSampleEventStream;
import opennlp.tools.postag.POSSampleSequenceStream;
import opennlp.tools.postag.POSTagFormat;
import opennlp.tools.postag.POSTagFormatMapper;
import opennlp.tools.postag.POSTagger;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.TagDictionary;
import opennlp.tools.util.DownloadUtil;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.Sequence;
import opennlp.tools.util.SequenceValidator;
import opennlp.tools.util.StringList;
import opennlp.tools.util.StringUtil;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.featuregen.StringPattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class POSTaggerME
implements POSTagger {
    private static final Logger logger = LoggerFactory.getLogger(POSTaggerME.class);
    public static final int DEFAULT_BEAM_SIZE = 3;
    private final POSModel modelPackage;
    protected final POSContextGenerator cg;
    protected final TagDictionary tagDictionary;
    protected final int size;
    private Sequence bestSequence;
    private final SequenceClassificationModel model;
    private final SequenceValidator<String> sequenceValidator;
    private final POSTagFormat posTagFormat;
    protected final POSTagFormatMapper posTagFormatMapper;

    public POSTaggerME(String language) throws IOException {
        this(language, POSTagFormat.UD);
    }

    public POSTaggerME(String language, POSTagFormat format) throws IOException {
        this(DownloadUtil.downloadModel(language, DownloadUtil.ModelType.POS, POSModel.class), format);
    }

    public POSTaggerME(POSModel model) {
        this(model, POSTagFormat.UD);
    }

    public POSTaggerME(POSModel model, POSTagFormat format) {
        this.posTagFormat = format;
        POSTaggerFactory factory = model.getFactory();
        int beamSize = 3;
        String beamSizeString = model.getManifestProperty("BeamSize");
        if (beamSizeString != null) {
            beamSize = Integer.parseInt(beamSizeString);
        }
        this.modelPackage = model;
        this.cg = factory.getPOSContextGenerator(beamSize);
        this.tagDictionary = factory.getTagDictionary();
        this.size = beamSize;
        this.sequenceValidator = factory.getSequenceValidator();
        this.model = model.getPosSequenceModel() != null ? model.getPosSequenceModel() : new BeamSearch(beamSize, (MaxentModel)model.getArtifact("pos.model"), 0);
        this.posTagFormatMapper = format == POSTagFormat.CUSTOM ? new POSTagFormatMapper.NoOp() : new POSTagFormatMapper(this.getAllPosTags());
    }

    public String[] getAllPosTags() {
        return this.model.getOutcomes();
    }

    @Override
    public String[] tag(String[] sentence) {
        return this.tag(sentence, null);
    }

    @Override
    public String[] tag(String[] sentence, Object[] additionalContext) {
        this.bestSequence = this.model.bestSequence(sentence, additionalContext, this.cg, this.sequenceValidator);
        List<String> t2 = this.bestSequence.getOutcomes();
        return this.convertTags(t2);
    }

    public String[][] tag(int numTaggings, String[] sentence) {
        Sequence[] bestSequences = this.model.bestSequences(numTaggings, sentence, null, this.cg, this.sequenceValidator);
        String[][] tags = new String[bestSequences.length][];
        for (int si = 0; si < tags.length; ++si) {
            List<String> t2 = bestSequences[si].getOutcomes();
            tags[si] = this.convertTags(t2);
        }
        return tags;
    }

    private String[] convertTags(List<String> t2) {
        if (this.posTagFormat == POSTagFormat.CUSTOM || this.posTagFormatMapper.getGuessedFormat() == this.posTagFormat) {
            return t2.toArray(new String[0]);
        }
        return this.posTagFormatMapper.convertTags(t2);
    }

    @Override
    public Sequence[] topKSequences(String[] sentence) {
        return this.topKSequences(sentence, null);
    }

    @Override
    public Sequence[] topKSequences(String[] sentence, Object[] additionalContext) {
        return this.model.bestSequences(this.size, sentence, additionalContext, this.cg, this.sequenceValidator);
    }

    public void probs(double[] probs) {
        this.bestSequence.getProbs(probs);
    }

    public double[] probs() {
        return this.bestSequence.getProbs();
    }

    public String[] getOrderedTags(List<String> words, List<String> tags, int index) {
        return this.getOrderedTags(words, tags, index, null);
    }

    public String[] getOrderedTags(List<String> words, List<String> tags, int index, double[] tprobs) {
        MaxentModel posModel = (MaxentModel)this.modelPackage.getArtifact("pos.model");
        if (posModel != null) {
            double[] probs = posModel.eval(this.cg.getContext(index, words.toArray(new String[0]), tags.toArray(new String[0]), (Object[])null));
            String[] orderedTags = new String[probs.length];
            for (int i = 0; i < probs.length; ++i) {
                int max = 0;
                for (int ti = 1; ti < probs.length; ++ti) {
                    if (!(probs[ti] > probs[max])) continue;
                    max = ti;
                }
                orderedTags[i] = posModel.getOutcome(max);
                if (tprobs != null) {
                    tprobs[i] = probs[max];
                }
                probs[max] = 0.0;
            }
            return this.convertTags(Arrays.stream(orderedTags).toList());
        }
        throw new UnsupportedOperationException("This method can only be called if the classification model is an event model!");
    }

    public static POSModel train(String languageCode, ObjectStream<POSSample> samples, TrainingParameters mlParams, POSTaggerFactory posFactory) throws IOException {
        int beamSize = mlParams.getIntParameter("BeamSize", 3);
        POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator();
        TrainerFactory.TrainerType trainerType = TrainerFactory.getTrainerType(mlParams);
        HashMap<String, String> manifestInfoEntries = new HashMap<String, String>();
        MaxentModel posModel = null;
        SequenceClassificationModel seqPosModel = null;
        if (TrainerFactory.TrainerType.EVENT_MODEL_TRAINER.equals((Object)trainerType)) {
            POSSampleEventStream es = new POSSampleEventStream(samples, contextGenerator);
            EventTrainer trainer = TrainerFactory.getEventTrainer(mlParams, manifestInfoEntries);
            posModel = trainer.train(es);
        } else if (TrainerFactory.TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals((Object)trainerType)) {
            POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
            EventModelSequenceTrainer<POSSample> trainer = TrainerFactory.getEventModelSequenceTrainer(mlParams, manifestInfoEntries);
            posModel = trainer.train(ss);
        } else if (TrainerFactory.TrainerType.SEQUENCE_TRAINER.equals((Object)trainerType)) {
            SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(mlParams, manifestInfoEntries);
            POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
            seqPosModel = trainer.train(ss);
        } else {
            throw new IllegalArgumentException("Trainer type is not supported: " + String.valueOf((Object)trainerType));
        }
        if (posModel != null) {
            return new POSModel(languageCode, posModel, beamSize, manifestInfoEntries, posFactory);
        }
        return new POSModel(languageCode, seqPosModel, manifestInfoEntries, posFactory);
    }

    public static Dictionary buildNGramDictionary(ObjectStream<POSSample> samples, int cutoff) throws IOException {
        POSSample sample2;
        NGramModel ngramModel = new NGramModel();
        while ((sample2 = samples.read()) != null) {
            String[] words = sample2.getSentence();
            if (words.length <= 0) continue;
            ngramModel.add(new StringList(words), 1, 1);
        }
        ngramModel.cutoff(cutoff, Integer.MAX_VALUE);
        return ngramModel.toDictionary(true);
    }

    public static void populatePOSDictionary(ObjectStream<POSSample> samples, MutableTagDictionary dict, int cutoff) throws IOException {
        POSSample sample2;
        logger.info("Expanding POS Dictionary ...");
        long start = System.nanoTime();
        HashMap newEntries = new HashMap();
        while ((sample2 = samples.read()) != null) {
            String[] words = sample2.getSentence();
            String[] tags = sample2.getTags();
            for (int i = 0; i < words.length; ++i) {
                String[] dictTags;
                if (StringPattern.recognize(words[i]).containsDigit()) continue;
                String word = dict.isCaseSensitive() ? words[i] : StringUtil.toLowerCase(words[i]);
                if (!newEntries.containsKey(word)) {
                    newEntries.put(word, new HashMap());
                }
                if ((dictTags = dict.getTags(word)) != null) {
                    for (String tag : dictTags) {
                        Map value = (Map)newEntries.get(word);
                        if (value.containsKey(tag)) continue;
                        value.put(tag, new AtomicInteger(cutoff));
                    }
                }
                if (!((Map)newEntries.get(word)).containsKey(tags[i])) {
                    ((Map)newEntries.get(word)).put(tags[i], new AtomicInteger(1));
                    continue;
                }
                ((AtomicInteger)((Map)newEntries.get(word)).get(tags[i])).incrementAndGet();
            }
        }
        for (Map.Entry wordEntry : newEntries.entrySet()) {
            ArrayList<String> tagsForWord = new ArrayList<String>();
            for (Map.Entry entry : ((Map)wordEntry.getValue()).entrySet()) {
                if (((AtomicInteger)entry.getValue()).get() < cutoff) continue;
                tagsForWord.add((String)entry.getKey());
            }
            if (tagsForWord.isEmpty()) continue;
            dict.put((String)wordEntry.getKey(), tagsForWord.toArray(new String[0]));
        }
        logger.info("... finished expanding POS Dictionary. [ {} ms]", (Object)((System.nanoTime() - start) / 1000000L));
    }
}

