package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.Binarizer;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.ml.feature.ChiSqSelectorModel;
import org.apache.spark.ml.feature.ColumnPruner;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.MinMaxScalerModel;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.PCAModel;
import org.apache.spark.ml.feature.RFormulaModel;
import org.apache.spark.ml.feature.StandardScalerModel;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorAttributeRewriter;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.feature.VectorSlicer;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.jpmml.converter.MiningModelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.feature.BinarizerConverter;
import org.jpmml.sparkml.feature.BucketizerConverter;
import org.jpmml.sparkml.feature.ChiSqSelectorModelConverter;
import org.jpmml.sparkml.feature.ColumnPrunerConverter;
import org.jpmml.sparkml.feature.IndexToStringConverter;
import org.jpmml.sparkml.feature.MinMaxScalerModelConverter;
import org.jpmml.sparkml.feature.OneHotEncoderConverter;
import org.jpmml.sparkml.feature.PCAModelConverter;
import org.jpmml.sparkml.feature.RFormulaModelConverter;
import org.jpmml.sparkml.feature.StandardScalerModelConverter;
import org.jpmml.sparkml.feature.StringIndexerModelConverter;
import org.jpmml.sparkml.feature.VectorAssemblerConverter;
import org.jpmml.sparkml.feature.VectorAttributeRewriterConverter;
import org.jpmml.sparkml.feature.VectorIndexerModelConverter;
import org.jpmml.sparkml.feature.VectorSlicerConverter;
import org.jpmml.sparkml.model.DecisionTreeClassificationModelConverter;
import org.jpmml.sparkml.model.DecisionTreeRegressionModelConverter;
import org.jpmml.sparkml.model.GBTClassificationModelConverter;
import org.jpmml.sparkml.model.GBTRegressionModelConverter;
import org.jpmml.sparkml.model.KMeansModelConverter;
import org.jpmml.sparkml.model.LinearRegressionModelConverter;
import org.jpmml.sparkml.model.LogisticRegressionModelConverter;
import org.jpmml.sparkml.model.MultilayerPerceptronClassificationModelConverter;
import org.jpmml.sparkml.model.RandomForestClassificationModelConverter;
import org.jpmml.sparkml.model.RandomForestRegressionModelConverter;

/* loaded from: input_file:org/jpmml/sparkml/ConverterUtil.class */
public class ConverterUtil {
    private static final Map<Class<? extends Transformer>, Class<? extends TransformerConverter>> converters = new LinkedHashMap();

