package org.jpmml.sparkml;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.Schema;

/* loaded from: input_file:org/jpmml/sparkml/ModelConverter.class */
public abstract class ModelConverter<T extends Model<T> & HasFeaturesCol & HasPredictionCol> extends TransformerConverter<T> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.ModelConverter$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/ModelConverter$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod = new int[Segmentation.MultipleModelMethod.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod[Segmentation.MultipleModelMethod.MODEL_CHAIN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 2;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public ModelConverter(T t) {
        super(t);
    }

    public abstract MiningFunction getMiningFunction();

    /* renamed from: encodeModel */
    public abstract org.dmg.pmml.Model mo7encodeModel(Schema schema);

    public Schema encodeSchema(SparkMLEncoder sparkMLEncoder) {
        int numFeatures;
        ClassificationModel classificationModel = (Model) getTransformer();
        CategoricalLabel categoricalLabel = null;
        if (classificationModel instanceof HasLabelCol) {
            String labelCol = ((HasLabelCol) classificationModel).getLabelCol();
            CategoricalFeature onlyFeature = sparkMLEncoder.getOnlyFeature(labelCol);
            MiningFunction miningFunction = getMiningFunction();
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
                case 1:
                    if (!(onlyFeature instanceof CategoricalFeature)) {
                        if (!(onlyFeature instanceof ContinuousFeature)) {
                            throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + onlyFeature);
                        }
                        ContinuousFeature continuousFeature = (ContinuousFeature) onlyFeature;
                        int numClasses = classificationModel instanceof ClassificationModel ? classificationModel.numClasses() : 2;
                        ArrayList arrayList = new ArrayList();
                        for (int i = 0; i < numClasses; i++) {
                            arrayList.add(String.valueOf(i));
                        }
                        Field categorical = sparkMLEncoder.toCategorical(continuousFeature.getName(), arrayList);
                        sparkMLEncoder.putOnlyFeature(labelCol, new CategoricalFeature(sparkMLEncoder, categorical, arrayList));
                        categoricalLabel = new CategoricalLabel(categorical.getName(), categorical.getDataType(), arrayList);
                        break;
                    } else {
                        categoricalLabel = new CategoricalLabel(sparkMLEncoder.getDataField(onlyFeature.getName()));
                        break;
                    }
                case 2:
                    Field continuous = sparkMLEncoder.toContinuous(onlyFeature.getName());
                    continuous.setDataType(DataType.DOUBLE);
                    categoricalLabel = new ContinuousLabel(continuous.getName(), continuous.getDataType());
                    break;
                default:
                    throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
            }
        }
        if (classificationModel instanceof ClassificationModel) {
            ClassificationModel classificationModel2 = classificationModel;
            CategoricalLabel categoricalLabel2 = categoricalLabel;
            int numClasses2 = classificationModel2.numClasses();
            if (numClasses2 != categoricalLabel2.size()) {
                throw new IllegalArgumentException("Expected " + numClasses2 + " target categories, got " + categoricalLabel2.size() + " target categories");
            }
        }
        List<Feature> features = sparkMLEncoder.getFeatures(((HasFeaturesCol) classificationModel).getFeaturesCol());
        if (!(classificationModel instanceof PredictionModel) || (numFeatures = ((PredictionModel) classificationModel).numFeatures()) == -1 || features.size() == numFeatures) {
            return new Schema(categoricalLabel, features);
        }
        throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
    }

    public List<OutputField> registerOutputFields(Label label, SparkMLEncoder sparkMLEncoder) {
        return null;
    }

    public org.dmg.pmml.Model registerModel(SparkMLEncoder sparkMLEncoder) {
        Schema encodeSchema = encodeSchema(sparkMLEncoder);
        Label label = encodeSchema.getLabel();
        org.dmg.pmml.Model mo7encodeModel = mo7encodeModel(encodeSchema);
        List<OutputField> registerOutputFields = registerOutputFields(label, sparkMLEncoder);
        if (registerOutputFields != null && registerOutputFields.size() > 0) {
            org.dmg.pmml.Model lastModel = getLastModel(mo7encodeModel);
            Output output = lastModel.getOutput();
            if (output == null) {
                output = new Output();
                lastModel.setOutput(output);
            }
            output.getOutputFields().addAll(0, registerOutputFields);
        }
        return mo7encodeModel;
    }

    protected org.dmg.pmml.Model getLastModel(org.dmg.pmml.Model model) {
        if (model instanceof MiningModel) {
            Segmentation segmentation = ((MiningModel) model).getSegmentation();
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod[segmentation.getMultipleModelMethod().ordinal()]) {
                case 1:
                    List segments = segmentation.getSegments();
                    if (segments.size() > 0) {
                        return ((Segment) segments.get(segments.size() - 1)).getModel();
                    }
                    break;
            }
        }
        return model;
    }
}
