package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.CategoricalSplit;
import org.apache.spark.ml.tree.ContinuousSplit;
import org.apache.spark.ml.tree.DecisionTreeModel;
import org.apache.spark.ml.tree.InternalNode;
import org.apache.spark.ml.tree.LeafNode;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.ml.tree.Split;
import org.apache.spark.ml.tree.TreeEnsembleModel;
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator;
import org.dmg.pmml.Array;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;

/* loaded from: input_file:org/jpmml/sparkml/model/TreeModelUtil.class */
public class TreeModelUtil {
    private static final double[] TRUE = {1.0d};
    private static final double[] FALSE = {0.0d};

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.model.TreeModelUtil$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/model/TreeModelUtil$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    private TreeModelUtil() {
    }

    public static TreeModel encodeDecisionTree(DecisionTreeModel decisionTreeModel, Schema schema) {
        Node rootNode = decisionTreeModel.rootNode();
        if (decisionTreeModel instanceof DecisionTreeRegressionModel) {
            return encodeTreeModel(MiningFunction.REGRESSION, rootNode, schema);
        }
        if (decisionTreeModel instanceof DecisionTreeClassificationModel) {
            return encodeTreeModel(MiningFunction.CLASSIFICATION, rootNode, schema);
        }
        throw new IllegalArgumentException();
    }

    public static List<TreeModel> encodeDecisionTreeEnsemble(TreeEnsembleModel<?> treeEnsembleModel, Schema schema) {
        Schema anonymousSchema = schema.toAnonymousSchema();
        ArrayList arrayList = new ArrayList();
        for (DecisionTreeModel decisionTreeModel : treeEnsembleModel.trees()) {
            arrayList.add(encodeDecisionTree(decisionTreeModel, anonymousSchema));
        }
        return arrayList;
    }

