package org.jpmml.sparkml.testing;

import com.google.common.base.Equivalence;
import com.google.common.io.ByteStreams;
import com.google.common.io.MoreFiles;
import com.google.common.io.RecursiveDeleteOption;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.testing.ModelEncoderBatch;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.PMMLEquivalence;
import org.jpmml.sparkml.ArchiveUtil;
import org.jpmml.sparkml.DatasetUtil;
import org.jpmml.sparkml.PMMLBuilder;
import org.jpmml.sparkml.PipelineModelUtil;
import org.jpmml.sparkml.model.HasRegressionTableOptions;

/* loaded from: input_file:org/jpmml/sparkml/testing/SparkMLEncoderBatch.class */
public abstract class SparkMLEncoderBatch extends ModelEncoderBatch {
    public SparkMLEncoderBatch(String str, String str2, Predicate<ResultField> predicate, Equivalence<Object> equivalence) {
        super(str, str2, predicate, equivalence);
    }

    @Override // 
    /* renamed from: getArchiveBatchTest */
    public abstract SparkMLEncoderBatchTest mo29getArchiveBatchTest();

    public List<Map<String, Object>> getOptionsMatrix() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(HasRegressionTableOptions.OPTION_LOOKUP_THRESHOLD, 5);
        return Collections.singletonList(linkedHashMap);
    }

    public String getSchemaJsonPath() {
        return "/schema/" + getDataset() + ".json";
    }

    public String getPipelineModelZipPath() {
        return "/pipeline/" + getAlgorithm() + getDataset() + ".zip";
    }

    public Dataset<Row> getVerificationDataset(Dataset<Row> dataset) {
        return dataset.sample(false, 0.05d, 63317L);
    }

    public PMML getPMML() throws Exception {
        SparkSession sparkSession = mo29getArchiveBatchTest().getSparkSession();
        if (sparkSession == null) {
            throw new IllegalStateException();
        }
        ArrayList arrayList = new ArrayList();
        StructType loadSchema = loadSchema(sparkSession, arrayList);
        PipelineModel loadPipelineModel = loadPipelineModel(sparkSession, arrayList);
        StructType updateSchema = updateSchema(loadSchema, loadPipelineModel);
        Dataset<Row> castColumns = DatasetUtil.castColumns(loadInput(sparkSession, arrayList), updateSchema);
        PMMLBuilder putOptions = new PMMLBuilder(updateSchema, loadPipelineModel).putOptions(getOptions());
        Dataset<Row> verificationDataset = getVerificationDataset(castColumns);
        if (verificationDataset != null) {
            PMMLEquivalence equivalence = getEquivalence();
            double d = 1.0E-14d;
            double d2 = 1.0E-14d;
            if (equivalence instanceof PMMLEquivalence) {
                PMMLEquivalence pMMLEquivalence = equivalence;
                d = pMMLEquivalence.getPrecision();
                d2 = pMMLEquivalence.getZeroThreshold();
            }
            putOptions = putOptions.verify(verificationDataset, d, d2);
        }
        PMML build = putOptions.build();
        validatePMML(build);
        Iterator<File> it = arrayList.iterator();
        while (it.hasNext()) {
            MoreFiles.deleteRecursively(it.next().toPath(), new RecursiveDeleteOption[0]);
        }
        return build;
    }

    protected StructType loadSchema(SparkSession sparkSession, List<File> list) throws IOException {
        InputStream open = open(getSchemaJsonPath());
        Throwable th = null;
        try {
            try {
                File tmpFile = toTmpFile(open, getDataset(), ".json");
                list.add(tmpFile);
                StructType loadSchema = DatasetUtil.loadSchema(tmpFile);
                if (open != null) {
                    if (0 != 0) {
                        try {
                            open.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        open.close();
                    }
                }
                return loadSchema;
            } finally {
            }
        } catch (Throwable th3) {
            if (open != null) {
                if (th != null) {
                    try {
                        open.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    open.close();
                }
            }
            throw th3;
        }
    }

    protected PipelineModel loadPipelineModel(SparkSession sparkSession, List<File> list) throws IOException {
        InputStream open = open(getPipelineModelZipPath());
        Throwable th = null;
        try {
            try {
                File tmpFile = toTmpFile(open, getAlgorithm() + getDataset(), ".zip");
                list.add(tmpFile);
                File uncompress = ArchiveUtil.uncompress(tmpFile);
                list.add(uncompress);
                PipelineModel load = PipelineModelUtil.load(sparkSession, uncompress);
                if (open != null) {
                    if (0 != 0) {
                        try {
                            open.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        open.close();
                    }
                }
                return load;
            } finally {
            }
        } catch (Throwable th3) {
            if (open != null) {
                if (th != null) {
                    try {
                        open.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    open.close();
                }
            }
            throw th3;
        }
    }

    protected StructType updateSchema(StructType structType, PipelineModel pipelineModel) {
        return structType;
    }

    protected Dataset<Row> loadInput(SparkSession sparkSession, List<File> list) throws IOException {
        InputStream open = open(getInputCsvPath());
        Throwable th = null;
        try {
            try {
                File tmpFile = toTmpFile(open, getDataset(), ".csv");
                list.add(tmpFile);
                Dataset<Row> loadCsv = DatasetUtil.loadCsv(sparkSession, tmpFile);
                if (open != null) {
                    if (0 != 0) {
                        try {
                            open.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        open.close();
                    }
                }
                return loadCsv;
            } finally {
            }
        } catch (Throwable th3) {
            if (open != null) {
                if (th != null) {
                    try {
                        open.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    open.close();
                }
            }
            throw th3;
        }
    }

    protected static File toTmpFile(InputStream inputStream, String str, String str2) throws IOException {
        File createTempFile = File.createTempFile(str, str2);
        FileOutputStream fileOutputStream = new FileOutputStream(createTempFile);
        Throwable th = null;
        try {
            try {
                ByteStreams.copy(inputStream, fileOutputStream);
                if (fileOutputStream != null) {
                    if (0 != 0) {
                        try {
                            fileOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        fileOutputStream.close();
                    }
                }
                return createTempFile;
            } finally {
            }
        } catch (Throwable th3) {
            if (fileOutputStream != null) {
                if (th != null) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            throw th3;
        }
    }
}
