package org.jpmml.sparkml;

import com.google.common.io.CharStreams;
import com.google.common.io.Files;
import com.google.common.io.MoreFiles;
import com.google.common.io.RecursiveDeleteOption;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Collections;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.AtomicType;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.FractionalType;
import org.apache.spark.sql.types.IntegralType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/* loaded from: input_file:org/jpmml/sparkml/DatasetUtil.class */
public class DatasetUtil {
    private static final AtomicInteger ID = new AtomicInteger(1);

    private DatasetUtil() {
    }

    public static StructType loadSchema(File file) throws IOException {
        FileInputStream fileInputStream = new FileInputStream(file);
        Throwable th = null;
        try {
            StructType fromJson = StructType.fromJson(CharStreams.toString(new InputStreamReader(fileInputStream, "UTF-8")));
            if (fileInputStream != null) {
                if (0 != 0) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            return fromJson;
        } catch (Throwable th3) {
            if (fileInputStream != null) {
                if (0 != 0) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            throw th3;
        }
    }

    public static void storeSchema(Dataset<Row> dataset, File file) throws IOException {
        storeSchema(dataset.schema(), file);
    }

    public static void storeSchema(StructType structType, File file) throws IOException {
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            try {
                fileOutputStream.write(structType.json().getBytes("UTF-8"));
                if (fileOutputStream != null) {
                    if (0 == 0) {
                        fileOutputStream.close();
                        return;
                    }
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (fileOutputStream != null) {
                if (th != null) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            throw th4;
        }
    }

    public static Dataset<Row> loadCsv(SparkSession sparkSession, File file) throws IOException {
        return sparkSession.read().format("csv").option("header", true).option("inferSchema", true).option("nullValue", "N/A").option("nanValue", "N/A").load(file.getAbsolutePath());
    }

    public static void storeCsv(Dataset<Row> dataset, File file) throws IOException {
        File createTempFile = File.createTempFile("Dataset", "");
        if (!createTempFile.delete()) {
            throw new IOException();
        }
        dataset.coalesce(1).write().format("csv").option("header", "true").save(createTempFile.getAbsolutePath());
        File[] listFiles = createTempFile.listFiles(new FileFilter() { // from class: org.jpmml.sparkml.DatasetUtil.1
            @Override // java.io.FileFilter
            public boolean accept(File file2) {
                return file2.getName().endsWith(".csv");
            }
        });
        if (listFiles.length != 1) {
            throw new IOException();
        }
        Files.copy(listFiles[0], file);
        MoreFiles.deleteRecursively(createTempFile.toPath(), new RecursiveDeleteOption[0]);
    }

    public static Dataset<Row> castColumn(Dataset<Row> dataset, String str, DataType dataType) {
        String str2 = "tmp_" + str;
        return dataset.withColumn(str2, dataset.apply(str).cast(dataType)).drop(str).withColumnRenamed(str2, str);
    }

    public static Dataset<Row> castColumns(Dataset<Row> dataset, StructType structType) {
        StructType schema = dataset.schema();
        for (StructField structField : structType.fields()) {
            try {
                if (!Objects.equals(structField.dataType(), schema.apply(structField.name()).dataType())) {
                    dataset = castColumn(dataset, structField.name(), structField.dataType());
                }
            } catch (IllegalArgumentException e) {
            }
        }
        return dataset;
    }

    public static LogicalPlan createAnalyzedLogicalPlan(SparkSession sparkSession, StructType structType, String str) {
        String str2 = "sql2pmml_" + ID.getAndIncrement();
        String replace = str.replace("__THIS__", str2);
        sparkSession.createDataFrame(Collections.emptyList(), structType).createOrReplaceTempView(str2);
        try {
            LogicalPlan analyzed = sparkSession.sql(replace).queryExecution().analyzed();
            sparkSession.catalog().dropTempView(str2);
            return analyzed;
        } catch (Throwable th) {
            sparkSession.catalog().dropTempView(str2);
            throw th;
        }
    }

    public static org.dmg.pmml.DataType translateDataType(DataType dataType) {
        if (dataType instanceof AtomicType) {
            return translateAtomicType((AtomicType) dataType);
        }
        throw new IllegalArgumentException("Expected atomic data type, got " + dataType.typeName() + " data type");
    }

    public static org.dmg.pmml.DataType translateAtomicType(AtomicType atomicType) {
        if (atomicType instanceof StringType) {
            return org.dmg.pmml.DataType.STRING;
        }
        if (atomicType instanceof IntegralType) {
            return translateIntegralType((IntegralType) atomicType);
        }
        if (atomicType instanceof FractionalType) {
            return translateFractionalType((FractionalType) atomicType);
        }
        if (atomicType instanceof BooleanType) {
            return org.dmg.pmml.DataType.BOOLEAN;
        }
        throw new IllegalArgumentException("Expected string, integral, fractional or boolean data type, got " + atomicType.typeName() + " data type");
    }

    public static org.dmg.pmml.DataType translateIntegralType(IntegralType integralType) {
        return org.dmg.pmml.DataType.INTEGER;
    }

    public static org.dmg.pmml.DataType translateFractionalType(FractionalType fractionalType) {
        if (fractionalType instanceof FloatType) {
            return org.dmg.pmml.DataType.FLOAT;
        }
        if (fractionalType instanceof DoubleType) {
            return org.dmg.pmml.DataType.DOUBLE;
        }
        throw new IllegalArgumentException("Expected float or double data type, got " + fractionalType.typeName() + " data type");
    }
}
