package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.CategoricalSplit;
import org.apache.spark.ml.tree.ContinuousSplit;
import org.apache.spark.ml.tree.DecisionTreeModel;
import org.apache.spark.ml.tree.InternalNode;
import org.apache.spark.ml.tree.LeafNode;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.ml.tree.TreeEnsembleModel;
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.TreeModelOptions;
import org.jpmml.sparkml.visitors.TreeModelCompactor;

/* loaded from: input_file:org/jpmml/sparkml/model/TreeModelUtil.class */
public class TreeModelUtil {
    private static final double[] TRUE = {1.0d};
    private static final double[] FALSE = {0.0d};

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.model.TreeModelUtil$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/model/TreeModelUtil$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    private TreeModelUtil() {
    }

    public static <M extends Model<M> & TreeEnsembleModel<T>, T extends Model<T> & DecisionTreeModel> List<TreeModel> encodeDecisionTreeEnsemble(M m, Schema schema) {
        return encodeDecisionTreeEnsemble(m, new PredicateManager(), schema);
    }

    public static <M extends Model<M> & TreeEnsembleModel<T>, T extends Model<T> & DecisionTreeModel> List<TreeModel> encodeDecisionTreeEnsemble(M m, PredicateManager predicateManager, Schema schema) {
        Schema anonymousSchema = schema.toAnonymousSchema();
        ArrayList arrayList = new ArrayList();
        for (Model model : ((TreeEnsembleModel) m).trees()) {
            arrayList.add(encodeDecisionTree(model, predicateManager, anonymousSchema));
        }
        return arrayList;
    }

    public static <M extends Model<M> & DecisionTreeModel> TreeModel encodeDecisionTree(M m, Schema schema) {
        return encodeDecisionTree(m, new PredicateManager(), schema);
    }

    public static <M extends Model<M> & DecisionTreeModel> TreeModel encodeDecisionTree(M m, PredicateManager predicateManager, Schema schema) {
        Node rootNode = m.rootNode();
        if (m instanceof DecisionTreeRegressionModel) {
            return encodeTreeModel(rootNode, predicateManager, MiningFunction.REGRESSION, schema);
        }
        if (m instanceof DecisionTreeClassificationModel) {
            return encodeTreeModel(rootNode, predicateManager, MiningFunction.CLASSIFICATION, schema);
        }
        throw new IllegalArgumentException();
    }