    public static TreeModel encodeTreeModel(MiningFunction miningFunction, Node node, Schema schema) {
        return new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema), encodeNode(miningFunction, node, schema).setPredicate(new True())).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    }

    public static org.dmg.pmml.tree.Node encodeNode(MiningFunction miningFunction, Node node, Schema schema) {
        if (node instanceof InternalNode) {
            return encodeInternalNode(miningFunction, (InternalNode) node, schema);
        }
        if (node instanceof LeafNode) {
            return encodeLeafNode(miningFunction, (LeafNode) node, schema);
        }
        throw new IllegalArgumentException();
    }

    private static org.dmg.pmml.tree.Node encodeInternalNode(MiningFunction miningFunction, InternalNode internalNode, Schema schema) {
        org.dmg.pmml.tree.Node createNode = createNode(miningFunction, internalNode, schema);
        Predicate[] encodeSplit = encodeSplit(internalNode.split(), schema);
        createNode.addNodes(new org.dmg.pmml.tree.Node[]{encodeNode(miningFunction, internalNode.leftChild(), schema).setPredicate(encodeSplit[0]), encodeNode(miningFunction, internalNode.rightChild(), schema).setPredicate(encodeSplit[1])});
        return createNode;
    }

    private static org.dmg.pmml.tree.Node encodeLeafNode(MiningFunction miningFunction, LeafNode leafNode, Schema schema) {
        return createNode(miningFunction, leafNode, schema);
    }

    private static org.dmg.pmml.tree.Node createNode(MiningFunction miningFunction, Node node, Schema schema) {
        org.dmg.pmml.tree.Node node2 = new org.dmg.pmml.tree.Node();
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
            case 1:
                node2.setScore(ValueUtil.formatValue(Double.valueOf(node.prediction())));
                break;
            case 2:
                CategoricalLabel label = schema.getLabel();
                node2.setScore(label.getValue(ValueUtil.asInt(Double.valueOf(node.prediction()))));
                ImpurityCalculator impurityStats = node.impurityStats();
                node2.setRecordCount(Double.valueOf(impurityStats.count()));
                double[] stats = impurityStats.stats();
                for (int i = 0; i < stats.length; i++) {
                    if (stats[i] != 0.0d) {
                        node2.addScoreDistributions(new ScoreDistribution[]{new ScoreDistribution(label.getValue(i), stats[i])});
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return node2;
    }

    private static Predicate[] encodeSplit(Split split, Schema schema) {
        if (split instanceof ContinuousSplit) {
            return encodeContinuousSplit((ContinuousSplit) split, schema);
        }
        if (split instanceof CategoricalSplit) {
            return encodeCategoricalSplit((CategoricalSplit) split, schema);
        }
        throw new IllegalArgumentException();
    }

    private static Predicate[] encodeContinuousSplit(ContinuousSplit continuousSplit, Schema schema) {
        BooleanFeature feature = schema.getFeature(continuousSplit.featureIndex());
        double threshold = continuousSplit.threshold();
        if (!(feature instanceof BooleanFeature)) {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            String formatValue = ValueUtil.formatValue(Double.valueOf(threshold));
            return new Predicate[]{new SimplePredicate(continuousFeature.getName(), SimplePredicate.Operator.LESS_OR_EQUAL).setValue(formatValue), new SimplePredicate(continuousFeature.getName(), SimplePredicate.Operator.GREATER_THAN).setValue(formatValue)};
        }
        BooleanFeature booleanFeature = feature;
        if (threshold != 0.0d) {
            throw new IllegalArgumentException();
        }
        return new Predicate[]{new SimplePredicate(feature.getName(), SimplePredicate.Operator.EQUAL).setValue(booleanFeature.getValue(0)), new SimplePredicate(feature.getName(), SimplePredicate.Operator.EQUAL).setValue(booleanFeature.getValue(1))};
    }

    private static Predicate[] encodeCategoricalSplit(CategoricalSplit categoricalSplit, Schema schema) {
        SimplePredicate.Operator operator;
        SimplePredicate.Operator operator2;
        BinaryFeature feature = schema.getFeature(categoricalSplit.featureIndex());
        double[] leftCategories = categoricalSplit.leftCategories();
        double[] rightCategories = categoricalSplit.rightCategories();
        if (!(feature instanceof BinaryFeature)) {
            if (!(feature instanceof CategoricalFeature)) {
                throw new IllegalArgumentException();
            }
            CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
            if (categoricalFeature.getValues().size() != leftCategories.length + rightCategories.length) {
                throw new IllegalArgumentException();
            }
            return new Predicate[]{createCategoricalPredicate(categoricalFeature, leftCategories), createCategoricalPredicate(categoricalFeature, rightCategories)};
        }
        BinaryFeature binaryFeature = feature;
        if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
            operator = SimplePredicate.Operator.EQUAL;
            operator2 = SimplePredicate.Operator.NOT_EQUAL;
        } else {
            if (!Arrays.equals(FALSE, leftCategories) || !Arrays.equals(TRUE, rightCategories)) {
                throw new IllegalArgumentException();
            }
            operator = SimplePredicate.Operator.NOT_EQUAL;
            operator2 = SimplePredicate.Operator.EQUAL;
        }
        String formatValue = ValueUtil.formatValue(binaryFeature.getValue());
        return new Predicate[]{new SimplePredicate(binaryFeature.getName(), operator).setValue(formatValue), new SimplePredicate(binaryFeature.getName(), operator2).setValue(formatValue)};
    }

    private static Predicate createCategoricalPredicate(CategoricalFeature categoricalFeature, double[] dArr) {
        ArrayList arrayList = new ArrayList();
        for (double d : dArr) {
            arrayList.add(categoricalFeature.getValue(ValueUtil.asInt(Double.valueOf(d))));
        }
        if (arrayList.size() == 1) {
            return new SimplePredicate().setField(categoricalFeature.getName()).setOperator(SimplePredicate.Operator.EQUAL).setValue((String) arrayList.get(0));
        }
        return new SimpleSetPredicate().setField(categoricalFeature.getName()).setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN).setArray(new Array(Array.Type.INT, ValueUtil.formatArrayValue(arrayList)));
    }
}
