package io.prestosql.plugin.ml;

import com.google.common.base.Throwables;
import com.google.common.util.concurrent.SimpleTimeLimiter;
import io.airlift.concurrent.Threads;
import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;

/* loaded from: input_file:io/prestosql/plugin/ml/AbstractSvmModel.class */
public abstract class AbstractSvmModel implements Model {
    protected svm_model model;
    protected svm_parameter params;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractSvmModel(svm_parameter svm_parameterVar) {
        this.params = (svm_parameter) Objects.requireNonNull(svm_parameterVar, "params is null");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractSvmModel(svm_model svm_modelVar) {
        this.model = (svm_model) Objects.requireNonNull(svm_modelVar, "model is null");
    }

    @Override // io.prestosql.plugin.ml.Model
    public byte[] getSerializedData() {
        File file = null;
        try {
            try {
                file = File.createTempFile("svm", null);
                svm.svm_save_model(file.getAbsolutePath(), this.model);
                byte[] readAllBytes = Files.readAllBytes(file.toPath());
                if (file != null) {
                    file.delete();
                }
                return readAllBytes;
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        } catch (Throwable th) {
            if (file != null) {
                file.delete();
            }
            throw th;
        }
    }

    @Override // io.prestosql.plugin.ml.Model
    public void train(Dataset dataset) {
        this.params.svm_type = getLibsvmType();
        svm_problem svmProblem = toSvmProblem(dataset);
        ExecutorService newCachedThreadPool = Executors.newCachedThreadPool(Threads.threadsNamed("libsvm-trainer-" + System.identityHashCode(this) + "-%s"));
        try {
            try {
                try {
                    try {
                        this.model = (svm_model) SimpleTimeLimiter.create(newCachedThreadPool).callWithTimeout(getTrainingFunction(svmProblem, this.params), 1L, TimeUnit.HOURS);
                        newCachedThreadPool.shutdownNow();
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        throw new RuntimeException(e);
                    }
                } catch (Exception e2) {
                    Throwables.throwIfUnchecked(e2);
                    throw new RuntimeException(e2);
                }
            } catch (ExecutionException e3) {
                Throwable cause = e3.getCause();
                if (cause != null) {
                    Throwables.throwIfUnchecked(cause);
                    throw new RuntimeException(cause);
                }
                newCachedThreadPool.shutdownNow();
            }
        } catch (Throwable th) {
            newCachedThreadPool.shutdownNow();
            throw th;
        }
    }

    private static Callable<svm_model> getTrainingFunction(svm_problem svm_problemVar, svm_parameter svm_parameterVar) {
        return () -> {
            return svm.svm_train(svm_problemVar, svm_parameterVar);
        };
    }

    protected abstract int getLibsvmType();

    /* JADX WARN: Type inference failed for: r1v10, types: [libsvm.svm_node[], libsvm.svm_node[][]] */
    private static svm_problem toSvmProblem(Dataset dataset) {
        svm_problem svm_problemVar = new svm_problem();
        List<Double> labels = dataset.getLabels();
        svm_problemVar.l = labels.size();
        svm_problemVar.y = new double[labels.size()];
        for (int i = 0; i < labels.size(); i++) {
            svm_problemVar.y[i] = labels.get(i).doubleValue();
        }
        svm_problemVar.x = new svm_node[labels.size()];
        for (int i2 = 0; i2 < dataset.getDatapoints().size(); i2++) {
            svm_problemVar.x[i2] = toSvmNodes(dataset.getDatapoints().get(i2));
        }
        return svm_problemVar;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static svm_node[] toSvmNodes(FeatureVector featureVector) {
        svm_node[] svm_nodeVarArr = new svm_node[featureVector.size()];
        int i = 0;
        for (Map.Entry<Integer, Double> entry : featureVector.getFeatures().entrySet()) {
            svm_nodeVarArr[i] = new svm_node();
            svm_nodeVarArr[i].index = entry.getKey().intValue();
            svm_nodeVarArr[i].value = entry.getValue().doubleValue();
            i++;
        }
        return svm_nodeVarArr;
    }
}
