package org.dkpro.tc.ml.vowpalwabbit.serialization;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.Properties;
import org.apache.commons.io.FileUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ExternalResource;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.pear.util.FileUtil;
import org.apache.uima.resource.ResourceInitializationException;
import org.dkpro.tc.api.features.Feature;
import org.dkpro.tc.api.features.FeatureExtractorResource_ImplBase;
import org.dkpro.tc.api.features.FeatureType;
import org.dkpro.tc.api.features.Instance;
import org.dkpro.tc.api.type.TextClassificationOutcome;
import org.dkpro.tc.core.ml.ModelSerialization_ImplBase;
import org.dkpro.tc.core.task.uima.InstanceExtractor;
import org.dkpro.tc.ml.vowpalwabbit.core.VowpalWabbitPredictor;
import org.dkpro.tc.ml.vowpalwabbit.writer.VowpalWabbitDataWriter;

/* loaded from: input_file:org/dkpro/tc/ml/vowpalwabbit/serialization/VowpalWabbitLoadModelConnector.class */
public class VowpalWabbitLoadModelConnector extends ModelSerialization_ImplBase {

    @ConfigurationParameter(name = "tcModel", mandatory = true)
    protected File tcModelLocation;

    @ExternalResource(key = "featureExtractors", mandatory = true)
    protected FeatureExtractorResource_ImplBase[] featureExtractors;
    protected String featureMode;
    protected Map<String, String> integer2OutcomeMapping;
    protected Map<String, String> stringValue2IntegerMapping;
    protected String learningMode;
    protected File model = null;
    Integer maxStringId = -1;

    public void initialize(UimaContext uimaContext) throws ResourceInitializationException {
        super.initialize(uimaContext);
        try {
            this.model = new File(this.tcModelLocation, "classifier.ser");
            this.featureMode = loadProperty(new File(this.tcModelLocation, "featureMode.txt"), "featureMode");
            this.learningMode = loadProperty(new File(this.tcModelLocation, "learningMode.txt"), "learningMode");
            this.integer2OutcomeMapping = loadMapping(this.tcModelLocation, VowpalWabbitDataWriter.OUTCOME_MAPPING);
            this.stringValue2IntegerMapping = loadMapping(this.tcModelLocation, VowpalWabbitDataWriter.STRING_MAPPING);
            determineTheMaxStringsIntIdValue();
            verifyTcVersion(this.tcModelLocation, getClass());
        } catch (Exception e) {
            throw new ResourceInitializationException(e);
        }
    }

    protected void determineTheMaxStringsIntIdValue() {
        OptionalInt max = this.stringValue2IntegerMapping.keySet().stream().mapToInt(Integer::parseInt).max();
        if (max.isPresent()) {
            this.maxStringId = Integer.valueOf(max.getAsInt());
        }
    }

    protected String loadProperty(File file, String str) throws IOException {
        FileInputStream fileInputStream = new FileInputStream(file);
        Throwable th = null;
        try {
            try {
                Properties properties = new Properties();
                properties.load(fileInputStream);
                String property = properties.getProperty(str);
                if (fileInputStream != null) {
                    if (0 != 0) {
                        try {
                            fileInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        fileInputStream.close();
                    }
                }
                if (property == null) {
                    throw new IllegalStateException("Could not load [" + str + "] from file [" + file + "]");
                }
                return property;
            } finally {
            }
        } catch (Throwable th3) {
            if (fileInputStream != null) {
                if (th != null) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            throw th3;
        }
    }

    protected Map<String, String> loadMapping(File file, String str) throws IOException {
        if (isRegression()) {
            return new HashMap();
        }
        HashMap hashMap = new HashMap();
        Iterator it = FileUtils.readLines(new File(file, str), StandardCharsets.UTF_8).iterator();
        while (it.hasNext()) {
            String[] split = ((String) it.next()).split("\t");
            hashMap.put(split[1], split[0]);
        }
        return hashMap;
    }

    protected boolean isRegression() {
        return this.learningMode.equals("regression");
    }

    public void process(JCas jCas) throws AnalysisEngineProcessException {
        List<String> runPrediction = runPrediction(isSequence() ? createInputFile(jCas, true) : createInputFile(jCas, false), isSequence());
        List<TextClassificationOutcome> outcomeAnnotations = getOutcomeAnnotations(jCas);
        for (int i = 0; i < outcomeAnnotations.size(); i++) {
            if (isRegression()) {
                outcomeAnnotations.get(i).setOutcome(runPrediction.get(i));
            } else {
                outcomeAnnotations.get(i).setOutcome(this.integer2OutcomeMapping.get(runPrediction.get(i)));
            }
        }
    }

    protected List<TextClassificationOutcome> getOutcomeAnnotations(JCas jCas) {
        return new ArrayList(JCasUtil.select(jCas, TextClassificationOutcome.class));
    }

    protected List<String> runPrediction(File file, boolean z) throws AnalysisEngineProcessException {
        try {
            FileUtil.createTempFile("vowpalWabbitPrediction" + System.currentTimeMillis(), ".txt").deleteOnExit();
            List<String> predict = new VowpalWabbitPredictor().predict(file, this.model);
            if (z) {
                ArrayList arrayList = new ArrayList();
                Iterator<String> it = predict.iterator();
                while (it.hasNext()) {
                    arrayList.addAll(Arrays.asList(it.next().split(" ")));
                }
                predict = arrayList;
            }
            return predict;
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
    }

    /* JADX WARN: Finally extract failed */
    protected File createInputFile(JCas jCas, boolean z) throws AnalysisEngineProcessException {
        try {
            File createTempFile = FileUtil.createTempFile("vowpalWabbit" + System.currentTimeMillis(), ".txt");
            createTempFile.deleteOnExit();
            BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(createTempFile), StandardCharsets.UTF_8));
            Throwable th = null;
            try {
                List<Instance> instances = new InstanceExtractor(this.featureMode, this.featureExtractors, false).getInstances(jCas, true);
                if (z) {
                    Collections.sort(instances, new Comparator<Instance>() { // from class: org.dkpro.tc.ml.vowpalwabbit.serialization.VowpalWabbitLoadModelConnector.1
                        @Override // java.util.Comparator
                        public int compare(Instance instance, Instance instance2) {
                            return Integer.compare(instance.getSequenceId(), instance2.getSequenceId());
                        }
                    });
                }
                for (Instance instance : instances) {
                    if (instance.getSequenceId() != -1 && -1 != -1) {
                        bufferedWriter.write("\n");
                    }
                    bufferedWriter.write("|");
                    for (Feature feature : instance.getFeatures()) {
                        bufferedWriter.write(" ");
                        bufferedWriter.write(feature.getName() + ":" + mapStringValues(feature.getType(), feature.getValue().toString()));
                    }
                    bufferedWriter.write("\n");
                }
                if (bufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            bufferedWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedWriter.close();
                    }
                }
                return createTempFile;
            } catch (Throwable th3) {
                if (bufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            bufferedWriter.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        bufferedWriter.close();
                    }
                }
                throw th3;
            }
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
    }

    protected String mapStringValues(FeatureType featureType, String str) {
        if ((featureType != FeatureType.STRING && featureType != FeatureType.NOMINAL) || this.stringValue2IntegerMapping.get(str) != null) {
            return str;
        }
        this.maxStringId = Integer.valueOf(this.maxStringId.intValue() + 1);
        return this.maxStringId.toString();
    }

    protected boolean isSequence() {
        return this.featureMode.equals("sequence");
    }
}