    public static TreeModel encodeTreeModel(Node node, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema) {
        TreeModel splitCharacteristic = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(node, predicateManager, Collections.emptyMap(), miningFunction, schema).setPredicate(new True())).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        String str = TreeModelOptions.COMPACT;
        if (str != null && Boolean.valueOf(str).booleanValue()) {
            new TreeModelCompactor().applyTo(splitCharacteristic);
        }
        return splitCharacteristic;
    }

    public static org.dmg.pmml.tree.Node encodeNode(Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> map, MiningFunction miningFunction, Schema schema) {
        Predicate createSimpleSetPredicate;
        Predicate createSimpleSetPredicate2;
        SimplePredicate.Operator operator;
        SimplePredicate.Operator operator2;
        if (!(node instanceof InternalNode)) {
            if (!(node instanceof LeafNode)) {
                throw new IllegalArgumentException();
            }
            org.dmg.pmml.tree.Node node2 = new org.dmg.pmml.tree.Node();
            switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
                case 1:
                    node2.setScore(ValueUtil.formatValue(Double.valueOf(node.prediction())));
                    break;
                case 2:
                    CategoricalLabel label = schema.getLabel();
                    node2.setScore(label.getValue(ValueUtil.asInt(Double.valueOf(node.prediction()))));
                    ImpurityCalculator impurityStats = node.impurityStats();
                    node2.setRecordCount(Double.valueOf(impurityStats.count()));
                    double[] stats = impurityStats.stats();
                    for (int i = 0; i < stats.length; i++) {
                        node2.addScoreDistributions(new ScoreDistribution[]{new ScoreDistribution(label.getValue(i), stats[i])});
                    }
                    break;
                default:
                    throw new UnsupportedOperationException();
            }
            return node2;
        }
        InternalNode internalNode = (InternalNode) node;
        Map<FieldName, Set<String>> map2 = map;
        Map<FieldName, Set<String>> map3 = map;
        ContinuousSplit split = internalNode.split();
        BooleanFeature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            double threshold = split.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException();
                }
                createSimpleSetPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                createSimpleSetPredicate2 = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                String formatValue = ValueUtil.formatValue(Double.valueOf(threshold));
                createSimpleSetPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, formatValue);
                createSimpleSetPredicate2 = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, formatValue);
            }
        } else {
            if (!(split instanceof CategoricalSplit)) {
                throw new IllegalArgumentException();
            }
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    operator = SimplePredicate.Operator.EQUAL;
                    operator2 = SimplePredicate.Operator.NOT_EQUAL;
                } else {
                    if (!Arrays.equals(FALSE, leftCategories) || !Arrays.equals(TRUE, rightCategories)) {
                        throw new IllegalArgumentException();
                    }
                    operator = SimplePredicate.Operator.NOT_EQUAL;
                    operator2 = SimplePredicate.Operator.EQUAL;
                }
                String formatValue2 = ValueUtil.formatValue(binaryFeature.getValue());
                createSimpleSetPredicate = predicateManager.createSimplePredicate(binaryFeature, operator, formatValue2);
                createSimpleSetPredicate2 = predicateManager.createSimplePredicate(binaryFeature, operator2, formatValue2);
            } else {
                if (!(feature instanceof CategoricalFeature)) {
                    throw new IllegalArgumentException();
                }
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List values = categoricalFeature.getValues();
                if (values.size() != leftCategories.length + rightCategories.length) {
                    throw new IllegalArgumentException();
                }
                final Set<String> set = map.get(name);
                com.google.common.base.Predicate<String> predicate = new com.google.common.base.Predicate<String>() { // from class: org.jpmml.sparkml.model.TreeModelUtil.1
                    public boolean apply(String str) {
                        if (set != null) {
                            return set.contains(str);
                        }
                        return true;
                    }
                };
                List<String> selectValues = selectValues(values, leftCategories, predicate);
                List<String> selectValues2 = selectValues(values, rightCategories, predicate);
                map2 = new HashMap((Map<? extends FieldName, ? extends Set<String>>) map);
                map2.put(name, new HashSet(selectValues));
                map3 = new HashMap((Map<? extends FieldName, ? extends Set<String>>) map);
                map3.put(name, new HashSet(selectValues2));
                createSimpleSetPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, selectValues);
                createSimpleSetPredicate2 = predicateManager.createSimpleSetPredicate(categoricalFeature, selectValues2);
            }
        }
        org.dmg.pmml.tree.Node node3 = new org.dmg.pmml.tree.Node();
        node3.addNodes(new org.dmg.pmml.tree.Node[]{encodeNode(internalNode.leftChild(), predicateManager, map2, miningFunction, schema).setPredicate(createSimpleSetPredicate), encodeNode(internalNode.rightChild(), predicateManager, map3, miningFunction, schema).setPredicate(createSimpleSetPredicate2)});
        return node3;
    }

    private static List<String> selectValues(List<String> list, double[] dArr, com.google.common.base.Predicate<String> predicate) {
        if (dArr.length == 1) {
            String str = list.get(ValueUtil.asInt(Double.valueOf(dArr[0])));
            return predicate.apply(str) ? Collections.singletonList(str) : Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList(dArr.length);
        for (double d : dArr) {
            String str2 = list.get(ValueUtil.asInt(Double.valueOf(d)));
            if (predicate.apply(str2)) {
                arrayList.add(str2);
            }
        }
        return arrayList;
    }
}
