package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ListFeature;
import org.jpmml.converter.PMMLMapper;
import org.jpmml.converter.Schema;
import org.jpmml.converter.WildcardFeature;

/* loaded from: input_file:org/jpmml/sparkml/FeatureMapper.class */
public class FeatureMapper extends PMMLMapper {
    private StructType schema;
    private Map<String, List<Feature>> columnFeatures = new LinkedHashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.FeatureMapper$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/FeatureMapper$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$FieldUsageType;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.INTEGER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.DOUBLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$dmg$pmml$FieldUsageType = new int[FieldUsageType.values().length];
            try {
                $SwitchMap$org$dmg$pmml$FieldUsageType[FieldUsageType.ACTIVE.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$dmg$pmml$FieldUsageType[FieldUsageType.PREDICTED.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$dmg$pmml$FieldUsageType[FieldUsageType.TARGET.ordinal()] = 3;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    public FeatureMapper(StructType structType) {
        this.schema = null;
        this.schema = structType;
    }

    public PMML encodePMML(Model model) {
        PMML encodePMML = super.encodePMML(model);
        HashSet hashSet = new HashSet();
        ListIterator listIterator = model.getMiningSchema().getMiningFields().listIterator();
        while (listIterator.hasNext()) {
            MiningField miningField = (MiningField) listIterator.next();
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$FieldUsageType[miningField.getUsageType().ordinal()]) {
                case 1:
                case 2:
                case 3:
                    hashSet.add(miningField.getName());
                    break;
            }
        }
        ListIterator listIterator2 = encodePMML.getDataDictionary().getDataFields().listIterator();
        while (listIterator2.hasNext()) {
            if (!hashSet.contains(((DataField) listIterator2.next()).getName())) {
                listIterator2.remove();
            }
        }
        return encodePMML;
    }

    public void append(FeatureConverter<?> featureConverter) {
        Object transformer = featureConverter.getTransformer();
        List<Feature> encodeFeatures = featureConverter.encodeFeatures(this);
        if (transformer instanceof HasOutputCol) {
            putFeatures(((HasOutputCol) transformer).getOutputCol(), encodeFeatures);
        }
    }

    public void append(ModelConverter<?> modelConverter) {
        T transformer = modelConverter.getTransformer();
        List<Feature> encodeFeatures = modelConverter.encodeFeatures(this);
        if (transformer instanceof HasPredictionCol) {
            putFeatures(((HasPredictionCol) transformer).getPredictionCol(), encodeFeatures);
        }
    }

    public Schema createSchema(org.apache.spark.ml.Model<?> model) {
        FieldName fieldName;
        List list = null;
        if (model instanceof PredictionModel) {
            ListFeature onlyFeature = getOnlyFeature(((HasLabelCol) model).getLabelCol());
            fieldName = onlyFeature.getName();
            if ((model instanceof ClassificationModel) || (model instanceof GBTClassificationModel)) {
                list = onlyFeature.getValues();
            }
        } else {
            if (!(model instanceof KMeansModel)) {
                throw new IllegalArgumentException();
            }
            fieldName = null;
        }
        ArrayList arrayList = new ArrayList(getDataFields().keySet());
        arrayList.remove(fieldName);
        List<Feature> features = getFeatures(((HasFeaturesCol) model).getFeaturesCol());
        if (model instanceof PredictionModel) {
            PredictionModel predictionModel = (PredictionModel) model;
            if (features.size() != predictionModel.numFeatures()) {
                throw new IllegalArgumentException("Expected " + predictionModel.numFeatures() + " features, got " + features.size() + " features");
            }
        }
        return new Schema(fieldName, list, arrayList, features);
    }

    public boolean hasFeatures(String str) {
        return this.columnFeatures.containsKey(str);
    }

    public Feature getOnlyFeature(String str) {
        return (Feature) Iterables.getOnlyElement(getFeatures(str));
    }

    public List<Feature> getFeatures(String str) {
        ContinuousFeature wildcardFeature;
        List<Feature> list = this.columnFeatures.get(str);
        if (list != null) {
            return list;
        }
        FieldName create = FieldName.create(str);
        DataField dataField = getDataField(create);
        if (dataField == null) {
            dataField = createDataField(create);
        }
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$DataType[dataField.getDataType().ordinal()]) {
            case 1:
            case 2:
            case 3:
                wildcardFeature = new ContinuousFeature(dataField);
                break;
            default:
                wildcardFeature = new WildcardFeature(dataField);
                break;
        }
        return Collections.singletonList(wildcardFeature);
    }

    public List<Feature> getFeatures(String str, int[] iArr) {
        List<Feature> features = getFeatures(str);
        ArrayList arrayList = new ArrayList();
        for (int i : iArr) {
            arrayList.add(features.get(i));
        }
        return arrayList;
    }

    public void putFeatures(String str, List<Feature> list) {
        checkColumn(str);
        this.columnFeatures.put(str, list);
    }

    public DataField createDataField(FieldName fieldName) {
        OpType opType;
        DataType dataType;
        org.apache.spark.sql.types.DataType dataType2 = this.schema.apply(fieldName.getValue()).dataType();
        if (dataType2 instanceof NumericType) {
            opType = OpType.CONTINUOUS;
            dataType = dataType2 instanceof IntegerType ? DataType.INTEGER : DataType.DOUBLE;
        } else {
            if (!(dataType2 instanceof StringType)) {
                throw new IllegalArgumentException();
            }
            opType = OpType.CATEGORICAL;
            dataType = DataType.STRING;
        }
        return createDataField(fieldName, opType, dataType);
    }

    public void removeDataField(FieldName fieldName) {
        if (((DataField) getDataFields().remove(fieldName)) == null) {
            throw new IllegalArgumentException();
        }
    }

    private void checkColumn(String str) {
        List<Feature> list = this.columnFeatures.get(str);
        if (list != null && list.size() > 0 && !(((Feature) Iterables.getOnlyElement(list)) instanceof WildcardFeature)) {
            throw new IllegalArgumentException(str);
        }
    }
}
