package org.jpmml.sparkml.model;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.CategoricalSplit;
import org.apache.spark.ml.tree.ContinuousSplit;
import org.apache.spark.ml.tree.DecisionTreeModel;
import org.apache.spark.ml.tree.InternalNode;
import org.apache.spark.ml.tree.LeafNode;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.ml.tree.Split;
import org.apache.spark.ml.tree.TreeEnsembleModel;
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator;
import org.dmg.pmml.Array;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;
import org.dmg.pmml.VisitorAction;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.FeatureSchema;
import org.jpmml.converter.ListFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.model.visitors.AbstractVisitor;

/* loaded from: input_file:org/jpmml/sparkml/model/TreeModelUtil.class */
public class TreeModelUtil {
    private static final double[] TRUE = {1.0d};
    private static final double[] FALSE = {0.0d};

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.model.TreeModelUtil$3, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/model/TreeModelUtil$3.class */
    public static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunctionType = new int[MiningFunctionType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunctionType[MiningFunctionType.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunctionType[MiningFunctionType.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    private TreeModelUtil() {
    }

    public static TreeModel encodeDecisionTree(DecisionTreeModel decisionTreeModel, FeatureSchema featureSchema) {
        Node rootNode = decisionTreeModel.rootNode();
        if (decisionTreeModel instanceof DecisionTreeRegressionModel) {
            return encodeTreeModel(MiningFunctionType.REGRESSION, rootNode, featureSchema);
        }
        if (decisionTreeModel instanceof DecisionTreeClassificationModel) {
            return encodeTreeModel(MiningFunctionType.CLASSIFICATION, rootNode, featureSchema);
        }
        throw new IllegalArgumentException();
    }

    public static List<TreeModel> encodeDecisionTreeEnsemble(TreeEnsembleModel treeEnsembleModel, final FeatureSchema featureSchema) {
        return new ArrayList(Lists.transform(Arrays.asList(treeEnsembleModel.trees()), new Function<DecisionTreeModel, TreeModel>() { // from class: org.jpmml.sparkml.model.TreeModelUtil.1
            private FeatureSchema segmentSchema;

            {
                this.segmentSchema = new FeatureSchema((FieldName) null, featureSchema.getTargetCategories(), featureSchema.getActiveFields(), featureSchema.getFeatures());
            }

            public TreeModel apply(DecisionTreeModel decisionTreeModel) {
                return TreeModelUtil.encodeDecisionTree(decisionTreeModel, this.segmentSchema);
            }
        }));
    }

    public static TreeModel encodeTreeModel(MiningFunctionType miningFunctionType, Node node, FeatureSchema featureSchema) {
        org.dmg.pmml.Node predicate = encodeNode(miningFunctionType, node, featureSchema).setPredicate(new True());
        return new TreeModel(miningFunctionType, ModelUtil.createMiningSchema(featureSchema, predicate), predicate).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    }

    public static void scalePredictions(TreeModel treeModel, final double d) {
        if (ValueUtil.isOne(Double.valueOf(d))) {
            return;
        }
        new AbstractVisitor() { // from class: org.jpmml.sparkml.model.TreeModelUtil.2
            public VisitorAction visit(org.dmg.pmml.Node node) {
                node.setScore(ValueUtil.formatValue(Double.valueOf(Double.parseDouble(node.getScore()) * d)));
                return super.visit(node);
            }
        }.applyTo(treeModel);
    }

    public static org.dmg.pmml.Node encodeNode(MiningFunctionType miningFunctionType, Node node, FeatureSchema featureSchema) {
        if (node instanceof InternalNode) {
            return encodeInternalNode(miningFunctionType, (InternalNode) node, featureSchema);
        }
        if (node instanceof LeafNode) {
            return encodeLeafNode(miningFunctionType, (LeafNode) node, featureSchema);
        }
        throw new IllegalArgumentException();
    }

    private static org.dmg.pmml.Node encodeInternalNode(MiningFunctionType miningFunctionType, InternalNode internalNode, FeatureSchema featureSchema) {
        org.dmg.pmml.Node createNode = createNode(miningFunctionType, internalNode, featureSchema);
        Predicate[] encodeSplit = encodeSplit(internalNode.split(), featureSchema);
        createNode.addNodes(new org.dmg.pmml.Node[]{encodeNode(miningFunctionType, internalNode.leftChild(), featureSchema).setPredicate(encodeSplit[0]), encodeNode(miningFunctionType, internalNode.rightChild(), featureSchema).setPredicate(encodeSplit[1])});
        return createNode;
    }

    private static org.dmg.pmml.Node encodeLeafNode(MiningFunctionType miningFunctionType, LeafNode leafNode, FeatureSchema featureSchema) {
        return createNode(miningFunctionType, leafNode, featureSchema);
    }

    private static org.dmg.pmml.Node createNode(MiningFunctionType miningFunctionType, Node node, FeatureSchema featureSchema) {
        org.dmg.pmml.Node node2 = new org.dmg.pmml.Node();
        switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$MiningFunctionType[miningFunctionType.ordinal()]) {
            case 1:
                node2.setScore(ValueUtil.formatValue(Double.valueOf(node.prediction())));
                break;
            case 2:
                List targetCategories = featureSchema.getTargetCategories();
                if (targetCategories == null) {
                    throw new IllegalArgumentException();
                }
                node2.setScore((String) targetCategories.get(ValueUtil.asInt(Double.valueOf(node.prediction()))));
                ImpurityCalculator impurityStats = node.impurityStats();
                node2.setRecordCount(Double.valueOf(impurityStats.count()));
                double[] stats = impurityStats.stats();
                for (int i = 0; i < stats.length; i++) {
                    node2.addScoreDistributions(new ScoreDistribution[]{new ScoreDistribution((String) targetCategories.get(i), stats[i])});
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return node2;
    }

    private static Predicate[] encodeSplit(Split split, FeatureSchema featureSchema) {
        if (split instanceof ContinuousSplit) {
            return encodeContinuousSplit((ContinuousSplit) split, featureSchema);
        }
        if (split instanceof CategoricalSplit) {
            return encodeCategoricalSplit((CategoricalSplit) split, featureSchema);
        }
        throw new IllegalArgumentException();
    }

    private static Predicate[] encodeContinuousSplit(ContinuousSplit continuousSplit, FeatureSchema featureSchema) {
        ContinuousFeature feature = featureSchema.getFeature(continuousSplit.featureIndex());
        String formatValue = ValueUtil.formatValue(Double.valueOf(continuousSplit.threshold()));
        return new Predicate[]{new SimplePredicate(feature.getName(), SimplePredicate.Operator.LESS_OR_EQUAL).setValue(formatValue), new SimplePredicate(feature.getName(), SimplePredicate.Operator.GREATER_THAN).setValue(formatValue)};
    }

    private static Predicate[] encodeCategoricalSplit(CategoricalSplit categoricalSplit, FeatureSchema featureSchema) {
        SimplePredicate.Operator operator;
        SimplePredicate.Operator operator2;
        ListFeature feature = featureSchema.getFeature(categoricalSplit.featureIndex());
        double[] leftCategories = categoricalSplit.leftCategories();
        double[] rightCategories = categoricalSplit.rightCategories();
        if (feature instanceof ListFeature) {
            ListFeature listFeature = feature;
            if (listFeature.getValues().size() != leftCategories.length + rightCategories.length) {
                throw new IllegalArgumentException();
            }
            return new Predicate[]{createCategoricalPredicate(listFeature, leftCategories), createCategoricalPredicate(listFeature, rightCategories)};
        }
        if (!(feature instanceof BinaryFeature)) {
            throw new IllegalArgumentException();
        }
        BinaryFeature binaryFeature = (BinaryFeature) feature;
        if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
            operator = SimplePredicate.Operator.EQUAL;
            operator2 = SimplePredicate.Operator.NOT_EQUAL;
        } else {
            if (!Arrays.equals(FALSE, leftCategories) || !Arrays.equals(TRUE, rightCategories)) {
                throw new IllegalArgumentException();
            }
            operator = SimplePredicate.Operator.NOT_EQUAL;
            operator2 = SimplePredicate.Operator.EQUAL;
        }
        String formatValue = ValueUtil.formatValue(binaryFeature.getValue());
        return new Predicate[]{new SimplePredicate(binaryFeature.getName(), operator).setValue(formatValue), new SimplePredicate(binaryFeature.getName(), operator2).setValue(formatValue)};
    }

    private static Predicate createCategoricalPredicate(ListFeature listFeature, double[] dArr) {
        ArrayList arrayList = new ArrayList();
        for (double d : dArr) {
            arrayList.add(listFeature.getValue(ValueUtil.asInt(Double.valueOf(d))));
        }
        if (arrayList.size() == 1) {
            return new SimplePredicate().setField(listFeature.getName()).setOperator(SimplePredicate.Operator.EQUAL).setValue((String) arrayList.get(0));
        }
        return new SimpleSetPredicate().setField(listFeature.getName()).setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN).setArray(new Array(Array.Type.INT, ValueUtil.formatArrayValue(arrayList)));
    }
}