    /* renamed from: org.jpmml.sparkml.ConverterUtil$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/ConverterUtil$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$FieldUsageType = new int[FieldUsageType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$FieldUsageType[FieldUsageType.PREDICTED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$FieldUsageType[FieldUsageType.TARGET.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    private ConverterUtil() {
    }

    public static PMML toPMML(StructType structType, PipelineModel pipelineModel) {
        Model segmentation;
        FeatureMapper featureMapper = new FeatureMapper(structType);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (HasPredictionCol hasPredictionCol : pipelineModel.stages()) {
            TransformerConverter createConverter = createConverter(hasPredictionCol);
            if (createConverter instanceof FeatureConverter) {
                featureMapper.append((FeatureConverter<?>) createConverter);
            } else {
                if (!(createConverter instanceof ModelConverter)) {
                    throw new IllegalArgumentException();
                }
                ModelConverter<?> modelConverter = (ModelConverter) createConverter;
                Schema createSchema = featureMapper.createSchema((org.apache.spark.ml.Model) hasPredictionCol);
                if (createConverter instanceof RegressionModelConverter) {
                    featureMapper.getDataField(createSchema.getTargetField()).setOpType(OpType.CONTINUOUS).setDataType(DataType.DOUBLE);
                }
                Model mo4encodeModel = modelConverter.mo4encodeModel(createSchema);
                featureMapper.append(modelConverter);
                linkedHashMap.put(hasPredictionCol.getPredictionCol(), mo4encodeModel);
            }
        }
        if (linkedHashMap.size() == 1) {
            segmentation = (Model) Iterables.getOnlyElement(linkedHashMap.values());
        } else {
            if (linkedHashMap.size() < 2) {
                throw new IllegalArgumentException();
            }
            ArrayList arrayList = new ArrayList();
            Iterator it = new ArrayList(linkedHashMap.entrySet()).iterator();
            while (it.hasNext()) {
                Map.Entry entry = (Map.Entry) it.next();
                String str = (String) entry.getKey();
                Model model = (Model) entry.getValue();
                for (MiningField miningField : model.getMiningSchema().getMiningFields()) {
                    switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$FieldUsageType[miningField.getUsageType().ordinal()]) {
                        case 1:
                        case 2:
                            arrayList.add(miningField);
                            break;
                    }
                }
                if (it.hasNext()) {
                    FieldName create = FieldName.create(str);
                    featureMapper.removeDataField(create);
                    Output output = model.getOutput();
                    if (output == null) {
                        output = new Output();
                        model.setOutput(output);
                    }
                    output.addOutputFields(new OutputField[]{ModelUtil.createPredictedField(create)});
                } else {
                    ArrayList arrayList2 = new ArrayList(linkedHashMap.values());
                    segmentation = new MiningModel(((Model) Iterables.getLast(arrayList2)).getFunctionName(), new MiningSchema(arrayList)).setSegmentation(MiningModelUtil.createSegmentation(MultipleModelMethodType.MODEL_CHAIN, arrayList2));
                }
            }
            ArrayList arrayList22 = new ArrayList(linkedHashMap.values());
            segmentation = new MiningModel(((Model) Iterables.getLast(arrayList22)).getFunctionName(), new MiningSchema(arrayList)).setSegmentation(MiningModelUtil.createSegmentation(MultipleModelMethodType.MODEL_CHAIN, arrayList22));
        }
        return featureMapper.encodePMML(segmentation);
    }

    public static FeatureConverter<?> createFeatureConverter(Transformer transformer) {
        return (FeatureConverter) createConverter(transformer);
    }

    public static ModelConverter<?> createModelConverter(Transformer transformer) {
        return (ModelConverter) createConverter(transformer);
    }

    public static <T extends Transformer> TransformerConverter<T> createConverter(T t) {
        Class<?> cls = t.getClass();
        Class<? extends TransformerConverter> converterClazz = getConverterClazz(cls);
        if (converterClazz == null) {
            throw new IllegalArgumentException("Transformer class " + cls.getName() + " is not supported");
        }
        try {
            return converterClazz.getDeclaredConstructor(cls).newInstance(t);
        } catch (Exception e) {
            throw new IllegalArgumentException(e);
        }
    }

    public static Class<? extends TransformerConverter> getConverterClazz(Class<? extends Transformer> cls) {
        return converters.get(cls);
    }

    public static void putConverterClazz(Class<? extends Transformer> cls, Class<? extends TransformerConverter<?>> cls2) {
        converters.put(cls, cls2);
    }

    static {
        converters.put(Binarizer.class, BinarizerConverter.class);
        converters.put(Bucketizer.class, BucketizerConverter.class);
        converters.put(ChiSqSelectorModel.class, ChiSqSelectorModelConverter.class);
        converters.put(ColumnPruner.class, ColumnPrunerConverter.class);
        converters.put(IndexToString.class, IndexToStringConverter.class);
        converters.put(MinMaxScalerModel.class, MinMaxScalerModelConverter.class);
        converters.put(OneHotEncoder.class, OneHotEncoderConverter.class);
        converters.put(PCAModel.class, PCAModelConverter.class);
        converters.put(RFormulaModel.class, RFormulaModelConverter.class);
        converters.put(StandardScalerModel.class, StandardScalerModelConverter.class);
        converters.put(StringIndexerModel.class, StringIndexerModelConverter.class);
        converters.put(VectorAssembler.class, VectorAssemblerConverter.class);
        converters.put(VectorAttributeRewriter.class, VectorAttributeRewriterConverter.class);
        converters.put(VectorIndexerModel.class, VectorIndexerModelConverter.class);
        converters.put(VectorSlicer.class, VectorSlicerConverter.class);
        converters.put(DecisionTreeClassificationModel.class, DecisionTreeClassificationModelConverter.class);
        converters.put(DecisionTreeRegressionModel.class, DecisionTreeRegressionModelConverter.class);
        converters.put(GBTClassificationModel.class, GBTClassificationModelConverter.class);
        converters.put(GBTRegressionModel.class, GBTRegressionModelConverter.class);
        converters.put(KMeansModel.class, KMeansModelConverter.class);
        converters.put(LinearRegressionModel.class, LinearRegressionModelConverter.class);
        converters.put(LogisticRegressionModel.class, LogisticRegressionModelConverter.class);
        converters.put(MultilayerPerceptronClassificationModel.class, MultilayerPerceptronClassificationModelConverter.class);
        converters.put(RandomForestClassificationModel.class, RandomForestClassificationModelConverter.class);
        converters.put(RandomForestRegressionModel.class, RandomForestRegressionModelConverter.class);
    }
}
