package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TransformationDictionary;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureSchema;
import org.jpmml.converter.ListFeature;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.PseudoFeature;

/* loaded from: input_file:org/jpmml/sparkml/FeatureMapper.class */
public class FeatureMapper {
    private StructType schema;
    private Map<String, List<Feature>> columnFeatures = new LinkedHashMap();
    private Map<FieldName, DataField> dataFields = new LinkedHashMap();
    private Map<FieldName, DerivedField> derivedFields = new LinkedHashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.FeatureMapper$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/FeatureMapper$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.INTEGER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.DOUBLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public FeatureMapper(StructType structType) {
        this.schema = null;
        this.schema = structType;
    }

    public PMML encodePMML() {
        if (!Collections.disjoint(this.dataFields.keySet(), this.derivedFields.keySet())) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList(this.dataFields.values());
        ArrayList arrayList2 = new ArrayList(this.derivedFields.values());
        DataDictionary dataDictionary = new DataDictionary();
        dataDictionary.getDataFields().addAll(arrayList);
        TransformationDictionary transformationDictionary = null;
        if (arrayList2.size() > 0) {
            transformationDictionary = new TransformationDictionary();
            transformationDictionary.getDerivedFields().addAll(arrayList2);
        }
        return new PMML("4.2", PMMLUtil.createHeader("JPMML-SparkML", "1.0-SNAPSHOT"), dataDictionary).setTransformationDictionary(transformationDictionary);
    }

    public void append(FeatureConverter<?> featureConverter) {
        Object transformer = featureConverter.getTransformer();
        List<Feature> encodeFeatures = featureConverter.encodeFeatures(this);
        if (transformer instanceof HasOutputCol) {
            this.columnFeatures.put(((HasOutputCol) transformer).getOutputCol(), encodeFeatures);
        }
    }

    public FeatureSchema createSchema(PredictionModel<?, ?> predictionModel) {
        FieldName name;
        List list = null;
        if ((predictionModel instanceof ClassificationModel) || (predictionModel instanceof GBTClassificationModel)) {
            ListFeature onlyFeature = getOnlyFeature(predictionModel.getLabelCol());
            name = onlyFeature.getName();
            list = onlyFeature.getValues();
        } else {
            name = getOnlyFeature(predictionModel.getLabelCol()).getName();
        }
        ArrayList arrayList = new ArrayList(this.dataFields.keySet());
        arrayList.remove(name);
        List<Feature> features = getFeatures(predictionModel.getFeaturesCol());
        if (features.size() != predictionModel.numFeatures()) {
            throw new IllegalArgumentException();
        }
        return new FeatureSchema(name, list, arrayList, features);
    }

    public Feature getOnlyFeature(String str) {
        return (Feature) Iterables.getOnlyElement(getFeatures(str));
    }

    public List<Feature> getFeatures(String str) {
        ContinuousFeature pseudoFeature;
        List<Feature> list = this.columnFeatures.get(str);
        if (list != null) {
            return list;
        }
        FieldName create = FieldName.create(str);
        DataField dataField = this.dataFields.get(create);
        if (dataField == null) {
            dataField = createDataField(create);
        }
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$DataType[dataField.getDataType().ordinal()]) {
            case 1:
            case 2:
            case 3:
                pseudoFeature = new ContinuousFeature(dataField);
                break;
            default:
                pseudoFeature = new PseudoFeature(dataField);
                break;
        }
        return Collections.singletonList(pseudoFeature);
    }

    public DataField getDataField(FieldName fieldName) {
        return this.dataFields.get(fieldName);
    }

    public DataField createDataField(FieldName fieldName) {
        OpType opType;
        DataType dataType;
        org.apache.spark.sql.types.DataType dataType2 = this.schema.apply(fieldName.getValue()).dataType();
        if (dataType2 instanceof NumericType) {
            opType = OpType.CONTINUOUS;
            dataType = dataType2 instanceof IntegerType ? DataType.INTEGER : DataType.DOUBLE;
        } else {
            if (!(dataType2 instanceof StringType)) {
                throw new IllegalArgumentException();
            }
            opType = OpType.CATEGORICAL;
            dataType = DataType.STRING;
        }
        return createDataField(fieldName, opType, dataType);
    }

    public DataField createDataField(FieldName fieldName, OpType opType, DataType dataType) {
        DataField dataField = new DataField(fieldName, opType, dataType);
        this.dataFields.put(dataField.getName(), dataField);
        return dataField;
    }

    public DerivedField getDerivedField(FieldName fieldName) {
        return this.derivedFields.get(fieldName);
    }

    public DerivedField createDerivedField(FieldName fieldName, OpType opType, DataType dataType, Expression expression) {
        DerivedField expression2 = new DerivedField(opType, dataType).setName(fieldName).setExpression(expression);
        this.derivedFields.put(expression2.getName(), expression2);
        return expression2;
    }
}
