package io.prestosql.plugin.ml;

import com.google.common.base.Preconditions;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.hash.HashCode;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.prestosql.plugin.ml.type.ClassifierType;
import io.prestosql.plugin.ml.type.RegressorType;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.function.ScalarFunction;
import io.prestosql.spi.function.SqlType;

/* loaded from: input_file:io/prestosql/plugin/ml/MLFunctions.class */
public final class MLFunctions {
    private static final Cache<HashCode, Model> MODEL_CACHE = CacheBuilder.newBuilder().maximumSize(5).build();
    private static final String MAP_BIGINT_DOUBLE = "map(bigint,double)";

    private MLFunctions() {
    }

    @SqlType("varchar")
    @ScalarFunction("classify")
    public static Slice varcharClassify(@SqlType("map(bigint,double)") Block block, @SqlType("Classifier<varchar>") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Model orLoadModel = getOrLoadModel(slice);
        Preconditions.checkArgument(orLoadModel.getType().equals(ClassifierType.VARCHAR_CLASSIFIER), "model is not a classifier<varchar>");
        return Slices.utf8Slice((String) ((Classifier) orLoadModel).classify(features));
    }

    @SqlType("bigint")
    @ScalarFunction
    public static long classify(@SqlType("map(bigint,double)") Block block, @SqlType("Classifier<bigint>") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Preconditions.checkArgument(getOrLoadModel(slice).getType().equals(ClassifierType.BIGINT_CLASSIFIER), "model is not a classifier<bigint>");
        return ((Integer) ((Classifier) r0).classify(features)).intValue();
    }

    @SqlType("double")
    @ScalarFunction
    public static double regress(@SqlType("map(bigint,double)") Block block, @SqlType("Regressor") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Model orLoadModel = getOrLoadModel(slice);
        Preconditions.checkArgument(orLoadModel.getType().equals(RegressorType.REGRESSOR), "model is not a regressor");
        return ((Regressor) orLoadModel).regress(features);
    }

    private static Model getOrLoadModel(Slice slice) {
        HashCode modelHash = ModelUtils.modelHash(slice);
        Model model = (Model) MODEL_CACHE.getIfPresent(modelHash);
        if (model == null) {
            model = ModelUtils.deserialize(slice);
            MODEL_CACHE.put(modelHash, model);
        }
        return model;
    }
}
