package org.jpmml.sparkml;

import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.feature.Binarizer;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.ml.feature.MinMaxScalerModel;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.PCAModel;
import org.apache.spark.ml.feature.StandardScalerModel;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitor;
import org.jpmml.model.visitors.DictionaryCleaner;
import org.jpmml.model.visitors.MiningSchemaCleaner;
import org.jpmml.sparkml.feature.BinarizerConverter;
import org.jpmml.sparkml.feature.BucketizerConverter;
import org.jpmml.sparkml.feature.MinMaxScalerModelConverter;
import org.jpmml.sparkml.feature.OneHotEncoderConverter;
import org.jpmml.sparkml.feature.PCAModelConverter;
import org.jpmml.sparkml.feature.StandardScalerModelConverter;
import org.jpmml.sparkml.feature.StringIndexerModelConverter;
import org.jpmml.sparkml.feature.VectorAssemblerConverter;
import org.jpmml.sparkml.model.DecisionTreeClassificationModelConverter;
import org.jpmml.sparkml.model.DecisionTreeRegressionModelConverter;
import org.jpmml.sparkml.model.GBTClassificationModelConverter;
import org.jpmml.sparkml.model.GBTRegressionModelConverter;
import org.jpmml.sparkml.model.LinearRegressionModelConverter;
import org.jpmml.sparkml.model.LogisticRegressionModelConverter;
import org.jpmml.sparkml.model.RandomForestClassificationModelConverter;
import org.jpmml.sparkml.model.RandomForestRegressionModelConverter;

/* loaded from: input_file:org/jpmml/sparkml/ConverterUtil.class */
public class ConverterUtil {
    private static final Map<Class<? extends Transformer>, Class<? extends TransformerConverter>> converters = new LinkedHashMap();

    private ConverterUtil() {
    }

    public static PMML toPMML(StructType structType, PipelineModel pipelineModel) {
        FeatureMapper featureMapper = new FeatureMapper(structType);
        for (Transformer transformer : pipelineModel.stages()) {
            try {
                TransformerConverter createConverter = createConverter(transformer);
                if (!(createConverter instanceof FeatureConverter)) {
                    if (!(createConverter instanceof ModelConverter)) {
                        throw new IllegalArgumentException();
                    }
                    Model mo2encodeModel = ((ModelConverter) createConverter).mo2encodeModel(featureMapper.createSchema((PredictionModel) transformer));
                    PMML addModels = featureMapper.encodePMML().addModels(new Model[]{mo2encodeModel});
                    Iterator it = Arrays.asList(new MiningSchemaCleaner(), new DictionaryCleaner()).iterator();
                    while (it.hasNext()) {
                        ((Visitor) it.next()).applyTo(addModels);
                    }
                    if (mo2encodeModel instanceof MiningModel) {
                        Iterator it2 = mo2encodeModel.getMiningSchema().getMiningFields().iterator();
                        while (it2.hasNext()) {
                            if ("binarizedGbtValue".equals(((MiningField) it2.next()).getName().getValue())) {
                                it2.remove();
                            }
                        }
                    }
                    return addModels;
                }
                featureMapper.append((FeatureConverter) createConverter);
            } catch (Exception e) {
                throw new IllegalArgumentException(e);
            }
        }
        throw new IllegalArgumentException();
    }

    public static <T extends Transformer> TransformerConverter<T> createConverter(T t) throws Exception {
        Class<?> cls = t.getClass();
        Class<? extends TransformerConverter> converterClazz = getConverterClazz(cls);
        if (converterClazz == null) {
            throw new IllegalArgumentException("Transformer class " + cls + " is not supported");
        }
        return converterClazz.getDeclaredConstructor(cls).newInstance(t);
    }

    public static Class<? extends TransformerConverter> getConverterClazz(Class<? extends Transformer> cls) {
        return converters.get(cls);
    }

    public static void putConverterClazz(Class<? extends Transformer> cls, Class<? extends TransformerConverter> cls2) {
        converters.put(cls, cls2);
    }

    static {
        converters.put(Binarizer.class, BinarizerConverter.class);
        converters.put(Bucketizer.class, BucketizerConverter.class);
        converters.put(MinMaxScalerModel.class, MinMaxScalerModelConverter.class);
        converters.put(OneHotEncoder.class, OneHotEncoderConverter.class);
        converters.put(PCAModel.class, PCAModelConverter.class);
        converters.put(StandardScalerModel.class, StandardScalerModelConverter.class);
        converters.put(StringIndexerModel.class, StringIndexerModelConverter.class);
        converters.put(VectorAssembler.class, VectorAssemblerConverter.class);
        converters.put(DecisionTreeClassificationModel.class, DecisionTreeClassificationModelConverter.class);
        converters.put(DecisionTreeRegressionModel.class, DecisionTreeRegressionModelConverter.class);
        converters.put(GBTClassificationModel.class, GBTClassificationModelConverter.class);
        converters.put(GBTRegressionModel.class, GBTRegressionModelConverter.class);
        converters.put(LinearRegressionModel.class, LinearRegressionModelConverter.class);
        converters.put(LogisticRegressionModel.class, LogisticRegressionModelConverter.class);
        converters.put(RandomForestClassificationModel.class, RandomForestClassificationModelConverter.class);
        converters.put(RandomForestRegressionModel.class, RandomForestRegressionModelConverter.class);
    }
}
